未验证 提交 3303b3e7 编辑于 作者: Masaki Kozuki's avatar Masaki Kozuki 提交者: GitHub
浏览文件

Use out-of-place to avoid D2D copy in tensor parallel cross entropy (#1198)



* switch from clone to out-of-place subtract

* Update apex/mpu/cross_entropy.py

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarEddie Yan <eddiey@nvidia.com>
上级 0c7d8e3f
......@@ -30,7 +30,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
logits_max, op=torch.distributed.ReduceOp.MAX, group=get_tensor_model_parallel_group()
)
# Subtract the maximum value.
vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
vocab_parallel_logits = vocab_parallel_logits - logits_max.unsqueeze(dim=-1)
# Get the partition's vocab indecies
get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
......@@ -100,4 +100,4 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
def vocab_parallel_cross_entropy(vocab_parallel_logits, target):
"""Helper function for the cross entropy."""
return _VocabParallelCrossEntropy.apply(torch.clone(vocab_parallel_logits), target)
return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target)
......@@ -51,8 +51,11 @@ def tensor_sharded_cross_entropy(batch_size, seq_length, vocab_size, logits_scal
logits_parallel = tensor_parallel.scatter_to_tensor_model_parallel_region(logits)
target = torch.cuda.LongTensor(
size=(batch_size, seq_length)).random_(0, vocab_size)
logits_parallel_ = logits_parallel.clone().detach()
loss = vocab_parallel_cross_entropy(logits_parallel, target).mean()
loss.backward()
# check for mutation
assert torch.equal(logits_parallel_, logits_parallel)
return loss, identity.weight.grad
......
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册