提交 333da806 编辑于 作者: Thorsten Kurth's avatar Thorsten Kurth
浏览文件

Wrote a small wrapper function for flat view creation in _lazy_init_stage2 to...

Wrote a small wrapper function for flat view creation in _lazy_init_stage2 to support channels last data formats
上级 d6b5ae5d
......@@ -332,6 +332,17 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._param_order.order.reverse()
def _get_flat_view(param):
if param.is_contiguous(memory_format=torch.channels_last):
K, C, H, W = param.shape
pv = param.as_strided(size=(K,H,W,C), stride=(H*W*C, W*C, C, 1))
elif param.is_contiguous(memory_format=torch.channels_last_3d):
K, C, D, H, W = param.shape
pv = param.as_strided(size=(K,D,H,W,C), stride=(D*H*W*C, H*W*C, W*C, C, 1))
else:
pv = param
return pv.view(-1)
# re-order model_params, grad_accs, group_properties lists
self._model_params = [self._model_params[i] for i in self._param_order.order]
self._grad_accs = [self._grad_accs[i] for i in self._param_order.order]
......@@ -392,7 +403,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
grad_offset = clipped_start - flat_grad_start
grad_length = clipped_end - clipped_start
shard_offset = clipped_start - flat_shard_start
model_param_fragment = p.view(-1)[grad_offset:grad_offset+grad_length]
pf = _get_flat_view(p)
model_param_fragment = pf[grad_offset:grad_offset+grad_length]
new_param_packed_fragment = self._new_params_mega_chunks[shard_id][block_id][chunk_id][shard_offset:shard_offset+grad_length]
if model_param_fragment.dtype == torch.float16:
self._packed_flat_to_model_params_fp16.append( (new_param_packed_fragment, model_param_fragment) )
......
支持 Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册