Skip to content
Snippets Groups Projects
Unverified Commit 12117715 authored by Jerry Jiarui XU's avatar Jerry Jiarui XU Committed by GitHub
Browse files

add optimizer registry (#2139)

* add optimizer registry

* move under core, add doc

* unexpose TORCH_OPTIMIZERS
parent 4c21f7ff
No related branches found
No related tags found
No related merge requests found
......@@ -422,6 +422,31 @@ There are two ways to work with custom datasets.
a pickle or json file, like [pascal_voc.py](https://github.com/open-mmlab/mmdetection/blob/master/tools/convert_datasets/pascal_voc.py).
Then you can simply use `CustomDataset`.
### Customize optimizer
An example of customized optimizer `CopyOfSGD` is defined in `mmdet/core/optimizer/copy_of_sgd.py`.
More generally, a customized optimizer could be defined as following.
In `mmdet/core/optimizer/my_optimizer.py`:
```python
from .registry import OPTIMIZERS
from torch.optim import Optimizer
@OPTIMIZERS.register_module
class MyOptimizer(Optimizer):
```
In `mmdet/core/optimizer/__init__.py`:
```python
from .my_optimizer import MyOptimizer
```
Then you can use `MyOptimizer` in `optimizer` field of config files.
### Develop new components
We basically categorize model components into 4 types.
......
import random
import re
from collections import OrderedDict
import numpy as np
import torch
import torch.distributed as dist
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import DistSamplerSeedHook, Runner, obj_from_dict
from mmcv.runner import DistSamplerSeedHook, Runner
from mmdet.core import DistEvalHook, DistOptimizerHook, Fp16OptimizerHook
from mmdet.core import (DistEvalHook, DistOptimizerHook, Fp16OptimizerHook,
build_optimizer)
from mmdet.datasets import build_dataloader
from mmdet.utils import get_root_logger
......@@ -111,86 +111,6 @@ def train_detector(model,
meta=meta)
def build_optimizer(model, optimizer_cfg):
"""Build optimizer from configs.
Args:
model (:obj:`nn.Module`): The model with parameters to be optimized.
optimizer_cfg (dict): The config dict of the optimizer.
Positional fields are:
- type: class name of the optimizer.
- lr: base learning rate.
Optional fields are:
- any arguments of the corresponding optimizer type, e.g.,
weight_decay, momentum, etc.
- paramwise_options: a dict with 3 accepted fileds
(bias_lr_mult, bias_decay_mult, norm_decay_mult).
`bias_lr_mult` and `bias_decay_mult` will be multiplied to
the lr and weight decay respectively for all bias parameters
(except for the normalization layers), and
`norm_decay_mult` will be multiplied to the weight decay
for all weight and bias parameters of normalization layers.
Returns:
torch.optim.Optimizer: The initialized optimizer.
Example:
>>> model = torch.nn.modules.Conv1d(1, 1, 1)
>>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9,
>>> weight_decay=0.0001)
>>> optimizer = build_optimizer(model, optimizer_cfg)
"""
if hasattr(model, 'module'):
model = model.module
optimizer_cfg = optimizer_cfg.copy()
paramwise_options = optimizer_cfg.pop('paramwise_options', None)
# if no paramwise option is specified, just use the global setting
if paramwise_options is None:
return obj_from_dict(optimizer_cfg, torch.optim,
dict(params=model.parameters()))
else:
assert isinstance(paramwise_options, dict)
# get base lr and weight decay
base_lr = optimizer_cfg['lr']
base_wd = optimizer_cfg.get('weight_decay', None)
# weight_decay must be explicitly specified if mult is specified
if ('bias_decay_mult' in paramwise_options
or 'norm_decay_mult' in paramwise_options):
assert base_wd is not None
# get param-wise options
bias_lr_mult = paramwise_options.get('bias_lr_mult', 1.)
bias_decay_mult = paramwise_options.get('bias_decay_mult', 1.)
norm_decay_mult = paramwise_options.get('norm_decay_mult', 1.)
# set param-wise lr and weight decay
params = []
for name, param in model.named_parameters():
param_group = {'params': [param]}
if not param.requires_grad:
# FP16 training needs to copy gradient/weight between master
# weight copy and model weight, it is convenient to keep all
# parameters here to align with model.parameters()
params.append(param_group)
continue
# for norm layers, overwrite the weight decay of weight and bias
# TODO: obtain the norm layer prefixes dynamically
if re.search(r'(bn|gn)(\d+)?.(weight|bias)', name):
if base_wd is not None:
param_group['weight_decay'] = base_wd * norm_decay_mult
# for other layers, overwrite both lr and weight decay of bias
elif name.endswith('.bias'):
param_group['lr'] = base_lr * bias_lr_mult
if base_wd is not None:
param_group['weight_decay'] = base_wd * bias_decay_mult
# otherwise use the global settings
params.append(param_group)
optimizer_cls = getattr(torch.optim, optimizer_cfg.pop('type'))
return optimizer_cls(params, **optimizer_cfg)
def _dist_train(model,
dataset,
cfg,
......
......@@ -3,5 +3,6 @@ from .bbox import * # noqa: F401, F403
from .evaluation import * # noqa: F401, F403
from .fp16 import * # noqa: F401, F403
from .mask import * # noqa: F401, F403
from .optimizer import * # noqa: F401, F403
from .post_processing import * # noqa: F401, F403
from .utils import * # noqa: F401, F403
from .builder import build_optimizer
from .copy_of_sgd import CopyOfSGD
from .registry import OPTIMIZERS
__all__ = ['OPTIMIZERS', 'build_optimizer', 'CopyOfSGD']
import re
from mmdet.utils import build_from_cfg
from .registry import OPTIMIZERS
def build_optimizer(model, optimizer_cfg):
"""Build optimizer from configs.
Args:
model (:obj:`nn.Module`): The model with parameters to be optimized.
optimizer_cfg (dict): The config dict of the optimizer.
Positional fields are:
- type: class name of the optimizer.
- lr: base learning rate.
Optional fields are:
- any arguments of the corresponding optimizer type, e.g.,
weight_decay, momentum, etc.
- paramwise_options: a dict with 3 accepted fileds
(bias_lr_mult, bias_decay_mult, norm_decay_mult).
`bias_lr_mult` and `bias_decay_mult` will be multiplied to
the lr and weight decay respectively for all bias parameters
(except for the normalization layers), and
`norm_decay_mult` will be multiplied to the weight decay
for all weight and bias parameters of normalization layers.
Returns:
torch.optim.Optimizer: The initialized optimizer.
Example:
>>> import torch
>>> model = torch.nn.modules.Conv1d(1, 1, 1)
>>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9,
>>> weight_decay=0.0001)
>>> optimizer = build_optimizer(model, optimizer_cfg)
"""
if hasattr(model, 'module'):
model = model.module
optimizer_cfg = optimizer_cfg.copy()
paramwise_options = optimizer_cfg.pop('paramwise_options', None)
# if no paramwise option is specified, just use the global setting
if paramwise_options is None:
params = model.parameters()
else:
assert isinstance(paramwise_options, dict)
# get base lr and weight decay
base_lr = optimizer_cfg['lr']
base_wd = optimizer_cfg.get('weight_decay', None)
# weight_decay must be explicitly specified if mult is specified
if ('bias_decay_mult' in paramwise_options
or 'norm_decay_mult' in paramwise_options):
assert base_wd is not None
# get param-wise options
bias_lr_mult = paramwise_options.get('bias_lr_mult', 1.)
bias_decay_mult = paramwise_options.get('bias_decay_mult', 1.)
norm_decay_mult = paramwise_options.get('norm_decay_mult', 1.)
# set param-wise lr and weight decay
params = []
for name, param in model.named_parameters():
param_group = {'params': [param]}
if not param.requires_grad:
# FP16 training needs to copy gradient/weight between master
# weight copy and model weight, it is convenient to keep all
# parameters here to align with model.parameters()
params.append(param_group)
continue
# for norm layers, overwrite the weight decay of weight and bias
# TODO: obtain the norm layer prefixes dynamically
if re.search(r'(bn|gn)(\d+)?.(weight|bias)', name):
if base_wd is not None:
param_group['weight_decay'] = base_wd * norm_decay_mult
# for other layers, overwrite both lr and weight decay of bias
elif name.endswith('.bias'):
param_group['lr'] = base_lr * bias_lr_mult
if base_wd is not None:
param_group['weight_decay'] = base_wd * bias_decay_mult
# otherwise use the global settings
params.append(param_group)
optimizer_cfg['params'] = params
return build_from_cfg(optimizer_cfg, OPTIMIZERS)
from torch.optim import SGD
from .registry import OPTIMIZERS
@OPTIMIZERS.register_module
class CopyOfSGD(SGD):
"""A clone of torch.optim.SGD.
A customized optimizer could be defined like CopyOfSGD.
You may derive from built-in optimizers in torch.optim,
or directly implement a new optimizer.
"""
import inspect
import torch
from mmdet.utils import Registry
OPTIMIZERS = Registry('optimizer')
def register_torch_optimizers():
torch_optimizers = []
for module_name in dir(torch.optim):
if module_name.startswith('__'):
continue
_optim = getattr(torch.optim, module_name)
if inspect.isclass(_optim) and issubclass(_optim,
torch.optim.Optimizer):
OPTIMIZERS.register_module(_optim)
torch_optimizers.append(module_name)
return torch_optimizers
TORCH_OPTIMIZERS = register_torch_optimizers()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment