未验证 提交 63d5dd63 编辑于 作者: Masaki Kozuki's avatar Masaki Kozuki 提交者: GitHub
浏览文件

Pipeline Model Parallel (#1202)



* Init apex.ppu (pipeline model parallel utility)

Reference commit:

```
commit 5ab646376d67831601d5552c193241d017f1b35c (HEAD -> main, internal/main)
Merge: 14f2c684 7b293d9b
Author: Mohammad Shoeybi <mshoeybi@nvidia.com>
Date:   Wed Sep 22 22:57:54 2021 -0700

    Merge branch 'add_BOS' into 'main'

    Add Beginning of Sentence token option and adding semaphore while multi-threading to prevent crashes and hangs due to connection keep-alives

    See merge request ADLR/megatron-lm!328
```

* removing get_args and replace import - phase 1

* removing get_args and replace import - phase 2

* move ppu to apex.transformer.pipeline_parallel

* update two __init__.py

* update READMEs

* mpu -> parallel_state & tensor_parallel

* fix

* remove not pipeline files

* separate schedules.py - phase 1

* dissect schedules.py

* data_iterators -> batch

* remove optimizer from forward_backward_step funcs

* init test

* Apply 2 suggestion(s) to 2 file(s)

* fix cyclic import

* fix syntax of Callable

* fix - 1

* move directory as testing used for pp test as well

* add some functions for num microbatches calculator

* model is a list in pipeline parallel

* skip build num microbatch calculator

* fix test

* assert -> raise

* skip args printing

* specify tensor shape everywhere even if None - phase 1

* private timers

* passing tensor shape & dtype around

* update dtype handling by introducing helper func

* write helper func to reduce cyclomatic complexity

* remove duplicate

* update

* move split_tensor_into_1d_equal_chunks to avoid cyclic import

* tmp

* cosmetic

* move gather_split_1d_tensor to avoid cyclic imports

* remove debug print

* add outer loop

* early return if possible

* cosmetic

* passing around tensor shape

* refactor test

* add script to learn batch sampler behavior

* update

* minibatch splitter

* add minibatch splitter

* split minibatch into microbatches

* minor changes

* uncomment split batch for test sake

* set as attribute

* study the behavior of no pipelining

* debug 1

* reflect test util namespace change

* update readme

* cosmetic in test

* add model build helper func for interleaving shced

* adding model builder from megatron

* canbe cyclic import

* fix

* enable interleaving test, but failing even if forward only

* fix batch preparation

* add explanation

* print data parallel size

* fix typo

* Add Megatron style GPT model by Rishi
Co-authored-by: default avatarRishi Puri <riship@nvidia.com>

* update

* type hint for jit

* fix forward_backward_no_pipelining test

* pipeline forward backward seem to hang if not forward only

* fix typo

* debug

* add p2p test

* simplify

* fix

* tentative

* set both tmp and pmp to 1

* init

* fix typo

* fix

* fix path of divide

* set seed for tmp

* update upon Eddie comment

* fix typo

* adding failing data loader test

* fix

* megatron still failing

* check in

* with the nested loop of new order, interleaving seems fine

* cosmetic change

* make `forward_backward_pipelining_with_interleaving private

* warn users that interleaving sched is unstable

* move noop handler to no pipelining

* comment out rank_print

* make `build_model` more flexible

* skip megatron test tentatively

* correctly comment out rank_print

* correctly comment out rank_print

* correctly comment out rank_print

* skip appropriately

* remove wip p2p comm test

* update type hint of model_provider_func

* disable tf32 in each test script

* skip interleaving w/ backward

* rename as mpu is the old name

* remove broken case

* expose build_model func

* delete `dist.ring_exchange` func call and `use_ring_exchange` argument

* nit fixes

* check in

* remove unused file

* update the list

* update tensor shape

* remove mixed dtype case

* use torch.distributed.run

* 2020 -> 2021

* another 2020 -> 2021

* docstring & type hint

* fix teardown

* update

* change to experimental

* check if warned
Co-authored-by: default avatarRishi Puri <riship@nvidia.com>
Co-authored-by: default avatarEddie Yan <eddiey@nvidia.com>
上级 3303b3e7
# May help avoid undefined symbol errors https://pytorch.org/cppdocs/notes/faq.html#undefined-symbol-errors-from-pytorch-aten
import torch
import warnings
if torch.distributed.is_available():
from . import parallel
......
from typing import Optional
import torch
def _get_current_dtype(dtype: Optional[torch.dtype] = None) -> torch.dtype:
if not torch.is_autocast_enabled():
return torch.float or dtype
else:
return torch.get_autocast_gpu_dtype()
def _cast_if_autocast_enabled(*args):
if not torch.is_autocast_enabled():
return args
......
......@@ -2,4 +2,80 @@
`apex.transformer` is a module which enables efficient large Transformer models at scale.
`apex.transformer.tensor_parallel` is based on [NVIDIA/Megatron-LM](https://github.com/NVIDIA/Megatron-LM)'s `megatron.mpu` module.
`apex.transformer.tensor_parallel` and `apex.transformer.pipeline_parallel` are both based on [NVIDIA/Megatron-LM](https://github.com/NVIDIA/Megatron-LM)'s module.
The former is based on `megatron.mpu` and the latter is on `megatron.schedules` and `megatron.p2p_communication`.
## Tensor Model Parallel (TP)
APEX's tensor model parallel utilities provides some `torch.nn.Module`'s, custom fused kernels, and PRNG state handling.
See Appendix B.2 of [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) for the details of
PRNG state handling.
## Pipeline Model Parallel (PP)
APEX's pipeline model parallel functions require models to have `.set_input_tensor` because
the input tensor for `.forward` method can be `None`.
The following is a really casual sketch of training script with apex pp.
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel import get_forward_backward_func
class Model(nn.Module):
...
def __init__(self, *args, **kwargs):
super().__init__()
pre_process = kwargs.pop("pre_process")
post_process = kwargs.pop("post_process")
def set_input_tensor(self, tensor):
self.input_tensor = tensor
def forward(self, x, ...):
if parallel_state.is_pipeline_first_stage():
input = x
else:
input = self.input_tensor
...
def model_provider_func(*args, **kwargs):
return Model(*args, **kwargs)
def loss_func(pred, label):
loss = ...
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'nice_loss': averaged_loss}
def forward_step_func(batch, model):
input, label = process_batch(batch)
out = model(input)
return out, partial(loss_func, label)
forward_backward_func = get_forward_backward_func(virtual_pipeline_model_parallel_size, pipeline_model_parallel_size)
parallel_state.initialize_model_parallel(
tensor_model_parallel_size,
pipeline_model_parallel_size,
virtual_pipeline_model_parallel_size,
)
# The following line basically is equivalent to `build_model(Model, wrap_with_ddp, virtual_pipeline_model_parallel_size, *model_args, **model_kwargs)`
model = build_model(model_provider_func, wrap_with_ddp, virtual_pipeline_model_parallel_size, *model_args, **model_kwargs)
optimizer = ...
data_loader = ...
for epoch in range(num_epochs):
for batch in data_loader:
forward_backward_func(forward_step_func, batch, model, forward_only=False, tensor_shape)
optimizer.step()
```
from . import tensor_parallel
from . import functional
from .enums import LayerType
from .enums import AttnType
from .enums import AttnMaskType
from .parallel_state import (
is_unitialized,
destroy_model_parallel,
get_data_parallel_group,
get_data_parallel_rank,
get_data_parallel_world_size,
get_embedding_group,
get_model_parallel_group,
get_tensor_model_parallel_group,
get_pipeline_model_parallel_group,
get_tensor_model_parallel_rank,
set_tensor_model_parallel_rank,
get_pipeline_model_parallel_rank,
set_pipeline_model_parallel_rank,
is_pipeline_first_stage,
is_pipeline_last_stage,
get_tensor_model_parallel_src_rank,
get_pipeline_model_parallel_first_rank,
get_pipeline_model_parallel_last_rank,
get_pipeline_model_parallel_next_rank,
get_pipeline_model_parallel_prev_rank,
get_tensor_model_parallel_world_size,
set_tensor_model_parallel_world_size,
get_pipeline_model_parallel_world_size,
set_pipeline_model_parallel_world_size,
get_virtual_pipeline_model_parallel_rank,
set_virtual_pipeline_model_parallel_rank,
initialize_model_parallel,
model_parallel_is_initialized,
)
from apex.transformer import functional
from apex.transformer import parallel_state
from apex.transformer import pipeline_parallel
from apex.transformer import tensor_parallel
from apex.transformer import utils
from apex.transformer.enums import LayerType
from apex.transformer.enums import AttnType
from apex.transformer.enums import AttnMaskType
__all__ = [
"functional",
"parallel_state",
"pipeline_parallel",
"tensor_parallel",
"utils",
# enums.py
"LayerType",
"AttnType",
"AttnMaskType",
]
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
from .fused_softmax import FusedScaleMaskSoftmax
from apex.transformer.functional.fused_softmax import FusedScaleMaskSoftmax
__all__ = [
"FusedScaleMaskSoftmax",
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -15,7 +15,7 @@
import torch
from apex._autocast_utils import _cast_if_autocast_enabled
from ..enums import AttnMaskType
from apex.transformer.enums import AttnMaskType
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -15,35 +15,41 @@
"""Megatron number of micro-batches calculators."""
from abc import ABC
from abc import abstractmethod
from typing import Optional, List
def build_num_microbatches_calculator(args):
def build_num_microbatches_calculator(
rank: int,
rampup_batch_size: Optional[List[int]],
global_batch_size: int,
micro_batch_size: int,
data_parallel_size: int,
):
# Constant num micro-batches.
if args.rampup_batch_size is None:
if rampup_batch_size is None:
num_microbatches_calculator = ConstantNumMicroBatches(
args.global_batch_size, args.micro_batch_size, args.data_parallel_size
global_batch_size, micro_batch_size, data_parallel_size
)
if args.rank == 0:
if rank == 0:
print(
"setting number of micro-batches to constant {}".format(num_microbatches_calculator.get()), flush=True
)
else:
assert len(args.rampup_batch_size) == 3, (
assert len(rampup_batch_size) == 3, (
"expected the following "
"format: --rampup-batch-size <start batch size> "
"<batch size incerement> <ramp-up samples>"
)
start_batch_size = int(args.rampup_batch_size[0])
batch_size_increment = int(args.rampup_batch_size[1])
ramup_samples = int(args.rampup_batch_size[2])
if args.rank == 0:
start_batch_size = int(rampup_batch_size[0])
batch_size_increment = int(rampup_batch_size[1])
ramup_samples = int(rampup_batch_size[2])
if rank == 0:
print(
"will use batch size rampup starting from global batch "
"size {} to global batch size {} with batch size increments "
"{} over {} samples.".format(
start_batch_size, args.global_batch_size, batch_size_increment, ramup_samples
start_batch_size, global_batch_size, batch_size_increment, ramup_samples
),
flush=True,
)
......@@ -51,9 +57,9 @@ def build_num_microbatches_calculator(args):
start_batch_size,
batch_size_increment,
ramup_samples,
args.global_batch_size,
args.micro_batch_size,
args.data_parallel_size,
global_batch_size,
micro_batch_size,
data_parallel_size,
)
return num_microbatches_calculator
......@@ -86,6 +92,8 @@ class ConstantNumMicroBatches(NumMicroBatchesCalculator):
assert self.num_micro_batches >= 1
self.current_global_batch_size = global_batch_size
self.micro_batch_size = micro_batch_size
def update(self, consumed_samples, consistency_check):
pass
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -17,7 +17,7 @@ import torch
# TODO (mkozuki): Consider dissecting utils as this utils import is here
# only for ensure_divisibility
from .tensor_parallel import utils
from apex.transformer.utils import ensure_divisibility
# Intra-layer model parallel group that the current rank belongs to.
......@@ -76,17 +76,17 @@ def initialize_model_parallel(
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
"""
if torch.distributed.get_rank() == 0:
print("> initializing tensor model parallel with size {}".format(tensor_model_parallel_size_))
print("> initializing pipeline model parallel with size {}".format(pipeline_model_parallel_size_))
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size()
tensor_model_parallel_size = min(tensor_model_parallel_size_, world_size)
pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size)
# TODO (mkozuki): Consider moving `ensure_divisibility` to this file.
utils.ensure_divisibility(world_size, tensor_model_parallel_size * pipeline_model_parallel_size)
ensure_divisibility(world_size, tensor_model_parallel_size * pipeline_model_parallel_size)
data_parallel_size = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size)
if torch.distributed.get_rank() == 0:
print("> initializing tensor model parallel with size {}".format(tensor_model_parallel_size))
print("> initializing pipeline model parallel with size {}".format(pipeline_model_parallel_size))
print("> initializing data parallel with size {}".format(data_parallel_size))
num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size
......@@ -344,3 +344,15 @@ def destroy_model_parallel():
_DATA_PARALLEL_GROUP = None
global _EMBEDDING_GROUP
_EMBEDDING_GROUP = None
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
global _MPU_TENSOR_MODEL_PARALLEL_RANK
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
from apex.transformer.pipeline_parallel.schedules import get_forward_backward_func
from apex.transformer.pipeline_parallel.schedules.common import build_model
__all__ = [
"get_forward_backward_func",
"build_model",
]
import time
import torch
class _Timer:
"""Timer."""
def __init__(self, name):
self.name_ = name
self.elapsed_ = 0.0
self.started_ = False
self.start_time = time.time()
def start(self):
"""Start the timer."""
assert not self.started_, "timer has already been started"
torch.cuda.synchronize()
self.start_time = time.time()
self.started_ = True
def stop(self):
"""Stop the timer."""
assert self.started_, "timer is not started"
torch.cuda.synchronize()
self.elapsed_ += time.time() - self.start_time
self.started_ = False
def reset(self):
"""Reset timer."""
self.elapsed_ = 0.0
self.started_ = False
def elapsed(self, reset=True):
"""Calculate the elapsed time."""
started_ = self.started_
# If the timing in progress, end it first.
if self.started_:
self.stop()
# Get the elapsed time.
elapsed_ = self.elapsed_
# Reset the elapsed time
if reset:
self.reset()
# If timing was in progress, set it back.
if started_:
self.start()
return elapsed_
class _Timers:
"""Group of timers."""
def __init__(self):
self.timers = {}
def __call__(self, name):
if name not in self.timers:
self.timers[name] = _Timer(name)
return self.timers[name]
def write(self, names, writer, iteration, normalizer=1.0, reset=False):
"""Write timers to a tensorboard writer"""
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
assert normalizer > 0.0
for name in names:
value = self.timers[name].elapsed(reset=reset) / normalizer
writer.add_scalar(name + "-time", value, iteration)
def log(self, names, normalizer=1.0, reset=True):
"""Log a group of timers."""
assert normalizer > 0.0
string = "time (ms)"
for name in names:
elapsed_time = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer
string += " | {}: {:.2f}".format(name, elapsed_time)
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1):
print(string, flush=True)
else:
print(string, flush=True)
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import reduce
import operator
from typing import Union, Optional, Tuple
import torch
from apex._autocast_utils import _get_current_dtype
from apex.transformer import parallel_state
from apex.transformer.utils import split_tensor_into_1d_equal_chunks
from apex.transformer.utils import gather_split_1d_tensor
from apex.transformer.pipeline_parallel.utils import Shape
from apex.transformer.pipeline_parallel._timers import _Timers
def _run_p2pops(
tensor_send_prev: Union[torch.Tensor, None],
tensor_send_next: Union[torch.Tensor, None],
tensor_recv_prev: Union[torch.Tensor, None],
tensor_recv_next: Union[torch.Tensor, None],
):
ops = []
if tensor_send_prev is not None:
send_prev_op = torch.distributed.P2POp(
torch.distributed.isend,
tensor_send_prev,
parallel_state.get_pipeline_model_parallel_prev_rank(),
)
ops.append(send_prev_op)
if tensor_recv_prev is not None:
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv,
tensor_recv_prev,
parallel_state.get_pipeline_model_parallel_prev_rank(),
)
ops.append(recv_prev_op)
if tensor_send_next is not None:
send_next_op = torch.distributed.P2POp(
torch.distributed.isend,
tensor_send_next,
parallel_state.get_pipeline_model_parallel_next_rank(),
)
ops.append(send_next_op)
if tensor_recv_next is not None:
recv_next_op = torch.distributed.P2POp(
torch.distributed.irecv,
tensor_recv_next,
parallel_state.get_pipeline_model_parallel_next_rank(),
)
ops.append(recv_next_op)
if len(ops) > 0:
reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# NOTE (mkozuki): Leaving `params_dytpe` as it is for future development in PyTorch, especially APEX O2 style AMP.
# But as of v1.10, basically all tensors are torch.float32 except for output tensors of `autocast` compatible layers.
def _communicate(
tensor_send_next: Optional[torch.Tensor],
tensor_send_prev: Optional[torch.Tensor],
recv_prev: bool,
recv_next: bool,
tensor_shape: Optional[Shape] = None,
override_scatter_gather_tensors_in_pipeline: bool = False,
dtype_: torch.dtype = torch.float,
*,
scatter_gather_tensors_in_pipeline: bool = True,
params_dtype: Optional[torch.dtype] = None,
fp32_residual_connection: bool = False,
) -> Tuple[Union[torch.Tensor, None], Union[torch.Tensor, None]]:
"""Base function for communication of tensors between stages.
Args:
tensor_send_next: tensor to send to next rank (no tensor sent if set to None).
tensor_send_prev: tensor to send to prev rank (no tensor sent if set to None).
recv_prev: boolean for whether tensor should be received from previous rank.
recv_next: boolean for whether tensor should be received from next rank.
tensor_shape: optional, use when the input sequence contains less tokens than the default sequence length
override_scatter_gather_tensors_in_pipeline:
optional, this is used when tensor_shape is provided to override scatter gather tensors
dtype_: This is used when tensor_shape is provided and what is the type of tensor_shape
Keyword args:
scatter_gather_tensors_in_pipeline: Optional. If :obj:`True`, use scatter/gather to optimize communication of tensors.
params_dtype: Optional and legacy. Defaults to torch.float. If you manually call `.half()` or `.bfloat16()` on
your model deliberately, pass this argument.
fp32_residual_connection: Optional. If :obj:`True`, move residual connections to fp32.
Returns:
tuple containing
- tensor_recv_prev: `torch.Tensor` if `recv_prev` is :obj:`True`, `None` otherwise.
- tensor_recv_next: `torch.Tensor` if `recv_next` is :obj:`True`, `None` otherwise.
"""
# Create placeholder tensors for receive in forward and backward directions if needed.
tensor_recv_prev = None
tensor_recv_next = None
if tensor_shape is None:
# In megatron, `tensor_shape` is set to `(args.seq_length, args.micro_batch_size, args.hidden_size)`
raise RuntimeError(
"`tensor_shape` must be specified. Common `tensor_shape` is `(seq_length, micro_batch_size, hidden_size)`")
if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline:
tensor_chunk_shape = (reduce(operator.mul, tensor_shape, 1) // parallel_state.get_tensor_model_parallel_world_size(),)
else:
tensor_chunk_shape = tensor_shape
dtype = params_dtype or torch.float
if fp32_residual_connection:
dtype = torch.float
requires_grad = True
if dtype_ is not None:
dtype = dtype_
requires_grad = False
if recv_prev:
tensor_recv_prev = torch.empty(
tensor_chunk_shape,
requires_grad=requires_grad,
device=torch.cuda.current_device(),
dtype=dtype,
)
if recv_next:
tensor_recv_next = torch.empty(
tensor_chunk_shape,
requires_grad=requires_grad,
device=torch.cuda.current_device(),
dtype=dtype,
)
# Split tensor into smaller chunks if using scatter-gather optimization.
if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline:
if tensor_send_next is not None:
tensor_send_next = split_tensor_into_1d_equal_chunks(tensor_send_next)
if tensor_send_prev is not None:
tensor_send_prev = split_tensor_into_1d_equal_chunks(tensor_send_prev)
# Send tensors in both the forward and backward directions as appropriate.
_run_p2pops(tensor_send_prev, tensor_send_next, tensor_recv_prev, tensor_recv_next)
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
# If using scatter-gather optimization, gather smaller chunks.
if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline:
if recv_prev:
tensor_recv_prev = (
gather_split_1d_tensor(tensor_recv_prev)
.view(tensor_shape)
.requires_grad_()