Skip to content
Snippets Groups Projects
Unverified Commit 84d0bd62 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Switch to EpochBasedRunner (#2976)

* switch to EpochBasedRunner

* add docstrings

* update the minimum version of mmcv to 0.6.0

* fix unit tests

* fix mmcv version in travis
parent 58772189
No related branches found
No related tags found
No related merge requests found
......@@ -35,7 +35,7 @@ before_install:
install:
- pip install Pillow==6.2.2 # remove this line when torchvision>=0.5
- pip install torch==${TORCH} torchvision==${TORCHVISION}
- pip install mmcv-nightly
- pip install mmcv
- pip install "git+https://github.com/open-mmlab/cocoapi.git#subdirectory=pycocotools"
- pip install -r requirements.txt
......
import random
from collections import OrderedDict
import numpy as np
import torch
import torch.distributed as dist
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (DistSamplerSeedHook, OptimizerHook, Runner,
from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner, OptimizerHook,
build_optimizer)
from mmdet.core import DistEvalHook, EvalHook, Fp16OptimizerHook
......@@ -32,54 +30,6 @@ def set_random_seed(seed, deterministic=False):
torch.backends.cudnn.benchmark = False
def parse_losses(losses):
log_vars = OrderedDict()
for loss_name, loss_value in losses.items():
if isinstance(loss_value, torch.Tensor):
log_vars[loss_name] = loss_value.mean()
elif isinstance(loss_value, list):
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
else:
raise TypeError(f'{loss_name} is not a tensor or list of tensors')
loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key)
log_vars['loss'] = loss
for loss_name, loss_value in log_vars.items():
# reduce loss when distributed training
if dist.is_available() and dist.is_initialized():
loss_value = loss_value.data.clone()
dist.all_reduce(loss_value.div_(dist.get_world_size()))
log_vars[loss_name] = loss_value.item()
return loss, log_vars
def batch_processor(model, data, train_mode):
"""Process a data batch.
This method is required as an argument of Runner, which defines how to
process a data batch and obtain proper outputs. The first 3 arguments of
batch_processor are fixed.
Args:
model (nn.Module): A PyTorch model.
data (dict): The data batch in a dict.
train_mode (bool): Training mode or not. It may be useless for some
models.
Returns:
dict: A dict containing losses and log vars.
"""
losses = model(**data)
loss, log_vars = parse_losses(losses)
outputs = dict(
loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))
return outputs
def train_detector(model,
dataset,
cfg,
......@@ -132,11 +82,10 @@ def train_detector(model,
# build runner
optimizer = build_optimizer(model, cfg.optimizer)
runner = Runner(
runner = EpochBasedRunner(
model,
batch_processor,
optimizer,
cfg.work_dir,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
meta=meta)
# an ugly workaround to make .log and .log.json filenames the same
......
import warnings
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
import mmcv
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from mmcv.utils import print_log
......@@ -149,6 +152,90 @@ class BaseDetector(nn.Module, metaclass=ABCMeta):
else:
return self.forward_test(img, img_metas, **kwargs)
def _parse_losses(self, losses):
"""Parse the raw outputs (losses) of the network.
Args:
losses (dict): Raw output of the network, which usually contain
losses and other necessary infomation.
Returns:
tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
which may be a weighted sum of all losses, log_vars contains
all the variables to be sent to the logger.
"""
log_vars = OrderedDict()
for loss_name, loss_value in losses.items():
if isinstance(loss_value, torch.Tensor):
log_vars[loss_name] = loss_value.mean()
elif isinstance(loss_value, list):
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
else:
raise TypeError(
f'{loss_name} is not a tensor or list of tensors')
loss = sum(_value for _key, _value in log_vars.items()
if 'loss' in _key)
log_vars['loss'] = loss
for loss_name, loss_value in log_vars.items():
# reduce loss when distributed training
if dist.is_available() and dist.is_initialized():
loss_value = loss_value.data.clone()
dist.all_reduce(loss_value.div_(dist.get_world_size()))
log_vars[loss_name] = loss_value.item()
return loss, log_vars
def train_step(self, data, optimizer):
"""The iteration step during training.
This method defines an iteration step during training, except for the
back propagation and optimizer updating, which are done in an optimizer
hook. Note that in some complicated cases or models, the whole process
including back propagation and optimizer updating is also defined in
this method, such as GAN.
Args:
data (dict): The output of dataloader.
optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
runner is passed to ``train_step()``. This argument is unused
and reserved.
Returns:
dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
``num_samples``.
``loss`` is a tensor for back propagation, which can be a
weighted sum of multiple losses.
``log_vars`` contains all the variables to be sent to the
logger.
``num_samples`` indicates the batch size (when the model is
DDP, it means the batch size on each GPU), which is used for
averaging the logs.
"""
losses = self(**data)
loss, log_vars = self._parse_losses(losses)
outputs = dict(
loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))
return outputs
def val_step(self, data, optimizer):
"""The iteration step during validation.
This method shares the same signature as :func:`train_step`, but used
during val epochs. Note that the evaluation after training epochs is
not implemented with this method, but an evaluation hook.
"""
losses = self(**data)
loss, log_vars = self._parse_losses(losses)
outputs = dict(
loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))
return outputs
def show_result(self,
img,
result,
......
matplotlib
mmcv>=0.5.9
mmcv>=0.6.0
numpy
# need older pillow until torchvision is fixed
Pillow<=6.2.2
......
......@@ -153,9 +153,8 @@ def test_faster_rcnn_ohem_forward():
gt_labels=gt_labels,
return_loss=True)
assert isinstance(losses, dict)
from mmdet.apis.train import parse_losses
total_loss = float(parse_losses(losses)[0].item())
assert total_loss > 0
loss, _ = detector._parse_losses(losses)
assert float(loss.item()) > 0
# Test forward train with an empty truth batch
mm_inputs = _demo_mm_inputs(input_shape, num_items=[0])
......@@ -170,9 +169,8 @@ def test_faster_rcnn_ohem_forward():
gt_labels=gt_labels,
return_loss=True)
assert isinstance(losses, dict)
from mmdet.apis.train import parse_losses
total_loss = float(parse_losses(losses)[0].item())
assert total_loss > 0
loss, _ = detector._parse_losses(losses)
assert float(loss.item()) > 0
# HTC is not ready yet
......@@ -206,10 +204,10 @@ def test_two_stage_forward(cfg_file):
gt_masks=gt_masks,
return_loss=True)
assert isinstance(losses, dict)
from mmdet.apis.train import parse_losses
total_loss = parse_losses(losses)[0].requires_grad_(True)
assert float(total_loss.item()) > 0
total_loss.backward()
loss, _ = detector._parse_losses(losses)
loss.requires_grad_(True)
assert float(loss.item()) > 0
loss.backward()
# Test forward train with an empty truth batch
mm_inputs = _demo_mm_inputs(input_shape, num_items=[0])
......@@ -226,10 +224,10 @@ def test_two_stage_forward(cfg_file):
gt_masks=gt_masks,
return_loss=True)
assert isinstance(losses, dict)
from mmdet.apis.train import parse_losses
total_loss = parse_losses(losses)[0].requires_grad_(True)
assert float(total_loss.item()) > 0
total_loss.backward()
loss, _ = detector._parse_losses(losses)
loss.requires_grad_(True)
assert float(loss.item()) > 0
loss.backward()
# Test forward test
with torch.no_grad():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment