未验证 提交 1d5f7e55 编辑于 作者: Burc Eryilmaz's avatar Burc Eryilmaz 提交者: GitHub
浏览文件

change chunking scheme for full-allreduce case, add parameter order argument,...

change chunking scheme for full-allreduce case, add parameter order argument, both to enable contiguous chunking of allgather (#1190)
上级 d9a46fde
import os
import math
import torch
import importlib
......@@ -89,7 +90,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0, fused_norm=False,
e5m2_allgather=False, verbose=False, clip_after_ar=True,
full_ar=False, set_param_views_to_flat_buffer=False, skip_allgather=False,
fuse_scale=False):
fuse_scale=False, param_order=None, nccl_allgather_channels=0):
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay,
grad_averaging=grad_averaging,
......@@ -170,7 +171,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
if torch.distributed.get_rank() in ranks:
self._ar_pg.append(grp)
self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]
if nccl_allgather_channels > 0:
os.putenv('NCCL_MAX_NCHANNELS', str(nccl_allgather_channels))
if self._num_ag_pg == 0:
self._ag_pg = self._ar_pg
self._ag_st = self._ar_st
......@@ -283,7 +285,11 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
))
p_grads_size = p.numel()
if self._set_flat_param_view:
self._param_order.add(p_i)
if param_order:
# this is executed when param_order is specified by the user
self._param_order.add(param_order[p])
else:
self._param_order.add(p_i)
p_offset += p_grads_size
# Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
# RNN is one example of consecutive parameters:
......@@ -292,6 +298,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
p_offset = ((p_offset + 63) // 64) * 64
prev = p
p_i += 1
if param_order:
self._param_order.order = torch.argsort(torch.tensor(self._param_order.order)).tolist()
self._grads_generated = [False]*len(self._model_params)
self._grads_fp16, self._grads_fp32 = [], []
if self._overlap_reductions:
......@@ -330,7 +338,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._do_overlapped_reduction(idx, param)
else:
if not self._first_step:
self._do_overlapped_reduction(param_i, param)
idx = self._param_order.order.index(param_i)
self._do_overlapped_reduction(idx, param)
grad_acc.register_hook(allreduce_hook)
self._grad_accs.append(grad_acc)
wrapper(p, p_i)
......@@ -359,7 +368,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)]
def __shardify(p):
return [p[shard_id*self._shard_size:(shard_id+1)*self._shard_size] for shard_id in range(self._group_size)]
list_of_blocks = __blockify(self._flat_grads)
list_of_blocks = __blockify(p)
list_of_list_of_chunks = [__chunkify(block) for block in list_of_blocks]
list_of_list_of_list_of_shards = [[__shardify(chunk) for chunk in chunks] for chunks in list_of_list_of_chunks]
return list_of_blocks, list_of_list_of_chunks, list_of_list_of_list_of_shards
......@@ -392,8 +401,20 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
list_of_blocks = __packed_blockify(p)
list_of_list_of_chunks = [__packed_chunkify(block) for block in list_of_blocks]
return list_of_blocks, list_of_list_of_chunks
def _split_assign(shards):
packed_block_size = self._num_chunks*self._shard_size
list_of_list_of_chunks=[]
for block_id in range(self._num_blocks):
list_of_chunks=[]
for chunk_id in range(self._num_chunks):
#self._fp16_g[block_id*packed_block_size+chunk_id*self._shard_size:block_id*packed_block_size+(chunk_id+1)*self._shard_size] = shards[block_id][chunk_id][self._rank_in_group]
list_of_chunks.append( shards[block_id][chunk_id][self._rank_in_group])
list_of_list_of_chunks.append(list_of_chunks)
return list_of_list_of_chunks
self._new_params_mega_shards, self._new_params_mega_blocks, self._new_params_mega_chunks = _full_packed_split(self._new_params)
# this splitting scheme is needed when allgather needs to be split into multiple chunks in a contiguous way
self._new_params2_blocks, self._new_params2_chunks, self._new_params2_shards = _flat_split(self._new_params)
self._fp32_p_blocks, self._fp32_p_chunks = _packed_split(self._fp32_p)
self._fp32_m_blocks, self._fp32_m_chunks = _packed_split(self._fp32_m)
......@@ -403,10 +424,9 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
if self._full_ar:
# for gradient all-reduce
self._flat_grads_blocks, self._flat_grads_chunks = _flat_split_no_shards(self._flat_grads)
self._flat_grads_blocks, self._flat_grads_chunks, self._flat_grads_shards = _flat_split(self._flat_grads)
# for weight update
self._flat_grads_shards, _, _ = _full_packed_split(self._flat_grads)
self._fp16_g_blocks, self._fp16_g_chunks = _packed_split(self._flat_grads_shards[self._rank_in_group])
self._fp16_g_chunks = _split_assign(self._flat_grads_shards)
else:
self._flat_grads_blocks, self._flat_grads_chunks, self._flat_grads_shards = _flat_split(self._flat_grads)
self._fp16_g_blocks, self._fp16_g_chunks = _packed_split(self._fp16_g)
......@@ -416,12 +436,13 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
def _lazy_init_stage2(self):
if self._lazy_init_stage2_done: return
if not self._set_flat_param_view:
# reversing is needed for overlapping allreduce and backprop, but currently not supported for flat param view
self._param_order.order.reverse()
# re-order model_params, grad_accs, group_properties lists
self._model_params = [self._model_params[i] for i in self._param_order.order]
self._grad_accs = [self._grad_accs[i] for i in self._param_order.order]
self._group_properties = [self._group_properties[i] for i in self._param_order.order]
self._model_params = [self._model_params[i] for i in self._param_order.order]
self._grad_accs = [self._grad_accs[i] for i in self._param_order.order]
self._group_properties = [self._group_properties[i] for i in self._param_order.order]
def _get_flat_view(param):
if param.is_contiguous(memory_format=torch.channels_last):
......@@ -570,13 +591,21 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
def _full_all_reduce_scale(self, block_id, scale):
works = [None]*self._num_chunks
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
if self._clip_after_ar:
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
ar_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(ar_stream):
works[chunk_id] = torch.distributed.all_reduce(self._flat_grads_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True,op=torch.distributed.make_nccl_premul_sum((scale,)))
else:
glob_chunk_id = block_id
ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
ar_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(ar_stream):
works[chunk_id] = torch.distributed.all_reduce(self._flat_grads_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True,op=torch.distributed.make_nccl_premul_sum((scale,)))
works0 = torch.distributed.all_reduce(self._flat_grads_blocks[block_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True,op=torch.distributed.make_nccl_premul_sum((scale,)))
for i in range(self._num_chunks):
works[i]=works0
self._reductions_works[block_id] = works
def _full_all_reduce(self, block_id):
......@@ -653,7 +682,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._reductions_works[block_id][chunk_id].wait()
# Since the packed format is contiguous after reductions, only one norm is needed
l2_grad_norm_sq = torch.empty([1], device='cuda')
if self._full_ar:
if 0:#self._full_ar:
l2_grad_norm_sq = self._flat_grads_shards[self._rank_in_group].norm(dtype=torch.float32, p=2)**2
else:
l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2
......@@ -775,7 +804,13 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
global_grad_norm,
self._use_nvlamb)
if not self._skip_ag:
torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)
# allgather chunking is currently not supported for clip after allreduce
if not self._clip_after_ar:
for block in range(self._num_blocks):
for chunk in range(self._num_chunks):
torch.distributed.all_gather(self._new_params2_shards[block][chunk], self._fp16_p_chunks[block][chunk], group=self._ag_pg[0], no_copy=True)
else:
torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)
def _flatten_grad_mt(self, scale):
if len(self._grads_fp16) > 0:
......
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册