Skip to content
GitLab
菜单
项目
群组
代码片段
帮助
帮助
支持
社区论坛
快捷键
?
提交反馈
登录/注册
切换导航
菜单
打开侧边栏
wanggh
apex
提交
9e295728
提交
9e295728
编辑于
9月 02, 2021
作者:
Thor Johnsen
浏览文件
Bug fix in wgrad
上级
8c4a0075
变更
1
Hide whitespace changes
Inline
Side-by-side
apex/contrib/bottleneck/bottleneck.py
浏览文件 @
9e295728
...
...
@@ -318,13 +318,14 @@ class SpatialBottleneckFunction(torch.autograd.Function):
# apply wgrad2 halos
if
ctx
.
spatial_group_size
>
1
:
top_grad2_halo
=
grad_out2
[:,:
1
,:,:]
btm_grad2_halo
=
grad_out2
[:,
-
1
:,:,:]
top_wgrad2_halo
=
fast_bottleneck
.
backward_wgrad2_halo
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
top_out1_halo
,
top_grad2_halo
)
btm_wgrad2_halo
=
fast_bottleneck
.
backward_wgrad2_halo
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
btm_out1_halo
,
btm_grad2_halo
)
#print("wgrad2.shape = %s, top_wgrad2_halo.shape = %s, btm_wgrad2_halo = %s" % (str(list(wgrad2.shape)), str(list(top_wgrad2_halo.shape)), str(list(btm_wgrad2_halo.shape))))
wgrad2
[:,:
1
,:,:].
add_
(
top_wgrad2_halo
)
wgrad2
[:,
-
1
:,:,:].
add_
(
btm_wgrad2_halo
)
if
ctx
.
local_rank
>
0
:
top_grad2_halo
=
grad_out2
[:,:
1
,:,:]
top_wgrad2_halo
=
fast_bottleneck
.
backward_wgrad2_halo
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
top_out1_halo
,
top_grad2_halo
)
wgrad2
[:,:
1
,:,:].
add_
(
top_wgrad2_halo
)
if
ctx
.
local_rank
<
ctx
.
spatial_group_size
-
1
:
btm_grad2_halo
=
grad_out2
[:,
-
1
:,:,:]
btm_wgrad2_halo
=
fast_bottleneck
.
backward_wgrad2_halo
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
btm_out1_halo
,
btm_grad2_halo
)
wgrad2
[:,
-
1
:,:,:].
add_
(
btm_wgrad2_halo
)
# do halo exchange of grad_out2 here
# compute halo cells for grad_out1
...
...
编辑
预览
Supports
Markdown
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录