Skip to content
GitLab
菜单
项目
群组
代码片段
帮助
帮助
支持
社区论坛
快捷键
?
提交反馈
登录/注册
切换导航
菜单
打开侧边栏
wanggh
apex
提交
14ccf598
未验证
提交
14ccf598
编辑于
10月 09, 2021
作者:
Masaki Kozuki
提交者:
GitHub
10月 08, 2021
浏览文件
Remove `custom_fwd`/`custom_bwd` from fused softmax (#1188)
* run backward * remove custom_fwd/custom_bwd
上级
3ad9db2a
变更
2
Hide whitespace changes
Inline
Side-by-side
apex/transformer/functional/fused_softmax.py
浏览文件 @
14ccf598
...
...
@@ -37,7 +37,6 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
return
softmax_results
@
staticmethod
@
torch
.
cuda
.
amp
.
custom_bwd
def
backward
(
ctx
,
output_grads
):
import
scaled_upper_triang_masked_softmax_cuda
...
...
@@ -68,7 +67,6 @@ def scaled_upper_triang_masked_softmax(inputs, _, scale):
class
ScaledMaskedSoftmax
(
torch
.
autograd
.
Function
):
@
staticmethod
@
torch
.
cuda
.
amp
.
custom_fwd
(
cast_inputs
=
torch
.
half
)
def
forward
(
ctx
,
inputs
,
mask
,
scale
):
import
scaled_masked_softmax_cuda
scale_t
=
torch
.
tensor
([
scale
])
...
...
@@ -78,7 +76,6 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
return
softmax_results
@
staticmethod
@
torch
.
cuda
.
amp
.
custom_bwd
def
backward
(
ctx
,
output_grads
):
import
scaled_masked_softmax_cuda
...
...
tests/L0/run_transformer/test_fused_softmax.py
浏览文件 @
14ccf598
...
...
@@ -12,8 +12,7 @@ from apex.transformer.functional import FusedScaleMaskSoftmax
def
attention_mask_func
(
attention_scores
,
attention_mask
):
attention_scores
.
masked_fill_
(
attention_mask
,
-
10000.0
)
return
attention_scores
return
attention_scores
.
masked_fill
(
attention_mask
,
-
10000.0
)
autocast_dtypes
=
(
torch
.
half
,
torch
.
bfloat16
)
if
torch
.
cuda
.
is_bf16_supported
()
else
(
torch
.
half
,)
...
...
@@ -61,11 +60,19 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase):
return
fused_fn
,
torch_fn
=
self
.
_setup_fused_softmax
(
input_in_fp16
,
input_in_bf16
,
scale
,
softmax_in_fp32
,
AttnMaskType
.
padding
)
attention_scores
=
torch
.
randn
((
4
,
12
,
24
,
24
)).
to
(
device
=
"cuda"
,
dtype
=
dtype
)
attention_scores_0
=
torch
.
randn
((
4
,
12
,
24
,
24
)).
to
(
device
=
"cuda"
,
dtype
=
dtype
).
requires_grad_
(
True
)
with
torch
.
no_grad
():
attention_scores_1
=
attention_scores_0
.
clone
().
requires_grad_
(
True
)
mask
=
torch
.
randint
(
0
,
2
,
(
4
,
1
,
24
,
24
),
device
=
"cuda"
).
bool
()
reference
=
fused_fn
(
attention_scores
,
mask
)
actual
=
torch_fn
(
attention_scores
,
mask
)
torch
.
testing
.
assert_allclose
(
actual
,
reference
)
expected
=
fused_fn
(
attention_scores_0
,
mask
)
actual
=
torch_fn
(
attention_scores_1
,
mask
)
torch
.
testing
.
assert_allclose
(
actual
,
expected
)
g0
=
torch
.
rand_like
(
actual
)
with
torch
.
no_grad
():
g1
=
g0
.
clone
()
expected
.
backward
(
g0
)
actual
.
backward
(
g1
)
def
test_autocast_fused_scale_mask_softmax
(
self
):
for
dtype
in
autocast_dtypes
:
...
...
@@ -74,16 +81,24 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase):
input_in_bf16
=
dtype
==
torch
.
bfloat16
fused_fn
,
torch_fn
=
self
.
_setup_fused_softmax
(
input_in_fp16
,
input_in_bf16
,
attn_mask_type
=
AttnMaskType
.
padding
)
attention_scores
=
torch
.
randn
((
4
,
12
,
24
,
24
)).
cuda
()
attention_scores_0
=
torch
.
randn
((
4
,
12
,
24
,
24
)).
cuda
().
requires_grad_
(
True
)
with
torch
.
no_grad
():
attention_scores_1
=
attention_scores_0
.
clone
().
to
(
dtype
).
requires_grad_
(
True
)
mask
=
torch
.
randint
(
0
,
2
,
(
4
,
1
,
24
,
24
)).
bool
().
cuda
()
expected
=
torch_fn
(
attention_scores_1
,
mask
)
with
torch
.
cuda
.
amp
.
autocast
(
dtype
=
dtype
):
actual
=
fused_fn
(
attention_scores
,
mask
)
actual
=
fused_fn
(
attention_scores
_0
,
mask
)
self
.
assertEqual
(
actual
.
dtype
,
dtype
)
with
torch
.
no_grad
():
expected
=
torch_fn
(
attention_scores
.
to
(
dtype
),
mask
)
torch
.
testing
.
assert_allclose
(
actual
,
expected
)
g0
=
torch
.
rand_like
(
actual
)
with
torch
.
no_grad
():
g1
=
g0
.
clone
()
expected
.
backward
(
g0
)
actual
.
backward
(
g1
)
def
test_fused_upper_triangle_mask_softmax
(
self
):
"""
attn_weights.shape: [4, 12, 24, 24]
...
...
@@ -108,14 +123,22 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase):
fused_fn
,
torch_fn
=
self
.
_setup_fused_softmax
(
input_in_fp16
,
input_in_bf16
,
scale
,
softmax_in_fp32
,
AttnMaskType
.
causal
)
attn_weights
=
torch
.
randn
((
4
,
12
,
24
,
24
)).
to
(
device
=
"cuda"
,
dtype
=
dtype
)
attn_weights_0
=
torch
.
randn
((
4
,
12
,
24
,
24
)).
to
(
device
=
"cuda"
,
dtype
=
dtype
).
requires_grad_
(
True
)
with
torch
.
no_grad
():
attn_weights_1
=
attn_weights_0
.
clone
().
requires_grad_
(
True
)
total_mask
=
(
~
(
torch
.
tril
(
torch
.
randn
((
24
,
24
),
device
=
"cuda"
)).
bool
()
).
unsqueeze
(
0
).
unsqueeze
(
0
))
total_mask
=
total_mask
.
repeat
((
4
,
1
,
1
,
1
))
reference
=
fused_fn
(
attn_weights
,
total_mask
)
actual
=
torch_fn
(
attn_weights
,
total_mask
)
torch
.
testing
.
assert_allclose
(
actual
,
reference
)
expected
=
fused_fn
(
attn_weights_0
,
total_mask
)
actual
=
torch_fn
(
attn_weights_1
,
total_mask
)
torch
.
testing
.
assert_allclose
(
actual
,
expected
)
g0
=
torch
.
randn_like
(
actual
)
with
torch
.
no_grad
():
g1
=
g0
.
clone
()
actual
.
backward
(
g0
)
expected
.
backward
(
g1
)
def
test_autocast_fused_upper_triangle_mask_softmax
(
self
):
for
dtype
in
autocast_dtypes
:
...
...
@@ -124,14 +147,22 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase):
input_in_bf16
=
dtype
==
torch
.
bfloat16
fused_fn
,
torch_fn
=
self
.
_setup_fused_softmax
(
input_in_fp16
,
input_in_bf16
,
attn_mask_type
=
AttnMaskType
.
causal
)
attn_weights
=
torch
.
randn
((
4
,
12
,
24
,
24
)).
cuda
()
attn_weights_0
=
torch
.
randn
((
4
,
12
,
24
,
24
)).
cuda
().
requires_grad_
(
True
)
with
torch
.
no_grad
():
attn_weights_1
=
attn_weights_0
.
clone
().
to
(
dtype
).
requires_grad_
(
True
)
total_mask
=
(
~
(
torch
.
tril
(
torch
.
randn
((
24
,
24
),
device
=
"cuda"
)).
bool
()
).
unsqueeze
(
0
).
unsqueeze
(
0
))
with
torch
.
cuda
.
amp
.
autocast
(
dtype
=
dtype
):
actual
=
fused_fn
(
attn_weights
,
total_mask
)
actual
=
fused_fn
(
attn_weights
_0
,
total_mask
)
self
.
assertEqual
(
actual
.
dtype
,
dtype
)
with
torch
.
no_grad
():
expected
=
torch_fn
(
attn_weights
.
to
(
dtype
),
total_mask
)
expected
=
torch_fn
(
attn_weights_1
,
total_mask
)
torch
.
testing
.
assert_allclose
(
actual
,
expected
)
g0
=
torch
.
randn_like
(
actual
)
with
torch
.
no_grad
():
g1
=
g0
.
clone
()
actual
.
backward
(
g0
)
expected
.
backward
(
g1
)
编辑
预览
Supports
Markdown
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录