Skip to content
GitLab
菜单
项目
群组
代码片段
帮助
帮助
支持
社区论坛
快捷键
?
提交反馈
登录/注册
切换导航
菜单
打开侧边栏
wanggh
apex
提交
1cf636d2
提交
1cf636d2
编辑于
5月 30, 2018
作者:
Carl Case
浏览文件
readme
上级
8c007904
变更
1
Hide whitespace changes
Inline
Side-by-side
apex/amp/README.md
浏览文件 @
1cf636d2
...
...
@@ -41,7 +41,7 @@ top-level README for more on installation.
## Usage and Getting Started
In the
normal
case, using amp requires adding two lines of code (and
In the
common
case, using amp requires adding two lines of code (and
an import). The first enables amp, so that it can hook into all the
relevant PyTorch functions. The second tells it where backpropagation
occurs so that it can properly scale the loss and clear internal
...
...
@@ -50,20 +50,25 @@ per-iteration state.
#### 1. Enable amp
```
python
from
apex
import
amp
amp_handle
=
amp
.
enable
()
amp_handle
=
amp
.
init
()
```
`amp.enable()`
takes two arguments, and the defaults are _highly_
recommended. The first,
`enable_caching`
(default=True), indicates
whether amp should cache fp16 casts of model parameters on a
per-iteration basis. This prevents things like RNN cells used inside a
loop from casting their weight matrices over and over. The second,
`verbose`
(default=False) toggles whether to print out every cast that
occurs. Useful for debugging, mostly.
`amp.init()`
takes three (optional) arguments. The most useful is
`enabled`
(default=True), which simplifies command-line arguments. If
False, then everything amp does will be a zero-overhead pass-through
-- i.e., your code will run as-is.
For the other two options, the defaults are _highly_ recommended. The
first,
`enable_caching`
(default=True), indicates whether amp should
cache fp16 casts of model parameters on a per-iteration basis. This
prevents things like RNN cells used inside a loop from casting their
weight matrices over and over. The second,
`verbose`
(default=False)
toggles whether to print out every cast that occurs. Useful for
debugging, mostly.
#### 2. Wrap backpropagation
Nearly all PyTorch training scripts have a loop
s
that looks like:
Nearly all PyTorch training scripts have a loop that looks like:
```
python
# ... do a bunch of stuff to compute a loss
...
...
@@ -91,9 +96,86 @@ you will not get automatic loss scaling, nor is it safe to
`enable_caching`
. (Power user note: you can manually clear the cache
after each optimizer step with
`amp_handle._clear_cache()`
.)
## Multiple Optimizers or Backward Passes
Step (2) from the previous section works when you have one PyTorch
optimizer and a single
`loss.backward()`
for each iteration. Some
models are more complex with:
-
Multiple optimizer objects (over different parameters)
-
Multiple backward passes for each iteration, taking advantage of
PyTorch's gradient accumulation
To work with such models, amp requires you to explicitly wrap each
optimizer and indicate if it will have more than one backward pass
per-iteration.
#### Explicitly wrapping optimizers
If you have more than one optimizer, then you must explicitly wrap
each. (You can also do so with a single optimizer.) First, wrap the
optimizer after initializing amp:
```
python
optimizer
=
# ... some optimizer
amp_handle
=
amp
.
init
()
optimizer
=
amp_handle
.
wrap_optimizer
(
optimizer
)
```
Second, use
`optimizer.scale_loss(...)`
to indicate where backprop
occurs:
```
python
with
optimizer
.
scale_loss
(
loss
)
as
scaled_loss
:
scaled_loss
.
backward
()
optimizer
.
step
()
# ...
```
In essence,
`amp_handle.scale_loss(loss, optimizer)`
is syntactic
sugar for first wrapping the optimizer and then calling
`optimizer.scale_loss(loss)`
in the single-optimizer case. But in the
multi-optimizer case, you must wrap each optimizer individually.
#### Handling multiple backward passes
PyTorch accumulates parameter gradients between calls to
`zero_grad()`
, so it is possible to perform multiple backward passes
before making a parameter update:
```
python
optimizer
.
zero_grad
()
loss1
=
ComputeLoss1
(
model
)
loss1
.
backward
()
# ...
loss2
=
ComputeLoss2
(
model
)
loss2
.
backward
()
# ...
optimizer
.
step
()
# has gradient contributions from both backward passes
```
The amp optimizer wrapper supports an additional argument
`num_loss`
to work with code like this:
```
python
amp_handle
=
amp
.
init
()
optimizer
=
amp_handle
.
wrap_optimizer
(
optimizer
,
num_loss
=
2
)
# ...
optimizer
.
zero_grad
()
loss1
=
ComputeLoss1
(
model
)
with
optimizer
.
scale_loss
(
loss1
)
as
scaled_loss
:
scaled_loss
.
backward
()
# ...
loss2
=
ComputeLoss2
(
model
)
with
optimizer
.
scale_loss
(
loss2
)
as
scaled_loss
:
scaled_loss
.
backward
()
# ...
optimizer
.
step
()
```
## Annotating User Functions
Nearly all PyTorch user code needs nothing more than
steps one and two
Nearly all PyTorch user code needs nothing more than
the two steps
above to use amp. After all, custom layers are built out of simpler
PyTorch components, and amp already can see those.
...
...
@@ -103,27 +185,62 @@ cell called a "forgetful recurrent unit" that calls directly into a
CUDA backend:
```
python
from
backend
import
FRUBackend
def
fru
(
input
,
hidden
,
weight
,
bias
):
# ... call to CUDA code
# call to CUDA code
FRUBackend
(
input
,
hidden
,
weight
,
bias
)
```
amp exposes two functions to handle this case:
`register_fp16`
and
`register_fp32`
. These add the given function to the white or
blacklist, respectively. You can use them as a decorator:
In this case, it is possible to get a runtime type mismatch. For
example, you might have
`input`
in fp16, and
`weight`
in fp32, and amp
doesn't have the visibility to insert an appropriate cast.
amp exposes two ways to handle "invisible" backend code: function
annotations and explicit registration.
#### Function annotation
The first way to handle backend code is a set of function annotations:
-
`@amp.half_function`
-
`@amp.float_function`
-
`@amp.promote_function`
These correspond to:
-
Cast all arguments to fp16
-
Cast all argumnets fo fp32
-
If there are any type mismatches, cast everything to the widest type
In our example, we believe that the FRU unit is fp16-safe and will get
performance gains from casting its arguments to fp16, so we write:
```
python
@
amp
.
register_fp16
@
amp
.
half_function
def
fru
(
input
,
hidden
,
weight
,
bias
):
#
...
#...
```
or as a library call:
#### Explicit registration
The other way to handle backend code is with explicit function
registration:
-
`amp.register_half_function(module, function_name)`
-
`amp.register_float_function(module, function_name)`
-
`amp.register_promote_function(module, function_name)`
When using this API,
`module`
is the containing class or module for
the function, and
`function_name`
is the _string_ name of the
function. Note that the function must be registered before the call to
`amp.init()`
.
For our FRU unit, we can register the backend function directly:
```
python
from
apex
import
amp
amp
.
register_fp16
(
custom_module
.
fru
)
amp
.
enable
()
```
import
backend
Note that the function must be registered before the call to
`amp.enable()`
. The library call makes this simple. If the function is
annotated, then you must ensure its module is loaded before the call
to
`amp.enable()`
. Furthermore, this does not (yet) work with class
methods, only free functions.
amp
.
register_half_function
(
backend
,
'FRUBackend'
)
amp
.
init
()
```
编辑
预览
Supports
Markdown
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录