Skip to content
Snippets Groups Projects
Unverified Commit a04a482e authored by Leonardo Rossi's avatar Leonardo Rossi Committed by GitHub
Browse files

roi_extractors: GRoIE addition (#2584)


* roi_extractors: GRoIE addition

Signed-off-by: default avatarLeonardo Rossi <leonardo.rossi@unipr.it>

* Fix bug caused by empty tensor and update benchmark results

Co-authored-by: default avatarZwwWayne <wayne.zw@outlook.com>
parent c6948c60
No related branches found
No related tags found
No related merge requests found
Showing
with 393 additions and 3 deletions
...@@ -104,7 +104,7 @@ venv.bak/ ...@@ -104,7 +104,7 @@ venv.bak/
.mypy_cache/ .mypy_cache/
mmdet/version.py mmdet/version.py
./data/ data/
.vscode .vscode
.idea .idea
.DS_Store .DS_Store
......
...@@ -88,6 +88,7 @@ Other features ...@@ -88,6 +88,7 @@ Other features
- [x] [GCNet](configs/gcnet/README.md) - [x] [GCNet](configs/gcnet/README.md)
- [x] [Mixed Precision (FP16) Training](configs/fp16/README.md) - [x] [Mixed Precision (FP16) Training](configs/fp16/README.md)
- [x] [InstaBoost](configs/instaboost/README.md) - [x] [InstaBoost](configs/instaboost/README.md)
- [x] [GRoIE](configs/groie/README.md)
Some other methods are also supported in [projects using MMDetection](./docs/projects.md). Some other methods are also supported in [projects using MMDetection](./docs/projects.md).
......
_base_ = ['../grid_rcnn/grid_rcnn_r50_fpn_gn-head_2x_coco.py']
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.001,
step=[8, 11])
checkpoint_config = dict(interval=1)
# runtime settings
total_epochs = 12
# GRoIE
## A novel Region of Interest Extraction Layer for Instance Segmentation
By Leonardo Rossi, Akbar Karimi and Andrea Prati from
[IMPLab](http://implab.ce.unipr.it/).
We provide config files to reproduce the results in the paper for
"*A novel Region of Interest Extraction Layer for Instance Segmentation*"
on COCO object detection.
## Introduction
This paper is motivated by the need to overcome to the limitations of existing
RoI extractors which select only one (the best) layer from FPN.
Our intuition is that all the layers of FPN retain useful information.
Therefore, the proposed layer (called Generic RoI Extractor - **GRoIE**)
introduces non-local building blocks and attention mechanisms to boost the
performance.
## Results and models
The results on COCO 2017 minival (5k images) are shown in the below table.
You can find
[here](https://drive.google.com/drive/folders/19ssstbq_h0Z1cgxHmJYFO8s1arf3QJbT)
the trained models.
### Application of GRoIE to different architectures
| Backbone | Model | Lr schd | box AP | mask AP | Config file |
| :-------: | :--------------: | :-----: | :----: | :-----: | :-----------------------------------------------------------------: |
| R-50-FPN | Faster Original | 1x | 37.4 | | [config file](../faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py) |
| R-50-FPN | + GRoIE | 1x | 38.3 | | [config file](./faster_rcnn_r50_fpn_groie_1x_coco.py) |
| R-50-FPN | Grid R-CNN | 1x | 39.1 | | [config file](../grid_rcnn/grid_rcnn_r50_fpn_gn-head_1x_coco.py) |
| R-50-FPN | + GRoIE | 1x | | | [config file](./grid_rcnn_r50_fpn_gn-head_groie_1x_coco.py) |
| R-50-FPN | Mask R-CNN | 1x | 38.2 | 34.7 | [config file](../mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py) |
| R-50-FPN | + GRoIE | 1x | 39.0 | 36.0 | [config file](./mask_rcnn_r50_fpn_groie_1x_coco.py) |
| R-50-FPN | GC-Net | 1x | 40.7 | 36.5 | [config file](../gcnet/mask_rcnn_r50_fpn_syncbn-backbone_r4_gcb_c3-c5_1x_coco.py) |
| R-50-FPN | + GRoIE | 1x | 41.0 | 37.8 | [config file](./mask_rcnn_r50_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco.py) |
| R-101-FPN | GC-Net | 1x | 42.2 | 37.8 | [config file](../configs/gcnet/mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_1x_coco.py) |
| R-101-FPN | + GRoIE | 1x | | | [config file](./mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_groie_1x_coco.py) |
## Citation
If you use this work or benchmark in your research, please cite this project.
```
@misc{rossi2020novel,
title={A novel Region of Interest Extraction Layer for Instance Segmentation},
author={Leonardo Rossi and Akbar Karimi and Andrea Prati},
year={2020},
eprint={2004.13665},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
## Contact
The implementation of GROI is currently maintained by
[Leonardo Rossi](https://github.com/hachreak/).
_base_ = '../faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py'
# model settings
model = dict(
roi_head=dict(
bbox_roi_extractor=dict(
type='SumGenericRoiExtractor',
roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32],
pre_cfg=dict(
type='ConvModule',
in_channels=256,
out_channels=256,
kernel_size=5,
padding=2,
inplace=False,
),
post_cfg=dict(
type='GeneralizedAttention',
in_channels=256,
spatial_range=-1,
num_heads=6,
attention_type='0100',
kv_stride=2))))
_base_ = '../grid_rcnn/grid_rcnn_r50_fpn_gn-head_1x_coco.py'
# model settings
model = dict(
roi_head=dict(
bbox_roi_extractor=dict(
type='SumGenericRoiExtractor',
roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32],
pre_cfg=dict(
type='ConvModule',
in_channels=256,
out_channels=256,
kernel_size=5,
padding=2,
inplace=False,
),
post_cfg=dict(
type='GeneralizedAttention',
in_channels=256,
spatial_range=-1,
num_heads=6,
attention_type='0100',
kv_stride=2)),
grid_roi_extractor=dict(
type='SumGenericRoiExtractor',
roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32],
pre_cfg=dict(
type='ConvModule',
in_channels=256,
out_channels=256,
kernel_size=5,
padding=2,
inplace=False,
),
post_cfg=dict(
type='GeneralizedAttention',
in_channels=256,
spatial_range=-1,
num_heads=6,
attention_type='0100',
kv_stride=2))))
_base_ = '../gcnet/mask_rcnn_r101_fpn_syncbn-backbone_r4_gcb_c3-c5_1x_coco.py'
# model settings
model = dict(
roi_head=dict(
bbox_roi_extractor=dict(
type='SumGenericRoiExtractor',
roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32],
pre_cfg=dict(
type='ConvModule',
in_channels=256,
out_channels=256,
kernel_size=5,
padding=2,
inplace=False,
),
post_cfg=dict(
type='GeneralizedAttention',
in_channels=256,
spatial_range=-1,
num_heads=6,
attention_type='0100',
kv_stride=2)),
mask_roi_extractor=dict(
type='SumGenericRoiExtractor',
roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32],
pre_cfg=dict(
type='ConvModule',
in_channels=256,
out_channels=256,
kernel_size=5,
padding=2,
inplace=False,
),
post_cfg=dict(
type='GeneralizedAttention',
in_channels=256,
spatial_range=-1,
num_heads=6,
attention_type='0100',
kv_stride=2))))
_base_ = '../mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py'
# model settings
model = dict(
roi_head=dict(
bbox_roi_extractor=dict(
type='SumGenericRoiExtractor',
roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32],
pre_cfg=dict(
type='ConvModule',
in_channels=256,
out_channels=256,
kernel_size=5,
padding=2,
inplace=False,
),
post_cfg=dict(
type='GeneralizedAttention',
in_channels=256,
spatial_range=-1,
num_heads=6,
attention_type='0100',
kv_stride=2)),
mask_roi_extractor=dict(
type='SumGenericRoiExtractor',
roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32],
pre_cfg=dict(
type='ConvModule',
in_channels=256,
out_channels=256,
kernel_size=5,
padding=2,
inplace=False,
),
post_cfg=dict(
type='GeneralizedAttention',
in_channels=256,
spatial_range=-1,
num_heads=6,
attention_type='0100',
kv_stride=2))))
_base_ = '../gcnet/mask_rcnn_r50_fpn_syncbn-backbone_r4_gcb_c3-c5_1x_coco.py'
# model settings
model = dict(
roi_head=dict(
bbox_roi_extractor=dict(
type='SumGenericRoiExtractor',
roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32],
pre_cfg=dict(
type='ConvModule',
in_channels=256,
out_channels=256,
kernel_size=5,
padding=2,
inplace=False,
),
post_cfg=dict(
type='GeneralizedAttention',
in_channels=256,
spatial_range=-1,
num_heads=6,
attention_type='0100',
kv_stride=2)),
mask_roi_extractor=dict(
type='SumGenericRoiExtractor',
roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32],
pre_cfg=dict(
type='ConvModule',
in_channels=256,
out_channels=256,
kernel_size=5,
padding=2,
inplace=False,
),
post_cfg=dict(
type='GeneralizedAttention',
in_channels=256,
spatial_range=-1,
num_heads=6,
attention_type='0100',
kv_stride=2))))
...@@ -129,6 +129,9 @@ Please refer to [RegNet](https://github.com/open-mmlab/mmdetection/blob/master/c ...@@ -129,6 +129,9 @@ Please refer to [RegNet](https://github.com/open-mmlab/mmdetection/blob/master/c
### Res2Net ### Res2Net
Please refer to [Res2Net](https://github.com/open-mmlab/mmdetection/blob/master/configs/res2net) for details. Please refer to [Res2Net](https://github.com/open-mmlab/mmdetection/blob/master/configs/res2net) for details.
### GRoIE
Please refer to [GRoIE](https://github.com/open-mmlab/mmdetection/blob/master/configs/groie) for details.
### Other datasets ### Other datasets
We also benchmark some methods on [PASCAL VOC](https://github.com/open-mmlab/mmdetection/blob/master/configs/pascal_voc), [Cityscapes](https://github.com/open-mmlab/mmdetection/blob/master/configs/cityscapes) and [WIDER FACE](https://github.com/open-mmlab/mmdetection/blob/master/configs/wider_face). We also benchmark some methods on [PASCAL VOC](https://github.com/open-mmlab/mmdetection/blob/master/configs/pascal_voc), [Cityscapes](https://github.com/open-mmlab/mmdetection/blob/master/configs/cityscapes) and [WIDER FACE](https://github.com/open-mmlab/mmdetection/blob/master/configs/wider_face).
......
from .groie import SumGenericRoiExtractor
from .single_level import SingleRoIExtractor from .single_level import SingleRoIExtractor
__all__ = ['SingleRoIExtractor'] __all__ = [
'SingleRoIExtractor',
'SumGenericRoiExtractor',
]
"""Generic RoI Extractor.
A novel Region of Interest Extraction Layer for Instance Segmentation.
"""
from torch import nn
from mmdet.core import force_fp32
from mmdet.models.builder import ROI_EXTRACTORS
from mmdet.ops.plugin import build_plugin_layer
from .single_level import SingleRoIExtractor
@ROI_EXTRACTORS.register_module
class SumGenericRoiExtractor(SingleRoIExtractor):
"""Extract RoI features from all summed feature maps levels.
https://arxiv.org/abs/2004.13665
Args:
pre_cfg (dict): Specify pre-processing modules.
post_cfg (dict): Specify post-processing modules.
kwargs (keyword arguments): Arguments that are the same
as :class:`SingleRoIExtractor`.
"""
def __init__(self, pre_cfg, post_cfg, **kwargs):
super(SumGenericRoiExtractor, self).__init__(**kwargs)
# build pre/post processing modules
self.post_module = build_plugin_layer(post_cfg, '_post_module')[1]
self.pre_module = build_plugin_layer(pre_cfg, '_pre_module')[1]
self.relu = nn.ReLU(inplace=False)
@force_fp32(apply_to=('feats', ), out_fp16=True)
def forward(self, feats, rois, roi_scale_factor=None):
if len(feats) == 1:
return self.roi_layers[0](feats[0], rois)
out_size = self.roi_layers[0].out_size
num_levels = len(feats)
roi_feats = feats[0].new_zeros(
rois.size(0), self.out_channels, *out_size)
# some times rois is an empty tensor
if roi_feats.shape[0] == 0:
return roi_feats
if roi_scale_factor is not None:
rois = self.roi_rescale(rois, roi_scale_factor)
for i in range(num_levels):
# apply pre-processing to a RoI extracted from each layer
roi_feats_t = self.roi_layers[i](feats[i], rois)
roi_feats_t = self.pre_module(roi_feats_t)
roi_feats_t = self.relu(roi_feats_t)
# and sum them all
roi_feats += roi_feats_t
# apply post-processing before return the result
x = self.post_module(roi_feats)
return x
from mmcv.cnn import ConvModule
from .context_block import ContextBlock from .context_block import ContextBlock
from .generalized_attention import GeneralizedAttention from .generalized_attention import GeneralizedAttention
from .non_local import NonLocal2D from .non_local import NonLocal2D
...@@ -6,7 +8,8 @@ plugin_cfg = { ...@@ -6,7 +8,8 @@ plugin_cfg = {
# format: layer_type: (abbreviation, module) # format: layer_type: (abbreviation, module)
'ContextBlock': ('context_block', ContextBlock), 'ContextBlock': ('context_block', ContextBlock),
'GeneralizedAttention': ('gen_attention_block', GeneralizedAttention), 'GeneralizedAttention': ('gen_attention_block', GeneralizedAttention),
'NonLocal2D': ('nonlocal_block', NonLocal2D) 'NonLocal2D': ('nonlocal_block', NonLocal2D),
'ConvModule': ('conv_block', ConvModule),
} }
......
import mmcv
import torch
from mmdet.models.roi_heads.roi_extractors import SumGenericRoiExtractor
def test_groie():
cfg = mmcv.Config(
dict(
roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32],
pre_cfg=dict(
type='ConvModule',
in_channels=256,
out_channels=256,
kernel_size=5,
padding=2,
inplace=False,
),
post_cfg=dict(
type='ConvModule',
in_channels=256,
out_channels=256,
kernel_size=5,
padding=2,
inplace=False)))
groie = SumGenericRoiExtractor(**cfg)
feats = (
torch.rand((1, 256, 200, 336)),
torch.rand((1, 256, 100, 168)),
torch.rand((1, 256, 50, 84)),
torch.rand((1, 256, 25, 42)),
)
rois = torch.tensor([[0.0000, 587.8285, 52.1405, 886.2484, 341.5644]])
res = groie(feats, rois)
assert res.shape == torch.Size([1, 256, 7, 7])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment