未验证 提交 14ccf598 编辑于 作者: Masaki Kozuki's avatar Masaki Kozuki 提交者: GitHub
浏览文件

Remove `custom_fwd`/`custom_bwd` from fused softmax (#1188)

* run backward

* remove custom_fwd/custom_bwd
上级 3ad9db2a
......@@ -37,7 +37,6 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
return softmax_results
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, output_grads):
import scaled_upper_triang_masked_softmax_cuda
......@@ -68,7 +67,6 @@ def scaled_upper_triang_masked_softmax(inputs, _, scale):
class ScaledMaskedSoftmax(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.half)
def forward(ctx, inputs, mask, scale):
import scaled_masked_softmax_cuda
scale_t = torch.tensor([scale])
......@@ -78,7 +76,6 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
return softmax_results
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, output_grads):
import scaled_masked_softmax_cuda
......
......@@ -12,8 +12,7 @@ from apex.transformer.functional import FusedScaleMaskSoftmax
def attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
return attention_scores.masked_fill(attention_mask, -10000.0)
autocast_dtypes = (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,)
......@@ -61,11 +60,19 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase):
return
fused_fn, torch_fn = self._setup_fused_softmax(input_in_fp16, input_in_bf16, scale, softmax_in_fp32, AttnMaskType.padding)
attention_scores = torch.randn((4, 12, 24, 24)).to(device="cuda", dtype=dtype)
attention_scores_0 = torch.randn((4, 12, 24, 24)).to(device="cuda", dtype=dtype).requires_grad_(True)
with torch.no_grad():
attention_scores_1 = attention_scores_0.clone().requires_grad_(True)
mask = torch.randint(0, 2, (4, 1, 24, 24), device="cuda").bool()
reference = fused_fn(attention_scores, mask)
actual = torch_fn(attention_scores, mask)
torch.testing.assert_allclose(actual, reference)
expected = fused_fn(attention_scores_0, mask)
actual = torch_fn(attention_scores_1, mask)
torch.testing.assert_allclose(actual, expected)
g0 = torch.rand_like(actual)
with torch.no_grad():
g1 = g0.clone()
expected.backward(g0)
actual.backward(g1)
def test_autocast_fused_scale_mask_softmax(self):
for dtype in autocast_dtypes:
......@@ -74,16 +81,24 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase):
input_in_bf16 = dtype == torch.bfloat16
fused_fn, torch_fn = self._setup_fused_softmax(
input_in_fp16, input_in_bf16, attn_mask_type=AttnMaskType.padding)
attention_scores = torch.randn((4, 12, 24, 24)).cuda()
attention_scores_0 = torch.randn((4, 12, 24, 24)).cuda().requires_grad_(True)
with torch.no_grad():
attention_scores_1 = attention_scores_0.clone().to(dtype).requires_grad_(True)
mask = torch.randint(0, 2, (4, 1, 24, 24)).bool().cuda()
expected = torch_fn(attention_scores_1, mask)
with torch.cuda.amp.autocast(dtype=dtype):
actual = fused_fn(attention_scores, mask)
actual = fused_fn(attention_scores_0, mask)
self.assertEqual(actual.dtype, dtype)
with torch.no_grad():
expected = torch_fn(attention_scores.to(dtype), mask)
torch.testing.assert_allclose(actual, expected)
g0 = torch.rand_like(actual)
with torch.no_grad():
g1 = g0.clone()
expected.backward(g0)
actual.backward(g1)
def test_fused_upper_triangle_mask_softmax(self):
"""
attn_weights.shape: [4, 12, 24, 24]
......@@ -108,14 +123,22 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase):
fused_fn, torch_fn = self._setup_fused_softmax(
input_in_fp16, input_in_bf16, scale, softmax_in_fp32, AttnMaskType.causal)
attn_weights = torch.randn((4, 12, 24, 24)).to(device="cuda", dtype=dtype)
attn_weights_0 = torch.randn((4, 12, 24, 24)).to(device="cuda", dtype=dtype).requires_grad_(True)
with torch.no_grad():
attn_weights_1 = attn_weights_0.clone().requires_grad_(True)
total_mask = (~(
torch.tril(torch.randn((24, 24), device="cuda")).bool()
).unsqueeze(0).unsqueeze(0))
total_mask = total_mask.repeat((4, 1, 1, 1))
reference = fused_fn(attn_weights, total_mask)
actual = torch_fn(attn_weights, total_mask)
torch.testing.assert_allclose(actual, reference)
expected = fused_fn(attn_weights_0, total_mask)
actual = torch_fn(attn_weights_1, total_mask)
torch.testing.assert_allclose(actual, expected)
g0 = torch.randn_like(actual)
with torch.no_grad():
g1 = g0.clone()
actual.backward(g0)
expected.backward(g1)
def test_autocast_fused_upper_triangle_mask_softmax(self):
for dtype in autocast_dtypes:
......@@ -124,14 +147,22 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase):
input_in_bf16 = dtype == torch.bfloat16
fused_fn, torch_fn = self._setup_fused_softmax(
input_in_fp16, input_in_bf16, attn_mask_type=AttnMaskType.causal)
attn_weights = torch.randn((4, 12, 24, 24)).cuda()
attn_weights_0 = torch.randn((4, 12, 24, 24)).cuda().requires_grad_(True)
with torch.no_grad():
attn_weights_1 = attn_weights_0.clone().to(dtype).requires_grad_(True)
total_mask = (~(
torch.tril(torch.randn((24, 24), device="cuda")).bool()
).unsqueeze(0).unsqueeze(0))
with torch.cuda.amp.autocast(dtype=dtype):
actual = fused_fn(attn_weights, total_mask)
actual = fused_fn(attn_weights_0, total_mask)
self.assertEqual(actual.dtype, dtype)
with torch.no_grad():
expected = torch_fn(attn_weights.to(dtype), total_mask)
expected = torch_fn(attn_weights_1, total_mask)
torch.testing.assert_allclose(actual, expected)
g0 = torch.randn_like(actual)
with torch.no_grad():
g1 = g0.clone()
actual.backward(g0)
expected.backward(g1)
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册