未验证 提交 17eec271 编辑于 作者: Burc Eryilmaz's avatar Burc Eryilmaz 提交者: GitHub
浏览文件

option to set param views to flat buffer (#1152)



* option to set param views to flat buffer

* remove redundant variables in init_stage1
Co-authored-by: default avatarSukru Eryilmaz <seryilmaz@computelab-dgx1v-32.nvidia.com>
Co-authored-by: default avatarptrblck <ptrblck@users.noreply.github.com>
上级 2e98baa7
......@@ -88,7 +88,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,
dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0, fused_norm=False,
e5m2_allgather=False, verbose=False, clip_after_ar=True,
full_ar=False, fuse_scale=False):
full_ar=False, set_param_views_to_flat_buffer=False, skip_allgather=False,
fuse_scale=False):
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay,
grad_averaging=grad_averaging,
......@@ -123,6 +124,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._full_ar = full_ar
self._fuse_scale = fuse_scale
self._L2_grad_norm = None
self._set_flat_param_view = set_param_views_to_flat_buffer
self._skip_ag = skip_allgather
self._fused_norm = fused_norm
self._current_process_group = c10d._get_default_group()
self._available_ranks = list(c10d._pg_group_ranks[self._current_process_group].keys())
......@@ -238,6 +241,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
if self._verbose:
print(f"creating AG group : {ranks}")
self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]
for ag_pg in self._ag_pg:
torch.distributed.barrier(group=ag_pg)
self._l2_grad_norm_st = torch.cuda.Stream()
self._completion_st = torch.cuda.Stream()
......@@ -252,9 +257,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._lazy_init_stage1_done, self._lazy_init_stage2_done = False, False
self._param_order = self.AtomicCounter()
def _lazy_init_stage1(self):
if self._lazy_init_stage1_done: return
p_offset = 0
p_i = 0
self._model_params = []
......@@ -281,19 +283,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
eps
))
p_grads_size = p.numel()
def wrapper(param, param_i):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
def allreduce_hook(*unused):
if self._first_step:
# first time
self._param_order.add(param_i)
else:
idx = self._param_order.order.index(param_i)
self._do_overlapped_reduction(idx, param)
grad_acc.register_hook(allreduce_hook)
self._grad_accs.append(grad_acc)
wrapper(p, p_i)
if self._set_flat_param_view:
self._param_order.add(p_i)
p_offset += p_grads_size
# Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
# RNN is one example of consecutive parameters:
......@@ -311,13 +302,46 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._total_param_size = p_offset
dwu_min_page_size = 256 * self._num_blocks * self._num_chunks * self._group_size
self._total_param_size = ((self._total_param_size + dwu_min_page_size - 1) // dwu_min_page_size) * dwu_min_page_size
self._new_params = torch.zeros([self._total_param_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
def _lazy_init_stage1(self):
if self._lazy_init_stage1_done: return
p_i = 0
#self._model_params = []
#self._grad_accs = []
#self._group_properties = []
for group in self.param_groups:
for p in group['params']:
torch.distributed.broadcast(p, 0)
if not p.requires_grad:
continue
def wrapper(param, param_i):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
def allreduce_hook(*unused):
if not self._set_flat_param_view:
if self._first_step:
# first time
self._param_order.add(param_i)
else:
idx = self._param_order.order.index(param_i)
self._do_overlapped_reduction(idx, param)
else:
if not self._first_step:
self._do_overlapped_reduction(param_i, param)
grad_acc.register_hook(allreduce_hook)
self._grad_accs.append(grad_acc)
wrapper(p, p_i)
p_i += 1
self._block_size = self._total_param_size // self._num_blocks
self._chunk_size = self._block_size // self._num_chunks
self._shard_size = self._chunk_size // self._group_size
#print("self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._chunk_size=%d, self._shard_size=%d" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._chunk_size,self._shard_size))
self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda')
self._new_params = torch.zeros([self._total_param_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
self._mega_shard_size = self._num_blocks * self._num_chunks * self._shard_size
# initialize master weights, moments buffers if not loaded from checkpoint
if self._fp32_p is None:
......@@ -392,8 +416,13 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
def _lazy_init_stage2(self):
if self._lazy_init_stage2_done: return
if not self._set_flat_param_view:
self._param_order.order.reverse()
self._param_order.order.reverse()
# re-order model_params, grad_accs, group_properties lists
self._model_params = [self._model_params[i] for i in self._param_order.order]
self._grad_accs = [self._grad_accs[i] for i in self._param_order.order]
self._group_properties = [self._group_properties[i] for i in self._param_order.order]
def _get_flat_view(param):
if param.is_contiguous(memory_format=torch.channels_last):
......@@ -406,11 +435,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
pv = param
return pv.view(-1)
# re-order model_params, grad_accs, group_properties lists
self._model_params = [self._model_params[i] for i in self._param_order.order]
self._grad_accs = [self._grad_accs[i] for i in self._param_order.order]
self._group_properties = [self._group_properties[i] for i in self._param_order.order]
# re-collect grads info (size, offset) after ordering
prev = None
p_offset = 0
......@@ -749,7 +773,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._contrib_weight_decay,
global_grad_norm,
self._use_nvlamb)
torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)
if not self._skip_ag:
torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)
def _flatten_grad_mt(self, scale):
if len(self._grads_fp16) > 0:
......@@ -851,21 +876,21 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
optimizer_state["found_inf_per_device"][current_device] = found_inf
self._completion_st.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._completion_st):
# Copy self._new_params to model params
with torch.no_grad():
if self._packed_flat_to_model_params_fp16 is not None:
multi_tensor_applier(
fused_adam_cuda.maybe_cast_mt,
self._overflow_buf,
self._packed_flat_to_model_params_fp16)
if self._packed_flat_to_model_params_fp32 is not None:
multi_tensor_applier(
fused_adam_cuda.maybe_cast_mt,
self._overflow_buf,
self._packed_flat_to_model_params_fp32)
if not self._set_flat_param_view:
with torch.cuda.stream(self._completion_st):
# Copy self._new_params to model params
with torch.no_grad():
if self._packed_flat_to_model_params_fp16 is not None:
multi_tensor_applier(
fused_adam_cuda.maybe_cast_mt,
self._overflow_buf,
self._packed_flat_to_model_params_fp16)
if self._packed_flat_to_model_params_fp32 is not None:
multi_tensor_applier(
fused_adam_cuda.maybe_cast_mt,
self._overflow_buf,
self._packed_flat_to_model_params_fp32)
torch.cuda.current_stream().wait_stream(self._completion_st)
self._reductions_works = [None]*self._num_blocks
......
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册