From a04a482e5412b45507bb712afa6c6c3a44515de5 Mon Sep 17 00:00:00 2001
From: Leonardo Rossi <hachreak@gmail.com>
Date: Mon, 8 Jun 2020 17:51:31 +0200
Subject: [PATCH] roi_extractors: GRoIE addition (#2584)

* roi_extractors: GRoIE addition

Signed-off-by: Leonardo Rossi <leonardo.rossi@unipr.it>

* Fix bug caused by empty tensor and update benchmark results

Co-authored-by: ZwwWayne <wayne.zw@outlook.com>
---
 .gitignore                                    |  2 +-
 README.md                                     |  1 +
 .../grid_rcnn_r50_fpn_gn-head_1x_coco.py      | 11 ++++
 configs/groie/README.md                       | 65 +++++++++++++++++++
 .../faster_rcnn_r50_fpn_groie_1x_coco.py      | 24 +++++++
 ...grid_rcnn_r50_fpn_gn-head_groie_1x_coco.py | 44 +++++++++++++
 ...cbn-backbone_r4_gcb_c3-c5_groie_1x_coco.py | 44 +++++++++++++
 .../groie/mask_rcnn_r50_fpn_groie_1x_coco.py  | 44 +++++++++++++
 ...cbn-backbone_r4_gcb_c3-c5_groie_1x_coco.py | 44 +++++++++++++
 docs/model_zoo.md                             |  3 +
 .../roi_heads/roi_extractors/__init__.py      |  6 +-
 .../models/roi_heads/roi_extractors/groie.py  | 62 ++++++++++++++++++
 mmdet/ops/plugin.py                           |  5 +-
 tests/test_roi_extractor.py                   | 41 ++++++++++++
 14 files changed, 393 insertions(+), 3 deletions(-)
 create mode 100644 configs/grid_rcnn/grid_rcnn_r50_fpn_gn-head_1x_coco.py
 create mode 100644 configs/groie/README.md
 create mode 100644 configs/groie/faster_rcnn_r50_fpn_groie_1x_coco.py
 create mode 100644 configs/groie/grid_rcnn_r50_fpn_gn-head_groie_1x_coco.py
 create mode 100644 configs/groie/mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco.py
 create mode 100644 configs/groie/mask_rcnn_r50_fpn_groie_1x_coco.py
 create mode 100644 configs/groie/mask_rcnn_r50_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco.py
 create mode 100644 mmdet/models/roi_heads/roi_extractors/groie.py
 create mode 100644 tests/test_roi_extractor.py

diff --git a/.gitignore b/.gitignore
index e83ec70e..9918ddfd 100644
--- a/.gitignore
+++ b/.gitignore
@@ -104,7 +104,7 @@ venv.bak/
 .mypy_cache/
 
 mmdet/version.py
-./data/
+data/
 .vscode
 .idea
 .DS_Store
diff --git a/README.md b/README.md
index a737a2df..c9dc9ac8 100644
--- a/README.md
+++ b/README.md
@@ -88,6 +88,7 @@ Other features
 - [x] [GCNet](configs/gcnet/README.md)
 - [x] [Mixed Precision (FP16) Training](configs/fp16/README.md)
 - [x] [InstaBoost](configs/instaboost/README.md)
+- [x] [GRoIE](configs/groie/README.md)
 
 Some other methods are also supported in [projects using MMDetection](./docs/projects.md).
 
diff --git a/configs/grid_rcnn/grid_rcnn_r50_fpn_gn-head_1x_coco.py b/configs/grid_rcnn/grid_rcnn_r50_fpn_gn-head_1x_coco.py
new file mode 100644
index 00000000..cc3e3ef5
--- /dev/null
+++ b/configs/grid_rcnn/grid_rcnn_r50_fpn_gn-head_1x_coco.py
@@ -0,0 +1,11 @@
+_base_ = ['../grid_rcnn/grid_rcnn_r50_fpn_gn-head_2x_coco.py']
+# learning policy
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=500,
+    warmup_ratio=0.001,
+    step=[8, 11])
+checkpoint_config = dict(interval=1)
+# runtime settings
+total_epochs = 12
diff --git a/configs/groie/README.md b/configs/groie/README.md
new file mode 100644
index 00000000..50a2f137
--- /dev/null
+++ b/configs/groie/README.md
@@ -0,0 +1,65 @@
+# GRoIE
+
+## A novel Region of Interest Extraction Layer for Instance Segmentation
+
+By Leonardo Rossi, Akbar Karimi and Andrea Prati from
+[IMPLab](http://implab.ce.unipr.it/).
+
+We provide config files to reproduce the results in the paper for
+"*A novel Region of Interest Extraction Layer for Instance Segmentation*"
+on COCO object detection.
+
+## Introduction
+
+This paper is motivated by the need to overcome to the limitations of existing
+RoI extractors which select only one (the best) layer from FPN.
+
+Our intuition is that all the layers of FPN retain useful information.
+
+Therefore, the proposed layer (called Generic RoI Extractor - **GRoIE**)
+introduces non-local building blocks and attention mechanisms to boost the
+performance.
+
+## Results and models
+
+The results on COCO 2017 minival (5k images) are shown in the below table.
+You can find
+[here](https://drive.google.com/drive/folders/19ssstbq_h0Z1cgxHmJYFO8s1arf3QJbT)
+the trained models.
+
+### Application of GRoIE to different architectures
+
+| Backbone  | Model            | Lr schd | box AP | mask AP | Config file                                                                                 |
+| :-------: | :--------------: | :-----: | :----: | :-----: | :-----------------------------------------------------------------:                         |
+| R-50-FPN  | Faster Original  |   1x    |  37.4  |         | [config file](../faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py)                                |
+| R-50-FPN  | + GRoIE          |   1x    |  38.3  |         | [config file](./faster_rcnn_r50_fpn_groie_1x_coco.py)                                       |
+| R-50-FPN  | Grid R-CNN       |   1x    |  39.1  |         | [config file](../grid_rcnn/grid_rcnn_r50_fpn_gn-head_1x_coco.py)                            |
+| R-50-FPN  | + GRoIE          |   1x    |    |         | [config file](./grid_rcnn_r50_fpn_gn-head_groie_1x_coco.py)                                 |
+| R-50-FPN  | Mask R-CNN       |   1x    |  38.2  |  34.7   | [config file](../mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py)                                    |
+| R-50-FPN  | + GRoIE          |   1x    |  39.0  |  36.0   | [config file](./mask_rcnn_r50_fpn_groie_1x_coco.py)                                         |
+| R-50-FPN  | GC-Net           |   1x    |  40.7  |  36.5   | [config file](../gcnet/mask_rcnn_r50_fpn_syncbn-backbone_r4_gcb_c3-c5_1x_coco.py)           |
+| R-50-FPN  | + GRoIE          |   1x    |  41.0  |  37.8   | [config file](./mask_rcnn_r50_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco.py)            |
+| R-101-FPN | GC-Net           |   1x    |  42.2  |  37.8   | [config file](../configs/gcnet/mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_1x_coco.py)  |
+| R-101-FPN | + GRoIE          |   1x    |   |    | [config file](./mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco.py)           |
+
+
+
+## Citation
+
+If you use this work or benchmark in your research, please cite this project.
+
+```
+@misc{rossi2020novel,
+    title={A novel Region of Interest Extraction Layer for Instance Segmentation},
+    author={Leonardo Rossi and Akbar Karimi and Andrea Prati},
+    year={2020},
+    eprint={2004.13665},
+    archivePrefix={arXiv},
+    primaryClass={cs.CV}
+}
+```
+
+## Contact
+
+The implementation of GROI is currently maintained by
+[Leonardo Rossi](https://github.com/hachreak/).
diff --git a/configs/groie/faster_rcnn_r50_fpn_groie_1x_coco.py b/configs/groie/faster_rcnn_r50_fpn_groie_1x_coco.py
new file mode 100644
index 00000000..f9166538
--- /dev/null
+++ b/configs/groie/faster_rcnn_r50_fpn_groie_1x_coco.py
@@ -0,0 +1,24 @@
+_base_ = '../faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py'
+# model settings
+model = dict(
+    roi_head=dict(
+        bbox_roi_extractor=dict(
+            type='SumGenericRoiExtractor',
+            roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
+            out_channels=256,
+            featmap_strides=[4, 8, 16, 32],
+            pre_cfg=dict(
+                type='ConvModule',
+                in_channels=256,
+                out_channels=256,
+                kernel_size=5,
+                padding=2,
+                inplace=False,
+            ),
+            post_cfg=dict(
+                type='GeneralizedAttention',
+                in_channels=256,
+                spatial_range=-1,
+                num_heads=6,
+                attention_type='0100',
+                kv_stride=2))))
diff --git a/configs/groie/grid_rcnn_r50_fpn_gn-head_groie_1x_coco.py b/configs/groie/grid_rcnn_r50_fpn_gn-head_groie_1x_coco.py
new file mode 100644
index 00000000..668ce3a0
--- /dev/null
+++ b/configs/groie/grid_rcnn_r50_fpn_gn-head_groie_1x_coco.py
@@ -0,0 +1,44 @@
+_base_ = '../grid_rcnn/grid_rcnn_r50_fpn_gn-head_1x_coco.py'
+# model settings
+model = dict(
+    roi_head=dict(
+        bbox_roi_extractor=dict(
+            type='SumGenericRoiExtractor',
+            roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
+            out_channels=256,
+            featmap_strides=[4, 8, 16, 32],
+            pre_cfg=dict(
+                type='ConvModule',
+                in_channels=256,
+                out_channels=256,
+                kernel_size=5,
+                padding=2,
+                inplace=False,
+            ),
+            post_cfg=dict(
+                type='GeneralizedAttention',
+                in_channels=256,
+                spatial_range=-1,
+                num_heads=6,
+                attention_type='0100',
+                kv_stride=2)),
+        grid_roi_extractor=dict(
+            type='SumGenericRoiExtractor',
+            roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
+            out_channels=256,
+            featmap_strides=[4, 8, 16, 32],
+            pre_cfg=dict(
+                type='ConvModule',
+                in_channels=256,
+                out_channels=256,
+                kernel_size=5,
+                padding=2,
+                inplace=False,
+            ),
+            post_cfg=dict(
+                type='GeneralizedAttention',
+                in_channels=256,
+                spatial_range=-1,
+                num_heads=6,
+                attention_type='0100',
+                kv_stride=2))))
diff --git a/configs/groie/mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco.py b/configs/groie/mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco.py
new file mode 100644
index 00000000..1cf22a09
--- /dev/null
+++ b/configs/groie/mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco.py
@@ -0,0 +1,44 @@
+_base_ = '../gcnet/mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_1x_coco.py'
+# model settings
+model = dict(
+    roi_head=dict(
+        bbox_roi_extractor=dict(
+            type='SumGenericRoiExtractor',
+            roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
+            out_channels=256,
+            featmap_strides=[4, 8, 16, 32],
+            pre_cfg=dict(
+                type='ConvModule',
+                in_channels=256,
+                out_channels=256,
+                kernel_size=5,
+                padding=2,
+                inplace=False,
+            ),
+            post_cfg=dict(
+                type='GeneralizedAttention',
+                in_channels=256,
+                spatial_range=-1,
+                num_heads=6,
+                attention_type='0100',
+                kv_stride=2)),
+        mask_roi_extractor=dict(
+            type='SumGenericRoiExtractor',
+            roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
+            out_channels=256,
+            featmap_strides=[4, 8, 16, 32],
+            pre_cfg=dict(
+                type='ConvModule',
+                in_channels=256,
+                out_channels=256,
+                kernel_size=5,
+                padding=2,
+                inplace=False,
+            ),
+            post_cfg=dict(
+                type='GeneralizedAttention',
+                in_channels=256,
+                spatial_range=-1,
+                num_heads=6,
+                attention_type='0100',
+                kv_stride=2))))
diff --git a/configs/groie/mask_rcnn_r50_fpn_groie_1x_coco.py b/configs/groie/mask_rcnn_r50_fpn_groie_1x_coco.py
new file mode 100644
index 00000000..3afa52ef
--- /dev/null
+++ b/configs/groie/mask_rcnn_r50_fpn_groie_1x_coco.py
@@ -0,0 +1,44 @@
+_base_ = '../mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py'
+# model settings
+model = dict(
+    roi_head=dict(
+        bbox_roi_extractor=dict(
+            type='SumGenericRoiExtractor',
+            roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
+            out_channels=256,
+            featmap_strides=[4, 8, 16, 32],
+            pre_cfg=dict(
+                type='ConvModule',
+                in_channels=256,
+                out_channels=256,
+                kernel_size=5,
+                padding=2,
+                inplace=False,
+            ),
+            post_cfg=dict(
+                type='GeneralizedAttention',
+                in_channels=256,
+                spatial_range=-1,
+                num_heads=6,
+                attention_type='0100',
+                kv_stride=2)),
+        mask_roi_extractor=dict(
+            type='SumGenericRoiExtractor',
+            roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
+            out_channels=256,
+            featmap_strides=[4, 8, 16, 32],
+            pre_cfg=dict(
+                type='ConvModule',
+                in_channels=256,
+                out_channels=256,
+                kernel_size=5,
+                padding=2,
+                inplace=False,
+            ),
+            post_cfg=dict(
+                type='GeneralizedAttention',
+                in_channels=256,
+                spatial_range=-1,
+                num_heads=6,
+                attention_type='0100',
+                kv_stride=2))))
diff --git a/configs/groie/mask_rcnn_r50_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco.py b/configs/groie/mask_rcnn_r50_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco.py
new file mode 100644
index 00000000..8dc9ed45
--- /dev/null
+++ b/configs/groie/mask_rcnn_r50_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco.py
@@ -0,0 +1,44 @@
+_base_ = '../gcnet/mask_rcnn_r50_fpn_syncbn-backbone_r4_gcb_c3-c5_1x_coco.py'
+# model settings
+model = dict(
+    roi_head=dict(
+        bbox_roi_extractor=dict(
+            type='SumGenericRoiExtractor',
+            roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
+            out_channels=256,
+            featmap_strides=[4, 8, 16, 32],
+            pre_cfg=dict(
+                type='ConvModule',
+                in_channels=256,
+                out_channels=256,
+                kernel_size=5,
+                padding=2,
+                inplace=False,
+            ),
+            post_cfg=dict(
+                type='GeneralizedAttention',
+                in_channels=256,
+                spatial_range=-1,
+                num_heads=6,
+                attention_type='0100',
+                kv_stride=2)),
+        mask_roi_extractor=dict(
+            type='SumGenericRoiExtractor',
+            roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
+            out_channels=256,
+            featmap_strides=[4, 8, 16, 32],
+            pre_cfg=dict(
+                type='ConvModule',
+                in_channels=256,
+                out_channels=256,
+                kernel_size=5,
+                padding=2,
+                inplace=False,
+            ),
+            post_cfg=dict(
+                type='GeneralizedAttention',
+                in_channels=256,
+                spatial_range=-1,
+                num_heads=6,
+                attention_type='0100',
+                kv_stride=2))))
diff --git a/docs/model_zoo.md b/docs/model_zoo.md
index d8e9c6d3..f216f054 100644
--- a/docs/model_zoo.md
+++ b/docs/model_zoo.md
@@ -129,6 +129,9 @@ Please refer to [RegNet](https://github.com/open-mmlab/mmdetection/blob/master/c
 ### Res2Net
 Please refer to [Res2Net](https://github.com/open-mmlab/mmdetection/blob/master/configs/res2net) for details.
 
+### GRoIE
+Please refer to [GRoIE](https://github.com/open-mmlab/mmdetection/blob/master/configs/groie) for details.
+
 ### Other datasets
 
 We also benchmark some methods on [PASCAL VOC](https://github.com/open-mmlab/mmdetection/blob/master/configs/pascal_voc), [Cityscapes](https://github.com/open-mmlab/mmdetection/blob/master/configs/cityscapes) and [WIDER FACE](https://github.com/open-mmlab/mmdetection/blob/master/configs/wider_face).
diff --git a/mmdet/models/roi_heads/roi_extractors/__init__.py b/mmdet/models/roi_heads/roi_extractors/__init__.py
index 9161708c..2fed89ba 100644
--- a/mmdet/models/roi_heads/roi_extractors/__init__.py
+++ b/mmdet/models/roi_heads/roi_extractors/__init__.py
@@ -1,3 +1,7 @@
+from .groie import SumGenericRoiExtractor
 from .single_level import SingleRoIExtractor
 
-__all__ = ['SingleRoIExtractor']
+__all__ = [
+    'SingleRoIExtractor',
+    'SumGenericRoiExtractor',
+]
diff --git a/mmdet/models/roi_heads/roi_extractors/groie.py b/mmdet/models/roi_heads/roi_extractors/groie.py
new file mode 100644
index 00000000..1c710553
--- /dev/null
+++ b/mmdet/models/roi_heads/roi_extractors/groie.py
@@ -0,0 +1,62 @@
+"""Generic RoI Extractor.
+
+A novel Region of Interest Extraction Layer for Instance Segmentation.
+"""
+
+from torch import nn
+
+from mmdet.core import force_fp32
+from mmdet.models.builder import ROI_EXTRACTORS
+from mmdet.ops.plugin import build_plugin_layer
+from .single_level import SingleRoIExtractor
+
+
+@ROI_EXTRACTORS.register_module
+class SumGenericRoiExtractor(SingleRoIExtractor):
+    """Extract RoI features from all summed feature maps levels.
+
+    https://arxiv.org/abs/2004.13665
+
+    Args:
+        pre_cfg (dict): Specify pre-processing modules.
+        post_cfg (dict): Specify post-processing modules.
+        kwargs (keyword arguments): Arguments that are the same
+            as :class:`SingleRoIExtractor`.
+    """
+
+    def __init__(self, pre_cfg, post_cfg, **kwargs):
+        super(SumGenericRoiExtractor, self).__init__(**kwargs)
+
+        # build pre/post processing modules
+        self.post_module = build_plugin_layer(post_cfg, '_post_module')[1]
+        self.pre_module = build_plugin_layer(pre_cfg, '_pre_module')[1]
+        self.relu = nn.ReLU(inplace=False)
+
+    @force_fp32(apply_to=('feats', ), out_fp16=True)
+    def forward(self, feats, rois, roi_scale_factor=None):
+        if len(feats) == 1:
+            return self.roi_layers[0](feats[0], rois)
+
+        out_size = self.roi_layers[0].out_size
+        num_levels = len(feats)
+        roi_feats = feats[0].new_zeros(
+            rois.size(0), self.out_channels, *out_size)
+
+        # some times rois is an empty tensor
+        if roi_feats.shape[0] == 0:
+            return roi_feats
+
+        if roi_scale_factor is not None:
+            rois = self.roi_rescale(rois, roi_scale_factor)
+
+        for i in range(num_levels):
+            # apply pre-processing to a RoI extracted from each layer
+            roi_feats_t = self.roi_layers[i](feats[i], rois)
+            roi_feats_t = self.pre_module(roi_feats_t)
+            roi_feats_t = self.relu(roi_feats_t)
+            # and sum them all
+            roi_feats += roi_feats_t
+
+        # apply post-processing before return the result
+        x = self.post_module(roi_feats)
+        return x
diff --git a/mmdet/ops/plugin.py b/mmdet/ops/plugin.py
index a104c1f5..4270f471 100644
--- a/mmdet/ops/plugin.py
+++ b/mmdet/ops/plugin.py
@@ -1,3 +1,5 @@
+from mmcv.cnn import ConvModule
+
 from .context_block import ContextBlock
 from .generalized_attention import GeneralizedAttention
 from .non_local import NonLocal2D
@@ -6,7 +8,8 @@ plugin_cfg = {
     # format: layer_type: (abbreviation, module)
     'ContextBlock': ('context_block', ContextBlock),
     'GeneralizedAttention': ('gen_attention_block', GeneralizedAttention),
-    'NonLocal2D': ('nonlocal_block', NonLocal2D)
+    'NonLocal2D': ('nonlocal_block', NonLocal2D),
+    'ConvModule': ('conv_block', ConvModule),
 }
 
 
diff --git a/tests/test_roi_extractor.py b/tests/test_roi_extractor.py
new file mode 100644
index 00000000..332f8b07
--- /dev/null
+++ b/tests/test_roi_extractor.py
@@ -0,0 +1,41 @@
+import mmcv
+import torch
+
+from mmdet.models.roi_heads.roi_extractors import SumGenericRoiExtractor
+
+
+def test_groie():
+    cfg = mmcv.Config(
+        dict(
+            roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
+            out_channels=256,
+            featmap_strides=[4, 8, 16, 32],
+            pre_cfg=dict(
+                type='ConvModule',
+                in_channels=256,
+                out_channels=256,
+                kernel_size=5,
+                padding=2,
+                inplace=False,
+            ),
+            post_cfg=dict(
+                type='ConvModule',
+                in_channels=256,
+                out_channels=256,
+                kernel_size=5,
+                padding=2,
+                inplace=False)))
+
+    groie = SumGenericRoiExtractor(**cfg)
+
+    feats = (
+        torch.rand((1, 256, 200, 336)),
+        torch.rand((1, 256, 100, 168)),
+        torch.rand((1, 256, 50, 84)),
+        torch.rand((1, 256, 25, 42)),
+    )
+
+    rois = torch.tensor([[0.0000, 587.8285, 52.1405, 886.2484, 341.5644]])
+
+    res = groie(feats, rois)
+    assert res.shape == torch.Size([1, 256, 7, 7])
-- 
GitLab