未验证 提交 d9a46fde 编辑于 作者: Nan Zheng's avatar Nan Zheng 提交者: GitHub
浏览文件

Fix dist lamb (#1185)

1. remove the weight broadcast in the constructor
2. disable unnecessary allreduces for clip-after-ar
上级 4e9fae9b
......@@ -270,7 +270,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
eps = group['eps']
weight_decay = group['weight_decay']
for p in group['params']:
torch.distributed.broadcast(p, 0)
if not p.requires_grad:
continue
self._model_params.append(p)
......@@ -729,12 +728,14 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
# check global_grad_norm and fill overflow_buf
is_finite = (global_grad_norm + 1 > global_grad_norm).int()
self._overflow_buf = self._one * (is_finite ^ self._one) # toggle between 0 and 1
torch.distributed.all_reduce(is_finite,
op=torch.distributed.ReduceOp.MIN,
group=self._current_process_group)
torch.distributed.all_reduce(self._overflow_buf,
op=torch.distributed.ReduceOp.MAX,
group=self._current_process_group)
if not self._clip_after_ar:
torch.distributed.all_reduce(is_finite,
op=torch.distributed.ReduceOp.MIN,
group=self._current_process_group)
torch.distributed.all_reduce(self._overflow_buf,
op=torch.distributed.ReduceOp.MAX,
group=self._current_process_group)
# increment step counter if no overflow
self._step += is_finite
......
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册