未验证 提交 b3da6036 编辑于 作者: Masaki Kozuki's avatar Masaki Kozuki 提交者: GitHub
浏览文件

ColumnParallelLinearWithAsyncAllreduce autocast support (#1183)

* [ColumnParallelLinear] Test behavior in autocast

* fix test

* casts manually to autocast dtype
上级 365fdc18
......@@ -21,6 +21,7 @@ import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.parameter import Parameter
from ..._autocast_utils import _cast_if_autocast_enabled
from ..parallel_state import get_tensor_model_parallel_group
from ..parallel_state import get_tensor_model_parallel_rank
from ..parallel_state import get_tensor_model_parallel_world_size
......@@ -221,7 +222,7 @@ class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function):
input, weight = ctx.saved_tensors
use_bias = ctx.use_bias
grad_input = grad_output.matmul(weight)
# Asyncronous all-reduce
# Asynchronous all-reduce
handle = torch.distributed.all_reduce(
grad_input, group=get_tensor_model_parallel_group(), async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
......@@ -233,6 +234,12 @@ class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function):
return grad_input, grad_weight, grad_bias
def column_parallel_linear(input, weight, bias):
args = _cast_if_autocast_enabled(input, weight, bias)
with torch.cuda.amp.autocast(enabled=False):
return ColumnParallelLinearWithAsyncAllreduce.apply(*args)
class ColumnParallelLinear(torch.nn.Module):
"""Linear layer with column parallelism.
......@@ -336,8 +343,7 @@ class ColumnParallelLinear(torch.nn.Module):
input_shape = input_.shape
input_ = input_.view(input_shape[0] * input_shape[1],input_shape[2])
# Matrix multiply with asynchronous all-reduce execution
output_parallel = ColumnParallelLinearWithAsyncAllreduce.apply(
input_, self.weight, bias)
output_parallel = column_parallel_linear(input_, self.weight, bias)
output_parallel = output_parallel.view(
input_shape[0], input_shape[1], output_parallel.shape[1])
else:
......
......@@ -28,6 +28,16 @@ from apex.transformer.tensor_parallel.tests.commons import TEST_SUCCESS_MESSAGE
global_vars.set_global_variables()
class IdentityLayer3D(torch.nn.Module):
def __init__(self, m, n, k):
super(IdentityLayer3D, self).__init__()
self.weight = Parameter(torch.Tensor(m, n, k))
torch.nn.init.xavier_normal_(self.weight)
def forward(self):
return self.weight
def test_parallel_embedding(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
......@@ -265,6 +275,86 @@ def test_column_parallel_linear(tensor_model_parallel_size):
print(' >> passed the test :-)')
def test_column_parallel_linear_with_async_allreduce_autocast(tensor_model_parallel_size):
autocast_dtypes = (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,)
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
input_size_coeff = 13
input_size = input_size_coeff * tensor_model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * tensor_model_parallel_size
batch_size = 7
# Network
identity_layer = IdentityLayer3D(batch_size, batch_size, input_size).cuda()
linear_layer = layers.ColumnParallelLinear(
input_size, output_size, keep_master_weight_for_test=True,
params_dtype=global_vars.get_args().params_dtype,
use_cpu_initialization=global_vars.get_args().use_cpu_initialization,
).cuda()
assert linear_layer.async_tensor_model_parallel_allreduce or tensor_model_parallel_size == 1
# Forward
for dtype in autocast_dtypes:
loss_weight = torch.randn([batch_size, output_size]).cuda()
with torch.cuda.amp.autocast(dtype=dtype):
output, _ = linear_layer(identity_layer())
loss = torch.mul(output, loss_weight).sum()
assert output.dtype == dtype
# Backward
loss.backward()
torch.distributed.barrier()
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)')
def test_column_parallel_linear_with_async_allreduce_custom_amp(tensor_model_parallel_size):
dtypes = (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,)
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
input_size_coeff = 13
input_size = input_size_coeff * tensor_model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * tensor_model_parallel_size
batch_size = 7
for dtype in dtypes:
# Network
identity_layer = IdentityLayer3D(batch_size, batch_size, input_size).to(device="cuda", dtype=dtype)
linear_layer = layers.ColumnParallelLinear(
input_size, output_size, keep_master_weight_for_test=True,
params_dtype=global_vars.get_args().params_dtype,
use_cpu_initialization=global_vars.get_args().use_cpu_initialization,
).to(device="cuda", dtype=dtype)
# Forward
loss_weight = torch.randn([batch_size, output_size]).cuda()
output, _ = linear_layer(identity_layer())
loss = torch.mul(output, loss_weight).sum()
loss.backward()
torch.distributed.barrier()
assert output.dtype == dtype
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)')
def test_row_parallel_linear(tensor_model_parallel_size):
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
......@@ -333,16 +423,6 @@ def test_row_parallel_linear(tensor_model_parallel_size):
print(' >> passed the test :-)')
class IdentityLayer3D(torch.nn.Module):
def __init__(self, m, n, k):
super(IdentityLayer3D, self).__init__()
self.weight = Parameter(torch.Tensor(m, n, k))
torch.nn.init.xavier_normal_(self.weight)
def forward(self):
return self.weight
def parallel_self_attention(tensor_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size,
sequence_length):
......@@ -511,19 +591,35 @@ if __name__ == '__main__':
initialize_distributed()
world_size = torch.distributed.get_world_size()
exceptions = []
print_separator('test initialize affine weight cpu')
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
test_initialize_affine_weight(tensor_model_parallel_size, 'cpu')
tensor_model_parallel_size *= 2
try:
test_initialize_affine_weight(tensor_model_parallel_size, 'cpu')
except Exception as e:
exceptions.append(f"test_initialize_affine_weight-cpu with tensor model parallel size of {tensor_model_parallel_size} failed: {str(e)}")
# Reset groups
parallel_state.destroy_model_parallel()
break
else:
tensor_model_parallel_size *= 2
# Reset groups
parallel_state.destroy_model_parallel()
print_separator('test initialize affine weight gpu')
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
test_initialize_affine_weight(tensor_model_parallel_size, 'gpu')
tensor_model_parallel_size *= 2
try:
test_initialize_affine_weight(tensor_model_parallel_size, 'gpu')
except Exception as e:
exceptions.append(f"test_initialize_affine_weight-gpu with tensor model parallel size of {tensor_model_parallel_size} failed: {str(e)}")
# Reset groups
parallel_state.destroy_model_parallel()
break
else:
tensor_model_parallel_size *= 2
# Deleted, replaced with vocab parallel embedding?
#tensor_model_parallel_size = 1
......@@ -535,15 +631,57 @@ if __name__ == '__main__':
print_separator('test column-parallel linear')
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
test_column_parallel_linear(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
try:
test_column_parallel_linear(tensor_model_parallel_size)
except Exception as e:
exceptions.append(f"test_column_parallel_linear with tensor model parallel size of {tensor_model_parallel_size} failed: {str(e)}")
# Reset groups
parallel_state.destroy_model_parallel()
break
else:
tensor_model_parallel_size *= 2
print_separator('test row-parallel linear')
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
test_row_parallel_linear(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
try:
test_row_parallel_linear(tensor_model_parallel_size)
except Exception as e:
exceptions.append(f"test_row_parallel_linear with tensor model parallel size of {tensor_model_parallel_size} failed: {str(e)}")
# Reset groups
parallel_state.destroy_model_parallel()
break
else:
tensor_model_parallel_size *= 2
print_separator("test ColumnParallelLinearWithAsyncAllreduce - autocast")
tensor_model_parallel_size = 2
while tensor_model_parallel_size <= world_size:
try:
test_column_parallel_linear_with_async_allreduce_autocast(tensor_model_parallel_size)
except Exception as e:
exceptions.append(f"test_column_parallel_linear_with_async_allreduce_autocast with tensor model parallel size of {tensor_model_parallel_size} failed: {str(e)}")
# Reset groups
parallel_state.destroy_model_parallel()
break
else:
tensor_model_parallel_size *= 2
print_separator("test ColumnParallelLinearWithAsyncAllreduce - custom AMP")
tensor_model_parallel_size = 2
while tensor_model_parallel_size <= world_size:
try:
test_column_parallel_linear_with_async_allreduce_custom_amp(tensor_model_parallel_size)
except Exception as e:
exceptions.append(f"test_column_parallel_linear_with_async_allreduce_custom_amp with tensor model parallel size of {tensor_model_parallel_size} failed: {str(e)}")
# Reset groups
parallel_state.destroy_model_parallel()
break
else:
tensor_model_parallel_size *= 2
if exceptions:
raise RuntimeError("\n".join(exceptions))
# Deleted
#print_separator('test parallel self-attention')
#tensor_model_parallel_size = 1
......
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册