未验证 提交 9b880665 编辑于 作者: Thor Johnsen's avatar Thor Johnsen 提交者: GitHub
浏览文件

Merge pull request #1160 from NVIDIA/bug_fix_in_wgrad

Bug fix in wgrad
......@@ -318,13 +318,14 @@ class SpatialBottleneckFunction(torch.autograd.Function):
# apply wgrad2 halos
if ctx.spatial_group_size > 1:
top_grad2_halo = grad_out2[:,:1,:,:]
btm_grad2_halo = grad_out2[:,-1:,:,:]
top_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, top_out1_halo, top_grad2_halo)
btm_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, btm_out1_halo, btm_grad2_halo)
#print("wgrad2.shape = %s, top_wgrad2_halo.shape = %s, btm_wgrad2_halo = %s" % (str(list(wgrad2.shape)), str(list(top_wgrad2_halo.shape)), str(list(btm_wgrad2_halo.shape))))
wgrad2[:,:1,:,:].add_(top_wgrad2_halo)
wgrad2[:,-1:,:,:].add_(btm_wgrad2_halo)
if ctx.local_rank > 0:
top_grad2_halo = grad_out2[:,:1,:,:]
top_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, top_out1_halo, top_grad2_halo)
wgrad2[:,:1,:,:].add_(top_wgrad2_halo)
if ctx.local_rank < ctx.spatial_group_size-1:
btm_grad2_halo = grad_out2[:,-1:,:,:]
btm_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, btm_out1_halo, btm_grad2_halo)
wgrad2[:,-1:,:,:].add_(btm_wgrad2_halo)
# do halo exchange of grad_out2 here
# compute halo cells for grad_out1
......
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册