diff --git a/README.md b/README.md index bd8ac8142975159f3a3a9ee6abc22b918378f06d..537b040d51f29b3076158aa28dce98bcf7c5bac0 100644 --- a/README.md +++ b/README.md @@ -92,6 +92,8 @@ Supported methods: - [x] [Mixed Precision (FP16) Training](configs/fp16/README.md) - [x] [InstaBoost](configs/instaboost/README.md) - [x] [GRoIE](configs/groie/README.md) +- [x] [DetectoRS](configs/detectors/README.md) +- [x] [Generalized Focal Loss](configs/gfl/README.md) Some other methods are also supported in [projects using MMDetection](./docs/projects.md). diff --git a/configs/detectors/README.md b/configs/detectors/README.md new file mode 100644 index 0000000000000000000000000000000000000000..bca758eb6ccdfb8c23a6f844de818ace82fd75c2 --- /dev/null +++ b/configs/detectors/README.md @@ -0,0 +1,37 @@ +# DetectoRS + +## Introduction + +We provide the config files for [DetectoRS: Detecting Objects with Recursive Feature Pyramid and Switchable Atrous Convolution](https://arxiv.org/pdf/2006.02334.pdf). + +```BibTeX +@article{qiao2020detectors, + title={DetectoRS: Detecting Objects with Recursive Feature Pyramid and Switchable Atrous Convolution}, + author={Qiao, Siyuan and Chen, Liang-Chieh and Yuille, Alan}, + journal={arXiv preprint arXiv:2006.02334}, + year={2020} +} +``` + +## Results and Models + +DetectoRS includes two major components: + +- Recursive Feature Pyramid (RFP). +- Switchable Atrous Convolution (SAC). + +They can be used independently. +Combining them together results in DetectoRS. +The results on COCO 2017 val are shown in the below table. + +| Method | Detector | Lr schd | Mem (GB) | Inf time (fps) | box AP | mask AP | Download | +|:------:|:--------:|:-------:|:--------:|:--------------:|:------:|:-------:|:--------:| +| RFP | Cascade + ResNet-50 | 1x | 7.5 | - | 44.8 | | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/detectors/cascade_rcnn_r50_rfp_1x_coco/cascade_rcnn_r50_rfp_1x_coco-8cf51bfd.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/detectors/cascade_rcnn_r50_rfp_1x_coco/cascade_rcnn_r50_rfp_1x_coco_20200624_104126.log.json) | +| SAC | Cascade + ResNet-50 | 1x | 5.6 | - | 45.0| | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/detectors/cascade_rcnn_r50_sac_1x_coco/cascade_rcnn_r50_sac_1x_coco-24bfda62.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/detectors/cascade_rcnn_r50_sac_1x_coco/cascade_rcnn_r50_sac_1x_coco_20200624_104402.log.json) | +| DetectoRS | Cascade + ResNet-50 | 1x | 9.9 | - | 46.9 | | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/detectors/detectors_cascade_rcnn_r50_1x_coco/detectors_cascade_rcnn_r50_1x_coco-0db1ab6a.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/detectors/detectors_cascade_rcnn_r50_1x_coco/detectors_cascade_rcnn_r50_1x_coco_20200624_103448.log.json) | +| RFP | HTC + ResNet-50 | 1x | 11.2 | - | 46.6 | 40.9 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/detectors/htc_r50_rfp_1x_coco/htc_r50_rfp_1x_coco-8ff87c51.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/detectors/htc_r50_rfp_1x_coco/htc_r50_rfp_1x_coco_20200624_103053.log.json) | +| SAC | HTC + ResNet-50 | 1x | 9.3 | - | 46.4 | 40.9 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/detectors/htc_r50_sac_1x_coco/htc_r50_sac_1x_coco-bfa60c54.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/detectors/htc_r50_sac_1x_coco/htc_r50_sac_1x_coco_20200624_103111.log.json) | +| DetectoRS | HTC + ResNet-50 | 1x | 13.6 | - | 49.1 | 42.6 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/detectors/detectors_htc_r50_1x_coco/detectors_htc_r50_1x_coco-329b1453.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/detectors/detectors_htc_r50_1x_coco/detectors_htc_r50_1x_coco_20200624_103659.log.json) | + +*Note*: This is a re-implementation based on MMDetection-V2. +The original implementation is based on MMDetection-V1. diff --git a/configs/detectors/cascade_rcnn_r50_rfp_1x_coco.py b/configs/detectors/cascade_rcnn_r50_rfp_1x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..4430d8a677e48f84552eb23403bc874c56bda506 --- /dev/null +++ b/configs/detectors/cascade_rcnn_r50_rfp_1x_coco.py @@ -0,0 +1,28 @@ +_base_ = [ + '../_base_/models/cascade_rcnn_r50_fpn.py', + '../_base_/datasets/coco_detection.py', + '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' +] + +model = dict( + backbone=dict( + type='DetectoRS_ResNet', + conv_cfg=dict(type='ConvAWS'), + output_img=True), + neck=dict( + type='RFP', + rfp_steps=2, + aspp_out_channels=64, + aspp_dilations=(1, 3, 6, 1), + rfp_backbone=dict( + rfp_inplanes=256, + type='DetectoRS_ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + conv_cfg=dict(type='ConvAWS'), + pretrained='torchvision://resnet50', + style='pytorch'))) diff --git a/configs/detectors/cascade_rcnn_r50_sac_1x_coco.py b/configs/detectors/cascade_rcnn_r50_sac_1x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..ccd9319b2d1badebf3b891c8e3bdd55a435a4b7c --- /dev/null +++ b/configs/detectors/cascade_rcnn_r50_sac_1x_coco.py @@ -0,0 +1,12 @@ +_base_ = [ + '../_base_/models/cascade_rcnn_r50_fpn.py', + '../_base_/datasets/coco_detection.py', + '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' +] + +model = dict( + backbone=dict( + type='DetectoRS_ResNet', + conv_cfg=dict(type='ConvAWS'), + sac=dict(type='SAC', use_deform=True), + stage_with_sac=(False, True, True, True))) diff --git a/configs/detectors/detectors_cascade_rcnn_r50_1x_coco.py b/configs/detectors/detectors_cascade_rcnn_r50_1x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..f76040434f1ff07608c83202f779dfacfe91c323 --- /dev/null +++ b/configs/detectors/detectors_cascade_rcnn_r50_1x_coco.py @@ -0,0 +1,32 @@ +_base_ = [ + '../_base_/models/cascade_rcnn_r50_fpn.py', + '../_base_/datasets/coco_detection.py', + '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' +] + +model = dict( + backbone=dict( + type='DetectoRS_ResNet', + conv_cfg=dict(type='ConvAWS'), + sac=dict(type='SAC', use_deform=True), + stage_with_sac=(False, True, True, True), + output_img=True), + neck=dict( + type='RFP', + rfp_steps=2, + aspp_out_channels=64, + aspp_dilations=(1, 3, 6, 1), + rfp_backbone=dict( + rfp_inplanes=256, + type='DetectoRS_ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + conv_cfg=dict(type='ConvAWS'), + sac=dict(type='SAC', use_deform=True), + stage_with_sac=(False, True, True, True), + pretrained='torchvision://resnet50', + style='pytorch'))) diff --git a/configs/detectors/detectors_htc_r50_1x_coco.py b/configs/detectors/detectors_htc_r50_1x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..0d2fc4f77fcca715c1dfb613306d214b636aa0c0 --- /dev/null +++ b/configs/detectors/detectors_htc_r50_1x_coco.py @@ -0,0 +1,28 @@ +_base_ = '../htc/htc_r50_fpn_1x_coco.py' + +model = dict( + backbone=dict( + type='DetectoRS_ResNet', + conv_cfg=dict(type='ConvAWS'), + sac=dict(type='SAC', use_deform=True), + stage_with_sac=(False, True, True, True), + output_img=True), + neck=dict( + type='RFP', + rfp_steps=2, + aspp_out_channels=64, + aspp_dilations=(1, 3, 6, 1), + rfp_backbone=dict( + rfp_inplanes=256, + type='DetectoRS_ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + conv_cfg=dict(type='ConvAWS'), + sac=dict(type='SAC', use_deform=True), + stage_with_sac=(False, True, True, True), + pretrained='torchvision://resnet50', + style='pytorch'))) diff --git a/configs/detectors/htc_r50_rfp_1x_coco.py b/configs/detectors/htc_r50_rfp_1x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..496104e12550a1985f9c9e3748a343f69d7df6d8 --- /dev/null +++ b/configs/detectors/htc_r50_rfp_1x_coco.py @@ -0,0 +1,24 @@ +_base_ = '../htc/htc_r50_fpn_1x_coco.py' + +model = dict( + backbone=dict( + type='DetectoRS_ResNet', + conv_cfg=dict(type='ConvAWS'), + output_img=True), + neck=dict( + type='RFP', + rfp_steps=2, + aspp_out_channels=64, + aspp_dilations=(1, 3, 6, 1), + rfp_backbone=dict( + rfp_inplanes=256, + type='DetectoRS_ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + conv_cfg=dict(type='ConvAWS'), + pretrained='torchvision://resnet50', + style='pytorch'))) diff --git a/configs/detectors/htc_r50_sac_1x_coco.py b/configs/detectors/htc_r50_sac_1x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..72d4db963ffd95851b945911b3db9941426583ab --- /dev/null +++ b/configs/detectors/htc_r50_sac_1x_coco.py @@ -0,0 +1,8 @@ +_base_ = '../htc/htc_r50_fpn_1x_coco.py' + +model = dict( + backbone=dict( + type='DetectoRS_ResNet', + conv_cfg=dict(type='ConvAWS'), + sac=dict(type='SAC', use_deform=True), + stage_with_sac=(False, True, True, True))) diff --git a/docs/model_zoo.md b/docs/model_zoo.md index 960960d15819645691db3625614d4de8a6feb31c..266def7f355beddca832369b60f8bf77accc686d 100644 --- a/docs/model_zoo.md +++ b/docs/model_zoo.md @@ -138,6 +138,12 @@ Please refer to [Dynamic R-CNN](https://github.com/open-mmlab/mmdetection/blob/m ### PointRend Please refer to [PointRend](https://github.com/open-mmlab/mmdetection/blob/master/configs/point_rend) for details. +### DetectoRS +Please refer to [DetectoRS](https://github.com/open-mmlab/mmdetection/blob/master/configs/detectors) for details. + +### Generalized Focal Loss +Please refer to [Generalized Focal Loss](https://github.com/open-mmlab/mmdetection/blob/master/configs/gfl) for details. + ### Other datasets We also benchmark some methods on [PASCAL VOC](https://github.com/open-mmlab/mmdetection/blob/master/configs/pascal_voc), [Cityscapes](https://github.com/open-mmlab/mmdetection/blob/master/configs/cityscapes) and [WIDER FACE](https://github.com/open-mmlab/mmdetection/blob/master/configs/wider_face). diff --git a/mmdet/models/backbones/__init__.py b/mmdet/models/backbones/__init__.py index 92c199c21a6da5642bd97d9b1422159d280558c8..6d4cc1db79e2c3d05afa3b6fd17e7b9a84450697 100644 --- a/mmdet/models/backbones/__init__.py +++ b/mmdet/models/backbones/__init__.py @@ -1,3 +1,5 @@ +from .detectors_resnet import DetectoRS_ResNet +from .detectors_resnext import DetectoRS_ResNeXt from .hourglass import HourglassNet from .hrnet import HRNet from .regnet import RegNet @@ -8,5 +10,5 @@ from .ssd_vgg import SSDVGG __all__ = [ 'RegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'Res2Net', - 'HourglassNet' + 'HourglassNet', 'DetectoRS_ResNet', 'DetectoRS_ResNeXt' ] diff --git a/mmdet/models/backbones/detectors_resnet.py b/mmdet/models/backbones/detectors_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..d252579e726a5118fa7405422718f38f1ce3e2fe --- /dev/null +++ b/mmdet/models/backbones/detectors_resnet.py @@ -0,0 +1,305 @@ +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import build_conv_layer, build_norm_layer, constant_init + +from ..builder import BACKBONES +from .resnet import Bottleneck as _Bottleneck +from .resnet import ResNet + + +class Bottleneck(_Bottleneck): + """Bottleneck for the ResNet backbone in `DetectoRS + <https://arxiv.org/pdf/2006.02334.pdf>`_. + + This bottleneck allows the users to specify whether to use + SAC (Switchable Atrous Convolution) and RFP (Recursive Feature Pyramid). + + Args: + inplanes (int): The number of input channels. + planes (int): The number of output channels before expansion. + rfp_inplanes (int, optional): The number of channels from RFP. + Default: None. If specified, an additional conv layer will be + added for ``rfp_feat``. Otherwise, the structure is the same as + base class. + sac (dict, optional): Dictionary to construct SAC. Default: None. + """ + expansion = 4 + + def __init__(self, + inplanes, + planes, + rfp_inplanes=None, + sac=None, + **kwargs): + super(Bottleneck, self).__init__(inplanes, planes, **kwargs) + + assert sac is None or isinstance(sac, dict) + self.sac = sac + self.with_sac = sac is not None + if self.with_sac: + self.conv2 = build_conv_layer( + self.sac, + planes, + planes, + kernel_size=3, + stride=self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + bias=False) + + self.rfp_inplanes = rfp_inplanes + if self.rfp_inplanes: + self.rfp_conv = build_conv_layer( + None, + self.rfp_inplanes, + planes * self.expansion, + 1, + stride=1, + bias=True) + self.init_weights() + + def init_weights(self): + """Initialize the weights.""" + if self.rfp_inplanes: + constant_init(self.rfp_conv, 0) + + def rfp_forward(self, x, rfp_feat): + """The forward function that also takes the RFP features as input.""" + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv1_plugin_names) + + out = self.conv2(out) + out = self.norm2(out) + out = self.relu(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv2_plugin_names) + + out = self.conv3(out) + out = self.norm3(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv3_plugin_names) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + if self.rfp_inplanes: + rfp_feat = self.rfp_conv(rfp_feat) + out = out + rfp_feat + + out = self.relu(out) + + return out + + +class ResLayer(nn.Sequential): + """ResLayer to build ResNet style backbone for RPF in detectoRS. + + The difference between this module and base class is that we pass + ``rfp_inplanes`` to the first block. + + Args: + block (nn.Module): block used to build ResLayer. + inplanes (int): inplanes of block. + planes (int): planes of block. + num_blocks (int): number of blocks. + stride (int): stride of the first block. Default: 1 + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False + conv_cfg (dict): dictionary to construct and config conv layer. + Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + downsample_first (bool): Downsample at the first block or last block. + False for Hourglass, True for ResNet. Default: True + rfp_inplanes (int, optional): The number of channels from RFP. + Default: None. If specified, an additional conv layer will be + added for ``rfp_feat``. Otherwise, the structure is the same as + base class. + """ + + def __init__(self, + block, + inplanes, + planes, + num_blocks, + stride=1, + avg_down=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + downsample_first=True, + rfp_inplanes=None, + **kwargs): + self.block = block + assert downsample_first, f'downsampel_first={downsample_first} is ' \ + 'not supported in DetectoRS' + + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = [] + conv_stride = stride + if avg_down and stride != 1: + conv_stride = 1 + downsample.append( + nn.AvgPool2d( + kernel_size=stride, + stride=stride, + ceil_mode=True, + count_include_pad=False)) + downsample.extend([ + build_conv_layer( + conv_cfg, + inplanes, + planes * block.expansion, + kernel_size=1, + stride=conv_stride, + bias=False), + build_norm_layer(norm_cfg, planes * block.expansion)[1] + ]) + downsample = nn.Sequential(*downsample) + + layers = [] + layers.append( + block( + inplanes=inplanes, + planes=planes, + stride=stride, + downsample=downsample, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + rfp_inplanes=rfp_inplanes, + **kwargs)) + inplanes = planes * block.expansion + for _ in range(1, num_blocks): + layers.append( + block( + inplanes=inplanes, + planes=planes, + stride=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + **kwargs)) + + super(ResLayer, self).__init__(*layers) + + +@BACKBONES.register_module() +class DetectoRS_ResNet(ResNet): + """ResNet backbone for DetectoRS. + + Args: + sac (dict, optional): Dictionary to construct SAC (Switchable Atrous + Convolution). Default: None. + stage_with_sac (list): Which stage to use sac. Default: (False, False, + False, False). + rfp_inplanes (int, optional): The number of channels from RFP. + Default: None. If specified, an additional conv layer will be + added for ``rfp_feat``. Otherwise, the structure is the same as + base class. + output_img (bool): If ``True``, the input image will be inserted into + the starting position of output. Default: False. + pretrained (str, optional): The pretrained model to load. + """ + + arch_settings = { + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)) + } + + def __init__(self, + sac=None, + stage_with_sac=(False, False, False, False), + rfp_inplanes=None, + output_img=False, + pretrained=None, + **kwargs): + self.sac = sac + self.stage_with_sac = stage_with_sac + self.rfp_inplanes = rfp_inplanes + self.output_img = output_img + self.pretrained = pretrained + super(DetectoRS_ResNet, self).__init__(**kwargs) + + self.inplanes = self.stem_channels + self.res_layers = [] + for i, num_blocks in enumerate(self.stage_blocks): + stride = self.strides[i] + dilation = self.dilations[i] + dcn = self.dcn if self.stage_with_dcn[i] else None + sac = self.sac if self.stage_with_sac[i] else None + if self.plugins is not None: + stage_plugins = self.make_stage_plugins(self.plugins, i) + else: + stage_plugins = None + planes = self.base_channels * 2**i + res_layer = self.make_res_layer( + block=self.block, + inplanes=self.inplanes, + planes=planes, + num_blocks=num_blocks, + stride=stride, + dilation=dilation, + style=self.style, + avg_down=self.avg_down, + with_cp=self.with_cp, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + dcn=dcn, + sac=sac, + rfp_inplanes=rfp_inplanes if i > 0 else None, + plugins=stage_plugins) + self.inplanes = planes * self.block.expansion + layer_name = f'layer{i + 1}' + self.add_module(layer_name, res_layer) + self.res_layers.append(layer_name) + + self._freeze_stages() + + def make_res_layer(self, **kwargs): + """Pack all blocks in a stage into a ``ResLayer`` for DetectoRS""" + return ResLayer(**kwargs) + + def forward(self, x): + """Forward function""" + outs = list(super(DetectoRS_ResNet, self).forward(x)) + if self.output_img: + outs.insert(0, x) + return tuple(outs) + + def rfp_forward(self, x, rfp_feats): + """Forward function for RFP""" + if self.deep_stem: + x = self.stem(x) + else: + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + x = self.maxpool(x) + outs = [] + for i, layer_name in enumerate(self.res_layers): + res_layer = getattr(self, layer_name) + rfp_feat = rfp_feats[i] if i > 0 else None + for layer in res_layer: + x = layer.rfp_forward(x, rfp_feat) + if i in self.out_indices: + outs.append(x) + return tuple(outs) diff --git a/mmdet/models/backbones/detectors_resnext.py b/mmdet/models/backbones/detectors_resnext.py new file mode 100644 index 0000000000000000000000000000000000000000..b8a1a34a60d510e837a08a9834c4b495586e572b --- /dev/null +++ b/mmdet/models/backbones/detectors_resnext.py @@ -0,0 +1,121 @@ +import math + +from mmcv.cnn import build_conv_layer, build_norm_layer + +from ..builder import BACKBONES +from .detectors_resnet import Bottleneck as _Bottleneck +from .detectors_resnet import DetectoRS_ResNet + + +class Bottleneck(_Bottleneck): + expansion = 4 + + def __init__(self, + inplanes, + planes, + groups=1, + base_width=4, + base_channels=64, + **kwargs): + """Bottleneck block for ResNeXt. + If style is "pytorch", the stride-two layer is the 3x3 conv layer, + if it is "caffe", the stride-two layer is the first 1x1 conv layer. + """ + super(Bottleneck, self).__init__(inplanes, planes, **kwargs) + + if groups == 1: + width = self.planes + else: + width = math.floor(self.planes * + (base_width / base_channels)) * groups + + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, width, postfix=1) + self.norm2_name, norm2 = build_norm_layer( + self.norm_cfg, width, postfix=2) + self.norm3_name, norm3 = build_norm_layer( + self.norm_cfg, self.planes * self.expansion, postfix=3) + + self.conv1 = build_conv_layer( + self.conv_cfg, + self.inplanes, + width, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + fallback_on_stride = False + self.with_modulated_dcn = False + if self.with_dcn: + fallback_on_stride = self.dcn.pop('fallback_on_stride', False) + if self.with_sac: + self.conv2 = build_conv_layer( + self.sac, + width, + width, + kernel_size=3, + stride=self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + bias=False) + elif not self.with_dcn or fallback_on_stride: + self.conv2 = build_conv_layer( + self.conv_cfg, + width, + width, + kernel_size=3, + stride=self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + bias=False) + else: + assert self.conv_cfg is None, 'conv_cfg must be None for DCN' + self.conv2 = build_conv_layer( + self.dcn, + width, + width, + kernel_size=3, + stride=self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.conv3 = build_conv_layer( + self.conv_cfg, + width, + self.planes * self.expansion, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + +@BACKBONES.register_module() +class DetectoRS_ResNeXt(DetectoRS_ResNet): + """ResNeXt backbone for DetectoRS. + + Args: + groups (int): The number of groups in ResNeXt. + base_width (int): The base width of ResNeXt. + """ + + arch_settings = { + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)) + } + + def __init__(self, groups=1, base_width=4, **kwargs): + self.groups = groups + self.base_width = base_width + super(DetectoRS_ResNeXt, self).__init__(**kwargs) + + def make_res_layer(self, **kwargs): + return super().make_res_layer( + groups=self.groups, + base_width=self.base_width, + base_channels=self.base_channels, + **kwargs) diff --git a/mmdet/models/necks/__init__.py b/mmdet/models/necks/__init__.py index 1d75e018a5761c5474c7ffc9da1b5a5b6a328f88..2fdea02f4fabbcdac210876cfa7a17ec6ebc4029 100644 --- a/mmdet/models/necks/__init__.py +++ b/mmdet/models/necks/__init__.py @@ -5,7 +5,9 @@ from .hrfpn import HRFPN from .nas_fpn import NASFPN from .nasfcos_fpn import NASFCOS_FPN from .pafpn import PAFPN +from .rfp import RFP __all__ = [ - 'FPN', 'BFP', 'HRFPN', 'NASFPN', 'FPN_CARAFE', 'PAFPN', 'NASFCOS_FPN' + 'FPN', 'BFP', 'HRFPN', 'NASFPN', 'FPN_CARAFE', 'PAFPN', 'NASFCOS_FPN', + 'RFP' ] diff --git a/mmdet/models/necks/rfp.py b/mmdet/models/necks/rfp.py new file mode 100644 index 0000000000000000000000000000000000000000..00d9ab653eab80deb2125f52a778711557f93418 --- /dev/null +++ b/mmdet/models/necks/rfp.py @@ -0,0 +1,122 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import constant_init, kaiming_init + +from ..builder import NECKS, build_backbone +from .fpn import FPN + + +class ASPP(nn.Module): + """ASPP (Atrous Spatial Pyramid Pooling) + + This is an implementation of the ASPP module used in DetectoRS + (https://arxiv.org/pdf/2006.02334.pdf) + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of channels produced by this module + dilations (tuple[int]): Dilations of the four branches. + Default: (1, 3, 6, 1) + """ + + def __init__(self, in_channels, out_channels, dilations=(1, 3, 6, 1)): + super().__init__() + assert dilations[-1] == 1 + self.aspp = nn.ModuleList() + for dilation in dilations: + kernel_size = 3 if dilation > 1 else 1 + padding = dilation if dilation > 1 else 0 + conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=1, + dilation=dilation, + padding=padding, + bias=True) + self.aspp.append(conv) + self.gap = nn.AdaptiveAvgPool2d(1) + self.init_weights() + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + + def forward(self, x): + avg_x = self.gap(x) + out = [] + for aspp_idx in range(len(self.aspp)): + inp = avg_x if (aspp_idx == len(self.aspp) - 1) else x + out.append(F.relu_(self.aspp[aspp_idx](inp))) + out[-1] = out[-1].expand_as(out[-2]) + out = torch.cat(out, dim=1) + return out + + +@NECKS.register_module() +class RFP(FPN): + """RFP (Recursive Feature Pyramid) + + This is an implementation of RFP in `DetectoRS + <https://arxiv.org/pdf/2006.02334.pdf>`_. Different from standard FPN, the + input of RFP should be multi level features along with origin input image + of backbone. + + Args: + rfp_steps (int): Number of unrolled steps of RFP. + rfp_backbone (dict): Configuration of the backbone for RFP. + aspp_out_channels (int): Number of output channels of ASPP module. + aspp_dilations (tuple[int]): Dilation rates of four branches. + Default: (1, 3, 6, 1) + """ + + def __init__(self, + rfp_steps, + rfp_backbone, + aspp_out_channels, + aspp_dilations=(1, 3, 6, 1), + **kwargs): + super().__init__(**kwargs) + self.rfp_steps = rfp_steps + self.rfp_modules = nn.ModuleList() + for rfp_idx in range(1, rfp_steps): + rfp_module = build_backbone(rfp_backbone) + self.rfp_modules.append(rfp_module) + self.rfp_aspp = ASPP(self.out_channels, aspp_out_channels, + aspp_dilations) + self.rfp_weight = nn.Conv2d( + self.out_channels, + 1, + kernel_size=1, + stride=1, + padding=0, + bias=True) + + def init_weights(self): + super().init_weights() + for rfp_idx in range(self.rfp_steps - 1): + self.rfp_modules[rfp_idx].init_weights( + self.rfp_modules[rfp_idx].pretrained) + constant_init(self.rfp_weight, 0) + + def forward(self, inputs): + inputs = list(inputs) + assert len(inputs) == len(self.in_channels) + 1 # +1 for input image + img = inputs.pop(0) + # FPN forward + x = super().forward(tuple(inputs)) + for rfp_idx in range(self.rfp_steps - 1): + rfp_feats = [x[0]] + list( + self.rfp_aspp(x[i]) for i in range(1, len(x))) + x_idx = self.rfp_modules[rfp_idx].rfp_forward(img, rfp_feats) + # FPN forward + x_idx = super().forward(x_idx) + x_new = [] + for ft_idx in range(len(x_idx)): + add_weight = torch.sigmoid(self.rfp_weight(x_idx[ft_idx])) + x_new.append(add_weight * x_idx[ft_idx] + + (1 - add_weight) * x[ft_idx]) + x = x_new + return x diff --git a/mmdet/ops/__init__.py b/mmdet/ops/__init__.py index 61529d2df4bc6d53b1901a9e78809b6e4b644fb9..e05334e76fc9b820d316ea916494cefbfd37984d 100644 --- a/mmdet/ops/__init__.py +++ b/mmdet/ops/__init__.py @@ -14,6 +14,7 @@ from .point_sample import (SimpleRoIAlign, point_sample, rel_roi_point_to_rel_img_point) from .roi_align import RoIAlign, roi_align from .roi_pool import RoIPool, roi_pool +from .saconv import SAConv2d from .sigmoid_focal_loss import SigmoidFocalLoss, sigmoid_focal_loss from .utils import get_compiler_version, get_compiling_cuda_version from .wrappers import Conv2d, ConvTranspose2d, Linear, MaxPool2d @@ -28,5 +29,6 @@ __all__ = [ 'get_compiler_version', 'get_compiling_cuda_version', 'ConvWS2d', 'conv_ws_2d', 'build_plugin_layer', 'batched_nms', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'Linear', 'nms_match', 'CornerPool', - 'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign' + 'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign', + 'SAConv2d' ] diff --git a/mmdet/ops/conv_ws.py b/mmdet/ops/conv_ws.py index 7704683ffc777297b94cb2e08ff6cba36233c95f..c68d80ae24845e7c3a3e1ffc800822b514ff45c0 100644 --- a/mmdet/ops/conv_ws.py +++ b/mmdet/ops/conv_ws.py @@ -1,3 +1,4 @@ +import torch import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import CONV_LAYERS @@ -46,3 +47,100 @@ class ConvWS2d(nn.Conv2d): def forward(self, x): return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups, self.eps) + + +@CONV_LAYERS.register_module('ConvAWS') +class ConvAWS2d(nn.Conv2d): + """AWS (Adaptive Weight Standardization) + + This is a variant of Weight Standardization + (https://arxiv.org/pdf/1903.10520.pdf) + It is used in DetectoRS to avoid NaN + (https://arxiv.org/pdf/2006.02334.pdf) + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the conv kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of + the input. Default: 0 + dilation (int or tuple, optional): Spacing between kernel elements. + Default: 1 + groups (int, optional): Number of blocked connections from input + channels to output channels. Default: 1 + bias (bool, optional): If set True, adds a learnable bias to the + output. Default: True + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias) + self.register_buffer('weight_gamma', + torch.ones(self.out_channels, 1, 1, 1)) + self.register_buffer('weight_beta', + torch.zeros(self.out_channels, 1, 1, 1)) + + def _get_weight(self, weight): + weight_flat = weight.view(weight.size(0), -1) + mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1) + std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1) + weight = (weight - mean) / std + weight = self.weight_gamma * weight + self.weight_beta + return weight + + def forward(self, x): + weight = self._get_weight(self.weight) + return F.conv2d(x, weight, self.bias, self.stride, self.padding, + self.dilation, self.groups) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + """Override default load function + + AWS overrides the function _load_from_state_dict to recover + weight_gamma and weight_beta if they are missing. If weight_gamma and + weight_beta are found in the checkpoint, this function will return + after super()._load_from_state_dict. Otherwise, it will compute the + mean and std of the pretrained weights and store them in weight_beta + and weight_gamma. + """ + + self.weight_gamma.data.fill_(-1) + local_missing_keys = [] + super()._load_from_state_dict(state_dict, prefix, local_metadata, + strict, local_missing_keys, + unexpected_keys, error_msgs) + if self.weight_gamma.data.mean() > 0: + for k in local_missing_keys: + missing_keys.append(k) + return + weight = self.weight.data + weight_flat = weight.view(weight.size(0), -1) + mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1) + std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1) + self.weight_beta.data.copy_(mean) + self.weight_gamma.data.copy_(std) + missing_gamma_beta = [ + k for k in local_missing_keys + if k.endswith('weight_gamma') or k.endswith('weight_beta') + ] + for k in missing_gamma_beta: + local_missing_keys.remove(k) + for k in local_missing_keys: + missing_keys.append(k) diff --git a/mmdet/ops/saconv.py b/mmdet/ops/saconv.py new file mode 100644 index 0000000000000000000000000000000000000000..fb35be67e0ee947460c1c1712cdc34334a97b3b5 --- /dev/null +++ b/mmdet/ops/saconv.py @@ -0,0 +1,126 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import CONV_LAYERS, constant_init + +from .conv_ws import ConvAWS2d +from .dcn import deform_conv + + +@CONV_LAYERS.register_module(name='SAC') +class SAConv2d(ConvAWS2d): + """SAC (Switchable Atrous Convolution) + + This is an implementation of SAC in DetectoRS + (https://arxiv.org/pdf/2006.02334.pdf). + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of + the input. Default: 0 + padding_mode (string, optional): ``'zeros'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` + dilation (int or tuple, optional): Spacing between kernel elements. + Default: 1 + groups (int, optional): Number of blocked connections from input + channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the + output. Default: ``True`` + use_deform: If ``True``, replace convolution with deformable + convolution. Default: ``False``. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + use_deform=False): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias) + self.use_deform = use_deform + self.switch = nn.Conv2d( + self.in_channels, 1, kernel_size=1, stride=stride, bias=True) + self.weight_diff = nn.Parameter(torch.Tensor(self.weight.size())) + self.pre_context = nn.Conv2d( + self.in_channels, self.in_channels, kernel_size=1, bias=True) + self.post_context = nn.Conv2d( + self.out_channels, self.out_channels, kernel_size=1, bias=True) + if self.use_deform: + self.offset_s = nn.Conv2d( + self.in_channels, + 18, + kernel_size=3, + padding=1, + stride=stride, + bias=True) + self.offset_l = nn.Conv2d( + self.in_channels, + 18, + kernel_size=3, + padding=1, + stride=stride, + bias=True) + self.init_weights() + + def init_weights(self): + constant_init(self.switch, 0, bias=1) + self.weight_diff.data.zero_() + constant_init(self.pre_context, 0) + constant_init(self.post_context, 0) + if self.use_deform: + constant_init(self.offset_s, 0) + constant_init(self.offset_l, 0) + + def forward(self, x): + # pre-context + avg_x = F.adaptive_avg_pool2d(x, output_size=1) + avg_x = self.pre_context(avg_x) + avg_x = avg_x.expand_as(x) + x = x + avg_x + # switch + avg_x = F.pad(x, pad=(2, 2, 2, 2), mode='reflect') + avg_x = F.avg_pool2d(avg_x, kernel_size=5, stride=1, padding=0) + switch = self.switch(avg_x) + # sac + weight = self._get_weight(self.weight) + if self.use_deform: + offset = self.offset_s(avg_x) + out_s = deform_conv(x, offset, weight, self.stride, self.padding, + self.dilation, self.groups, 1) + else: + out_s = super().conv2d_forward(x, weight) + ori_p = self.padding + ori_d = self.dilation + self.padding = tuple(3 * p for p in self.padding) + self.dilation = tuple(3 * d for d in self.dilation) + weight = weight + self.weight_diff + if self.use_deform: + offset = self.offset_l(avg_x) + out_l = deform_conv(x, offset, weight, self.stride, self.padding, + self.dilation, self.groups, 1) + else: + out_l = super().conv2d_forward(x, weight) + out = switch * out_s + (1 - switch) * out_l + self.padding = ori_p + self.dilation = ori_d + # post-context + avg_x = F.adaptive_avg_pool2d(out, output_size=1) + avg_x = self.post_context(avg_x) + avg_x = avg_x.expand_as(out) + out = out + avg_x + return out diff --git a/tools/test.py b/tools/test.py index 6904ba13db06969250cf397947c1db0ee100dcec..b9ed6d4fdaece5dff058f77e776de140bba97f23 100644 --- a/tools/test.py +++ b/tools/test.py @@ -87,6 +87,10 @@ def main(): if cfg.get('cudnn_benchmark', False): torch.backends.cudnn.benchmark = True cfg.model.pretrained = None + if cfg.model.get('neck'): + if cfg.model.neck.get('rfp_backbone'): + if cfg.model.neck.rfp_backbone.get('pretrained'): + cfg.model.neck.rfp_backbone.pretrained = None cfg.data.test.test_mode = True # init distributed env first, since logger depends on the dist info.