未验证 提交 1cb9c5c3 编辑于 作者: Kexin Yu's avatar Kexin Yu 提交者: GitHub
浏览文件

Add full all-reduce code path for DistributedFusedAdam (#1146)



* add full all-reduce code path

* debug

* debug
Co-authored-by: default avatarptrblck <ptrblck@users.noreply.github.com>
上级 d934eca3
......@@ -87,7 +87,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
step_supports_amp_scaling=True, overlap_reductions=True,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,
dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0, fused_norm=False,
e5m2_allgather=False, verbose=False, clip_after_ar=True):
e5m2_allgather=False, verbose=False, clip_after_ar=True,
full_ar=False):
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay,
grad_averaging=grad_averaging,
......@@ -119,6 +120,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._e5m2_allgather = e5m2_allgather
self._verbose = verbose
self._clip_after_ar = clip_after_ar
self._full_ar = full_ar
self._L2_grad_norm = None
self._fused_norm = fused_norm
self._current_process_group = c10d._get_default_group()
......@@ -142,58 +144,100 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._num_rs_pg = dwu_num_rs_pg
self._num_ar_pg = dwu_num_ar_pg
self._num_ag_pg = dwu_num_ag_pg
if self._num_groups > 1:
self._ar_pg = []
for dev_i in range(self._group_size):
ranks = [dev_i+j*self._group_size for j in range(self._num_groups)]
for i in range(self._num_ar_pg):
if self._verbose:
print(f"creating new group {i}: {ranks}")
grp = torch.distributed.new_group(ranks=ranks)
if grp != torch.distributed.GroupMember.NON_GROUP_MEMBER:
if self._verbose:
print(f"group {i}: init barrier (device: {torch.cuda.current_device()})")
torch.distributed.barrier(group=grp, device_ids=[torch.cuda.current_device()])
if self._verbose:
print(f"created new group {i}")
if torch.distributed.get_rank() in ranks:
self._ar_pg.append(grp)
self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]
#for ar_pg in self._ar_pg:
# torch.distributed.all_reduce(self._overflow_buf,group=ar_pg)
rs_ranks = []
for group_i in range(self._num_groups):
rs_ranks.append([group_i*self._group_size+j for j in range(self._group_size)])
self._rs_pg = []
for group_i in range(self._num_groups):
ranks = rs_ranks[group_i]
for i in range(self._num_rs_pg):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._rs_pg.append(grp)
if self._full_ar: # full all reduce, only need AR and AG groups
self._ar_pg = []
# consider all the ranks
ranks = list(range(0, self._world_size))
l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._l2_grad_norm_pg = l2_grad_norm_pg
#torch.distributed.all_reduce(self._overflow_buf,group=self._l2_grad_norm_pg)
self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)]
#for rs_pg in self._rs_pg:
# torch.distributed.all_reduce(self._overflow_buf,group=rs_pg)
if self._num_ag_pg == 0:
self._ag_pg = self._rs_pg
self._ag_st = self._rs_st
self._num_ag_pg = self._num_rs_pg
else:
self._ag_pg = []
for i in range(self._num_ar_pg):
if self._verbose:
print(f"creating new AR group {i}: {ranks}")
grp = torch.distributed.new_group(ranks=ranks)
if grp != torch.distributed.GroupMember.NON_GROUP_MEMBER:
if self._verbose:
print(f"group {i}: init barrier (device: {torch.cuda.current_device()})")
torch.distributed.barrier(group=grp, device_ids=[torch.cuda.current_device()])
if self._verbose:
print(f"created new AR group {i}: {ranks}")
if torch.distributed.get_rank() in ranks:
self._ar_pg.append(grp)
self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]
if self._num_ag_pg == 0:
self._ag_pg = self._ar_pg
self._ag_st = self._ar_st
self._num_ag_pg = self._num_ar_pg
else:
self._ag_pg = []
ranks = []
stride = torch.cuda.device_count()
for i in range(self._num_groups):
rs = list(range(i*stride, (i+1)*stride))
ranks.append(rs)
for rs in ranks:
for i in range(self._num_ag_pg):
grp = torch.distributed.new_group(ranks=rs)
if torch.distributed.get_rank() in rs:
if self._verbose:
print(f"creating AG group {i}: {rs}")
self._ag_pg.append(grp)
self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]
else: # reduce-scatter + all-reduce, need RS, AR, AG groups
if self._num_groups > 1:
self._ar_pg = []
for dev_i in range(self._group_size):
ranks = [dev_i+j*self._group_size for j in range(self._num_groups)]
for i in range(self._num_ar_pg):
if self._verbose:
print(f"creating new AR group {i}: {ranks}")
grp = torch.distributed.new_group(ranks=ranks)
if grp != torch.distributed.GroupMember.NON_GROUP_MEMBER:
if self._verbose:
print(f"group {i}: init barrier (device: {torch.cuda.current_device()})")
torch.distributed.barrier(group=grp, device_ids=[torch.cuda.current_device()])
if self._verbose:
print(f"created new AR group {i}: {ranks}")
if torch.distributed.get_rank() in ranks:
self._ar_pg.append(grp)
self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]
rs_ranks = []
for group_i in range(self._num_groups):
rs_ranks.append([group_i*self._group_size+j for j in range(self._group_size)])
self._rs_pg = []
for group_i in range(self._num_groups):
ranks = rs_ranks[group_i]
for i in range(self._num_ag_pg):
for i in range(self._num_rs_pg):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._ag_pg.append(grp)
self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]
#for ag_pg in self._ag_pg:
# torch.distributed.all_reduce(self._overflow_buf,group=ag_pg)
self._rs_pg.append(grp)
if self._verbose:
print(f"creating RS group : {ranks}")
l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._l2_grad_norm_pg = l2_grad_norm_pg
self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)]
if self._num_ag_pg == 0:
self._ag_pg = self._rs_pg
self._ag_st = self._rs_st
self._num_ag_pg = self._num_rs_pg
else:
self._ag_pg = []
for group_i in range(self._num_groups):
ranks = rs_ranks[group_i]
for i in range(self._num_ag_pg):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._ag_pg.append(grp)
if self._verbose:
print(f"creating AG group : {ranks}")
self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]
self._l2_grad_norm_st = torch.cuda.Stream()
self._completion_st = torch.cuda.Stream()
self._step.record_stream(self._completion_st)
......@@ -295,7 +339,14 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
list_of_list_of_chunks = [__chunkify(block) for block in list_of_blocks]
list_of_list_of_list_of_shards = [[__shardify(chunk) for chunk in chunks] for chunks in list_of_list_of_chunks]
return list_of_blocks, list_of_list_of_chunks, list_of_list_of_list_of_shards
self._flat_grads_blocks, self._flat_grads_chunks, self._flat_grads_shards = _flat_split(self._flat_grads)
def _flat_split_no_shards(p):
def __blockify(p):
return [p[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)]
def __chunkify(p):
return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)]
list_of_blocks = __blockify(self._flat_grads)
list_of_list_of_chunks = [__chunkify(block) for block in list_of_blocks]
return list_of_blocks, list_of_list_of_chunks
def _full_packed_split(p):
def __shardify(p):
return [p[mega_shard*self._mega_shard_size:(mega_shard+1)*self._mega_shard_size] for mega_shard in range(self._group_size)]
......@@ -307,7 +358,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
list_of_list_of_mega_blocks = [__blockify(mega_shard) for mega_shard in list_of_mega_shards]
list_of_list_of_list_of_mega_chunks = [[__chunkify(mega_block) for mega_block in mega_blocks] for mega_blocks in list_of_list_of_mega_blocks]
return list_of_mega_shards, list_of_list_of_mega_blocks, list_of_list_of_list_of_mega_chunks
self._new_params_mega_shards, self._new_params_mega_blocks, self._new_params_mega_chunks = _full_packed_split(self._new_params)
def _packed_split(p):
def __packed_blockify(p):
packed_block_size = self._num_chunks*self._shard_size
......@@ -318,12 +368,24 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
list_of_blocks = __packed_blockify(p)
list_of_list_of_chunks = [__packed_chunkify(block) for block in list_of_blocks]
return list_of_blocks, list_of_list_of_chunks
self._new_params_mega_shards, self._new_params_mega_blocks, self._new_params_mega_chunks = _full_packed_split(self._new_params)
self._fp32_p_blocks, self._fp32_p_chunks = _packed_split(self._fp32_p)
self._fp32_m_blocks, self._fp32_m_chunks = _packed_split(self._fp32_m)
self._fp32_v_blocks, self._fp32_v_chunks = _packed_split(self._fp32_v)
self._fp32_u_blocks, self._fp32_u_chunks = _packed_split(self._fp32_u)
self._fp16_p_blocks, self._fp16_p_chunks = _packed_split(self._fp16_p)
self._fp16_g_blocks, self._fp16_g_chunks = _packed_split(self._fp16_g)
if self._full_ar:
# for gradient all-reduce
self._flat_grads_blocks, self._flat_grads_chunks = _flat_split_no_shards(self._flat_grads)
# for weight update
self._flat_grads_shards, _, _ = _full_packed_split(self._flat_grads)
self._fp16_g_blocks, self._fp16_g_chunks = _packed_split(self._flat_grads_shards[self._rank_in_group])
else:
self._flat_grads_blocks, self._flat_grads_chunks, self._flat_grads_shards = _flat_split(self._flat_grads)
self._fp16_g_blocks, self._fp16_g_chunks = _packed_split(self._fp16_g)
self._lazy_init_stage1_done = True
......@@ -482,30 +544,48 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
return flush_block
def _full_all_reduce(self, block_id):
works = [None]*self._num_chunks
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
ar_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(ar_stream):
works[chunk_id] = torch.distributed.all_reduce(self._flat_grads_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)
self._reductions_works[block_id] = works
def _reduce_scatter_and_all_reduce(self, block_id):
# Reduction within each node
# Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
# The output format is the same as the fp32 master parameters
works = [None]*self._num_chunks
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg]
rs_stream.wait_stream(torch.cuda.current_stream())
rs_stream.wait_stream(self._l2_grad_norm_st)
with torch.cuda.stream(rs_stream):
works[chunk_id] = torch.distributed.reduce_scatter(self._fp16_g_chunks[block_id][chunk_id],self._flat_grads_shards[block_id][chunk_id],group=self._rs_pg[glob_chunk_id%self._num_rs_pg],async_op=True,no_copy=True)
# Reduction across nodes for each rank
if self._num_groups > 1:
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
with torch.cuda.stream(ar_stream):
works[chunk_id].wait()
works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)
self._reductions_works[block_id] = works
def _pipeline_block_reductions(self, block_id):
if self._clip_after_ar:
self._flatten_grad_mt(1.0/self._world_size)
# Reduction within each node
# Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
# The output format is the same as the fp32 master parameters
works = [None]*self._num_chunks
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg]
rs_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(rs_stream):
works[chunk_id] = torch.distributed.reduce_scatter(self._fp16_g_chunks[block_id][chunk_id],self._flat_grads_shards[block_id][chunk_id],group=self._rs_pg[glob_chunk_id%self._num_rs_pg],async_op=True,no_copy=True)
# Reduction across nodes for each rank
if self._num_groups > 1:
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
with torch.cuda.stream(ar_stream):
works[chunk_id].wait()
works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)
self._reductions_works[block_id] = works
if self._full_ar:
self._full_all_reduce(block_id)
else:
self._reduce_scatter_and_all_reduce(block_id)
# Compute L2 grad norm
if block_id == 0:
......@@ -515,7 +595,10 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._reductions_works[block_id][chunk_id].wait()
# Since the packed format is contiguous after reductions, only one norm is needed
l2_grad_norm_sq = torch.empty([1], device='cuda')
l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2
if self._full_ar:
l2_grad_norm_sq = self._flat_grads_shards[self._rank_in_group].norm(dtype=torch.float32, p=2)**2
else:
l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2
torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg)
self._L2_grad_norm = l2_grad_norm_sq.sqrt()
else:
......@@ -539,27 +622,10 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
scale = tmp.index_select(0, index).half()/self._world_size
self._flat_grads.mul_(scale)
# Reduction within each node
# Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
# The output format is the same as the fp32 master parameters
works = [None]*self._num_chunks
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg]
rs_stream.wait_stream(torch.cuda.current_stream())
rs_stream.wait_stream(self._l2_grad_norm_st)
with torch.cuda.stream(rs_stream):
works[chunk_id] = torch.distributed.reduce_scatter(self._fp16_g_chunks[block_id][chunk_id],self._flat_grads_shards[block_id][chunk_id],group=self._rs_pg[glob_chunk_id%self._num_rs_pg],async_op=True,no_copy=True)
# Reduction across nodes for each rank
if self._num_groups > 1:
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
with torch.cuda.stream(ar_stream):
works[chunk_id].wait()
works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)
self._reductions_works[block_id] = works
if self._full_ar:
self._full_all_reduce(block_id)
else:
self._reduce_scatter_and_all_reduce(block_id)
if block_id == 0:
for block_id in range(self._num_blocks):
......
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册