提交 1cf636d2 编辑于 作者: Carl Case's avatar Carl Case
浏览文件

readme

上级 8c007904
......@@ -41,7 +41,7 @@ top-level README for more on installation.
## Usage and Getting Started
In the normal case, using amp requires adding two lines of code (and
In the common case, using amp requires adding two lines of code (and
an import). The first enables amp, so that it can hook into all the
relevant PyTorch functions. The second tells it where backpropagation
occurs so that it can properly scale the loss and clear internal
......@@ -50,20 +50,25 @@ per-iteration state.
#### 1. Enable amp
```python
from apex import amp
amp_handle = amp.enable()
amp_handle = amp.init()
```
`amp.enable()` takes two arguments, and the defaults are _highly_
recommended. The first, `enable_caching` (default=True), indicates
whether amp should cache fp16 casts of model parameters on a
per-iteration basis. This prevents things like RNN cells used inside a
loop from casting their weight matrices over and over. The second,
`verbose` (default=False) toggles whether to print out every cast that
occurs. Useful for debugging, mostly.
`amp.init()` takes three (optional) arguments. The most useful is
`enabled` (default=True), which simplifies command-line arguments. If
False, then everything amp does will be a zero-overhead pass-through
-- i.e., your code will run as-is.
For the other two options, the defaults are _highly_ recommended. The
first, `enable_caching` (default=True), indicates whether amp should
cache fp16 casts of model parameters on a per-iteration basis. This
prevents things like RNN cells used inside a loop from casting their
weight matrices over and over. The second, `verbose` (default=False)
toggles whether to print out every cast that occurs. Useful for
debugging, mostly.
#### 2. Wrap backpropagation
Nearly all PyTorch training scripts have a loops that looks like:
Nearly all PyTorch training scripts have a loop that looks like:
```python
# ... do a bunch of stuff to compute a loss
......@@ -91,9 +96,86 @@ you will not get automatic loss scaling, nor is it safe to
`enable_caching`. (Power user note: you can manually clear the cache
after each optimizer step with `amp_handle._clear_cache()`.)
## Multiple Optimizers or Backward Passes
Step (2) from the previous section works when you have one PyTorch
optimizer and a single `loss.backward()` for each iteration. Some
models are more complex with:
- Multiple optimizer objects (over different parameters)
- Multiple backward passes for each iteration, taking advantage of
PyTorch's gradient accumulation
To work with such models, amp requires you to explicitly wrap each
optimizer and indicate if it will have more than one backward pass
per-iteration.
#### Explicitly wrapping optimizers
If you have more than one optimizer, then you must explicitly wrap
each. (You can also do so with a single optimizer.) First, wrap the
optimizer after initializing amp:
```python
optimizer = # ... some optimizer
amp_handle = amp.init()
optimizer = amp_handle.wrap_optimizer(optimizer)
```
Second, use `optimizer.scale_loss(...)` to indicate where backprop
occurs:
```python
with optimizer.scale_loss(loss) as scaled_loss:
scaled_loss.backward()
optimizer.step()
# ...
```
In essence, `amp_handle.scale_loss(loss, optimizer)` is syntactic
sugar for first wrapping the optimizer and then calling
`optimizer.scale_loss(loss)` in the single-optimizer case. But in the
multi-optimizer case, you must wrap each optimizer individually.
#### Handling multiple backward passes
PyTorch accumulates parameter gradients between calls to
`zero_grad()`, so it is possible to perform multiple backward passes
before making a parameter update:
```python
optimizer.zero_grad()
loss1 = ComputeLoss1(model)
loss1.backward()
# ...
loss2 = ComputeLoss2(model)
loss2.backward()
# ...
optimizer.step() # has gradient contributions from both backward passes
```
The amp optimizer wrapper supports an additional argument `num_loss`
to work with code like this:
```python
amp_handle = amp.init()
optimizer = amp_handle.wrap_optimizer(optimizer, num_loss=2)
# ...
optimizer.zero_grad()
loss1 = ComputeLoss1(model)
with optimizer.scale_loss(loss1) as scaled_loss:
scaled_loss.backward()
# ...
loss2 = ComputeLoss2(model)
with optimizer.scale_loss(loss2) as scaled_loss:
scaled_loss.backward()
# ...
optimizer.step()
```
## Annotating User Functions
Nearly all PyTorch user code needs nothing more than steps one and two
Nearly all PyTorch user code needs nothing more than the two steps
above to use amp. After all, custom layers are built out of simpler
PyTorch components, and amp already can see those.
......@@ -103,27 +185,62 @@ cell called a "forgetful recurrent unit" that calls directly into a
CUDA backend:
```python
from backend import FRUBackend
def fru(input, hidden, weight, bias):
# ... call to CUDA code
# call to CUDA code
FRUBackend(input, hidden, weight, bias)
```
amp exposes two functions to handle this case: `register_fp16` and
`register_fp32`. These add the given function to the white or
blacklist, respectively. You can use them as a decorator:
In this case, it is possible to get a runtime type mismatch. For
example, you might have `input` in fp16, and `weight` in fp32, and amp
doesn't have the visibility to insert an appropriate cast.
amp exposes two ways to handle "invisible" backend code: function
annotations and explicit registration.
#### Function annotation
The first way to handle backend code is a set of function annotations:
- `@amp.half_function`
- `@amp.float_function`
- `@amp.promote_function`
These correspond to:
- Cast all arguments to fp16
- Cast all argumnets fo fp32
- If there are any type mismatches, cast everything to the widest type
In our example, we believe that the FRU unit is fp16-safe and will get
performance gains from casting its arguments to fp16, so we write:
```python
@amp.register_fp16
@amp.half_function
def fru(input, hidden, weight, bias):
# ...
#...
```
or as a library call:
#### Explicit registration
The other way to handle backend code is with explicit function
registration:
- `amp.register_half_function(module, function_name)`
- `amp.register_float_function(module, function_name)`
- `amp.register_promote_function(module, function_name)`
When using this API, `module` is the containing class or module for
the function, and `function_name` is the _string_ name of the
function. Note that the function must be registered before the call to
`amp.init()`.
For our FRU unit, we can register the backend function directly:
```python
from apex import amp
amp.register_fp16(custom_module.fru)
amp.enable()
```
import backend
Note that the function must be registered before the call to
`amp.enable()`. The library call makes this simple. If the function is
annotated, then you must ensure its module is loaded before the call
to `amp.enable()`. Furthermore, this does not (yet) work with class
methods, only free functions.
amp.register_half_function(backend, 'FRUBackend')
amp.init()
```
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册