提交 8cdcc821 编辑于 作者: Thor Johnsen's avatar Thor Johnsen
浏览文件

Bug fixes

上级 67a0ffcb
......@@ -354,7 +354,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
btm_halo = all_halos[ctx.local_rank+1][:,:1,:,:]
fat_halo[:,:2,:,:].copy_(grad_out2[:,Hs-2:,:,:])
fat_halo[:,2:,:,:].copy_(btm_halo)
relu_halo[:,:2,:,:].copy_(relu1[:,Hs-2,:,:])
relu_halo[:,:2,:,:].copy_(relu1[:,Hs-2:,:,:])
relu_halo[:,2:,:,:].zero_()
btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, fat_halo, relu_halo)
btm_grad_out1_halo = btm_grad_out1_halo[:,1:2,:,:]
......
......@@ -2158,6 +2158,73 @@ at::Tensor bottleneck_backward_grad_out2(bool explicit_nhwc, int stride_1X1, std
return grad_out2;
}
// compute dgrad of 3x3 convolution without fusing with drelu and dscale
at::Tensor bottleneck_backward_dgrad1(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2) {
bool requires_grad = inputs[0].requires_grad();
std::cout << std::fixed;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// dgrad
at::Half* dy2 = grad_out2.data_ptr<at::Half>();
// dgrad
auto dgrad1 = at::empty(backward_state.outdim1, inputs[0].type(), output_format);
at::Half* dy1 = dgrad1.data_ptr<at::Half>();
at::Half* w = inputs[2].data_ptr<at::Half>();
at::Half* z = inputs[4].data_ptr<at::Half>();
at::Half* relu1 = inputs[12].data_ptr<at::Half>();
//printf("relu.shape = [%d,%d,%d,%d]\n",inputs[12].size(0),inputs[12].size(1),inputs[12].size(2),inputs[12].size(3));
// dgrad
run_dconv(backward_state.outdimA1,
backward_state.padA1,
backward_state.convstrideA,
backward_state.dilationA,
backward_state.filterdimA2,
backward_state.outdimA2,
CUDNN_DATA_HALF,
dy1,
w,
dy2,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);
return dgrad1;
}
at::Tensor bottleneck_backward_dgrad1_halo(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2_halo) {
bool requires_grad = inputs[0].requires_grad();
std::cout << std::fixed;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// dgrad
at::Half* dy2h = grad_out2_halo.data_ptr<at::Half>();
// dgrad
auto dgrad1_halo = at::empty(backward_state.outdim1h, inputs[0].type(), output_format);
at::Half* dy1h = dgrad1_halo.data_ptr<at::Half>();
at::Half* w = inputs[2].data_ptr<at::Half>();
// dgrad
run_dconv(backward_state.outdimA1h,
backward_state.padA1,
backward_state.convstrideA,
backward_state.dilationA,
backward_state.filterdimA2,
backward_state.outdimA2h,
CUDNN_DATA_HALF,
dy1h,
w,
dy2h,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);
return dgrad1_halo;
}
at::Tensor bottleneck_backward_grad_out1(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2) {
bool requires_grad = inputs[0].requires_grad();
......@@ -2480,6 +2547,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("backward_grad_out2", &bottleneck_backward_grad_out2, "Bottleneck block backward");
m.def("backward_grad_out1", &bottleneck_backward_grad_out1, "Bottleneck block backward");
m.def("backward_grad_out1_halo", &bottleneck_backward_grad_out1_halo, "Bottleneck block backward");
m.def("backward_dgrad1", &bottleneck_backward_dgrad1, "Bottleneck block backward");
m.def("backward_dgrad1_halo", &bottleneck_backward_dgrad1_halo, "Bottleneck block backward");
m.def("backward_wgrad2", &bottleneck_backward_wgrad2, "Bottleneck block backward");
m.def("backward_wgrad2_halo", &bottleneck_backward_wgrad2_halo, "Bottleneck block backward");
m.def("backward_rest", &bottleneck_backward_rest, "Bottleneck block backward");
......
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册