未验证 提交 2e98baa7 编辑于 作者: Burc Eryilmaz's avatar Burc Eryilmaz 提交者: GitHub
浏览文件

use prescaling for collective (#1157)

上级 1cb9c5c3
......@@ -88,7 +88,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,
dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0, fused_norm=False,
e5m2_allgather=False, verbose=False, clip_after_ar=True,
full_ar=False):
full_ar=False, fuse_scale=False):
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay,
grad_averaging=grad_averaging,
......@@ -121,6 +121,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._verbose = verbose
self._clip_after_ar = clip_after_ar
self._full_ar = full_ar
self._fuse_scale = fuse_scale
self._L2_grad_norm = None
self._fused_norm = fused_norm
self._current_process_group = c10d._get_default_group()
......@@ -544,6 +545,17 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
return flush_block
def _full_all_reduce_scale(self, block_id, scale):
works = [None]*self._num_chunks
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
ar_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(ar_stream):
works[chunk_id] = torch.distributed.all_reduce(self._flat_grads_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True,op=torch.distributed.make_nccl_premul_sum((scale,)))
self._reductions_works[block_id] = works
def _full_all_reduce(self, block_id):
works = [None]*self._num_chunks
......@@ -555,6 +567,29 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
works[chunk_id] = torch.distributed.all_reduce(self._flat_grads_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)
self._reductions_works[block_id] = works
def _reduce_scatter_and_all_reduce_scale(self, block_id, scale):
# Reduction within each node
# Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
# The output format is the same as the fp32 master parameters
works = [None]*self._num_chunks
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg]
rs_stream.wait_stream(torch.cuda.current_stream())
rs_stream.wait_stream(self._l2_grad_norm_st)
with torch.cuda.stream(rs_stream):
works[chunk_id] = torch.distributed.reduce_scatter(self._fp16_g_chunks[block_id][chunk_id],self._flat_grads_shards[block_id][chunk_id],group=self._rs_pg[glob_chunk_id%self._num_rs_pg],async_op=True,no_copy=True,op=torch.distributed.make_nccl_premul_sum((scale,)))
# Reduction across nodes for each rank
if self._num_groups > 1:
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
with torch.cuda.stream(ar_stream):
works[chunk_id].wait()
works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)
self._reductions_works[block_id] = works
def _reduce_scatter_and_all_reduce(self, block_id):
# Reduction within each node
# Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
......@@ -620,12 +655,19 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
tmp = torch.cat(((self._one), (coeff)))
index = (coeff+1>coeff).int()
scale = tmp.index_select(0, index).half()/self._world_size
self._flat_grads.mul_(scale)
if not self._fuse_scale:
self._flat_grads.mul_(scale)
if self._full_ar:
self._full_all_reduce(block_id)
if self._fuse_scale:
self._full_all_reduce_scale(block_id, scale)
else:
self._full_all_reduce(block_id)
else:
self._reduce_scatter_and_all_reduce(block_id)
if self._fuse_scale:
self._reduce_scatter_and_all_reduce_scale(block_id, scale)
else:
self._reduce_scatter_and_all_reduce(block_id)
if block_id == 0:
for block_id in range(self._num_blocks):
......
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册