diff --git a/.gitignore b/.gitignore index e83ec70ebe57efac39ff559dee5c5743913bf675..9918ddfd120d72ac6c8bf2a6f4615e7551c87471 100644 --- a/.gitignore +++ b/.gitignore @@ -104,7 +104,7 @@ venv.bak/ .mypy_cache/ mmdet/version.py -./data/ +data/ .vscode .idea .DS_Store diff --git a/README.md b/README.md index a737a2df4e2163c820ddfd4f5eb46fd45bd37eae..c9dc9ac851d7fe3e0f8307226efbbe0dd057cba1 100644 --- a/README.md +++ b/README.md @@ -88,6 +88,7 @@ Other features - [x] [GCNet](configs/gcnet/README.md) - [x] [Mixed Precision (FP16) Training](configs/fp16/README.md) - [x] [InstaBoost](configs/instaboost/README.md) +- [x] [GRoIE](configs/groie/README.md) Some other methods are also supported in [projects using MMDetection](./docs/projects.md). diff --git a/configs/grid_rcnn/grid_rcnn_r50_fpn_gn-head_1x_coco.py b/configs/grid_rcnn/grid_rcnn_r50_fpn_gn-head_1x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..cc3e3ef594243be1335aa3b3d2f78f50f4477082 --- /dev/null +++ b/configs/grid_rcnn/grid_rcnn_r50_fpn_gn-head_1x_coco.py @@ -0,0 +1,11 @@ +_base_ = ['../grid_rcnn/grid_rcnn_r50_fpn_gn-head_2x_coco.py'] +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=0.001, + step=[8, 11]) +checkpoint_config = dict(interval=1) +# runtime settings +total_epochs = 12 diff --git a/configs/groie/README.md b/configs/groie/README.md new file mode 100644 index 0000000000000000000000000000000000000000..50a2f1373927a8eb4b0300f1d1e20844930e954d --- /dev/null +++ b/configs/groie/README.md @@ -0,0 +1,65 @@ +# GRoIE + +## A novel Region of Interest Extraction Layer for Instance Segmentation + +By Leonardo Rossi, Akbar Karimi and Andrea Prati from +[IMPLab](http://implab.ce.unipr.it/). + +We provide config files to reproduce the results in the paper for +"*A novel Region of Interest Extraction Layer for Instance Segmentation*" +on COCO object detection. + +## Introduction + +This paper is motivated by the need to overcome to the limitations of existing +RoI extractors which select only one (the best) layer from FPN. + +Our intuition is that all the layers of FPN retain useful information. + +Therefore, the proposed layer (called Generic RoI Extractor - **GRoIE**) +introduces non-local building blocks and attention mechanisms to boost the +performance. + +## Results and models + +The results on COCO 2017 minival (5k images) are shown in the below table. +You can find +[here](https://drive.google.com/drive/folders/19ssstbq_h0Z1cgxHmJYFO8s1arf3QJbT) +the trained models. + +### Application of GRoIE to different architectures + +| Backbone | Model | Lr schd | box AP | mask AP | Config file | +| :-------: | :--------------: | :-----: | :----: | :-----: | :-----------------------------------------------------------------: | +| R-50-FPN | Faster Original | 1x | 37.4 | | [config file](../faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py) | +| R-50-FPN | + GRoIE | 1x | 38.3 | | [config file](./faster_rcnn_r50_fpn_groie_1x_coco.py) | +| R-50-FPN | Grid R-CNN | 1x | 39.1 | | [config file](../grid_rcnn/grid_rcnn_r50_fpn_gn-head_1x_coco.py) | +| R-50-FPN | + GRoIE | 1x | | | [config file](./grid_rcnn_r50_fpn_gn-head_groie_1x_coco.py) | +| R-50-FPN | Mask R-CNN | 1x | 38.2 | 34.7 | [config file](../mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py) | +| R-50-FPN | + GRoIE | 1x | 39.0 | 36.0 | [config file](./mask_rcnn_r50_fpn_groie_1x_coco.py) | +| R-50-FPN | GC-Net | 1x | 40.7 | 36.5 | [config file](../gcnet/mask_rcnn_r50_fpn_syncbn-backbone_r4_gcb_c3-c5_1x_coco.py) | +| R-50-FPN | + GRoIE | 1x | 41.0 | 37.8 | [config file](./mask_rcnn_r50_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco.py) | +| R-101-FPN | GC-Net | 1x | 42.2 | 37.8 | [config file](../configs/gcnet/mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_1x_coco.py) | +| R-101-FPN | + GRoIE | 1x | | | [config file](./mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco.py) | + + + +## Citation + +If you use this work or benchmark in your research, please cite this project. + +``` +@misc{rossi2020novel, + title={A novel Region of Interest Extraction Layer for Instance Segmentation}, + author={Leonardo Rossi and Akbar Karimi and Andrea Prati}, + year={2020}, + eprint={2004.13665}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} +``` + +## Contact + +The implementation of GROI is currently maintained by +[Leonardo Rossi](https://github.com/hachreak/). diff --git a/configs/groie/faster_rcnn_r50_fpn_groie_1x_coco.py b/configs/groie/faster_rcnn_r50_fpn_groie_1x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..f916653890c145d1ac4680a4dd1141eb54aae311 --- /dev/null +++ b/configs/groie/faster_rcnn_r50_fpn_groie_1x_coco.py @@ -0,0 +1,24 @@ +_base_ = '../faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py' +# model settings +model = dict( + roi_head=dict( + bbox_roi_extractor=dict( + type='SumGenericRoiExtractor', + roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2), + out_channels=256, + featmap_strides=[4, 8, 16, 32], + pre_cfg=dict( + type='ConvModule', + in_channels=256, + out_channels=256, + kernel_size=5, + padding=2, + inplace=False, + ), + post_cfg=dict( + type='GeneralizedAttention', + in_channels=256, + spatial_range=-1, + num_heads=6, + attention_type='0100', + kv_stride=2)))) diff --git a/configs/groie/grid_rcnn_r50_fpn_gn-head_groie_1x_coco.py b/configs/groie/grid_rcnn_r50_fpn_gn-head_groie_1x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..668ce3a04dae0216da55d6c65d143d4c92295513 --- /dev/null +++ b/configs/groie/grid_rcnn_r50_fpn_gn-head_groie_1x_coco.py @@ -0,0 +1,44 @@ +_base_ = '../grid_rcnn/grid_rcnn_r50_fpn_gn-head_1x_coco.py' +# model settings +model = dict( + roi_head=dict( + bbox_roi_extractor=dict( + type='SumGenericRoiExtractor', + roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2), + out_channels=256, + featmap_strides=[4, 8, 16, 32], + pre_cfg=dict( + type='ConvModule', + in_channels=256, + out_channels=256, + kernel_size=5, + padding=2, + inplace=False, + ), + post_cfg=dict( + type='GeneralizedAttention', + in_channels=256, + spatial_range=-1, + num_heads=6, + attention_type='0100', + kv_stride=2)), + grid_roi_extractor=dict( + type='SumGenericRoiExtractor', + roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2), + out_channels=256, + featmap_strides=[4, 8, 16, 32], + pre_cfg=dict( + type='ConvModule', + in_channels=256, + out_channels=256, + kernel_size=5, + padding=2, + inplace=False, + ), + post_cfg=dict( + type='GeneralizedAttention', + in_channels=256, + spatial_range=-1, + num_heads=6, + attention_type='0100', + kv_stride=2)))) diff --git a/configs/groie/mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco.py b/configs/groie/mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..1cf22a09aed3fceb9f9c20f189b3f0f31ad470ff --- /dev/null +++ b/configs/groie/mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco.py @@ -0,0 +1,44 @@ +_base_ = '../gcnet/mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_1x_coco.py' +# model settings +model = dict( + roi_head=dict( + bbox_roi_extractor=dict( + type='SumGenericRoiExtractor', + roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2), + out_channels=256, + featmap_strides=[4, 8, 16, 32], + pre_cfg=dict( + type='ConvModule', + in_channels=256, + out_channels=256, + kernel_size=5, + padding=2, + inplace=False, + ), + post_cfg=dict( + type='GeneralizedAttention', + in_channels=256, + spatial_range=-1, + num_heads=6, + attention_type='0100', + kv_stride=2)), + mask_roi_extractor=dict( + type='SumGenericRoiExtractor', + roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2), + out_channels=256, + featmap_strides=[4, 8, 16, 32], + pre_cfg=dict( + type='ConvModule', + in_channels=256, + out_channels=256, + kernel_size=5, + padding=2, + inplace=False, + ), + post_cfg=dict( + type='GeneralizedAttention', + in_channels=256, + spatial_range=-1, + num_heads=6, + attention_type='0100', + kv_stride=2)))) diff --git a/configs/groie/mask_rcnn_r50_fpn_groie_1x_coco.py b/configs/groie/mask_rcnn_r50_fpn_groie_1x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..3afa52ef81324835134e5c037dd6ae109bc4c9c5 --- /dev/null +++ b/configs/groie/mask_rcnn_r50_fpn_groie_1x_coco.py @@ -0,0 +1,44 @@ +_base_ = '../mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py' +# model settings +model = dict( + roi_head=dict( + bbox_roi_extractor=dict( + type='SumGenericRoiExtractor', + roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2), + out_channels=256, + featmap_strides=[4, 8, 16, 32], + pre_cfg=dict( + type='ConvModule', + in_channels=256, + out_channels=256, + kernel_size=5, + padding=2, + inplace=False, + ), + post_cfg=dict( + type='GeneralizedAttention', + in_channels=256, + spatial_range=-1, + num_heads=6, + attention_type='0100', + kv_stride=2)), + mask_roi_extractor=dict( + type='SumGenericRoiExtractor', + roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2), + out_channels=256, + featmap_strides=[4, 8, 16, 32], + pre_cfg=dict( + type='ConvModule', + in_channels=256, + out_channels=256, + kernel_size=5, + padding=2, + inplace=False, + ), + post_cfg=dict( + type='GeneralizedAttention', + in_channels=256, + spatial_range=-1, + num_heads=6, + attention_type='0100', + kv_stride=2)))) diff --git a/configs/groie/mask_rcnn_r50_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco.py b/configs/groie/mask_rcnn_r50_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..8dc9ed4574ae6278806e7ce989a76b3e3e214aa2 --- /dev/null +++ b/configs/groie/mask_rcnn_r50_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco.py @@ -0,0 +1,44 @@ +_base_ = '../gcnet/mask_rcnn_r50_fpn_syncbn-backbone_r4_gcb_c3-c5_1x_coco.py' +# model settings +model = dict( + roi_head=dict( + bbox_roi_extractor=dict( + type='SumGenericRoiExtractor', + roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2), + out_channels=256, + featmap_strides=[4, 8, 16, 32], + pre_cfg=dict( + type='ConvModule', + in_channels=256, + out_channels=256, + kernel_size=5, + padding=2, + inplace=False, + ), + post_cfg=dict( + type='GeneralizedAttention', + in_channels=256, + spatial_range=-1, + num_heads=6, + attention_type='0100', + kv_stride=2)), + mask_roi_extractor=dict( + type='SumGenericRoiExtractor', + roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2), + out_channels=256, + featmap_strides=[4, 8, 16, 32], + pre_cfg=dict( + type='ConvModule', + in_channels=256, + out_channels=256, + kernel_size=5, + padding=2, + inplace=False, + ), + post_cfg=dict( + type='GeneralizedAttention', + in_channels=256, + spatial_range=-1, + num_heads=6, + attention_type='0100', + kv_stride=2)))) diff --git a/docs/model_zoo.md b/docs/model_zoo.md index d8e9c6d3e9f59627540c8299ed1ba1ae0fdcb0ef..f216f054536576b118e60e7ea2d0b1dfaf7c87c6 100644 --- a/docs/model_zoo.md +++ b/docs/model_zoo.md @@ -129,6 +129,9 @@ Please refer to [RegNet](https://github.com/open-mmlab/mmdetection/blob/master/c ### Res2Net Please refer to [Res2Net](https://github.com/open-mmlab/mmdetection/blob/master/configs/res2net) for details. +### GRoIE +Please refer to [GRoIE](https://github.com/open-mmlab/mmdetection/blob/master/configs/groie) for details. + ### Other datasets We also benchmark some methods on [PASCAL VOC](https://github.com/open-mmlab/mmdetection/blob/master/configs/pascal_voc), [Cityscapes](https://github.com/open-mmlab/mmdetection/blob/master/configs/cityscapes) and [WIDER FACE](https://github.com/open-mmlab/mmdetection/blob/master/configs/wider_face). diff --git a/mmdet/models/roi_heads/roi_extractors/__init__.py b/mmdet/models/roi_heads/roi_extractors/__init__.py index 9161708ce13fa4f0a6bb188e82a19a163b9b7e4f..2fed89ba58c117bfa64ce2cb7d7059e81cdf91a9 100644 --- a/mmdet/models/roi_heads/roi_extractors/__init__.py +++ b/mmdet/models/roi_heads/roi_extractors/__init__.py @@ -1,3 +1,7 @@ +from .groie import SumGenericRoiExtractor from .single_level import SingleRoIExtractor -__all__ = ['SingleRoIExtractor'] +__all__ = [ + 'SingleRoIExtractor', + 'SumGenericRoiExtractor', +] diff --git a/mmdet/models/roi_heads/roi_extractors/groie.py b/mmdet/models/roi_heads/roi_extractors/groie.py new file mode 100644 index 0000000000000000000000000000000000000000..1c7105538ce808bb6a76106308aae38412d374e8 --- /dev/null +++ b/mmdet/models/roi_heads/roi_extractors/groie.py @@ -0,0 +1,62 @@ +"""Generic RoI Extractor. + +A novel Region of Interest Extraction Layer for Instance Segmentation. +""" + +from torch import nn + +from mmdet.core import force_fp32 +from mmdet.models.builder import ROI_EXTRACTORS +from mmdet.ops.plugin import build_plugin_layer +from .single_level import SingleRoIExtractor + + +@ROI_EXTRACTORS.register_module +class SumGenericRoiExtractor(SingleRoIExtractor): + """Extract RoI features from all summed feature maps levels. + + https://arxiv.org/abs/2004.13665 + + Args: + pre_cfg (dict): Specify pre-processing modules. + post_cfg (dict): Specify post-processing modules. + kwargs (keyword arguments): Arguments that are the same + as :class:`SingleRoIExtractor`. + """ + + def __init__(self, pre_cfg, post_cfg, **kwargs): + super(SumGenericRoiExtractor, self).__init__(**kwargs) + + # build pre/post processing modules + self.post_module = build_plugin_layer(post_cfg, '_post_module')[1] + self.pre_module = build_plugin_layer(pre_cfg, '_pre_module')[1] + self.relu = nn.ReLU(inplace=False) + + @force_fp32(apply_to=('feats', ), out_fp16=True) + def forward(self, feats, rois, roi_scale_factor=None): + if len(feats) == 1: + return self.roi_layers[0](feats[0], rois) + + out_size = self.roi_layers[0].out_size + num_levels = len(feats) + roi_feats = feats[0].new_zeros( + rois.size(0), self.out_channels, *out_size) + + # some times rois is an empty tensor + if roi_feats.shape[0] == 0: + return roi_feats + + if roi_scale_factor is not None: + rois = self.roi_rescale(rois, roi_scale_factor) + + for i in range(num_levels): + # apply pre-processing to a RoI extracted from each layer + roi_feats_t = self.roi_layers[i](feats[i], rois) + roi_feats_t = self.pre_module(roi_feats_t) + roi_feats_t = self.relu(roi_feats_t) + # and sum them all + roi_feats += roi_feats_t + + # apply post-processing before return the result + x = self.post_module(roi_feats) + return x diff --git a/mmdet/ops/plugin.py b/mmdet/ops/plugin.py index a104c1f56b10de23d95a0239f752cb9552dcfb19..4270f471a49ce8064c5ee4bd604b24ec75984c0a 100644 --- a/mmdet/ops/plugin.py +++ b/mmdet/ops/plugin.py @@ -1,3 +1,5 @@ +from mmcv.cnn import ConvModule + from .context_block import ContextBlock from .generalized_attention import GeneralizedAttention from .non_local import NonLocal2D @@ -6,7 +8,8 @@ plugin_cfg = { # format: layer_type: (abbreviation, module) 'ContextBlock': ('context_block', ContextBlock), 'GeneralizedAttention': ('gen_attention_block', GeneralizedAttention), - 'NonLocal2D': ('nonlocal_block', NonLocal2D) + 'NonLocal2D': ('nonlocal_block', NonLocal2D), + 'ConvModule': ('conv_block', ConvModule), } diff --git a/tests/test_roi_extractor.py b/tests/test_roi_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..332f8b07edf70932938dedb2f9f984746216dc95 --- /dev/null +++ b/tests/test_roi_extractor.py @@ -0,0 +1,41 @@ +import mmcv +import torch + +from mmdet.models.roi_heads.roi_extractors import SumGenericRoiExtractor + + +def test_groie(): + cfg = mmcv.Config( + dict( + roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2), + out_channels=256, + featmap_strides=[4, 8, 16, 32], + pre_cfg=dict( + type='ConvModule', + in_channels=256, + out_channels=256, + kernel_size=5, + padding=2, + inplace=False, + ), + post_cfg=dict( + type='ConvModule', + in_channels=256, + out_channels=256, + kernel_size=5, + padding=2, + inplace=False))) + + groie = SumGenericRoiExtractor(**cfg) + + feats = ( + torch.rand((1, 256, 200, 336)), + torch.rand((1, 256, 100, 168)), + torch.rand((1, 256, 50, 84)), + torch.rand((1, 256, 25, 42)), + ) + + rois = torch.tensor([[0.0000, 587.8285, 52.1405, 886.2484, 341.5644]]) + + res = groie(feats, rois) + assert res.shape == torch.Size([1, 256, 7, 7])