提交 ec3cd13d 编辑于 作者: Carl Case's avatar Carl Case
浏览文件

merge in some changes to use amp pre-0.4

上级 db6ae13a
from . import amp
from . import RNN
from . import reparameterization
from . import fp16_utils
from . import parallel
#from . import RNN
#from . import reparameterization
#from . import fp16_utils
#from . import parallel
......@@ -145,6 +145,14 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False):
# 5.5) Extra-special handling of RNN backend
wrap.rnn_cast(torch.nn.backends.thnn.backend, 'RNN', verbose)
# And even more special handling of `backward` for fused gru / lstm
# The `backward` method calls Tensor.sum() (blacklist) internally,
# and then the resulting grad_input has the wrong type.
# TODO: where else is this a problem?
for rnn_type in ['GRUFused', 'LSTMFused']:
mod = getattr(torch.nn._functions.thnn.rnnFusedPointwise, rnn_type)
wrap.disable_casts(mod, 'backward', handle)
# 6) Place error+print message on banned functions
if not allow_banned:
for fn, err_msg in functional_overrides.BANNED_FUNCS:
......
......@@ -11,9 +11,16 @@ class AmpHandle(object):
self._verbose = verbose
self._cache = dict()
self._default_scaler = LossScaler()
self._is_active = True
def is_active(self):
return True
return self._is_active
@contextlib.contextmanager
def _disable_casts(self):
self._is_active = False
yield
self._is_active = True
def wrap_optimizer(self, optimizer, num_loss=1):
self._default_scaler = None
......@@ -76,6 +83,10 @@ class NoOpHandle(object):
def is_active(self):
return False
@contextlib.contextmanager
def _disable_casts(self):
yield
def wrap_optimizer(self, optimizer, num_loss=1):
return OptimWrapper(optimizer, self, num_loss)
......
......@@ -49,7 +49,7 @@ def maybe_half(x, name='', verbose=False):
if is_nested(x):
return type(x)([maybe_half(y) for y in x])
if type_string(x) == 'HalfTensor':
if not x.is_cuda or type_string(x) == 'HalfTensor':
return x
else:
if verbose:
......@@ -60,7 +60,7 @@ def maybe_float(x, name='', verbose=False):
if is_nested(x):
return type(x)([maybe_float(y) for y in x])
if type_string(x) == 'FloatTensor':
if not x.is_cuda or type_string(x) == 'FloatTensor':
return x
else:
if verbose:
......
......@@ -9,6 +9,9 @@ def make_cast_wrapper(orig_fn, cast_fn, handle,
try_caching=False):
@functools.wraps(orig_fn)
def wrapper(*args, **kwargs):
if not handle.is_active():
return orig_fn(*args, **kwargs)
if try_caching and handle.has_cache:
args = list(args)
for i in range(len(args)):
......@@ -70,7 +73,7 @@ def sequence_promote(mod, fn, verbose=False):
seq, {})
return orig_fn(cast_seq, *args, **kwargs)
else:
# TODO: other mixed-type cases aren't due to autohalf.
# TODO: other mixed-type cases aren't due to amp.
# Just pass through?
return orig_fn(seq, *args, **kwargs)
utils.set_func(mod, fn, wrapper)
......@@ -201,3 +204,14 @@ def rnn_cast(backend, fn, verbose=False):
return forward(*new_args, **fkwargs)
return fwd_wrapper
utils.set_func(backend, fn, rnn_wrapper)
def disable_casts(mod, fn, handle):
if not utils.has_func(mod, fn):
return
orig_fn = utils.get_func(mod, fn)
@functools.wraps(orig_fn)
def wrapper(*args, **kwargs):
with handle._disable_casts():
return orig_fn(*args, **kwargs)
utils.set_func(mod, fn, wrapper)
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册