未验证 提交 ed713c84 编辑于 作者: Thor Johnsen's avatar Thor Johnsen 提交者: GitHub
浏览文件

Merge pull request #1151 from NVIDIA/spatial_fast_bottleneck

Spatially Distributed Fast Bottleneck block
from .bottleneck import Bottleneck
from .bottleneck import Bottleneck, SpatialBottleneck
import torch
import torch.distributed as dist
from torch import nn
import fast_bottleneck
......@@ -212,3 +213,235 @@ class Bottleneck(torch.nn.Module):
out = self.relu(out)
return out
class SpatialBottleneckFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, spatial_group_size, local_rank, comm, stream1, nhwc, stride_1x1, scale, bias, x, *conv):
# TODO: clean up order of tensors
args = [x, *conv[0:3], *scale[0:3], *bias[0:3]]
ctx.downsample = len(conv) > 3
if ctx.downsample:
args.append(conv[3])
args.append(scale[3])
args.append(bias[3])
# weight buffers are always in nhwc while shape can be nhwc or channels_last
# here we pass in flag and let c++ handle it
# alternatively, we can put all sizes into a fixed format and pass it in
outputs = fast_bottleneck.forward_init(nhwc, stride_1x1, args)
fast_bottleneck.forward_out1(nhwc, stride_1x1, args, outputs)
fast_bottleneck.forward_out2(nhwc, stride_1x1, args, outputs)
# do halo exchange for outputs[0] (out1)
if spatial_group_size > 1:
out1 = outputs[0]
N,Hs,W,C = list(out1.shape)
padded_out1 = torch.empty((N,Hs+2,W,C),dtype=out1.dtype,device=out1.device)
padded_out1[:,1:Hs+1,:,:].copy_(out1)
stream1.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream1):
# copy halos to send buffer
send_halos = torch.empty((N,2,W,C),dtype=out1.dtype,device=out1.device)
send_halos[:,:1,:,:].copy_(out1[:,:1,:,:])
send_halos[:,1:,:,:].copy_(out1[:,Hs-1:,:,:])
all_halos = torch.empty((N,2*spatial_group_size,W,C),dtype=out1.dtype,device=out1.device)
all_halos = [all_halos[:,i*2:(i+1)*2,:,:] for i in range(spatial_group_size)]
dist.all_gather(all_halos,send_halos)
padded_out1_top_halo = padded_out1[:,:1,:,:]
if local_rank > 0:
top_halo = all_halos[local_rank-1][:,1:,:,:]
padded_out1_top_halo.copy_(top_halo)
fat_top_halo = padded_out1[:,:3,:,:]
top_out2 = fast_bottleneck.forward_out2_halo(nhwc, fat_top_halo, args)
else:
padded_out1_top_halo.zero_()
padded_out1_btm_halo = padded_out1[:,Hs+1:,:,:]
if local_rank < spatial_group_size-1:
btm_halo = all_halos[local_rank+1][:,:1,:,:]
padded_out1_btm_halo.copy_(btm_halo)
fat_btm_halo = padded_out1[:,Hs-1:,:,:]
btm_out2 = fast_bottleneck.forward_out2_halo(nhwc, fat_btm_halo, args)
else:
padded_out1_btm_halo.zero_()
torch.cuda.current_stream().wait_stream(stream1)
out2 = outputs[1]
if local_rank > 0:
out2[:,:1,:,:].copy_(top_out2)
if local_rank < spatial_group_size-1:
out2[:,Hs-1:,:,:].copy_(btm_out2)
fast_bottleneck.forward_rest(nhwc, stride_1x1, args, outputs)
if spatial_group_size > 1:
ctx.save_for_backward(*(args+outputs+[padded_out1]))
else:
ctx.save_for_backward(*(args+outputs))
# save relu outputs for drelu
ctx.nhwc = nhwc
ctx.stride_1x1 = stride_1x1
ctx.spatial_group_size = spatial_group_size
ctx.local_rank = local_rank
ctx.comm = comm
ctx.stream1 = stream1
return outputs[2]
# backward relu is not exposed, MUL with mask used now
# only support dgrad
@staticmethod
def backward(ctx, grad_o):
if ctx.spatial_group_size > 1:
outputs = ctx.saved_tensors[-4:-1]
else:
outputs = ctx.saved_tensors[-3:]
if ctx.downsample:
grad_conv3, grad_conv4 = drelu_dscale2(grad_o, outputs[2], ctx.saved_tensors[6], ctx.saved_tensors[11])
else:
grad_conv3, grad_conv4 = drelu_dscale1(grad_o, outputs[2], ctx.saved_tensors[6])
# create input vector for backward
t_list = [*ctx.saved_tensors[0:10]]
t_list.append(grad_conv3)
t_list.append(grad_conv4)
# outputs used for wgrad and generating drelu mask
t_list.append(outputs[0])
t_list.append(outputs[1])
# in case there is downsample
if ctx.downsample:
t_list.append(ctx.saved_tensors[10])
grads = fast_bottleneck.backward_init(ctx.nhwc, ctx.stride_1x1, t_list)
grad_out2 = fast_bottleneck.backward_grad_out2(ctx.nhwc, ctx.stride_1x1, t_list, grads)
# do halo exchange of grad_out2 here
fast_bottleneck.backward_rest(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
return (None, None, None, None, None, None, None, None, *grads)
spatial_bottleneck_function = SpatialBottleneckFunction.apply
class SpatialBottleneck(torch.nn.Module):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
# here we put it at 1x1
def __init__(self, in_channels, bottleneck_channels, out_channels, stride=1, groups=1,
dilation=1, norm_func=None, use_cudnn=False, explicit_nhwc=False,
spatial_group_size=1):
super(SpatialBottleneck, self).__init__()
if groups != 1:
raise RuntimeError('Only support groups == 1')
if dilation != 1:
raise RuntimeError('Only support dilation == 1')
if norm_func == None:
norm_func = FrozenBatchNorm2d
else:
raise RuntimeError('Only support frozen BN now.')
if stride != 1 or in_channels != out_channels:
self.downsample = nn.Sequential(
conv1x1(in_channels, out_channels, stride),
norm_func(out_channels),
)
else:
self.downsample = None
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(in_channels, bottleneck_channels, stride)
self.conv2 = conv3x3(bottleneck_channels, bottleneck_channels)
self.conv3 = conv1x1(bottleneck_channels, out_channels)
self.relu = nn.ReLU(inplace=True)
self.stride = stride
self.bn1 = norm_func(bottleneck_channels)
self.bn2 = norm_func(bottleneck_channels)
self.bn3 = norm_func(out_channels)
self.use_cudnn = use_cudnn
# setup conv weights
self.w_conv = [self.conv1.weight, self.conv2.weight, self.conv3.weight]
if self.downsample is not None:
self.w_conv.append(self.downsample[0].weight)
# init weight in nchw format before possible transpose
for w in self.w_conv:
kaiming_uniform_(w, a=1)
# TODO: prevent unsupported case usage
# support cases
# native cudnn
# normal yes no
# channel_last yes yes
# explicit_nhwc no yes
self.explicit_nhwc = explicit_nhwc
if self.explicit_nhwc:
for p in self.parameters():
with torch.no_grad():
p.data = p.data.permute(0,2,3,1).contiguous()
# spatial communicator
self.spatial_group_size = spatial_group_size
if spatial_group_size > 1:
world_size = dist.get_world_size()
num_groups = world_size // spatial_group_size
assert(num_groups*spatial_group_size == world_size), "torch.distributed.get_world_size() must be multiple of group_size"
rank = dist.get_rank()
self.local_rank = rank % spatial_group_size
for group in range(num_groups):
ranks = list(range(group*spatial_group_size,(group+1)*spatial_group_size))
comm = torch.distributed.new_group(ranks=ranks)
if rank in ranks:
self.communicator = comm
self.stream1 = torch.cuda.Stream()
self.spatial_args = self.spatial_group_size, self.local_rank, self.communicator, self.stream1
else:
self.spatial_args = 1, 0, None, None
return
def forward(self, x):
if self.use_cudnn:
# calculate scale/bias from registered buffers
# TODO: make this better
s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc)
s2, b2 = self.bn2.get_scale_bias(self.explicit_nhwc)
s3, b3 = self.bn3.get_scale_bias(self.explicit_nhwc)
w_scale = [s1, s2, s3]
w_bias = [b1, b2, b3]
if self.downsample is not None:
s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc)
w_scale.append(s4)
w_bias.append(b4)
out = spatial_bottleneck_function(*self.spatial_args, self.explicit_nhwc, self.stride, w_scale, w_bias, x, *self.w_conv)
return out
if self.explicit_nhwc:
raise RuntimeError('explicit nhwc with native ops is not supported.')
# fallback to native ops
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
import os
import torch
from maskrcnn_benchmark.modeling.backbone.resnet import Bottleneck
from maskrcnn_benchmark.layers.nhwc import nhwc_to_nchw_transform, nchw_to_nhwc_transform
from maskrcnn_benchmark.layers.nhwc.batch_norm import FrozenBatchNorm2d_NHWC
from apex.contrib.bottleneck import Bottleneck as FastBottleneck
def single_module_test(ref, rank, world_size, numtype, device, shape, fast, spatial_group_size, in_channels, bottleneck_channels, out_channels, num_groups, stride_in_1x1, stride, dilation, norm_func, nhwc):
# inputs + modules
with torch.no_grad():
input_shape = [1, in_channels] + list(shape)
x = torch.randn(input_shape, dtype=numtype, device=device)
if nhwc:
x = nchw_to_nhwc_transform(x).contiguous()
x.requires_grad = True
print(x.shape, x.stride())
#if spatial_group_size > 1:
# fast = False # hack so fast bottleneck can be run against distributed bottleneck
#if spatial_group_size == 1:
# fast = False
if fast:
bottleneck = FastBottleneck(
in_channels=in_channels,
bottleneck_channels=bottleneck_channels,
out_channels=out_channels,
stride=stride,
dilation=dilation,
explicit_nhwc=nhwc,
use_cudnn=True)
if spatial_group_size > 1:
print("WARNING! spatial_group_size ignored by FastBottleneck")
else:
bottleneck = Bottleneck(
in_channels,
bottleneck_channels,
out_channels,
num_groups,
stride_in_1x1,
stride,
dilation,
norm_func,
nhwc,
spatial_group_size)
bottleneck = bottleneck.to(dtype=numtype,device=device)
weights = dict(bottleneck.named_parameters())
if ref is not None:
ref_x, _, ref_weights = ref
Hs,H = x.shape[1], ref_x.shape[1]
assert(Hs*spatial_group_size == H), "Hs not a multiple of H"
ref_x = ref_x[:,rank*Hs:(rank+1)*Hs,:,:]
x.copy_(ref_x)
assert(len(weights) == len(ref_weights)), "Reference weights and weights don't match"
for k in weights.keys():
weights[k].copy_(ref_weights[k])
# forward
out = bottleneck(x)
# gradient output
with torch.no_grad():
grad_out = torch.randn_like(out)
if ref is not None:
_, ref_grad_out, _ = ref
Hs,H = grad_out.shape[1], ref_grad_out.shape[1]
assert(Hs*spatial_group_size == H), "Hs not a multiple of H"
ref_grad_out = ref_grad_out[:,rank*Hs:(rank+1)*Hs,:,:]
grad_out.copy_(ref_grad_out)
# backward
out.backward(grad_out)
with torch.no_grad():
dgrad = x.grad.detach()
wgrad = {}
for n,p in bottleneck.named_parameters():
wgrad[n] = p.grad.detach()
if world_size > 1:
if spatial_group_size == 1:
# broadcast x, grad_out and weights from rank 0
with torch.no_grad():
torch.distributed.broadcast(x,0)
torch.distributed.broadcast(grad_out,0)
for k in weights.keys():
torch.distributed.broadcast(weights[k],0)
else:
# gather dgrad (x.grad), sum wgrad (weights)
N,Hs,W,C = dgrad.shape
H = Hs * spatial_group_size
dgrad_gathered = torch.empty((N,H,W,C),dtype=dgrad.dtype,device=dgrad.device)
dgrad_tensors = [dgrad_gathered[:,i*Hs:(i+1)*Hs,:,:] for i in range(spatial_group_size)]
torch.distributed.all_gather(dgrad_tensors, dgrad)
dgrad = dgrad_gathered
for k in wgrad.keys():
torch.distributed.all_reduce(wgrad[k])
return x, out, grad_out, weights, dgrad, wgrad
def module_tests(rank, world_size, numtype, device, fast, spatial_group_sizes, init_args):
r = []
for ia in init_args:
shape = ia[0:4]
args = ia[4:]
rr = []
ref = None
for spatial_group_size in spatial_group_sizes:
N,H,W,C = shape
H = H//spatial_group_size
x, out, grad_out, weights, dgrad, wgrad = single_module_test(ref, rank, world_size, numtype, device, [H,W], fast, spatial_group_size, *args)
if ref is None:
assert(spatial_group_size == 1), "Wrong reference weights"
ref = x, grad_out, weights
if rank == 0:
rr.append( (out, dgrad, wgrad) )
torch.distributed.barrier()
r.append(rr)
return r
def main():
total_num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
distributed = total_num_gpus > 1
ngpus = torch.cuda.device_count()
if distributed:
torch.distributed.init_process_group("nccl")
rank, world_size = torch.distributed.get_rank(), torch.distributed.get_world_size()
is_master = True if rank == 0 else False
local_rank = rank % ngpus
torch.cuda.set_device(local_rank)
spatial_group_size = total_num_gpus
else:
rank, local_rank, is_master, world_size, spatial_group_size = 0, 0, True, 1, 1
#torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = True
#torch.backends.cudnn.deterministic = True
#torch.backends.cuda.matmul.allow_tf32 = False
#torch.backends.cudnn.allow_tf32 = False
norm_func = FrozenBatchNorm2d_NHWC
init_args = [
(1, 200, 336, 64, 64, 64, 256, 1, True, 1, 1, norm_func, True),
(1, 200, 336, 256, 256, 64, 256, 1, True, 1, 1, norm_func, True),
(1, 200, 336, 256, 256, 128, 512, 1, True, 2, 1, norm_func, True),
(1, 100, 168, 512, 512, 128, 512, 1, True, 1, 1, norm_func, True),
(1, 100, 168, 512, 512, 256, 1024, 1, True, 2, 1, norm_func, True),
(1, 50, 84, 1024, 1024, 256, 1024, 1, True, 1, 1, norm_func, True),
(1, 50, 84, 1024, 1024, 512, 2048, 1, True, 2, 1, norm_func, True),
(1, 25, 42, 2048, 2048, 512, 2048, 1, True, 1, 1, norm_func, True),
(1, 336, 200, 64, 64, 64, 256, 1, True, 1, 1, norm_func, True),
(1, 336, 200, 256, 256, 64, 256, 1, True, 1, 1, norm_func, True),
(1, 336, 200, 256, 256, 128, 512, 1, True, 2, 1, norm_func, True),
(1, 168, 100, 512, 512, 128, 512, 1, True, 1, 1, norm_func, True),
(1, 168, 100, 512, 512, 256, 1024, 1, True, 2, 1, norm_func, True),
(1, 84, 50, 1024, 1024, 256, 1024, 1, True, 1, 1, norm_func, True),
(1, 84, 50, 1024, 1024, 512, 2048, 1, True, 2, 1, norm_func, True),
(1, 42, 25, 2048, 2048, 512, 2048, 1, True, 1, 1, norm_func, True),
]
# pad H to account for spatial distribution
padded_init_args = []
for ia in init_args:
N,H,W,C = ia[0:4]
m = spatial_group_size * H // (25 if H < W else 42)
H = ((H + m - 1) // m) * m
args = tuple( [N,H,W,C] + list(ia[4:]) )
padded_init_args.append(args)
init_args = padded_init_args
if rank == 0:
for ia in init_args:
print(ia)
spatial_group_sizes = [1]
if spatial_group_size > 1:
spatial_group_sizes.append(spatial_group_size)
numtype, device, fast = torch.float16, 'cuda', False
r = module_tests(rank, world_size, numtype, device, fast, spatial_group_sizes, init_args)
torch.distributed.barrier()
if rank == 0:
for rr in r:
print("***")
for out, dgrad, wgrad in rr:
gr = [("dgrad",dgrad.norm(p=2,dtype=torch.float64).item())] + [(k+".wgrad",wgrad[k].norm(p=2,dtype=torch.float64).item()) for k in wgrad.keys()]
print(gr)
torch.distributed.barrier()
if __name__ == "__main__":
main()
支持 Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册