未验证 提交 ae1cdd64 编辑于 作者: Thor Johnsen's avatar Thor Johnsen 提交者: GitHub
浏览文件

Merge pull request #1161 from NVIDIA/optional_caller_supplied_communicator

Optional NCCL communicator argument to init method
......@@ -393,7 +393,7 @@ class SpatialBottleneck(torch.nn.Module):
def __init__(self, in_channels, bottleneck_channels, out_channels, stride=1, groups=1,
dilation=1, norm_func=None, use_cudnn=False, explicit_nhwc=False,
spatial_group_size=1):
spatial_group_size=1, communicator=None):
super(SpatialBottleneck, self).__init__()
if groups != 1:
raise RuntimeError('Only support groups == 1')
......@@ -454,11 +454,14 @@ class SpatialBottleneck(torch.nn.Module):
assert(num_groups*spatial_group_size == world_size), "torch.distributed.get_world_size() must be multiple of group_size"
rank = dist.get_rank()
self.local_rank = rank % spatial_group_size
for group in range(num_groups):
ranks = list(range(group*spatial_group_size,(group+1)*spatial_group_size))
comm = torch.distributed.new_group(ranks=ranks)
if rank in ranks:
self.communicator = comm
if communicator is None:
for group in range(num_groups):
ranks = list(range(group*spatial_group_size,(group+1)*spatial_group_size))
comm = torch.distributed.new_group(ranks=ranks)
if rank in ranks:
self.communicator = comm
else:
self.communicator = communicator
self.stream1 = torch.cuda.Stream()
self.spatial_args = self.spatial_group_size, self.local_rank, self.communicator, self.stream1
else:
......
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册