diff --git a/README.md b/README.md index e9dd0d5d009dd48c555c175dff149ffb50b1f001..cc8363b886b3905e9c01a6dbb6f0f4512874ffa7 100644 --- a/README.md +++ b/README.md @@ -76,6 +76,7 @@ Results and models are available in the [model zoo](docs/model_zoo.md). | PAFPN | 鉁� | 鉁� | 鈽� | 鉁� | 鉁� | 鈽� | 鈽� | | NAS-FCOS | 鉁� | 鉁� | 鈽� | 鉁� | 鉁� | 鈽� | 鈽� | | PISA | 鉁� | 鉁� | 鈽� | 鉁� | 鉁� | 鈽� | 鈽� | +| Dynamic R-CNN | 鉁� | 鉁� | 鈽� | 鉁� | 鉁� | 鈽� | 鈽� | Other features - [x] [CARAFE](configs/carafe/README.md) diff --git a/configs/dynamic_rcnn/README.md b/configs/dynamic_rcnn/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f1b0bab957e780ebac59d5f069125b916cd9179c --- /dev/null +++ b/configs/dynamic_rcnn/README.md @@ -0,0 +1,18 @@ +# Dynamic R-CNN: Towards High Quality Object Detection via Dynamic Training + +## Introduction + +``` +@article{DynamicRCNN, + author = {Hongkai Zhang and Hong Chang and Bingpeng Ma and Naiyan Wang and Xilin Chen}, + title = {Dynamic {R-CNN}: Towards High Quality Object Detection via Dynamic Training}, + journal = {arXiv preprint arXiv:2004.06002}, + year = {2020} +} +``` + +## Results and Models + +| Backbone | Style | Lr schd | Mem (GB) | Inf time (fps) | box AP | Download | +|:---------:|:-------:|:-------:|:--------:|:--------------:|:------:|:--------:| +| R-50 | pytorch | 1x | 3.8 | | 38.9 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/dynamic_rcnn/dynamic_rcnn_r50_fpn_1x/dynamic_rcnn_r50_fpn_1x-62a3f276.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/dynamic_rcnn/dynamic_rcnn_r50_fpn_1x/dynamic_rcnn_r50_fpn_1x_20200618_095048.log.json) | diff --git a/configs/dynamic_rcnn/dynamic_rcnn_r50_fpn_1x.py b/configs/dynamic_rcnn/dynamic_rcnn_r50_fpn_1x.py new file mode 100644 index 0000000000000000000000000000000000000000..60f9c5043a6d8e7da0c6038aca868ad7e966c534 --- /dev/null +++ b/configs/dynamic_rcnn/dynamic_rcnn_r50_fpn_1x.py @@ -0,0 +1,28 @@ +_base_ = '../faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py' +model = dict( + roi_head=dict( + type='DynamicRoIHead', + bbox_head=dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=False, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)))) +train_cfg = dict( + rpn_proposal=dict(nms_thr=0.85), + rcnn=dict( + dynamic_rcnn=dict( + iou_topk=75, + beta_topk=10, + update_iter_interval=100, + initial_iou=0.4, + initial_beta=1.0))) +test_cfg = dict(rpn=dict(nms_thr=0.85)) diff --git a/docs/model_zoo.md b/docs/model_zoo.md index f216f054536576b118e60e7ea2d0b1dfaf7c87c6..266ccac5a75ca14e39987b04d9bd8b7607cd47df 100644 --- a/docs/model_zoo.md +++ b/docs/model_zoo.md @@ -132,6 +132,9 @@ Please refer to [Res2Net](https://github.com/open-mmlab/mmdetection/blob/master/ ### GRoIE Please refer to [GRoIE](https://github.com/open-mmlab/mmdetection/blob/master/configs/groie) for details. +### Dynamic R-CNN +Please refer to [Dynamic R-CNN](https://github.com/open-mmlab/mmdetection/blob/master/configs/dynamic_rcnn) for details. + ### Other datasets We also benchmark some methods on [PASCAL VOC](https://github.com/open-mmlab/mmdetection/blob/master/configs/pascal_voc), [Cityscapes](https://github.com/open-mmlab/mmdetection/blob/master/configs/cityscapes) and [WIDER FACE](https://github.com/open-mmlab/mmdetection/blob/master/configs/wider_face). diff --git a/mmdet/models/roi_heads/__init__.py b/mmdet/models/roi_heads/__init__.py index 0e05b5a57fbe1081248bb1df185ee6c021ac2f02..6d50b202684ed4a06fcdf51c287cf55c430a90ba 100644 --- a/mmdet/models/roi_heads/__init__.py +++ b/mmdet/models/roi_heads/__init__.py @@ -3,6 +3,7 @@ from .bbox_heads import (BBoxHead, ConvFCBBoxHead, DoubleConvFCBBoxHead, Shared2FCBBoxHead, Shared4Conv1FCBBoxHead) from .cascade_roi_head import CascadeRoIHead from .double_roi_head import DoubleHeadRoIHead +from .dynamic_roi_head import DynamicRoIHead from .grid_roi_head import GridRoIHead from .htc_roi_head import HybridTaskCascadeRoIHead from .mask_heads import (FCNMaskHead, FusedSemanticHead, GridHead, HTCMaskHead, @@ -17,5 +18,6 @@ __all__ = [ 'HybridTaskCascadeRoIHead', 'GridRoIHead', 'ResLayer', 'BBoxHead', 'ConvFCBBoxHead', 'Shared2FCBBoxHead', 'Shared4Conv1FCBBoxHead', 'DoubleConvFCBBoxHead', 'FCNMaskHead', 'HTCMaskHead', 'FusedSemanticHead', - 'GridHead', 'MaskIoUHead', 'SingleRoIExtractor', 'PISARoIHead' + 'GridHead', 'MaskIoUHead', 'SingleRoIExtractor', 'PISARoIHead', + 'DynamicRoIHead' ] diff --git a/mmdet/models/roi_heads/dynamic_roi_head.py b/mmdet/models/roi_heads/dynamic_roi_head.py new file mode 100644 index 0000000000000000000000000000000000000000..4bcbd9a296a3278d2270c08effdc3500f5298141 --- /dev/null +++ b/mmdet/models/roi_heads/dynamic_roi_head.py @@ -0,0 +1,152 @@ +import numpy as np +import torch + +from mmdet.core import bbox2roi +from mmdet.models.losses import SmoothL1Loss +from ..builder import HEADS +from .standard_roi_head import StandardRoIHead + + +@HEADS.register_module() +class DynamicRoIHead(StandardRoIHead): + """RoI head for `Dynamic R-CNN <https://arxiv.org/abs/2004.06002>`_.""" + + def __init__(self, **kwargs): + super(DynamicRoIHead, self).__init__(**kwargs) + assert isinstance(self.bbox_head.loss_bbox, SmoothL1Loss) + self.iou_topk = self.train_cfg.dynamic_rcnn.iou_topk + self.beta_topk = self.train_cfg.dynamic_rcnn.beta_topk + self.update_iter_interval = \ + self.train_cfg.dynamic_rcnn.update_iter_interval + # warm-up values for IoU and beta + self.initial_iou = self.train_cfg.dynamic_rcnn.initial_iou + self.initial_beta = self.train_cfg.dynamic_rcnn.initial_beta + # the IoU history of the past `update_iter_interval` iterations + self.iou_history = [] + # the beta history of the past `update_iter_interval` iterations + self.beta_history = [] + + def forward_train(self, + x, + img_metas, + proposal_list, + gt_bboxes, + gt_labels, + gt_bboxes_ignore=None, + gt_masks=None): + """ + Args: + x (list[Tensor]): list of multi-level img features. + + img_metas (list[dict]): list of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmdet/datasets/pipelines/formatting.py:Collect`. + + proposals (list[Tensors]): list of region proposals. + + gt_bboxes (list[Tensor]): each item are the truth boxes for each + image in [tl_x, tl_y, br_x, br_y] format. + + gt_labels (list[Tensor]): class indices corresponding to each box + + gt_bboxes_ignore (None | list[Tensor]): specify which bounding + boxes can be ignored when computing the loss. + + gt_masks (None | Tensor) : true segmentation masks for each box + used if the architecture supports a segmentation task. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + # assign gts and sample proposals + if self.with_bbox or self.with_mask: + num_imgs = len(img_metas) + if gt_bboxes_ignore is None: + gt_bboxes_ignore = [None for _ in range(num_imgs)] + sampling_results = [] + cur_iou = [] + for i in range(num_imgs): + assign_result = self.bbox_assigner.assign( + proposal_list[i], gt_bboxes[i], gt_bboxes_ignore[i], + gt_labels[i]) + sampling_result = self.bbox_sampler.sample( + assign_result, + proposal_list[i], + gt_bboxes[i], + gt_labels[i], + feats=[lvl_feat[i][None] for lvl_feat in x]) + # record the `iou_topk`-th largest IoU in an image + iou_topk = min(self.iou_topk, len(assign_result.max_overlaps)) + ious, _ = torch.topk(assign_result.max_overlaps, iou_topk) + cur_iou.append(ious[-1].item()) + sampling_results.append(sampling_result) + # average the current IoUs over images + cur_iou = np.mean(cur_iou) + self.iou_history.append(cur_iou) + + losses = dict() + # bbox head forward and loss + if self.with_bbox: + bbox_results = self._bbox_forward_train(x, sampling_results, + gt_bboxes, gt_labels, + img_metas) + losses.update(bbox_results['loss_bbox']) + + # mask head forward and loss + if self.with_mask: + mask_results = self._mask_forward_train(x, sampling_results, + bbox_results['bbox_feats'], + gt_masks, img_metas) + # TODO: Support empty tensor input. #2280 + if mask_results['loss_mask'] is not None: + losses.update(mask_results['loss_mask']) + + # update IoU threshold and SmoothL1 beta + if len(self.iou_history) % self.update_iter_interval == 0: + new_iou_thr, new_beta = self.update_hyperparameters() + + return losses + + def _bbox_forward_train(self, x, sampling_results, gt_bboxes, gt_labels, + img_metas): + num_imgs = len(img_metas) + rois = bbox2roi([res.bboxes for res in sampling_results]) + bbox_results = self._bbox_forward(x, rois) + + bbox_targets = self.bbox_head.get_targets(sampling_results, gt_bboxes, + gt_labels, self.train_cfg) + # record the `beta_topk`-th smallest target + # `bbox_targets[2]` and `bbox_targets[3]` stand for bbox_targets + # and bbox_weights, respectively + pos_inds = bbox_targets[3][:, 0].nonzero().squeeze(1) + num_pos = len(pos_inds) + cur_target = bbox_targets[2][pos_inds, :2].abs().mean(dim=1) + beta_topk = min(self.beta_topk * num_imgs, num_pos) + cur_target = torch.kthvalue(cur_target, beta_topk)[0].item() + self.beta_history.append(cur_target) + loss_bbox = self.bbox_head.loss(bbox_results['cls_score'], + bbox_results['bbox_pred'], rois, + *bbox_targets) + + bbox_results.update(loss_bbox=loss_bbox) + return bbox_results + + def update_hyperparameters(self): + """ + Update hyperparameters like `iou_thr` and `SmoothL1 beta` based + on the training statistics. + + Returns: + tuple[float]: the updated `iou_thr` and `SmoothL1 beta` + """ + new_iou_thr = max(self.initial_iou, np.mean(self.iou_history)) + self.iou_history = [] + self.bbox_assigner.pos_iou_thr = new_iou_thr + self.bbox_assigner.neg_iou_thr = new_iou_thr + self.bbox_assigner.min_pos_iou = new_iou_thr + new_beta = min(self.initial_beta, np.median(self.beta_history)) + self.beta_history = [] + self.bbox_head.loss_bbox.beta = new_beta + return new_iou_thr, new_beta