From c449d02534af05d7fe4487d65f021dc75e28bd9f Mon Sep 17 00:00:00 2001
From: impiga <xjtu_lyt@163.com>
Date: Tue, 13 Apr 2021 00:23:55 +0800
Subject: [PATCH] Add amp support

---
 mmcv_custom/runner/__init__.py           |  11 +
 mmcv_custom/runner/base_runner.py        | 447 +++++++++++++++++++++
 mmcv_custom/runner/checkpoint.py         | 473 +++++++++++++++++++++++
 mmcv_custom/runner/epoch_based_runner.py | 171 ++++++++
 mmdet/apis/train.py                      |  43 ++-
 mmdet/utils/__init__.py                  |   3 +-
 mmdet/utils/optimizer.py                 |  33 ++
 7 files changed, 1170 insertions(+), 11 deletions(-)
 create mode 100644 mmcv_custom/runner/__init__.py
 create mode 100644 mmcv_custom/runner/base_runner.py
 create mode 100644 mmcv_custom/runner/checkpoint.py
 create mode 100644 mmcv_custom/runner/epoch_based_runner.py
 create mode 100644 mmdet/utils/optimizer.py

diff --git a/mmcv_custom/runner/__init__.py b/mmcv_custom/runner/__init__.py
new file mode 100644
index 00000000..1f0c98aa
--- /dev/null
+++ b/mmcv_custom/runner/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) Open-MMLab. All rights reserved.
+from .base_runner import BaseRunner
+from .checkpoint import (_load_checkpoint, load_checkpoint, load_state_dict,
+                         save_checkpoint, weights_to_cpu)
+from .epoch_based_runner import EpochBasedRunnerAmp
+
+
+__all__ = [
+    'BaseRunner', 'EpochBasedRunnerAmp', '_load_checkpoint', 'load_checkpoint',
+    'load_state_dict', 'save_checkpoint', 'weights_to_cpu'
+]
diff --git a/mmcv_custom/runner/base_runner.py b/mmcv_custom/runner/base_runner.py
new file mode 100644
index 00000000..adda0794
--- /dev/null
+++ b/mmcv_custom/runner/base_runner.py
@@ -0,0 +1,447 @@
+# Copyright (c) Open-MMLab. All rights reserved.
+import logging
+import os.path as osp
+import warnings
+from abc import ABCMeta, abstractmethod
+
+import torch
+from torch.optim import Optimizer
+
+import mmcv
+from mmcv.parallel import is_module_wrapper
+from mmcv.runner import (load_checkpoint, get_dist_info, get_priority, get_time_str,
+                         HOOKS, Hook, IterTimerHook, LogBuffer)
+try:
+    import apex
+except:
+    print('apex is not installed')
+
+
+class BaseRunner(metaclass=ABCMeta):
+    """The base class of Runner, a training helper for PyTorch.
+
+    All subclasses should implement the following APIs:
+
+    - ``run()``
+    - ``train()``
+    - ``val()``
+    - ``save_checkpoint()``
+
+    Args:
+        model (:obj:`torch.nn.Module`): The model to be run.
+        batch_processor (callable): A callable method that process a data
+            batch. The interface of this method should be
+            `batch_processor(model, data, train_mode) -> dict`
+        optimizer (dict or :obj:`torch.optim.Optimizer`): It can be either an
+            optimizer (in most cases) or a dict of optimizers (in models that
+            requires more than one optimizer, e.g., GAN).
+        work_dir (str, optional): The working directory to save checkpoints
+            and logs. Defaults to None.
+        logger (:obj:`logging.Logger`): Logger used during training.
+             Defaults to None. (The default value is just for backward
+             compatibility)
+        meta (dict | None): A dict records some import information such as
+            environment info and seed, which will be logged in logger hook.
+            Defaults to None.
+        max_epochs (int, optional): Total training epochs.
+        max_iters (int, optional): Total training iterations.
+        amp (bool): Whether to use amp. Default: False.
+    """
+
+    def __init__(self,
+                 model,
+                 batch_processor=None,
+                 optimizer=None,
+                 work_dir=None,
+                 logger=None,
+                 meta=None,
+                 max_iters=None,
+                 max_epochs=None,
+                 amp=False):
+        if batch_processor is not None:
+            if not callable(batch_processor):
+                raise TypeError('batch_processor must be callable, '
+                                f'but got {type(batch_processor)}')
+            warnings.warn('batch_processor is deprecated, please implement '
+                          'train_step() and val_step() in the model instead.')
+            # raise an error is `batch_processor` is not None and
+            # `model.train_step()` exists.
+            if is_module_wrapper(model):
+                _model = model.module
+            else:
+                _model = model
+            if hasattr(_model, 'train_step') or hasattr(_model, 'val_step'):
+                raise RuntimeError(
+                    'batch_processor and model.train_step()/model.val_step() '
+                    'cannot be both available.')
+        else:
+            assert hasattr(model, 'train_step')
+
+        # check the type of `optimizer`
+        if isinstance(optimizer, dict):
+            for name, optim in optimizer.items():
+                if not isinstance(optim, Optimizer):
+                    raise TypeError(
+                        f'optimizer must be a dict of torch.optim.Optimizers, '
+                        f'but optimizer["{name}"] is a {type(optim)}')
+        elif not isinstance(optimizer, Optimizer) and optimizer is not None:
+            raise TypeError(
+                f'optimizer must be a torch.optim.Optimizer object '
+                f'or dict or None, but got {type(optimizer)}')
+
+        # check the type of `logger`
+        if not isinstance(logger, logging.Logger):
+            raise TypeError(f'logger must be a logging.Logger object, '
+                            f'but got {type(logger)}')
+
+        # check the type of `meta`
+        if meta is not None and not isinstance(meta, dict):
+            raise TypeError(
+                f'meta must be a dict or None, but got {type(meta)}')
+
+        self.model = model
+        self.batch_processor = batch_processor
+        self.optimizer = optimizer
+        self.logger = logger
+        self.meta = meta
+        self.amp = amp
+
+        # create work_dir
+        if mmcv.is_str(work_dir):
+            self.work_dir = osp.abspath(work_dir)
+            mmcv.mkdir_or_exist(self.work_dir)
+        elif work_dir is None:
+            self.work_dir = None
+        else:
+            raise TypeError('"work_dir" must be a str or None')
+
+        # get model name from the model class
+        if hasattr(self.model, 'module'):
+            self._model_name = self.model.module.__class__.__name__
+        else:
+            self._model_name = self.model.__class__.__name__
+
+        self._rank, self._world_size = get_dist_info()
+        self.timestamp = get_time_str()
+        self.mode = None
+        self._hooks = []
+        self._epoch = 0
+        self._iter = 0
+        self._inner_iter = 0
+
+        if max_epochs is not None and max_iters is not None:
+            raise ValueError(
+                'Only one of `max_epochs` or `max_iters` can be set.')
+
+        self._max_epochs = max_epochs
+        self._max_iters = max_iters
+        # TODO: Redesign LogBuffer, it is not flexible and elegant enough
+        self.log_buffer = LogBuffer()
+
+    @property
+    def model_name(self):
+        """str: Name of the model, usually the module class name."""
+        return self._model_name
+
+    @property
+    def rank(self):
+        """int: Rank of current process. (distributed training)"""
+        return self._rank
+
+    @property
+    def world_size(self):
+        """int: Number of processes participating in the job.
+        (distributed training)"""
+        return self._world_size
+
+    @property
+    def hooks(self):
+        """list[:obj:`Hook`]: A list of registered hooks."""
+        return self._hooks
+
+    @property
+    def epoch(self):
+        """int: Current epoch."""
+        return self._epoch
+
+    @property
+    def iter(self):
+        """int: Current iteration."""
+        return self._iter
+
+    @property
+    def inner_iter(self):
+        """int: Iteration in an epoch."""
+        return self._inner_iter
+
+    @property
+    def max_epochs(self):
+        """int: Maximum training epochs."""
+        return self._max_epochs
+
+    @property
+    def max_iters(self):
+        """int: Maximum training iterations."""
+        return self._max_iters
+
+    @abstractmethod
+    def train(self):
+        pass
+
+    @abstractmethod
+    def val(self):
+        pass
+
+    @abstractmethod
+    def run(self, data_loaders, workflow, **kwargs):
+        pass
+
+    @abstractmethod
+    def save_checkpoint(self,
+                        out_dir,
+                        filename_tmpl,
+                        save_optimizer=True,
+                        meta=None,
+                        create_symlink=True):
+        pass
+
+    def current_lr(self):
+        """Get current learning rates.
+
+        Returns:
+            list[float] | dict[str, list[float]]: Current learning rates of all
+                param groups. If the runner has a dict of optimizers, this
+                method will return a dict.
+        """
+        if isinstance(self.optimizer, torch.optim.Optimizer):
+            lr = [group['lr'] for group in self.optimizer.param_groups]
+        elif isinstance(self.optimizer, dict):
+            lr = dict()
+            for name, optim in self.optimizer.items():
+                lr[name] = [group['lr'] for group in optim.param_groups]
+        else:
+            raise RuntimeError(
+                'lr is not applicable because optimizer does not exist.')
+        return lr
+
+    def current_momentum(self):
+        """Get current momentums.
+
+        Returns:
+            list[float] | dict[str, list[float]]: Current momentums of all
+                param groups. If the runner has a dict of optimizers, this
+                method will return a dict.
+        """
+
+        def _get_momentum(optimizer):
+            momentums = []
+            for group in optimizer.param_groups:
+                if 'momentum' in group.keys():
+                    momentums.append(group['momentum'])
+                elif 'betas' in group.keys():
+                    momentums.append(group['betas'][0])
+                else:
+                    momentums.append(0)
+            return momentums
+
+        if self.optimizer is None:
+            raise RuntimeError(
+                'momentum is not applicable because optimizer does not exist.')
+        elif isinstance(self.optimizer, torch.optim.Optimizer):
+            momentums = _get_momentum(self.optimizer)
+        elif isinstance(self.optimizer, dict):
+            momentums = dict()
+            for name, optim in self.optimizer.items():
+                momentums[name] = _get_momentum(optim)
+        return momentums
+
+    def register_hook(self, hook, priority='NORMAL'):
+        """Register a hook into the hook list.
+
+        The hook will be inserted into a priority queue, with the specified
+        priority (See :class:`Priority` for details of priorities).
+        For hooks with the same priority, they will be triggered in the same
+        order as they are registered.
+
+        Args:
+            hook (:obj:`Hook`): The hook to be registered.
+            priority (int or str or :obj:`Priority`): Hook priority.
+                Lower value means higher priority.
+        """
+        assert isinstance(hook, Hook)
+        if hasattr(hook, 'priority'):
+            raise ValueError('"priority" is a reserved attribute for hooks')
+        priority = get_priority(priority)
+        hook.priority = priority
+        # insert the hook to a sorted list
+        inserted = False
+        for i in range(len(self._hooks) - 1, -1, -1):
+            if priority >= self._hooks[i].priority:
+                self._hooks.insert(i + 1, hook)
+                inserted = True
+                break
+        if not inserted:
+            self._hooks.insert(0, hook)
+
+    def register_hook_from_cfg(self, hook_cfg):
+        """Register a hook from its cfg.
+
+        Args:
+            hook_cfg (dict): Hook config. It should have at least keys 'type'
+              and 'priority' indicating its type and priority.
+
+        Notes:
+            The specific hook class to register should not use 'type' and
+            'priority' arguments during initialization.
+        """
+        hook_cfg = hook_cfg.copy()
+        priority = hook_cfg.pop('priority', 'NORMAL')
+        hook = mmcv.build_from_cfg(hook_cfg, HOOKS)
+        self.register_hook(hook, priority=priority)
+
+    def call_hook(self, fn_name):
+        """Call all hooks.
+
+        Args:
+            fn_name (str): The function name in each hook to be called, such as
+                "before_train_epoch".
+        """
+        for hook in self._hooks:
+            getattr(hook, fn_name)(self)
+
+    def load_checkpoint(self, filename, map_location='cpu', strict=False):
+        self.logger.info('load checkpoint from %s', filename)
+        return load_checkpoint(self.model, filename, map_location, strict,
+                               self.logger)
+
+    def resume(self,
+               checkpoint,
+               resume_optimizer=True,
+               map_location='default',
+               resume_amp=False):
+        if map_location == 'default':
+            if torch.cuda.is_available():
+                device_id = torch.cuda.current_device()
+                checkpoint = self.load_checkpoint(
+                    checkpoint,
+                    map_location=lambda storage, loc: storage.cuda(device_id))
+            else:
+                checkpoint = self.load_checkpoint(checkpoint)
+        else:
+            checkpoint = self.load_checkpoint(
+                checkpoint, map_location=map_location)
+
+        self._epoch = checkpoint['meta']['epoch']
+        self._iter = checkpoint['meta']['iter']
+        if 'optimizer' in checkpoint and resume_optimizer:
+            if isinstance(self.optimizer, Optimizer):
+                self.optimizer.load_state_dict(checkpoint['optimizer'])
+            elif isinstance(self.optimizer, dict):
+                for k in self.optimizer.keys():
+                    self.optimizer[k].load_state_dict(
+                        checkpoint['optimizer'][k])
+            else:
+                raise TypeError(
+                    'Optimizer should be dict or torch.optim.Optimizer '
+                    f'but got {type(self.optimizer)}')
+
+        if 'amp' in checkpoint and resume_amp:
+            apex.amp.load_state_dict(checkpoint['amp'])
+            self.logger.info('load amp state dict')
+
+        self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter)
+
+    def register_lr_hook(self, lr_config):
+        if lr_config is None:
+            return
+        elif isinstance(lr_config, dict):
+            assert 'policy' in lr_config
+            policy_type = lr_config.pop('policy')
+            # If the type of policy is all in lower case, e.g., 'cyclic',
+            # then its first letter will be capitalized, e.g., to be 'Cyclic'.
+            # This is for the convenient usage of Lr updater.
+            # Since this is not applicable for `
+            # CosineAnnealingLrUpdater`,
+            # the string will not be changed if it contains capital letters.
+            if policy_type == policy_type.lower():
+                policy_type = policy_type.title()
+            hook_type = policy_type + 'LrUpdaterHook'
+            lr_config['type'] = hook_type
+            hook = mmcv.build_from_cfg(lr_config, HOOKS)
+        else:
+            hook = lr_config
+        self.register_hook(hook)
+
+    def register_momentum_hook(self, momentum_config):
+        if momentum_config is None:
+            return
+        if isinstance(momentum_config, dict):
+            assert 'policy' in momentum_config
+            policy_type = momentum_config.pop('policy')
+            # If the type of policy is all in lower case, e.g., 'cyclic',
+            # then its first letter will be capitalized, e.g., to be 'Cyclic'.
+            # This is for the convenient usage of momentum updater.
+            # Since this is not applicable for
+            # `CosineAnnealingMomentumUpdater`,
+            # the string will not be changed if it contains capital letters.
+            if policy_type == policy_type.lower():
+                policy_type = policy_type.title()
+            hook_type = policy_type + 'MomentumUpdaterHook'
+            momentum_config['type'] = hook_type
+            hook = mmcv.build_from_cfg(momentum_config, HOOKS)
+        else:
+            hook = momentum_config
+        self.register_hook(hook)
+
+    def register_optimizer_hook(self, optimizer_config):
+        if optimizer_config is None:
+            return
+        if isinstance(optimizer_config, dict):
+            optimizer_config.setdefault('type', 'OptimizerHook')
+            hook = mmcv.build_from_cfg(optimizer_config, HOOKS)
+        else:
+            hook = optimizer_config
+        self.register_hook(hook)
+
+    def register_checkpoint_hook(self, checkpoint_config):
+        if checkpoint_config is None:
+            return
+        if isinstance(checkpoint_config, dict):
+            checkpoint_config.setdefault('type', 'CheckpointHook')
+            hook = mmcv.build_from_cfg(checkpoint_config, HOOKS)
+        else:
+            hook = checkpoint_config
+        self.register_hook(hook)
+
+    def register_logger_hooks(self, log_config):
+        if log_config is None:
+            return
+        log_interval = log_config['interval']
+        for info in log_config['hooks']:
+            logger_hook = mmcv.build_from_cfg(
+                info, HOOKS, default_args=dict(interval=log_interval))
+            self.register_hook(logger_hook, priority='VERY_LOW')
+
+    def register_training_hooks(self,
+                                lr_config,
+                                optimizer_config=None,
+                                checkpoint_config=None,
+                                log_config=None,
+                                momentum_config=None):
+        """Register default hooks for training.
+
+        Default hooks include:
+
+        - LrUpdaterHook
+        - MomentumUpdaterHook
+        - OptimizerStepperHook
+        - CheckpointSaverHook
+        - IterTimerHook
+        - LoggerHook(s)
+        """
+        self.register_lr_hook(lr_config)
+        self.register_momentum_hook(momentum_config)
+        self.register_optimizer_hook(optimizer_config)
+        self.register_checkpoint_hook(checkpoint_config)
+        self.register_hook(IterTimerHook())
+        self.register_logger_hooks(log_config)
diff --git a/mmcv_custom/runner/checkpoint.py b/mmcv_custom/runner/checkpoint.py
new file mode 100644
index 00000000..30a7d414
--- /dev/null
+++ b/mmcv_custom/runner/checkpoint.py
@@ -0,0 +1,473 @@
+# Copyright (c) Open-MMLab. All rights reserved.
+import io
+import os
+import os.path as osp
+import pkgutil
+import time
+import warnings
+from collections import OrderedDict
+from importlib import import_module
+from tempfile import TemporaryDirectory
+
+import torch
+import torchvision
+from torch.optim import Optimizer
+from torch.utils import model_zoo
+
+import mmcv
+from mmcv.fileio import FileClient
+from mmcv.fileio import load as load_file
+from mmcv.parallel import is_module_wrapper
+from mmcv.utils import mkdir_or_exist
+from mmcv.runner import get_dist_info
+
+try:
+    import apex
+except:
+    print('apex is not installed')
+
+ENV_MMCV_HOME = 'MMCV_HOME'
+ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
+DEFAULT_CACHE_DIR = '~/.cache'
+
+
+def _get_mmcv_home():
+    mmcv_home = os.path.expanduser(
+        os.getenv(
+            ENV_MMCV_HOME,
+            os.path.join(
+                os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))
+
+    mkdir_or_exist(mmcv_home)
+    return mmcv_home
+
+
+def load_state_dict(module, state_dict, strict=False, logger=None):
+    """Load state_dict to a module.
+
+    This method is modified from :meth:`torch.nn.Module.load_state_dict`.
+    Default value for ``strict`` is set to ``False`` and the message for
+    param mismatch will be shown even if strict is False.
+
+    Args:
+        module (Module): Module that receives the state_dict.
+        state_dict (OrderedDict): Weights.
+        strict (bool): whether to strictly enforce that the keys
+            in :attr:`state_dict` match the keys returned by this module's
+            :meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
+        logger (:obj:`logging.Logger`, optional): Logger to log the error
+            message. If not specified, print function will be used.
+    """
+    unexpected_keys = []
+    all_missing_keys = []
+    err_msg = []
+
+    metadata = getattr(state_dict, '_metadata', None)
+    state_dict = state_dict.copy()
+    if metadata is not None:
+        state_dict._metadata = metadata
+
+    # use _load_from_state_dict to enable checkpoint version control
+    def load(module, prefix=''):
+        # recursively check parallel module in case that the model has a
+        # complicated structure, e.g., nn.Module(nn.Module(DDP))
+        if is_module_wrapper(module):
+            module = module.module
+        local_metadata = {} if metadata is None else metadata.get(
+            prefix[:-1], {})
+        module._load_from_state_dict(state_dict, prefix, local_metadata, True,
+                                     all_missing_keys, unexpected_keys,
+                                     err_msg)
+        for name, child in module._modules.items():
+            if child is not None:
+                load(child, prefix + name + '.')
+
+    load(module)
+    load = None  # break load->load reference cycle
+
+    # ignore "num_batches_tracked" of BN layers
+    missing_keys = [
+        key for key in all_missing_keys if 'num_batches_tracked' not in key
+    ]
+
+    if unexpected_keys:
+        err_msg.append('unexpected key in source '
+                       f'state_dict: {", ".join(unexpected_keys)}\n')
+    if missing_keys:
+        err_msg.append(
+            f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
+
+    rank, _ = get_dist_info()
+    if len(err_msg) > 0 and rank == 0:
+        err_msg.insert(
+            0, 'The model and loaded state dict do not match exactly\n')
+        err_msg = '\n'.join(err_msg)
+        if strict:
+            raise RuntimeError(err_msg)
+        elif logger is not None:
+            logger.warning(err_msg)
+        else:
+            print(err_msg)
+
+
+def load_url_dist(url, model_dir=None):
+    """In distributed setting, this function only download checkpoint at local
+    rank 0."""
+    rank, world_size = get_dist_info()
+    rank = int(os.environ.get('LOCAL_RANK', rank))
+    if rank == 0:
+        checkpoint = model_zoo.load_url(url, model_dir=model_dir)
+    if world_size > 1:
+        torch.distributed.barrier()
+        if rank > 0:
+            checkpoint = model_zoo.load_url(url, model_dir=model_dir)
+    return checkpoint
+
+
+def load_pavimodel_dist(model_path, map_location=None):
+    """In distributed setting, this function only download checkpoint at local
+    rank 0."""
+    try:
+        from pavi import modelcloud
+    except ImportError:
+        raise ImportError(
+            'Please install pavi to load checkpoint from modelcloud.')
+    rank, world_size = get_dist_info()
+    rank = int(os.environ.get('LOCAL_RANK', rank))
+    if rank == 0:
+        model = modelcloud.get(model_path)
+        with TemporaryDirectory() as tmp_dir:
+            downloaded_file = osp.join(tmp_dir, model.name)
+            model.download(downloaded_file)
+            checkpoint = torch.load(downloaded_file, map_location=map_location)
+    if world_size > 1:
+        torch.distributed.barrier()
+        if rank > 0:
+            model = modelcloud.get(model_path)
+            with TemporaryDirectory() as tmp_dir:
+                downloaded_file = osp.join(tmp_dir, model.name)
+                model.download(downloaded_file)
+                checkpoint = torch.load(
+                    downloaded_file, map_location=map_location)
+    return checkpoint
+
+
+def load_fileclient_dist(filename, backend, map_location):
+    """In distributed setting, this function only download checkpoint at local
+    rank 0."""
+    rank, world_size = get_dist_info()
+    rank = int(os.environ.get('LOCAL_RANK', rank))
+    allowed_backends = ['ceph']
+    if backend not in allowed_backends:
+        raise ValueError(f'Load from Backend {backend} is not supported.')
+    if rank == 0:
+        fileclient = FileClient(backend=backend)
+        buffer = io.BytesIO(fileclient.get(filename))
+        checkpoint = torch.load(buffer, map_location=map_location)
+    if world_size > 1:
+        torch.distributed.barrier()
+        if rank > 0:
+            fileclient = FileClient(backend=backend)
+            buffer = io.BytesIO(fileclient.get(filename))
+            checkpoint = torch.load(buffer, map_location=map_location)
+    return checkpoint
+
+
+def get_torchvision_models():
+    model_urls = dict()
+    for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
+        if ispkg:
+            continue
+        _zoo = import_module(f'torchvision.models.{name}')
+        if hasattr(_zoo, 'model_urls'):
+            _urls = getattr(_zoo, 'model_urls')
+            model_urls.update(_urls)
+    return model_urls
+
+
+def get_external_models():
+    mmcv_home = _get_mmcv_home()
+    default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')
+    default_urls = load_file(default_json_path)
+    assert isinstance(default_urls, dict)
+    external_json_path = osp.join(mmcv_home, 'open_mmlab.json')
+    if osp.exists(external_json_path):
+        external_urls = load_file(external_json_path)
+        assert isinstance(external_urls, dict)
+        default_urls.update(external_urls)
+
+    return default_urls
+
+
+def get_mmcls_models():
+    mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')
+    mmcls_urls = load_file(mmcls_json_path)
+
+    return mmcls_urls
+
+
+def get_deprecated_model_names():
+    deprecate_json_path = osp.join(mmcv.__path__[0],
+                                   'model_zoo/deprecated.json')
+    deprecate_urls = load_file(deprecate_json_path)
+    assert isinstance(deprecate_urls, dict)
+
+    return deprecate_urls
+
+
+def _process_mmcls_checkpoint(checkpoint):
+    state_dict = checkpoint['state_dict']
+    new_state_dict = OrderedDict()
+    for k, v in state_dict.items():
+        if k.startswith('backbone.'):
+            new_state_dict[k[9:]] = v
+    new_checkpoint = dict(state_dict=new_state_dict)
+
+    return new_checkpoint
+
+
+def _load_checkpoint(filename, map_location=None):
+    """Load checkpoint from somewhere (modelzoo, file, url).
+
+    Args:
+        filename (str): Accept local filepath, URL, ``torchvision://xxx``,
+            ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
+            details.
+        map_location (str | None): Same as :func:`torch.load`. Default: None.
+
+    Returns:
+        dict | OrderedDict: The loaded checkpoint. It can be either an
+            OrderedDict storing model weights or a dict containing other
+            information, which depends on the checkpoint.
+    """
+    if filename.startswith('modelzoo://'):
+        warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
+                      'use "torchvision://" instead')
+        model_urls = get_torchvision_models()
+        model_name = filename[11:]
+        checkpoint = load_url_dist(model_urls[model_name])
+    elif filename.startswith('torchvision://'):
+        model_urls = get_torchvision_models()
+        model_name = filename[14:]
+        checkpoint = load_url_dist(model_urls[model_name])
+    elif filename.startswith('open-mmlab://'):
+        model_urls = get_external_models()
+        model_name = filename[13:]
+        deprecated_urls = get_deprecated_model_names()
+        if model_name in deprecated_urls:
+            warnings.warn(f'open-mmlab://{model_name} is deprecated in favor '
+                          f'of open-mmlab://{deprecated_urls[model_name]}')
+            model_name = deprecated_urls[model_name]
+        model_url = model_urls[model_name]
+        # check if is url
+        if model_url.startswith(('http://', 'https://')):
+            checkpoint = load_url_dist(model_url)
+        else:
+            filename = osp.join(_get_mmcv_home(), model_url)
+            if not osp.isfile(filename):
+                raise IOError(f'{filename} is not a checkpoint file')
+            checkpoint = torch.load(filename, map_location=map_location)
+    elif filename.startswith('mmcls://'):
+        model_urls = get_mmcls_models()
+        model_name = filename[8:]
+        checkpoint = load_url_dist(model_urls[model_name])
+        checkpoint = _process_mmcls_checkpoint(checkpoint)
+    elif filename.startswith(('http://', 'https://')):
+        checkpoint = load_url_dist(filename)
+    elif filename.startswith('pavi://'):
+        model_path = filename[7:]
+        checkpoint = load_pavimodel_dist(model_path, map_location=map_location)
+    elif filename.startswith('s3://'):
+        checkpoint = load_fileclient_dist(
+            filename, backend='ceph', map_location=map_location)
+    else:
+        if not osp.isfile(filename):
+            raise IOError(f'{filename} is not a checkpoint file')
+        checkpoint = torch.load(filename, map_location=map_location)
+    return checkpoint
+
+
+def load_checkpoint(model,
+                    filename,
+                    map_location=None,
+                    strict=False,
+                    logger=None):
+    """Load checkpoint from a file or URI.
+
+    Args:
+        model (Module): Module to load checkpoint.
+        filename (str): Accept local filepath, URL, ``torchvision://xxx``,
+            ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
+            details.
+        map_location (str): Same as :func:`torch.load`.
+        strict (bool): Whether to allow different params for the model and
+            checkpoint.
+        logger (:mod:`logging.Logger` or None): The logger for error message.
+
+    Returns:
+        dict or OrderedDict: The loaded checkpoint.
+    """
+    checkpoint = _load_checkpoint(filename, map_location)
+    # OrderedDict is a subclass of dict
+    if not isinstance(checkpoint, dict):
+        raise RuntimeError(
+            f'No state_dict found in checkpoint file {filename}')
+    # get state_dict from checkpoint
+    if 'state_dict' in checkpoint:
+        state_dict = checkpoint['state_dict']
+    else:
+        state_dict = checkpoint
+    # strip prefix of state_dict
+    if list(state_dict.keys())[0].startswith('module.'):
+        state_dict = {k[7:]: v for k, v in state_dict.items()}
+    # load state_dict
+    load_state_dict(model, state_dict, strict, logger)
+    return checkpoint
+
+
+def weights_to_cpu(state_dict):
+    """Copy a model state_dict to cpu.
+
+    Args:
+        state_dict (OrderedDict): Model weights on GPU.
+
+    Returns:
+        OrderedDict: Model weights on GPU.
+    """
+    state_dict_cpu = OrderedDict()
+    for key, val in state_dict.items():
+        state_dict_cpu[key] = val.cpu()
+    return state_dict_cpu
+
+
+def _save_to_state_dict(module, destination, prefix, keep_vars):
+    """Saves module state to `destination` dictionary.
+
+    This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
+
+    Args:
+        module (nn.Module): The module to generate state_dict.
+        destination (dict): A dict where state will be stored.
+        prefix (str): The prefix for parameters and buffers used in this
+            module.
+    """
+    for name, param in module._parameters.items():
+        if param is not None:
+            destination[prefix + name] = param if keep_vars else param.detach()
+    for name, buf in module._buffers.items():
+        # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
+        if buf is not None:
+            destination[prefix + name] = buf if keep_vars else buf.detach()
+
+
+def get_state_dict(module, destination=None, prefix='', keep_vars=False):
+    """Returns a dictionary containing a whole state of the module.
+
+    Both parameters and persistent buffers (e.g. running averages) are
+    included. Keys are corresponding parameter and buffer names.
+
+    This method is modified from :meth:`torch.nn.Module.state_dict` to
+    recursively check parallel module in case that the model has a complicated
+    structure, e.g., nn.Module(nn.Module(DDP)).
+
+    Args:
+        module (nn.Module): The module to generate state_dict.
+        destination (OrderedDict): Returned dict for the state of the
+            module.
+        prefix (str): Prefix of the key.
+        keep_vars (bool): Whether to keep the variable property of the
+            parameters. Default: False.
+
+    Returns:
+        dict: A dictionary containing a whole state of the module.
+    """
+    # recursively check parallel module in case that the model has a
+    # complicated structure, e.g., nn.Module(nn.Module(DDP))
+    if is_module_wrapper(module):
+        module = module.module
+
+    # below is the same as torch.nn.Module.state_dict()
+    if destination is None:
+        destination = OrderedDict()
+        destination._metadata = OrderedDict()
+    destination._metadata[prefix[:-1]] = local_metadata = dict(
+        version=module._version)
+    _save_to_state_dict(module, destination, prefix, keep_vars)
+    for name, child in module._modules.items():
+        if child is not None:
+            get_state_dict(
+                child, destination, prefix + name + '.', keep_vars=keep_vars)
+    for hook in module._state_dict_hooks.values():
+        hook_result = hook(module, destination, prefix, local_metadata)
+        if hook_result is not None:
+            destination = hook_result
+    return destination
+
+
+def save_checkpoint(model, filename, optimizer=None, meta=None, amp=False):
+    """Save checkpoint to file.
+
+    The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
+    ``optimizer``. By default ``meta`` will contain version and time info.
+
+    Args:
+        model (Module): Module whose params are to be saved.
+        filename (str): Checkpoint filename.
+        optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
+        meta (dict, optional): Metadata to be saved in checkpoint.
+    """
+    if meta is None:
+        meta = {}
+    elif not isinstance(meta, dict):
+        raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
+    meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
+
+    if is_module_wrapper(model):
+        model = model.module
+
+    if hasattr(model, 'CLASSES') and model.CLASSES is not None:
+        # save class name to the meta
+        meta.update(CLASSES=model.CLASSES)
+
+    checkpoint = {
+        'meta': meta,
+        'state_dict': weights_to_cpu(get_state_dict(model))
+    }
+    # save optimizer state dict in the checkpoint
+    if isinstance(optimizer, Optimizer):
+        checkpoint['optimizer'] = optimizer.state_dict()
+    elif isinstance(optimizer, dict):
+        checkpoint['optimizer'] = {}
+        for name, optim in optimizer.items():
+            checkpoint['optimizer'][name] = optim.state_dict()
+
+    # save amp state dict in the checkpoint
+    if amp:
+        checkpoint['amp'] = apex.amp.state_dict()
+
+    if filename.startswith('pavi://'):
+        try:
+            from pavi import modelcloud
+            from pavi.exception import NodeNotFoundError
+        except ImportError:
+            raise ImportError(
+                'Please install pavi to load checkpoint from modelcloud.')
+        model_path = filename[7:]
+        root = modelcloud.Folder()
+        model_dir, model_name = osp.split(model_path)
+        try:
+            model = modelcloud.get(model_dir)
+        except NodeNotFoundError:
+            model = root.create_training_model(model_dir)
+        with TemporaryDirectory() as tmp_dir:
+            checkpoint_file = osp.join(tmp_dir, model_name)
+            with open(checkpoint_file, 'wb') as f:
+                torch.save(checkpoint, f)
+                f.flush()
+            model.create_file(checkpoint_file, name=model_name)
+    else:
+        mmcv.mkdir_or_exist(osp.dirname(filename))
+        # immediately flush buffer
+        with open(filename, 'wb') as f:
+            torch.save(checkpoint, f)
+            f.flush()
diff --git a/mmcv_custom/runner/epoch_based_runner.py b/mmcv_custom/runner/epoch_based_runner.py
new file mode 100644
index 00000000..334b9f5d
--- /dev/null
+++ b/mmcv_custom/runner/epoch_based_runner.py
@@ -0,0 +1,171 @@
+# Copyright (c) Open-MMLab. All rights reserved.
+import os.path as osp
+import platform
+import shutil
+import time
+import warnings
+
+import torch
+
+import mmcv
+from .base_runner import BaseRunner
+from .checkpoint import save_checkpoint
+from mmcv.runner import RUNNERS, get_host_info
+
+
+@RUNNERS.register_module()
+class EpochBasedRunnerAmp(BaseRunner):
+    """Epoch-based Runner.
+
+    This runner train models epoch by epoch.
+    """
+
+    def run_iter(self, data_batch, train_mode, **kwargs):
+        if self.batch_processor is not None:
+            outputs = self.batch_processor(
+                self.model, data_batch, train_mode=train_mode, **kwargs)
+        elif train_mode:
+            outputs = self.model.train_step(data_batch, self.optimizer,
+                                            **kwargs)
+        else:
+            outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
+        if not isinstance(outputs, dict):
+            raise TypeError('"batch_processor()" or "model.train_step()"'
+                            'and "model.val_step()" must return a dict')
+        if 'log_vars' in outputs:
+            self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
+        self.outputs = outputs
+
+    def train(self, data_loader, **kwargs):
+        self.model.train()
+        self.mode = 'train'
+        self.data_loader = data_loader
+        self._max_iters = self._max_epochs * len(self.data_loader)
+        self.call_hook('before_train_epoch')
+        time.sleep(2)  # Prevent possible deadlock during epoch transition
+        for i, data_batch in enumerate(self.data_loader):
+            self._inner_iter = i
+            self.call_hook('before_train_iter')
+            self.run_iter(data_batch, train_mode=True)
+            self.call_hook('after_train_iter')
+            self._iter += 1
+
+        self.call_hook('after_train_epoch')
+        self._epoch += 1
+
+    @torch.no_grad()
+    def val(self, data_loader, **kwargs):
+        self.model.eval()
+        self.mode = 'val'
+        self.data_loader = data_loader
+        self.call_hook('before_val_epoch')
+        time.sleep(2)  # Prevent possible deadlock during epoch transition
+        for i, data_batch in enumerate(self.data_loader):
+            self._inner_iter = i
+            self.call_hook('before_val_iter')
+            self.run_iter(data_batch, train_mode=False)
+            self.call_hook('after_val_iter')
+
+        self.call_hook('after_val_epoch')
+
+    def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
+        """Start running.
+
+        Args:
+            data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
+                and validation.
+            workflow (list[tuple]): A list of (phase, epochs) to specify the
+                running order and epochs. E.g, [('train', 2), ('val', 1)] means
+                running 2 epochs for training and 1 epoch for validation,
+                iteratively.
+        """
+        assert isinstance(data_loaders, list)
+        assert mmcv.is_list_of(workflow, tuple)
+        assert len(data_loaders) == len(workflow)
+        if max_epochs is not None:
+            warnings.warn(
+                'setting max_epochs in run is deprecated, '
+                'please set max_epochs in runner_config', DeprecationWarning)
+            self._max_epochs = max_epochs
+
+        assert self._max_epochs is not None, (
+            'max_epochs must be specified during instantiation')
+
+        for i, flow in enumerate(workflow):
+            mode, epochs = flow
+            if mode == 'train':
+                self._max_iters = self._max_epochs * len(data_loaders[i])
+                break
+
+        work_dir = self.work_dir if self.work_dir is not None else 'NONE'
+        self.logger.info('Start running, host: %s, work_dir: %s',
+                         get_host_info(), work_dir)
+        self.logger.info('workflow: %s, max: %d epochs', workflow,
+                         self._max_epochs)
+        self.call_hook('before_run')
+
+        while self.epoch < self._max_epochs:
+            for i, flow in enumerate(workflow):
+                mode, epochs = flow
+                if isinstance(mode, str):  # self.train()
+                    if not hasattr(self, mode):
+                        raise ValueError(
+                            f'runner has no method named "{mode}" to run an '
+                            'epoch')
+                    epoch_runner = getattr(self, mode)
+                else:
+                    raise TypeError(
+                        'mode in workflow must be a str, but got {}'.format(
+                            type(mode)))
+
+                for _ in range(epochs):
+                    if mode == 'train' and self.epoch >= self._max_epochs:
+                        break
+                    epoch_runner(data_loaders[i], **kwargs)
+
+        time.sleep(1)  # wait for some hooks like loggers to finish
+        self.call_hook('after_run')
+
+    def save_checkpoint(self,
+                        out_dir,
+                        filename_tmpl='epoch_{}.pth',
+                        save_optimizer=True,
+                        meta=None,
+                        create_symlink=True):
+        """Save the checkpoint.
+
+        Args:
+            out_dir (str): The directory that checkpoints are saved.
+            filename_tmpl (str, optional): The checkpoint filename template,
+                which contains a placeholder for the epoch number.
+                Defaults to 'epoch_{}.pth'.
+            save_optimizer (bool, optional): Whether to save the optimizer to
+                the checkpoint. Defaults to True.
+            meta (dict, optional): The meta information to be saved in the
+                checkpoint. Defaults to None.
+            create_symlink (bool, optional): Whether to create a symlink
+                "latest.pth" to point to the latest checkpoint.
+                Defaults to True.
+        """
+        if meta is None:
+            meta = dict(epoch=self.epoch + 1, iter=self.iter)
+        elif isinstance(meta, dict):
+            meta.update(epoch=self.epoch + 1, iter=self.iter)
+        else:
+            raise TypeError(
+                f'meta should be a dict or None, but got {type(meta)}')
+        if self.meta is not None:
+            meta.update(self.meta)
+
+        filename = filename_tmpl.format(self.epoch + 1)
+        filepath = osp.join(out_dir, filename)
+        optimizer = self.optimizer if save_optimizer else None
+        save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta, amp=self.amp)
+        # in some environments, `os.symlink` is not supported, you may need to
+        # set `create_symlink` to False
+        if create_symlink:
+            dst_file = osp.join(out_dir, 'latest.pth')
+            if platform.system() != 'Windows':
+                mmcv.symlink(filename, dst_file)
+            else:
+                shutil.copy(filepath, dst_file)
diff --git a/mmdet/apis/train.py b/mmdet/apis/train.py
index 82c20bff..44fd2cd9 100644
--- a/mmdet/apis/train.py
+++ b/mmdet/apis/train.py
@@ -13,6 +13,11 @@ from mmdet.core import DistEvalHook, EvalHook
 from mmdet.datasets import (build_dataloader, build_dataset,
                             replace_ImageToTensor)
 from mmdet.utils import get_root_logger
