Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
wanggh
apex
Commits
2f164a2a
Commit
2f164a2a
authored
Aug 31, 2021
by
Thor Johnsen
Browse files
First release
parent
d6b5ae5d
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
apex/contrib/bottleneck/__init__.py
View file @
2f164a2a
from
.bottleneck
import
Bottleneck
from
.bottleneck
import
Bottleneck
,
SpatialBottleneck
apex/contrib/bottleneck/bottleneck.py
View file @
2f164a2a
import
torch
import
torch.distributed
as
dist
from
torch
import
nn
import
fast_bottleneck
...
...
@@ -212,3 +213,235 @@ class Bottleneck(torch.nn.Module):
out
=
self
.
relu
(
out
)
return
out
class
SpatialBottleneckFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
spatial_group_size
,
local_rank
,
comm
,
stream1
,
nhwc
,
stride_1x1
,
scale
,
bias
,
x
,
*
conv
):
# TODO: clean up order of tensors
args
=
[
x
,
*
conv
[
0
:
3
],
*
scale
[
0
:
3
],
*
bias
[
0
:
3
]]
ctx
.
downsample
=
len
(
conv
)
>
3
if
ctx
.
downsample
:
args
.
append
(
conv
[
3
])
args
.
append
(
scale
[
3
])
args
.
append
(
bias
[
3
])
# weight buffers are always in nhwc while shape can be nhwc or channels_last
# here we pass in flag and let c++ handle it
# alternatively, we can put all sizes into a fixed format and pass it in
outputs
=
fast_bottleneck
.
forward_init
(
nhwc
,
stride_1x1
,
args
)
fast_bottleneck
.
forward_out1
(
nhwc
,
stride_1x1
,
args
,
outputs
)
fast_bottleneck
.
forward_out2
(
nhwc
,
stride_1x1
,
args
,
outputs
)
# do halo exchange for outputs[0] (out1)
if
spatial_group_size
>
1
:
out1
=
outputs
[
0
]
N
,
Hs
,
W
,
C
=
list
(
out1
.
shape
)
padded_out1
=
torch
.
empty
((
N
,
Hs
+
2
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
padded_out1
[:,
1
:
Hs
+
1
,:,:].
copy_
(
out1
)
stream1
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
stream1
):
# copy halos to send buffer
send_halos
=
torch
.
empty
((
N
,
2
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
send_halos
[:,:
1
,:,:].
copy_
(
out1
[:,:
1
,:,:])
send_halos
[:,
1
:,:,:].
copy_
(
out1
[:,
Hs
-
1
:,:,:])
all_halos
=
torch
.
empty
((
N
,
2
*
spatial_group_size
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
all_halos
=
[
all_halos
[:,
i
*
2
:(
i
+
1
)
*
2
,:,:]
for
i
in
range
(
spatial_group_size
)]
dist
.
all_gather
(
all_halos
,
send_halos
)
padded_out1_top_halo
=
padded_out1
[:,:
1
,:,:]
if
local_rank
>
0
:
top_halo
=
all_halos
[
local_rank
-
1
][:,
1
:,:,:]
padded_out1_top_halo
.
copy_
(
top_halo
)
fat_top_halo
=
padded_out1
[:,:
3
,:,:]
top_out2
=
fast_bottleneck
.
forward_out2_halo
(
nhwc
,
fat_top_halo
,
args
)
else
:
padded_out1_top_halo
.
zero_
()
padded_out1_btm_halo
=
padded_out1
[:,
Hs
+
1
:,:,:]
if
local_rank
<
spatial_group_size
-
1
:
btm_halo
=
all_halos
[
local_rank
+
1
][:,:
1
,:,:]
padded_out1_btm_halo
.
copy_
(
btm_halo
)
fat_btm_halo
=
padded_out1
[:,
Hs
-
1
:,:,:]
btm_out2
=
fast_bottleneck
.
forward_out2_halo
(
nhwc
,
fat_btm_halo
,
args
)
else
:
padded_out1_btm_halo
.
zero_
()
torch
.
cuda
.
current_stream
().
wait_stream
(
stream1
)
out2
=
outputs
[
1
]
if
local_rank
>
0
:
out2
[:,:
1
,:,:].
copy_
(
top_out2
)
if
local_rank
<
spatial_group_size
-
1
:
out2
[:,
Hs
-
1
:,:,:].
copy_
(
btm_out2
)
fast_bottleneck
.
forward_rest
(
nhwc
,
stride_1x1
,
args
,
outputs
)
if
spatial_group_size
>
1
:
ctx
.
save_for_backward
(
*
(
args
+
outputs
+
[
padded_out1
]))
else
:
ctx
.
save_for_backward
(
*
(
args
+
outputs
))
# save relu outputs for drelu
ctx
.
nhwc
=
nhwc
ctx
.
stride_1x1
=
stride_1x1
ctx
.
spatial_group_size
=
spatial_group_size
ctx
.
local_rank
=
local_rank
ctx
.
comm
=
comm
ctx
.
stream1
=
stream1
return
outputs
[
2
]
# backward relu is not exposed, MUL with mask used now
# only support dgrad
@
staticmethod
def
backward
(
ctx
,
grad_o
):
if
ctx
.
spatial_group_size
>
1
:
outputs
=
ctx
.
saved_tensors
[
-
4
:
-
1
]
else
:
outputs
=
ctx
.
saved_tensors
[
-
3
:]
if
ctx
.
downsample
:
grad_conv3
,
grad_conv4
=
drelu_dscale2
(
grad_o
,
outputs
[
2
],
ctx
.
saved_tensors
[
6
],
ctx
.
saved_tensors
[
11
])
else
:
grad_conv3
,
grad_conv4
=
drelu_dscale1
(
grad_o
,
outputs
[
2
],
ctx
.
saved_tensors
[
6
])
# create input vector for backward
t_list
=
[
*
ctx
.
saved_tensors
[
0
:
10
]]
t_list
.
append
(
grad_conv3
)
t_list
.
append
(
grad_conv4
)
# outputs used for wgrad and generating drelu mask
t_list
.
append
(
outputs
[
0
])
t_list
.
append
(
outputs
[
1
])
# in case there is downsample
if
ctx
.
downsample
:
t_list
.
append
(
ctx
.
saved_tensors
[
10
])
grads
=
fast_bottleneck
.
backward_init
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
)
grad_out2
=
fast_bottleneck
.
backward_grad_out2
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
)
# do halo exchange of grad_out2 here
fast_bottleneck
.
backward_rest
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
return
(
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
*
grads
)
spatial_bottleneck_function
=
SpatialBottleneckFunction
.
apply
class
SpatialBottleneck
(
torch
.
nn
.
Module
):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
# here we put it at 1x1
def
__init__
(
self
,
in_channels
,
bottleneck_channels
,
out_channels
,
stride
=
1
,
groups
=
1
,
dilation
=
1
,
norm_func
=
None
,
use_cudnn
=
False
,
explicit_nhwc
=
False
,
spatial_group_size
=
1
):
super
(
SpatialBottleneck
,
self
).
__init__
()
if
groups
!=
1
:
raise
RuntimeError
(
'Only support groups == 1'
)
if
dilation
!=
1
:
raise
RuntimeError
(
'Only support dilation == 1'
)
if
norm_func
==
None
:
norm_func
=
FrozenBatchNorm2d
else
:
raise
RuntimeError
(
'Only support frozen BN now.'
)
if
stride
!=
1
or
in_channels
!=
out_channels
:
self
.
downsample
=
nn
.
Sequential
(
conv1x1
(
in_channels
,
out_channels
,
stride
),
norm_func
(
out_channels
),
)
else
:
self
.
downsample
=
None
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self
.
conv1
=
conv1x1
(
in_channels
,
bottleneck_channels
,
stride
)
self
.
conv2
=
conv3x3
(
bottleneck_channels
,
bottleneck_channels
)
self
.
conv3
=
conv1x1
(
bottleneck_channels
,
out_channels
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
stride
=
stride
self
.
bn1
=
norm_func
(
bottleneck_channels
)
self
.
bn2
=
norm_func
(
bottleneck_channels
)
self
.
bn3
=
norm_func
(
out_channels
)
self
.
use_cudnn
=
use_cudnn
# setup conv weights
self
.
w_conv
=
[
self
.
conv1
.
weight
,
self
.
conv2
.
weight
,
self
.
conv3
.
weight
]
if
self
.
downsample
is
not
None
:
self
.
w_conv
.
append
(
self
.
downsample
[
0
].
weight
)
# init weight in nchw format before possible transpose
for
w
in
self
.
w_conv
:
kaiming_uniform_
(
w
,
a
=
1
)
# TODO: prevent unsupported case usage
# support cases
# native cudnn
# normal yes no
# channel_last yes yes
# explicit_nhwc no yes
self
.
explicit_nhwc
=
explicit_nhwc
if
self
.
explicit_nhwc
:
for
p
in
self
.
parameters
():
with
torch
.
no_grad
():
p
.
data
=
p
.
data
.
permute
(
0
,
2
,
3
,
1
).
contiguous
()
# spatial communicator
self
.
spatial_group_size
=
spatial_group_size
if
spatial_group_size
>
1
:
world_size
=
dist
.
get_world_size
()
num_groups
=
world_size
//
spatial_group_size
assert
(
num_groups
*
spatial_group_size
==
world_size
),
"torch.distributed.get_world_size() must be multiple of group_size"
rank
=
dist
.
get_rank
()
self
.
local_rank
=
rank
%
spatial_group_size
for
group
in
range
(
num_groups
):
ranks
=
list
(
range
(
group
*
spatial_group_size
,(
group
+
1
)
*
spatial_group_size
))
comm
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
if
rank
in
ranks
:
self
.
communicator
=
comm
self
.
stream1
=
torch
.
cuda
.
Stream
()
self
.
spatial_args
=
self
.
spatial_group_size
,
self
.
local_rank
,
self
.
communicator
,
self
.
stream1
else
:
self
.
spatial_args
=
1
,
0
,
None
,
None
return
def
forward
(
self
,
x
):
if
self
.
use_cudnn
:
# calculate scale/bias from registered buffers
# TODO: make this better
s1
,
b1
=
self
.
bn1
.
get_scale_bias
(
self
.
explicit_nhwc
)
s2
,
b2
=
self
.
bn2
.
get_scale_bias
(
self
.
explicit_nhwc
)
s3
,
b3
=
self
.
bn3
.
get_scale_bias
(
self
.
explicit_nhwc
)
w_scale
=
[
s1
,
s2
,
s3
]
w_bias
=
[
b1
,
b2
,
b3
]
if
self
.
downsample
is
not
None
:
s4
,
b4
=
self
.
downsample
[
1
].
get_scale_bias
(
self
.
explicit_nhwc
)
w_scale
.
append
(
s4
)
w_bias
.
append
(
b4
)
out
=
spatial_bottleneck_function
(
*
self
.
spatial_args
,
self
.
explicit_nhwc
,
self
.
stride
,
w_scale
,
w_bias
,
x
,
*
self
.
w_conv
)
return
out
if
self
.
explicit_nhwc
:
raise
RuntimeError
(
'explicit nhwc with native ops is not supported.'
)
# fallback to native ops
identity
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
bn1
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
bn2
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv3
(
out
)
out
=
self
.
bn3
(
out
)
if
self
.
downsample
is
not
None
:
identity
=
self
.
downsample
(
x
)
out
+=
identity
out
=
self
.
relu
(
out
)
return
out
apex/contrib/csrc/bottleneck/bottleneck.cpp
View file @
2f164a2a
This diff is collapsed.
Click to expand it.
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment