提交 bc9114c9 编辑于 作者: Thor Johnsen's avatar Thor Johnsen
浏览文件

Bug fixes

上级 d934eca3
......@@ -233,7 +233,9 @@ class SpatialBottleneckFunction(torch.autograd.Function):
fast_bottleneck.forward_out1(nhwc, stride_1x1, args, outputs)
fast_bottleneck.forward_out2(nhwc, stride_1x1, args, outputs)
# do halo exchange for outputs[0] (out1)
# compute halo cells for outputs[1]
if spatial_group_size > 1:
out1 = outputs[0]
N,Hs,W,C = list(out1.shape)
......@@ -245,17 +247,17 @@ class SpatialBottleneckFunction(torch.autograd.Function):
send_halos[:,1:,:,:].copy_(out1[:,Hs-1:,:,:])
all_halos = torch.empty((N,2*spatial_group_size,W,C),dtype=out1.dtype,device=out1.device)
all_halos = [all_halos[:,i*2:(i+1)*2,:,:] for i in range(spatial_group_size)]
dist.all_gather(all_halos,send_halos)
dist.all_gather(all_halos,send_halos,group=comm)
fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
top_out1_halo = all_halos[(spatial_group_size+local_rank-1)%spatial_group_size][:,1:,:,:]
if local_rank > 0:
top_halo = all_halos[local_rank-1][:,1:,:,:]
fat_halo[:,:1,:,:].copy_(top_halo)
fat_halo[:,:1,:,:].copy_(top_out1_halo)
fat_halo[:,1:3,:,:].copy_(out1[:,:2,:,:])
top_out2 = fast_bottleneck.forward_out2_halo(nhwc, fat_halo, args)
btm_out1_halo = all_halos[(local_rank+1)%spatial_group_size][:,:1,:,:]
if local_rank < spatial_group_size-1:
btm_halo = all_halos[local_rank+1][:,:1,:,:]
fat_halo[:,0:2,:,:].copy_(out1[:,Hs-2:,:,:])
fat_halo[:,2:,:,:].copy_(btm_halo)
fat_halo[:,2:,:,:].copy_(btm_out1_halo)
btm_out2 = fast_bottleneck.forward_out2_halo(nhwc, fat_halo, args)
torch.cuda.current_stream().wait_stream(stream1)
out2 = outputs[1]
......@@ -265,8 +267,11 @@ class SpatialBottleneckFunction(torch.autograd.Function):
out2[:,Hs-1:,:,:].copy_(btm_out2)
fast_bottleneck.forward_rest(nhwc, stride_1x1, args, outputs)
# TODO: save halos for backward pass
ctx.save_for_backward(*(args+outputs))
# save halos for backward pass
if spatial_group_size > 1:
ctx.save_for_backward(*(args+outputs+[top_out1_halo,btm_out1_halo]))
else:
ctx.save_for_backward(*(args+outputs))
# save relu outputs for drelu
ctx.nhwc = nhwc
ctx.stride_1x1 = stride_1x1
......@@ -280,7 +285,12 @@ class SpatialBottleneckFunction(torch.autograd.Function):
# only support dgrad
@staticmethod
def backward(ctx, grad_o):
outputs = ctx.saved_tensors[-3:]
if ctx.spatial_group_size > 1:
top_out1_halo = ctx.saved_tensors[-2]
btm_out1_halo = ctx.saved_tensors[-1]
outputs = ctx.saved_tensors[-5:-2]
else:
outputs = ctx.saved_tensors[-3:]
if ctx.downsample:
grad_conv3, grad_conv4 = drelu_dscale2(grad_o, outputs[2], ctx.saved_tensors[6], ctx.saved_tensors[11])
......@@ -302,22 +312,69 @@ class SpatialBottleneckFunction(torch.autograd.Function):
grads = fast_bottleneck.backward_init(ctx.nhwc, ctx.stride_1x1, t_list)
grad_out2 = fast_bottleneck.backward_grad_out2(ctx.nhwc, ctx.stride_1x1, t_list, grads)
# do halo exchange of grad_out2 here
# need fast_bottleneck.backward_grad_out2_halo
# testing
N,H,W,C = grad_out2.shape
grad_out2_halo = torch.empty([N,3,W,C],dtype=grad_out2.dtype,device=grad_out2.device)
grad_out2_halo[:,:1,:,:].zero_()
grad_out2_halo[:,1:,:,:].copy_(grad_out2[:,:2,:,:])
grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2_halo)
# print("grad_out2_halo.shape = %s -> grad_out1_halo.shape = %s" % (str(list(grad_out2_halo.shape)), str(list(grad_out1_halo.shape))))
# compute wgrad2 for internal cells
wgrad2 = fast_bottleneck.backward_wgrad2(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
# apply wgrad2 halos here
# no need for custom wgrad2_halo function, this is just a backwards data convolution
# apply wgrad2 halos
if ctx.spatial_group_size > 1:
top_grad2_halo = grad_out2[:,:1,:,:]
btm_grad2_halo = grad_out2[:,-1:,:,:]
top_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, top_out1_halo, top_grad2_halo)
btm_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, btm_out1_halo, btm_grad2_halo)
#print("wgrad2.shape = %s, top_wgrad2_halo.shape = %s, btm_wgrad2_halo = %s" % (str(list(wgrad2.shape)), str(list(top_wgrad2_halo.shape)), str(list(btm_wgrad2_halo.shape))))
wgrad2[:,:1,:,:].add_(top_wgrad2_halo)
wgrad2[:,-1:,:,:].add_(btm_wgrad2_halo)
# do halo exchange of grad_out2 here
# compute halo cells for grad_out1
if ctx.spatial_group_size > 1:
N,Hs,W,C = list(grad_out2.shape)
ctx.stream1.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(ctx.stream1):
# copy halos to send buffer
send_halos = torch.empty((N,2,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
send_halos[:,:1,:,:].copy_(grad_out2[:,:1,:,:])
send_halos[:,1:,:,:].copy_(grad_out2[:,Hs-1:,:,:])
all_halos = torch.empty((N,2*ctx.spatial_group_size,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
all_halos = [all_halos[:,i*2:(i+1)*2,:,:] for i in range(ctx.spatial_group_size)]
dist.all_gather(all_halos,send_halos,group=ctx.comm)
relu1 = t_list[12]
fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
if ctx.local_rank > 0:
top_halo = all_halos[ctx.local_rank-1][:,1:,:,:]
fat_halo[:,:1,:,:].copy_(top_halo)
fat_halo[:,1:,:,:].copy_(grad_out2[:,:2,:,:])
relu_halo[:,:1,:,:].zero_()
relu_halo[:,1:,:,:].copy_(relu1[:,:2,:,:])
top_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, fat_halo, relu_halo)
top_grad_out1_halo = top_grad_out1_halo[:,1:2,:,:]
if ctx.local_rank < ctx.spatial_group_size-1:
btm_halo = all_halos[ctx.local_rank+1][:,:1,:,:]
fat_halo[:,:2,:,:].copy_(grad_out2[:,Hs-2:,:,:])
fat_halo[:,2:,:,:].copy_(btm_halo)
relu_halo[:,:2,:,:].copy_(relu1[:,Hs-2,:,:])
relu_halo[:,2:,:,:].zero_()
btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, fat_halo, relu_halo)
btm_grad_out1_halo = btm_grad_out1_halo[:,1:2,:,:]
# compute grad_out1 for internal cells
grad_out1 = fast_bottleneck.backward_grad_out1(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
# apply grad_out1 halos here
# apply halo cells to grad_out1
if ctx.spatial_group_size > 1:
w = t_list[2]
z = t_list[4]
relu1 = t_list[12]
#print("w.shape = %s, z.shape = %s, relu1.shape = %s" % (str(list(w.shape)), str(list(z.shape)), str(list(relu1.shape))))
torch.cuda.current_stream().wait_stream(ctx.stream1)
if ctx.local_rank > 0:
grad_out1[:,:1,:,:].copy_(top_grad_out1_halo)
#print("ctx.local_rank = %d, apply grad_out1 top halo (grad_out1.shape = %s)" % (ctx.local_rank, str(list(grad_out1.shape))))
if ctx.local_rank < ctx.spatial_group_size-1:
grad_out1[:,Hs-1:,:,:].copy_(btm_grad_out1_halo)
#print("ctx.local_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.local_rank, str(list(grad_out1.shape))))
fast_bottleneck.backward_rest(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2, grad_out1, wgrad2)
......
......@@ -1931,6 +1931,7 @@ struct bottleneck_backward_state {
int64_t filterdimA2[4];
int64_t filterdimA3[4];
int64_t filterdimA4[4];
int64_t filterdimA2hh[4]; // Cin,Cout,1,3
int axis[4];
......@@ -1939,6 +1940,8 @@ struct bottleneck_backward_state {
int64_t outdimA3[4];
int64_t outdimA1h[4]; // output: grad_out1 halo (H=3)
int64_t outdimA2h[4]; // input : grad_out2 halo cells (H=3)
int64_t outdimA1hh[4]; // input: grad_out2 halo (H=1)
int64_t outdimA2hh[4]; // input: out1 halo (H=1)
int64_t padA[2];
int64_t padA1[2];
......@@ -1947,11 +1950,12 @@ struct bottleneck_backward_state {
int64_t convstrideA[2];
int64_t convstride1X1[2];
int64_t filterdim2hh[4]; // Cin,1,3,Cout
int64_t outdim1[4];
int64_t outdim2[4];
int64_t outdim3[4];
int64_t outdim1h[4];
int64_t outdim2hh[4];
void init(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs) {
// setup dimensions
......@@ -1960,6 +1964,7 @@ struct bottleneck_backward_state {
filterdimA2[0] = filterdimA2[1] = filterdimA2[2] = filterdimA2[3] = 0;
filterdimA3[0] = filterdimA3[1] = filterdimA3[2] = filterdimA3[3] = 0;
filterdimA4[0] = filterdimA4[1] = filterdimA4[2] = filterdimA4[3] = 0;
filterdimA2hh[0] = filterdimA2hh[1] = filterdimA2hh[2] = filterdimA2hh[3] = 0;
// All dim calculation after this order of n,c,h,w
if (explicit_nhwc) {
......@@ -1986,16 +1991,27 @@ struct bottleneck_backward_state {
}
}
for (int dim=0;dim<4;dim++) {
if (dim == 2) {
filterdimA2hh[dim] = 1;
} else {
filterdimA2hh[dim] = filterdimA2[dim];
}
}
// output dim in n,c,h,w used by backend
outdimA1[0] = outdimA1[1] = outdimA1[2] = outdimA1[3] = 0;
outdimA2[0] = outdimA2[1] = outdimA2[2] = outdimA2[3] = 0;
outdimA3[0] = outdimA3[1] = outdimA3[2] = outdimA3[3] = 0;
outdimA1h[0] = outdimA1h[1] = outdimA1h[2] = outdimA1h[3] = 0;
outdimA2h[0] = outdimA2h[1] = outdimA2h[2] = outdimA2h[3] = 0;
outdimA1hh[0] = outdimA1hh[1] = outdimA1hh[2] = outdimA1hh[3] = 0;
outdimA2hh[0] = outdimA2hh[1] = outdimA2hh[2] = outdimA2hh[3] = 0;
// use these fixed value for test run
padA[0] = 0; padA[1] = 0;
padA1[0] = 1; padA1[1] = 1;
padA2[0] = 0; padA2[1] = 1;
dilationA[0] = 1; dilationA[1] = 1;
convstrideA[0] = 1; convstrideA[1] = 1;
convstride1X1[0] = stride_1X1; convstride1X1[1] = stride_1X1;
......@@ -2023,9 +2039,13 @@ struct bottleneck_backward_state {
if (dim == 2) {
outdimA1h[dim] = 3;
outdimA2h[dim] = 3;
outdimA1hh[dim] = 1;
outdimA2hh[dim] = 1;
} else {
outdimA1h[dim] = outdimA1[dim];
outdimA2h[dim] = outdimA2[dim];
outdimA1hh[dim] = outdimA1[dim];
outdimA2hh[dim] = outdimA2[dim];
}
}
......@@ -2034,6 +2054,7 @@ struct bottleneck_backward_state {
outdim2[0] = outdim2[1] = outdim2[2] = outdim2[3] = 0;
outdim3[0] = outdim3[1] = outdim3[2] = outdim3[3] = 0;
outdim1h[0] = outdim1h[1] = outdim1h[2] = outdim1h[3] = 0;
filterdim2hh[0] = filterdim2hh[1] = filterdim2hh[2] = filterdim2hh[3] = 0;
if (explicit_nhwc) {
axis[0] = 0;
axis[1] = 2;
......@@ -2045,6 +2066,7 @@ struct bottleneck_backward_state {
outdim2[dim] = outdimA2[axis[dim]];
outdim3[dim] = outdimA3[axis[dim]];
outdim1h[dim] = outdimA1h[axis[dim]];
filterdim2hh[dim] = filterdimA2hh[axis[dim]];
}
}
};
......@@ -2153,6 +2175,8 @@ at::Tensor bottleneck_backward_grad_out1(bool explicit_nhwc, int stride_1X1, std
at::Half* z = inputs[4].data_ptr<at::Half>();
at::Half* relu1 = inputs[12].data_ptr<at::Half>();
//printf("relu.shape = [%d,%d,%d,%d]\n",inputs[12].size(0),inputs[12].size(1),inputs[12].size(2),inputs[12].size(3));
// fused dgrad
run_dconv_drelu_dscale(backward_state.outdimA1,
backward_state.padA1,
......@@ -2171,7 +2195,7 @@ at::Tensor bottleneck_backward_grad_out1(bool explicit_nhwc, int stride_1X1, std
}
// perform backward data 3x3 convolution (grad_out * w_rot180) on grad_out2 input of shape [N,3,W,C] with padding=(1,1) to produce output of shape [N,3,W,C]
at::Tensor bottleneck_backward_grad_out1_halo(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2_halo) {
at::Tensor bottleneck_backward_grad_out1_halo(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2_halo, at::Tensor relu1_halo) {
bool requires_grad = inputs[0].requires_grad();
......@@ -2187,10 +2211,12 @@ at::Tensor bottleneck_backward_grad_out1_halo(bool explicit_nhwc, int stride_1X1
at::Half* w = inputs[2].data_ptr<at::Half>();
at::Half* z = inputs[4].data_ptr<at::Half>();
at::Half* relu1 = inputs[12].data_ptr<at::Half>();
at::Half* relu1h = relu1_halo.data_ptr<at::Half>();
//printf("relu.shape = [%d,%d,%d,%d]\n",relu1_halo.size(0),relu1_halo.size(1),relu1_halo.size(2),relu1_halo.size(3));
// fused dgrad
//printf("backward_state.outdimA1h = {%d,%d,%d,%d}\n",backward_state.outdimA1h[0],backward_state.outdimA1h[1],backward_state.outdimA1h[2],backward_state.outdimA1h[3]);
//printf("backward_state.outdimA2h = {%d,%d,%d,%d}\n",backward_state.outdimA2h[0],backward_state.outdimA2h[1],backward_state.outdimA2h[2],backward_state.outdimA2h[3]);
//printf("backward_state.filterdimA2 = {%d,%d,%d,%d}\n",backward_state.filterdimA2[0],backward_state.filterdimA2[1],backward_state.filterdimA2[2],backward_state.filterdimA2[3]);
run_dconv_drelu_dscale(backward_state.outdimA1h,
backward_state.padA1,
backward_state.convstrideA,
......@@ -2202,7 +2228,7 @@ at::Tensor bottleneck_backward_grad_out1_halo(bool explicit_nhwc, int stride_1X1
w,
dy2h,
z,
relu1);
relu1h);
return grad_out1_halo;
}
......@@ -2221,9 +2247,9 @@ at::Tensor bottleneck_backward_wgrad2(bool explicit_nhwc, int stride_1X1, std::v
at::Half* conv_in = inputs[12].data_ptr<at::Half>();
// wgrad
auto wgrad2 = outputs[2];
at::Half* dw2 = wgrad2.data_ptr<at::Half>();
//printf("outdimA1 = (%d,%d,%d,%d)\n",backward_state.outdimA1[0],backward_state.outdimA1[1],backward_state.outdimA1[2],backward_state.outdimA1[3]);
run_dconv(backward_state.outdimA1,
backward_state.padA1,
......@@ -2240,6 +2266,46 @@ at::Tensor bottleneck_backward_wgrad2(bool explicit_nhwc, int stride_1X1, std::v
return wgrad2;
}
// compute halo cells for input volume of dimension [N,1,W,C] with padding=(0,1) to produce output volume of dimension [N,1,W,C]
// input and grad_out2_halo tensors are all of same shape
// output tensor is of shape [Cin,1,3,Cout] (regular filter dims are [Cin,3,3,Cout]
at::Tensor bottleneck_backward_wgrad2_halo(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor input, at::Tensor grad_out2_halo) {
bool requires_grad = inputs[0].requires_grad();
std::cout << std::fixed;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// dgrad
at::Half* dy2 = grad_out2_halo.data_ptr<at::Half>();
// dconv2+drelu1+dscale1
at::Half* conv_in = input.data_ptr<at::Half>();
// wgrad
auto wgrad2_halo = at::empty(backward_state.filterdim2hh, input.type(), output_format);
at::Half* dw2 = wgrad2_halo.data_ptr<at::Half>();
//printf("backward_state.outdimA1hh = {%d,%d,%d,%d}\n",backward_state.outdimA1hh[0],backward_state.outdimA1hh[1],backward_state.outdimA1hh[2],backward_state.outdimA1hh[3]);
//printf("backward_state.outdimA2hh = {%d,%d,%d,%d}\n",backward_state.outdimA2hh[0],backward_state.outdimA2hh[1],backward_state.outdimA2hh[2],backward_state.outdimA2hh[3]);
//printf("backward_state.filterdim2hh = {%d,%d,%d,%d}\n",backward_state.filterdim2hh[0],backward_state.filterdim2hh[1],backward_state.filterdim2hh[2],backward_state.filterdim2hh[3]);
//printf("backward_state.filterdimA2hh = {%d,%d,%d,%d}\n",backward_state.filterdimA2hh[0],backward_state.filterdimA2hh[1],backward_state.filterdimA2hh[2],backward_state.filterdimA2hh[3]);
//printf("backward_state.padA2 = {%d,%d}\n",backward_state.padA2[0],backward_state.padA2[1]);
run_dconv(backward_state.outdimA1hh, // N,C,1,W
backward_state.padA2, // 0, 1
backward_state.convstrideA,
backward_state.dilationA,
backward_state.filterdimA2hh, // Cin,Cout,1,3
backward_state.outdimA2hh, // N,C,1,W
CUDNN_DATA_HALF,
conv_in,
dw2,
dy2,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
return wgrad2_halo;
}
void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2, at::Tensor grad_out1, at::Tensor wgrad2) {
bool requires_grad = inputs[0].requires_grad();
......@@ -2415,5 +2481,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("backward_grad_out1", &bottleneck_backward_grad_out1, "Bottleneck block backward");
m.def("backward_grad_out1_halo", &bottleneck_backward_grad_out1_halo, "Bottleneck block backward");
m.def("backward_wgrad2", &bottleneck_backward_wgrad2, "Bottleneck block backward");
m.def("backward_wgrad2_halo", &bottleneck_backward_wgrad2_halo, "Bottleneck block backward");
m.def("backward_rest", &bottleneck_backward_rest, "Bottleneck block backward");
}
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册