+from mmcv_custom.runner import EpochBasedRunnerAmp
+try:
+    import apex
+except:
+    print('apex is not installed')
 
 
 def set_random_seed(seed, deterministic=False):
@@ -70,6 +75,20 @@ def train_detector(model,
             seed=cfg.seed) for ds in dataset
     ]
 
+    # build optimizer
+    optimizer = build_optimizer(model, cfg.optimizer)
+
+    # use apex fp16 optimizer
+    use_amp = False
+    if cfg.optimizer_config.get("type", None) and cfg.optimizer_config["type"] == "DistOptimizerHook":
+        if cfg.optimizer_config.get("use_fp16", False):
+            model, optimizer = apex.amp.initialize(
+                model.cuda(), optimizer, opt_level="O1")
+            for m in model.modules():
+                if hasattr(m, "fp16_enabled"):
+                    m.fp16_enabled = True
+            use_amp = True
+
     # put model on gpus
     if distributed:
         find_unused_parameters = cfg.get('find_unused_parameters', False)
@@ -84,9 +103,6 @@ def train_detector(model,
         model = MMDataParallel(
             model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
 
-    # build runner
-    optimizer = build_optimizer(model, cfg.optimizer)
-
     if 'runner' not in cfg:
         cfg.runner = {
             'type': 'EpochBasedRunner',
@@ -99,14 +115,18 @@ def train_detector(model,
         if 'total_epochs' in cfg:
             assert cfg.total_epochs == cfg.runner.max_epochs
 
+    # build runner
+    runner_default_args=dict(
+        model=model,
+        optimizer=optimizer,
+        work_dir=cfg.work_dir,
+        logger=logger,
+        meta=meta)
+    if cfg.runner['type'] == 'EpochBasedRunnerAmp':
+        runner_default_args['amp'] = use_amp
     runner = build_runner(
         cfg.runner,
-        default_args=dict(
-            model=model,
-            optimizer=optimizer,
-            work_dir=cfg.work_dir,
-            logger=logger,
-            meta=meta))
+        default_args=runner_default_args)
 
     # an ugly workaround to make .log and .log.json filenames the same
     runner.timestamp = timestamp
@@ -164,7 +184,10 @@ def train_detector(model,
             runner.register_hook(hook, priority=priority)
 
     if cfg.resume_from:
-        runner.resume(cfg.resume_from)
+        if cfg.runner['type'] == 'EpochBasedRunnerAmp':
+            runner.resume(cfg.resume_from, resume_amp=use_amp)
+        else:
+            runner.resume(cfg.resume_from)
     elif cfg.load_from:
         runner.load_checkpoint(cfg.load_from)
     runner.run(data_loaders, cfg.workflow)
diff --git a/mmdet/utils/__init__.py b/mmdet/utils/__init__.py
index ac489e2d..e79ad8c0 100644
--- a/mmdet/utils/__init__.py
+++ b/mmdet/utils/__init__.py
@@ -1,4 +1,5 @@
 from .collect_env import collect_env
 from .logger import get_root_logger
+from .optimizer import DistOptimizerHook
 
-__all__ = ['get_root_logger', 'collect_env']
+__all__ = ['get_root_logger', 'collect_env', 'DistOptimizerHook']
diff --git a/mmdet/utils/optimizer.py b/mmdet/utils/optimizer.py
new file mode 100644
index 00000000..9c9d1194
--- /dev/null
+++ b/mmdet/utils/optimizer.py
@@ -0,0 +1,33 @@
+from mmcv.runner import OptimizerHook, HOOKS
+try:
+    import apex
+except:
+    print('apex is not installed')
+
+
+@HOOKS.register_module()
+class DistOptimizerHook(OptimizerHook):
+    """Optimizer hook for distributed training."""
+
+    def __init__(self, update_interval=1, grad_clip=None, coalesce=True, bucket_size_mb=-1, use_fp16=False):
+        self.grad_clip = grad_clip
+        self.coalesce = coalesce
+        self.bucket_size_mb = bucket_size_mb
+        self.update_interval = update_interval
+        self.use_fp16 = use_fp16
+
+    def before_run(self, runner):
+        runner.optimizer.zero_grad()
+
+    def after_train_iter(self, runner):
+        runner.outputs['loss'] /= self.update_interval
+        if self.use_fp16:
+            with apex.amp.scale_loss(runner.outputs['loss'], runner.optimizer) as scaled_loss:
+                scaled_loss.backward()
+        else:
+            runner.outputs['loss'].backward()
+        if self.every_n_iters(runner, self.update_interval):
+            if self.grad_clip is not None:
+                self.clip_grads(runner.model.parameters())
+            runner.optimizer.step()
+            runner.optimizer.zero_grad()
-- 
GitLab