提交 b6980a0d 编辑于 作者: Thor Johnsen's avatar Thor Johnsen
浏览文件

Add functions to compute grad_out1, grad_out1_halo

上级 ed713c84
......@@ -237,8 +237,6 @@ class SpatialBottleneckFunction(torch.autograd.Function):
if spatial_group_size > 1:
out1 = outputs[0]
N,Hs,W,C = list(out1.shape)
padded_out1 = torch.empty((N,Hs+2,W,C),dtype=out1.dtype,device=out1.device)
padded_out1[:,1:Hs+1,:,:].copy_(out1)
stream1.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream1):
# copy halos to send buffer
......@@ -248,22 +246,17 @@ class SpatialBottleneckFunction(torch.autograd.Function):
all_halos = torch.empty((N,2*spatial_group_size,W,C),dtype=out1.dtype,device=out1.device)
all_halos = [all_halos[:,i*2:(i+1)*2,:,:] for i in range(spatial_group_size)]
dist.all_gather(all_halos,send_halos)
padded_out1_top_halo = padded_out1[:,:1,:,:]
fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
if local_rank > 0:
top_halo = all_halos[local_rank-1][:,1:,:,:]
padded_out1_top_halo.copy_(top_halo)
fat_top_halo = padded_out1[:,:3,:,:]
top_out2 = fast_bottleneck.forward_out2_halo(nhwc, fat_top_halo, args)
else:
padded_out1_top_halo.zero_()
padded_out1_btm_halo = padded_out1[:,Hs+1:,:,:]
fat_halo[:,:1,:,:].copy_(top_halo)
fat_halo[:,1:3,:,:].copy_(out1[:,:2,:,:])
top_out2 = fast_bottleneck.forward_out2_halo(nhwc, fat_halo, args)
if local_rank < spatial_group_size-1:
btm_halo = all_halos[local_rank+1][:,:1,:,:]
padded_out1_btm_halo.copy_(btm_halo)
fat_btm_halo = padded_out1[:,Hs-1:,:,:]
btm_out2 = fast_bottleneck.forward_out2_halo(nhwc, fat_btm_halo, args)
else:
padded_out1_btm_halo.zero_()
fat_halo[:,0:2,:,:].copy_(out1[:,Hs-2:,:,:])
fat_halo[:,2:,:,:].copy_(btm_halo)
btm_out2 = fast_bottleneck.forward_out2_halo(nhwc, fat_halo, args)
torch.cuda.current_stream().wait_stream(stream1)
out2 = outputs[1]
if local_rank > 0:
......@@ -272,10 +265,8 @@ class SpatialBottleneckFunction(torch.autograd.Function):
out2[:,Hs-1:,:,:].copy_(btm_out2)
fast_bottleneck.forward_rest(nhwc, stride_1x1, args, outputs)
if spatial_group_size > 1:
ctx.save_for_backward(*(args+outputs+[padded_out1]))
else:
ctx.save_for_backward(*(args+outputs))
# TODO: save halos for backward pass
ctx.save_for_backward(*(args+outputs))
# save relu outputs for drelu
ctx.nhwc = nhwc
ctx.stride_1x1 = stride_1x1
......@@ -289,10 +280,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
# only support dgrad
@staticmethod
def backward(ctx, grad_o):
if ctx.spatial_group_size > 1:
outputs = ctx.saved_tensors[-4:-1]
else:
outputs = ctx.saved_tensors[-3:]
outputs = ctx.saved_tensors[-3:]
if ctx.downsample:
grad_conv3, grad_conv4 = drelu_dscale2(grad_o, outputs[2], ctx.saved_tensors[6], ctx.saved_tensors[11])
......@@ -315,7 +303,23 @@ class SpatialBottleneckFunction(torch.autograd.Function):
grads = fast_bottleneck.backward_init(ctx.nhwc, ctx.stride_1x1, t_list)
grad_out2 = fast_bottleneck.backward_grad_out2(ctx.nhwc, ctx.stride_1x1, t_list, grads)
# do halo exchange of grad_out2 here
fast_bottleneck.backward_rest(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
# need fast_bottleneck.backward_grad_out2_halo
# testing
N,H,W,C = grad_out2.shape
grad_out2_halo = torch.empty([N,3,W,C],dtype=grad_out2.dtype,device=grad_out2.device)
grad_out2_halo[:,:1,:,:].zero_()
grad_out2_halo[:,1:,:,:].copy_(grad_out2[:,:2,:,:])
grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2_halo)
# print("grad_out2_halo.shape = %s -> grad_out1_halo.shape = %s" % (str(list(grad_out2_halo.shape)), str(list(grad_out1_halo.shape))))
wgrad2 = fast_bottleneck.backward_wgrad2(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
# apply wgrad2 halos here
# no need for custom wgrad2_halo function, this is just a backwards data convolution
grad_out1 = fast_bottleneck.backward_grad_out1(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
# apply grad_out1 halos here
fast_bottleneck.backward_rest(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2, grad_out1, wgrad2)
return (None, None, None, None, None, None, None, None, *grads)
......
......@@ -1746,7 +1746,7 @@ std::vector<at::Tensor> bottleneck_forward_init(bool explicit_nhwc, int stride_1
std::vector<at::Tensor> outputs;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
printf("outdim1 = (%d,%d,%d,%d)\n",forward_state.outdim1[0],forward_state.outdim1[1],forward_state.outdim1[2],forward_state.outdim1[3]);
//printf("outdim1 = (%d,%d,%d,%d)\n",forward_state.outdim1[0],forward_state.outdim1[1],forward_state.outdim1[2],forward_state.outdim1[3]);
auto out1 = at::empty(forward_state.outdim1, inputs[0].type(), output_format);
auto out2 = at::empty(forward_state.outdim2, inputs[0].type(), output_format);
auto out3 = at::empty(forward_state.outdim3, inputs[0].type(), output_format);
......@@ -1837,12 +1837,12 @@ void bottleneck_forward_out2(bool explicit_nhwc, int stride_1X1, std::vector<at:
auto out2 = outputs[1];
at::Half* y2 = out2.data_ptr<at::Half>();
printf("forward_state.outdimA1 = {%d,%d,%d,%d}\n",forward_state.outdimA1[0],forward_state.outdimA1[1],forward_state.outdimA1[2],forward_state.outdimA1[3]);
printf("forward_state.padA1 = {%d,%d}\n",forward_state.padA1[0],forward_state.padA1[1]);
printf("forward_state.convstrideA = {%d,%d}\n",forward_state.convstrideA[0],forward_state.convstrideA[1]);
printf("forward_state.dilationA = {%d,%d}\n",forward_state.dilationA[0],forward_state.dilationA[1]);
printf("forward_state.filterdimA2 = {%d,%d,%d,%d}\n",forward_state.filterdimA2[0],forward_state.filterdimA2[1],forward_state.filterdimA2[2],forward_state.filterdimA2[3]);
printf("forward_state.outdimA2 = {%d,%d,%d,%d}\n",forward_state.outdimA2[0],forward_state.outdimA2[1],forward_state.outdimA2[2],forward_state.outdimA2[3]);
//printf("forward_state.outdimA1 = {%d,%d,%d,%d}\n",forward_state.outdimA1[0],forward_state.outdimA1[1],forward_state.outdimA1[2],forward_state.outdimA1[3]);
//printf("forward_state.padA1 = {%d,%d}\n",forward_state.padA1[0],forward_state.padA1[1]);
//printf("forward_state.convstrideA = {%d,%d}\n",forward_state.convstrideA[0],forward_state.convstrideA[1]);
//printf("forward_state.dilationA = {%d,%d}\n",forward_state.dilationA[0],forward_state.dilationA[1]);
//printf("forward_state.filterdimA2 = {%d,%d,%d,%d}\n",forward_state.filterdimA2[0],forward_state.filterdimA2[1],forward_state.filterdimA2[2],forward_state.filterdimA2[3]);
//printf("forward_state.outdimA2 = {%d,%d,%d,%d}\n",forward_state.outdimA2[0],forward_state.outdimA2[1],forward_state.outdimA2[2],forward_state.outdimA2[3]);
run_conv_scale_bias_add_activation(forward_state.outdimA1,
forward_state.padA1,
forward_state.convstrideA,
......@@ -1934,12 +1934,15 @@ struct bottleneck_backward_state {
int axis[4];
int64_t outdimA1[4];
int64_t outdimA2[4];
int64_t outdimA1[4]; // grad_out1
int64_t outdimA2[4]; // grad_out2
int64_t outdimA3[4];
int64_t outdimA1h[4]; // output: grad_out1 halo (H=3)
int64_t outdimA2h[4]; // input : grad_out2 halo cells (H=3)
int64_t padA[2];
int64_t padA1[2];
int64_t padA2[2];
int64_t dilationA[2];
int64_t convstrideA[2];
int64_t convstride1X1[2];
......@@ -1947,6 +1950,8 @@ struct bottleneck_backward_state {
int64_t outdim1[4];
int64_t outdim2[4];
int64_t outdim3[4];
int64_t outdim1h[4];
int64_t outdim2hh[4];
void init(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs) {
// setup dimensions
......@@ -1985,6 +1990,8 @@ struct bottleneck_backward_state {
outdimA1[0] = outdimA1[1] = outdimA1[2] = outdimA1[3] = 0;
outdimA2[0] = outdimA2[1] = outdimA2[2] = outdimA2[3] = 0;
outdimA3[0] = outdimA3[1] = outdimA3[2] = outdimA3[3] = 0;
outdimA1h[0] = outdimA1h[1] = outdimA1h[2] = outdimA1h[3] = 0;
outdimA2h[0] = outdimA2h[1] = outdimA2h[2] = outdimA2h[3] = 0;
// use these fixed value for test run
padA[0] = 0; padA[1] = 0;
......@@ -2012,10 +2019,21 @@ struct bottleneck_backward_state {
outdimA3[dim + 2] = getFwdConvOutputDim(outdimA2[dim + 2], padA[dim], filterdimA3[dim + 2], convstrideA[dim], dilationA[dim]);
}
for (int dim = 0; dim < 4; dim++) {
if (dim == 2) {
outdimA1h[dim] = 3;
outdimA2h[dim] = 3;
} else {
outdimA1h[dim] = outdimA1[dim];
outdimA2h[dim] = outdimA2[dim];
}
}
// Create output tensor in the correct shape in pytorch's view
outdim1[0] = outdim1[1] = outdim1[2] = outdim1[3] = 0;
outdim2[0] = outdim2[1] = outdim2[2] = outdim2[3] = 0;
outdim3[0] = outdim3[1] = outdim3[2] = outdim3[3] = 0;
outdim1h[0] = outdim1h[1] = outdim1h[2] = outdim1h[3] = 0;
if (explicit_nhwc) {
axis[0] = 0;
axis[1] = 2;
......@@ -2026,6 +2044,7 @@ struct bottleneck_backward_state {
outdim1[dim] = outdimA1[axis[dim]];
outdim2[dim] = outdimA2[axis[dim]];
outdim3[dim] = outdimA3[axis[dim]];
outdim1h[dim] = outdimA1h[axis[dim]];
}
}
};
......@@ -2117,7 +2136,78 @@ at::Tensor bottleneck_backward_grad_out2(bool explicit_nhwc, int stride_1X1, std
return grad_out2;
}
void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2) {
at::Tensor bottleneck_backward_grad_out1(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2) {
bool requires_grad = inputs[0].requires_grad();
std::cout << std::fixed;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// dgrad
at::Half* dy2 = grad_out2.data_ptr<at::Half>();
// dgrad
auto grad_out1 = at::empty(backward_state.outdim1, inputs[0].type(), output_format);
at::Half* dy1 = grad_out1.data_ptr<at::Half>();
at::Half* w = inputs[2].data_ptr<at::Half>();
at::Half* z = inputs[4].data_ptr<at::Half>();
at::Half* relu1 = inputs[12].data_ptr<at::Half>();
// fused dgrad
run_dconv_drelu_dscale(backward_state.outdimA1,
backward_state.padA1,
backward_state.convstrideA,
backward_state.dilationA,
backward_state.filterdimA2,
backward_state.outdimA2,
CUDNN_DATA_HALF,
dy1,
w,
dy2,
z,
relu1);
return grad_out1;
}
// perform backward data 3x3 convolution (grad_out * w_rot180) on grad_out2 input of shape [N,3,W,C] with padding=(1,1) to produce output of shape [N,3,W,C]
at::Tensor bottleneck_backward_grad_out1_halo(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2_halo) {
bool requires_grad = inputs[0].requires_grad();
std::cout << std::fixed;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// dgrad
at::Half* dy2h = grad_out2_halo.data_ptr<at::Half>();
// dgrad
auto grad_out1_halo = at::empty(backward_state.outdim1h, inputs[0].type(), output_format);
at::Half* dy1h = grad_out1_halo.data_ptr<at::Half>();
at::Half* w = inputs[2].data_ptr<at::Half>();
at::Half* z = inputs[4].data_ptr<at::Half>();
at::Half* relu1 = inputs[12].data_ptr<at::Half>();
// fused dgrad
//printf("backward_state.outdimA1h = {%d,%d,%d,%d}\n",backward_state.outdimA1h[0],backward_state.outdimA1h[1],backward_state.outdimA1h[2],backward_state.outdimA1h[3]);
//printf("backward_state.outdimA2h = {%d,%d,%d,%d}\n",backward_state.outdimA2h[0],backward_state.outdimA2h[1],backward_state.outdimA2h[2],backward_state.outdimA2h[3]);
run_dconv_drelu_dscale(backward_state.outdimA1h,
backward_state.padA1,
backward_state.convstrideA,
backward_state.dilationA,
backward_state.filterdimA2,
backward_state.outdimA2h,
CUDNN_DATA_HALF,
dy1h,
w,
dy2h,
z,
relu1);
return grad_out1_halo;
}
at::Tensor bottleneck_backward_wgrad2(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2) {
bool requires_grad = inputs[0].requires_grad();
......@@ -2134,7 +2224,7 @@ void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at
auto wgrad2 = outputs[2];
at::Half* dw2 = wgrad2.data_ptr<at::Half>();
printf("outdimA1 = (%d,%d,%d,%d)\n",backward_state.outdimA1[0],backward_state.outdimA1[1],backward_state.outdimA1[2],backward_state.outdimA1[3]);
//printf("outdimA1 = (%d,%d,%d,%d)\n",backward_state.outdimA1[0],backward_state.outdimA1[1],backward_state.outdimA1[2],backward_state.outdimA1[3]);
run_dconv(backward_state.outdimA1,
backward_state.padA1,
backward_state.convstrideA,
......@@ -2147,26 +2237,19 @@ void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at
dy2,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
return wgrad2;
}
void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2, at::Tensor grad_out1, at::Tensor wgrad2) {
bool requires_grad = inputs[0].requires_grad();
std::cout << std::fixed;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// dgrad
auto grad_out1 = at::empty(backward_state.outdim1, inputs[0].type(), output_format);
at::Half* dy2 = grad_out2.data_ptr<at::Half>();
at::Half* dy1 = grad_out1.data_ptr<at::Half>();
at::Half* w = inputs[2].data_ptr<at::Half>();
at::Half* z = inputs[4].data_ptr<at::Half>();
at::Half* relu1 = inputs[12].data_ptr<at::Half>();
// fused dgrad
run_dconv_drelu_dscale(backward_state.outdimA1,
backward_state.padA1,
backward_state.convstrideA,
backward_state.dilationA,
backward_state.filterdimA2,
backward_state.outdimA2,
CUDNN_DATA_HALF,
dy1,
w,
dy2,
z,
relu1);
/*
// backward strided conv cannot be fused
......@@ -2215,6 +2298,8 @@ void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at
// x used for dconv1 and dconv4 wgrad
at::Half* x = inputs[0].data_ptr<at::Half>();
at::Half* w = NULL;
if (stride_1X1 != 1 || backward_state.filterdimA3[0] != backward_state.dimA[1]){
w = inputs[14].data_ptr<at::Half>();
at::Half* dy_conv4 = inputs[11].data_ptr<at::Half>();
......@@ -2327,5 +2412,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward_rest", &bottleneck_forward_rest, "Bottleneck block forward");
m.def("backward_init", &bottleneck_backward_init, "Bottleneck block backward init");
m.def("backward_grad_out2", &bottleneck_backward_grad_out2, "Bottleneck block backward");
m.def("backward_grad_out1", &bottleneck_backward_grad_out1, "Bottleneck block backward");
m.def("backward_grad_out1_halo", &bottleneck_backward_grad_out1_halo, "Bottleneck block backward");
m.def("backward_wgrad2", &bottleneck_backward_wgrad2, "Bottleneck block backward");
m.def("backward_rest", &bottleneck_backward_rest, "Bottleneck block backward");
}
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册