提交 1cf636d2 编辑于 作者: Carl Case's avatar Carl Case
浏览文件

readme

上级 8c007904
...@@ -41,7 +41,7 @@ top-level README for more on installation. ...@@ -41,7 +41,7 @@ top-level README for more on installation.
## Usage and Getting Started ## Usage and Getting Started
In the normal case, using amp requires adding two lines of code (and In the common case, using amp requires adding two lines of code (and
an import). The first enables amp, so that it can hook into all the an import). The first enables amp, so that it can hook into all the
relevant PyTorch functions. The second tells it where backpropagation relevant PyTorch functions. The second tells it where backpropagation
occurs so that it can properly scale the loss and clear internal occurs so that it can properly scale the loss and clear internal
...@@ -50,20 +50,25 @@ per-iteration state. ...@@ -50,20 +50,25 @@ per-iteration state.
#### 1. Enable amp #### 1. Enable amp
```python ```python
from apex import amp from apex import amp
amp_handle = amp.enable() amp_handle = amp.init()
``` ```
`amp.enable()` takes two arguments, and the defaults are _highly_ `amp.init()` takes three (optional) arguments. The most useful is
recommended. The first, `enable_caching` (default=True), indicates `enabled` (default=True), which simplifies command-line arguments. If
whether amp should cache fp16 casts of model parameters on a False, then everything amp does will be a zero-overhead pass-through
per-iteration basis. This prevents things like RNN cells used inside a -- i.e., your code will run as-is.
loop from casting their weight matrices over and over. The second,
`verbose` (default=False) toggles whether to print out every cast that For the other two options, the defaults are _highly_ recommended. The
occurs. Useful for debugging, mostly. first, `enable_caching` (default=True), indicates whether amp should
cache fp16 casts of model parameters on a per-iteration basis. This
prevents things like RNN cells used inside a loop from casting their
weight matrices over and over. The second, `verbose` (default=False)
toggles whether to print out every cast that occurs. Useful for
debugging, mostly.
#### 2. Wrap backpropagation #### 2. Wrap backpropagation
Nearly all PyTorch training scripts have a loops that looks like: Nearly all PyTorch training scripts have a loop that looks like:
```python ```python
# ... do a bunch of stuff to compute a loss # ... do a bunch of stuff to compute a loss
...@@ -91,9 +96,86 @@ you will not get automatic loss scaling, nor is it safe to ...@@ -91,9 +96,86 @@ you will not get automatic loss scaling, nor is it safe to
`enable_caching`. (Power user note: you can manually clear the cache `enable_caching`. (Power user note: you can manually clear the cache
after each optimizer step with `amp_handle._clear_cache()`.) after each optimizer step with `amp_handle._clear_cache()`.)
## Multiple Optimizers or Backward Passes
Step (2) from the previous section works when you have one PyTorch
optimizer and a single `loss.backward()` for each iteration. Some
models are more complex with:
- Multiple optimizer objects (over different parameters)
- Multiple backward passes for each iteration, taking advantage of
PyTorch's gradient accumulation
To work with such models, amp requires you to explicitly wrap each
optimizer and indicate if it will have more than one backward pass
per-iteration.
#### Explicitly wrapping optimizers
If you have more than one optimizer, then you must explicitly wrap
each. (You can also do so with a single optimizer.) First, wrap the
optimizer after initializing amp:
```python
optimizer = # ... some optimizer
amp_handle = amp.init()
optimizer = amp_handle.wrap_optimizer(optimizer)
```
Second, use `optimizer.scale_loss(...)` to indicate where backprop
occurs:
```python
with optimizer.scale_loss(loss) as scaled_loss:
scaled_loss.backward()
optimizer.step()
# ...
```
In essence, `amp_handle.scale_loss(loss, optimizer)` is syntactic
sugar for first wrapping the optimizer and then calling
`optimizer.scale_loss(loss)` in the single-optimizer case. But in the
multi-optimizer case, you must wrap each optimizer individually.
#### Handling multiple backward passes
PyTorch accumulates parameter gradients between calls to
`zero_grad()`, so it is possible to perform multiple backward passes
before making a parameter update:
```python
optimizer.zero_grad()
loss1 = ComputeLoss1(model)
loss1.backward()
# ...
loss2 = ComputeLoss2(model)
loss2.backward()
# ...
optimizer.step() # has gradient contributions from both backward passes
```
The amp optimizer wrapper supports an additional argument `num_loss`
to work with code like this:
```python
amp_handle = amp.init()
optimizer = amp_handle.wrap_optimizer(optimizer, num_loss=2)
# ...
optimizer.zero_grad()
loss1 = ComputeLoss1(model)
with optimizer.scale_loss(loss1) as scaled_loss:
scaled_loss.backward()
# ...
loss2 = ComputeLoss2(model)
with optimizer.scale_loss(loss2) as scaled_loss:
scaled_loss.backward()
# ...
optimizer.step()
```
## Annotating User Functions ## Annotating User Functions
Nearly all PyTorch user code needs nothing more than steps one and two Nearly all PyTorch user code needs nothing more than the two steps
above to use amp. After all, custom layers are built out of simpler above to use amp. After all, custom layers are built out of simpler
PyTorch components, and amp already can see those. PyTorch components, and amp already can see those.
...@@ -103,27 +185,62 @@ cell called a "forgetful recurrent unit" that calls directly into a ...@@ -103,27 +185,62 @@ cell called a "forgetful recurrent unit" that calls directly into a
CUDA backend: CUDA backend:
```python ```python
from backend import FRUBackend
def fru(input, hidden, weight, bias): def fru(input, hidden, weight, bias):
# ... call to CUDA code # call to CUDA code
FRUBackend(input, hidden, weight, bias)
``` ```
amp exposes two functions to handle this case: `register_fp16` and In this case, it is possible to get a runtime type mismatch. For
`register_fp32`. These add the given function to the white or example, you might have `input` in fp16, and `weight` in fp32, and amp
blacklist, respectively. You can use them as a decorator: doesn't have the visibility to insert an appropriate cast.
amp exposes two ways to handle "invisible" backend code: function
annotations and explicit registration.
#### Function annotation
The first way to handle backend code is a set of function annotations:
- `@amp.half_function`
- `@amp.float_function`
- `@amp.promote_function`
These correspond to:
- Cast all arguments to fp16
- Cast all argumnets fo fp32
- If there are any type mismatches, cast everything to the widest type
In our example, we believe that the FRU unit is fp16-safe and will get
performance gains from casting its arguments to fp16, so we write:
```python ```python
@amp.register_fp16 @amp.half_function
def fru(input, hidden, weight, bias): def fru(input, hidden, weight, bias):
# ... #...
``` ```
or as a library call:
#### Explicit registration
The other way to handle backend code is with explicit function
registration:
- `amp.register_half_function(module, function_name)`
- `amp.register_float_function(module, function_name)`
- `amp.register_promote_function(module, function_name)`
When using this API, `module` is the containing class or module for
the function, and `function_name` is the _string_ name of the
function. Note that the function must be registered before the call to
`amp.init()`.
For our FRU unit, we can register the backend function directly:
```python ```python
from apex import amp import backend
amp.register_fp16(custom_module.fru)
amp.enable()
```
Note that the function must be registered before the call to amp.register_half_function(backend, 'FRUBackend')
`amp.enable()`. The library call makes this simple. If the function is amp.init()
annotated, then you must ensure its module is loaded before the call ```
to `amp.enable()`. Furthermore, this does not (yet) work with class
methods, only free functions.
支持 Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册