提交 8c007904 编辑于 作者: Carl Case's avatar Carl Case
浏览文件

Hard ban on fp16 BCELoss

上级 dbbad668
......@@ -54,7 +54,7 @@ def register_promote_function(module, name):
_USER_PROMOTE_REGISTRY.add((module, name))
# Top-level function to insert _all_ the hooks.
def init(enabled=True, enable_caching=True, verbose=False):
def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False):
global _DECORATOR_HANDLE
if not enabled:
......@@ -145,5 +145,10 @@ def init(enabled=True, enable_caching=True, verbose=False):
# 5.5) Extra-special handling of RNN backend
wrap.rnn_cast(torch.nn.backends.thnn.backend, 'RNN', verbose)
# 6) Place error+print message on banned functions
if not allow_banned:
for fn, err_msg in functional_overrides.BANNED_FUNCS:
wrap.err_if_any_half(functional_overrides.MODULE, fn, err_msg)
_DECORATOR_HANDLE = handle
return handle
......@@ -42,7 +42,6 @@ FP32_FUNCS = [
# Loss functions
# TODO: which of these can be fp16?
'binary_cross_entropy',
'poisson_nll_loss',
'cosine_embedding_loss',
'cross_entropy',
......@@ -60,3 +59,15 @@ FP32_FUNCS = [
'soft_margin_loss',
'triplet_margin_loss'
]
BANNED_FUNCS = [
('binary_cross_entropy',
("\namp does not work out-of-the-box with `F.binary_cross_entropy` or `torch.nn.BCELoss.` "
"It requires that the output of the previous function be already a FloatTensor. \n\n"
"Most models have a Sigmoid right before BCELoss. In that case, you can use\n"
" torch.nn.BCEWithLogitsLoss\nto combine Sigmoid+BCELoss into a single layer "
"that is compatible with amp.\nAnother option is to add\n"
" amp.register_float_function(torch, 'sigmoid')\nbefore calling `amp.init()`.\n"
"If you _really_ know what you are doing, you can disable this warning by passing "
"allow_banned=True to `amp.init()`."))
]
......@@ -94,7 +94,7 @@ def promote_match_arg0(mod, fn, verbose=False):
return orig_fn(arg0, *new_args, **kwargs)
utils.set_func(mod, fn, wrapper)
def err_if_any_half(mod, fn):
def err_if_any_half(mod, fn, custom_err_msg=None):
if not utils.has_func(mod, fn):
return
......@@ -103,8 +103,11 @@ def err_if_any_half(mod, fn):
def wrapper(*args, **kwargs):
types = utils.collect_fp_tensor_types(args, kwargs)
if 'HalfTensor' in types:
raise NotImplementedError('Cannot call in-place function ' +
'{} with fp16 arguments.'.format(fn))
if custom_err_msg:
raise NotImplementedError(custom_err_msg)
else:
raise NotImplementedError('Cannot call in-place function ' +
'{} with fp16 arguments.'.format(fn))
else:
return orig_fn(*args, **kwargs)
utils.set_func(mod, fn, wrapper)
......
支持 Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册