Skip to content
Snippets Groups Projects
Unverified Commit d2a8ba76 authored by Hongkai Zhang's avatar Hongkai Zhang Committed by GitHub
Browse files

Code for paper "Dynamic R-CNN: Towards High Quality Object Detection via Dynamic Training" (#3040)


* Code for paper "Dynamic R-CNN: Towards High Quality Object Detection via Dynamic Training"

* update configs/dynamic_rcnn/dynamic_rcnn_r50_fpn_1x.py

* reformat code

* simplify code

* update model link

* simplify code

* simplify code logic

* simplify code and add comments

* minor updates of the docstring

Co-authored-by: default avatarJiarui XU <xvjiarui0826@gmail.com>
Co-authored-by: default avatarKai Chen <chenkaidev@gmail.com>
parent 47d663f0
No related branches found
No related tags found
No related merge requests found
......@@ -76,6 +76,7 @@ Results and models are available in the [model zoo](docs/model_zoo.md).
| PAFPN | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ |
| NAS-FCOS | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ |
| PISA | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ |
| Dynamic R-CNN | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ |
Other features
- [x] [CARAFE](configs/carafe/README.md)
......
# Dynamic R-CNN: Towards High Quality Object Detection via Dynamic Training
## Introduction
```
@article{DynamicRCNN,
author = {Hongkai Zhang and Hong Chang and Bingpeng Ma and Naiyan Wang and Xilin Chen},
title = {Dynamic {R-CNN}: Towards High Quality Object Detection via Dynamic Training},
journal = {arXiv preprint arXiv:2004.06002},
year = {2020}
}
```
## Results and Models
| Backbone | Style | Lr schd | Mem (GB) | Inf time (fps) | box AP | Download |
|:---------:|:-------:|:-------:|:--------:|:--------------:|:------:|:--------:|
| R-50 | pytorch | 1x | 3.8 | | 38.9 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/dynamic_rcnn/dynamic_rcnn_r50_fpn_1x/dynamic_rcnn_r50_fpn_1x-62a3f276.pth) &#124; [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/dynamic_rcnn/dynamic_rcnn_r50_fpn_1x/dynamic_rcnn_r50_fpn_1x_20200618_095048.log.json) |
_base_ = '../faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py'
model = dict(
roi_head=dict(
type='DynamicRoIHead',
bbox_head=dict(
type='Shared2FCBBoxHead',
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2]),
reg_class_agnostic=False,
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))))
train_cfg = dict(
rpn_proposal=dict(nms_thr=0.85),
rcnn=dict(
dynamic_rcnn=dict(
iou_topk=75,
beta_topk=10,
update_iter_interval=100,
initial_iou=0.4,
initial_beta=1.0)))
test_cfg = dict(rpn=dict(nms_thr=0.85))
......@@ -132,6 +132,9 @@ Please refer to [Res2Net](https://github.com/open-mmlab/mmdetection/blob/master/
### GRoIE
Please refer to [GRoIE](https://github.com/open-mmlab/mmdetection/blob/master/configs/groie) for details.
### Dynamic R-CNN
Please refer to [Dynamic R-CNN](https://github.com/open-mmlab/mmdetection/blob/master/configs/dynamic_rcnn) for details.
### Other datasets
We also benchmark some methods on [PASCAL VOC](https://github.com/open-mmlab/mmdetection/blob/master/configs/pascal_voc), [Cityscapes](https://github.com/open-mmlab/mmdetection/blob/master/configs/cityscapes) and [WIDER FACE](https://github.com/open-mmlab/mmdetection/blob/master/configs/wider_face).
......
......@@ -3,6 +3,7 @@ from .bbox_heads import (BBoxHead, ConvFCBBoxHead, DoubleConvFCBBoxHead,
Shared2FCBBoxHead, Shared4Conv1FCBBoxHead)
from .cascade_roi_head import CascadeRoIHead
from .double_roi_head import DoubleHeadRoIHead
from .dynamic_roi_head import DynamicRoIHead
from .grid_roi_head import GridRoIHead
from .htc_roi_head import HybridTaskCascadeRoIHead
from .mask_heads import (FCNMaskHead, FusedSemanticHead, GridHead, HTCMaskHead,
......@@ -17,5 +18,6 @@ __all__ = [
'HybridTaskCascadeRoIHead', 'GridRoIHead', 'ResLayer', 'BBoxHead',
'ConvFCBBoxHead', 'Shared2FCBBoxHead', 'Shared4Conv1FCBBoxHead',
'DoubleConvFCBBoxHead', 'FCNMaskHead', 'HTCMaskHead', 'FusedSemanticHead',
'GridHead', 'MaskIoUHead', 'SingleRoIExtractor', 'PISARoIHead'
'GridHead', 'MaskIoUHead', 'SingleRoIExtractor', 'PISARoIHead',
'DynamicRoIHead'
]
import numpy as np
import torch
from mmdet.core import bbox2roi
from mmdet.models.losses import SmoothL1Loss
from ..builder import HEADS
from .standard_roi_head import StandardRoIHead
@HEADS.register_module()
class DynamicRoIHead(StandardRoIHead):
"""RoI head for `Dynamic R-CNN <https://arxiv.org/abs/2004.06002>`_."""
def __init__(self, **kwargs):
super(DynamicRoIHead, self).__init__(**kwargs)
assert isinstance(self.bbox_head.loss_bbox, SmoothL1Loss)
self.iou_topk = self.train_cfg.dynamic_rcnn.iou_topk
self.beta_topk = self.train_cfg.dynamic_rcnn.beta_topk
self.update_iter_interval = \
self.train_cfg.dynamic_rcnn.update_iter_interval
# warm-up values for IoU and beta
self.initial_iou = self.train_cfg.dynamic_rcnn.initial_iou
self.initial_beta = self.train_cfg.dynamic_rcnn.initial_beta
# the IoU history of the past `update_iter_interval` iterations
self.iou_history = []
# the beta history of the past `update_iter_interval` iterations
self.beta_history = []
def forward_train(self,
x,
img_metas,
proposal_list,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None,
gt_masks=None):
"""
Args:
x (list[Tensor]): list of multi-level img features.
img_metas (list[dict]): list of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmdet/datasets/pipelines/formatting.py:Collect`.
proposals (list[Tensors]): list of region proposals.
gt_bboxes (list[Tensor]): each item are the truth boxes for each
image in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): class indices corresponding to each box
gt_bboxes_ignore (None | list[Tensor]): specify which bounding
boxes can be ignored when computing the loss.
gt_masks (None | Tensor) : true segmentation masks for each box
used if the architecture supports a segmentation task.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
# assign gts and sample proposals
if self.with_bbox or self.with_mask:
num_imgs = len(img_metas)
if gt_bboxes_ignore is None:
gt_bboxes_ignore = [None for _ in range(num_imgs)]
sampling_results = []
cur_iou = []
for i in range(num_imgs):
assign_result = self.bbox_assigner.assign(
proposal_list[i], gt_bboxes[i], gt_bboxes_ignore[i],
gt_labels[i])
sampling_result = self.bbox_sampler.sample(
assign_result,
proposal_list[i],
gt_bboxes[i],
gt_labels[i],
feats=[lvl_feat[i][None] for lvl_feat in x])
# record the `iou_topk`-th largest IoU in an image
iou_topk = min(self.iou_topk, len(assign_result.max_overlaps))
ious, _ = torch.topk(assign_result.max_overlaps, iou_topk)
cur_iou.append(ious[-1].item())
sampling_results.append(sampling_result)
# average the current IoUs over images
cur_iou = np.mean(cur_iou)
self.iou_history.append(cur_iou)
losses = dict()
# bbox head forward and loss
if self.with_bbox:
bbox_results = self._bbox_forward_train(x, sampling_results,
gt_bboxes, gt_labels,
img_metas)
losses.update(bbox_results['loss_bbox'])
# mask head forward and loss
if self.with_mask:
mask_results = self._mask_forward_train(x, sampling_results,
bbox_results['bbox_feats'],
gt_masks, img_metas)
# TODO: Support empty tensor input. #2280
if mask_results['loss_mask'] is not None:
losses.update(mask_results['loss_mask'])
# update IoU threshold and SmoothL1 beta
if len(self.iou_history) % self.update_iter_interval == 0:
new_iou_thr, new_beta = self.update_hyperparameters()
return losses
def _bbox_forward_train(self, x, sampling_results, gt_bboxes, gt_labels,
img_metas):
num_imgs = len(img_metas)
rois = bbox2roi([res.bboxes for res in sampling_results])
bbox_results = self._bbox_forward(x, rois)
bbox_targets = self.bbox_head.get_targets(sampling_results, gt_bboxes,
gt_labels, self.train_cfg)
# record the `beta_topk`-th smallest target
# `bbox_targets[2]` and `bbox_targets[3]` stand for bbox_targets
# and bbox_weights, respectively
pos_inds = bbox_targets[3][:, 0].nonzero().squeeze(1)
num_pos = len(pos_inds)
cur_target = bbox_targets[2][pos_inds, :2].abs().mean(dim=1)
beta_topk = min(self.beta_topk * num_imgs, num_pos)
cur_target = torch.kthvalue(cur_target, beta_topk)[0].item()
self.beta_history.append(cur_target)
loss_bbox = self.bbox_head.loss(bbox_results['cls_score'],
bbox_results['bbox_pred'], rois,
*bbox_targets)
bbox_results.update(loss_bbox=loss_bbox)
return bbox_results
def update_hyperparameters(self):
"""
Update hyperparameters like `iou_thr` and `SmoothL1 beta` based
on the training statistics.
Returns:
tuple[float]: the updated `iou_thr` and `SmoothL1 beta`
"""
new_iou_thr = max(self.initial_iou, np.mean(self.iou_history))
self.iou_history = []
self.bbox_assigner.pos_iou_thr = new_iou_thr
self.bbox_assigner.neg_iou_thr = new_iou_thr
self.bbox_assigner.min_pos_iou = new_iou_thr
new_beta = min(self.initial_beta, np.median(self.beta_history))
self.beta_history = []
self.bbox_head.loss_bbox.beta = new_beta
return new_iou_thr, new_beta
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment