未验证 提交 2205cff2 编辑于 作者: eqy's avatar eqy 提交者: GitHub
浏览文件

check in (#1210)

上级 fa8bd7e6
......@@ -49,10 +49,6 @@ def _forward_backward_pipelining_with_interleaving(
"""
if not isinstance(model, list):
raise RuntimeError("`model` must be a list of `nn.Module`'s'")
# TODO (mkozuki): Sanity check the following condition.
if len(batch) != len(model):
msg = f"`batch` and `model` must have the same number of elements. Actual {len(batch)} and {len(model)}"
raise RuntimeError(msg)
num_model_chunks = len(model)
input_tensors = [[] for _ in range(num_model_chunks)]
......@@ -122,7 +118,7 @@ def _forward_backward_pipelining_with_interleaving(
input_tensor = input_tensors[model_chunk_id][-1]
output_tensor = forward_step(
forward_step_func,
get_kth_microbatch(batch[model_chunk_id], curr_iters[model_chunk_id]),
get_kth_microbatch(batch, curr_iters[model_chunk_id]),
model[model_chunk_id],
input_tensor,
losses_reduced,
......
......@@ -125,10 +125,7 @@ def forward_backward_func_template(
torch.optim.Adam(_param_groups)
tensor_shape = [batch_size // parallel_state.get_data_parallel_world_size(), hidden_size]
if virtual_pipeline_model_parallel_size is None:
batch = (torch.randn(tensor_shape).cuda(),)
else:
batch = [(torch.randn(tensor_shape).cuda(),) for _ in range(virtual_pipeline_model_parallel_size)]
batch = (torch.randn(tensor_shape).cuda(),)
tensor_shape[0] = micro_batch_size
update_num_microbatches(0)
......
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册