From c29a789c46014cc1ae181f1bbb2e8ec01c44a01c Mon Sep 17 00:00:00 2001
From: Wenwei Zhang <40779233+ZwwWayne@users.noreply.github.com>
Date: Tue, 31 Mar 2020 21:47:20 +0800
Subject: [PATCH] Encapsulate the second stage into RoI heads (#1999)

* create roi_head

* create roi_head
mv shared_head into roi_head

* fix conflict

* refactor(cascade rcnn): create cascade head in roi head, simplify cascade rcnn

* type: fix
scope: cascade_head
subject: mv stage_loss_weight from train_cfg to roi_head creation

* type: fix
scope: roi_heads
subject: handle the case when train_cfg is None

* type: fix
scope: roi_heads mask inference
subject: fix test_cfg.rcnn bug to test_cfg

* fix mask eval bug

* refactor(double_head): refactor double head roi head

* fix(double_head): clean code and fix __init__.py

* refactor(mask scoring): create mask scoring roi head

* refactor(htc_head): add htc_head

* fix(htc_head): add htc_head in __init__.py

* fix (htc_head): fix htc_head bugs

* fix (ms_roi_head): rm test_cfg.rcnn to test_cfg

* feature (grid rcnn): grid roi head

* fix (grid_head): fix grid head bug

* reformat and change all cfgs

* reformat (roi_head): reformat __init__.py for isort=4.3.21

* fix roi_head test bug

* fix (carafe): change carafe cfgs to use roi_head

* fix (roi_head): fix missing roi_head refactor

* reformat to pass CI

* test all cfgs

* match keys of configs with previous commit 77d073a

* add and pass unittest for all roi heads

* Refactor (roi_head): extract abstract base class for roi head

* Refactor (roi_head): refactor init functions

* Refactor (roi_head): weight init

* Refactor (roi_head): add _bbox_forwrd & _mask_forward as basic functions

* Fix (grid_roi_head): fix bug in bbox_forward_train

* Refactor (roi_head): change to use img.device in forward_dummy)

* Refactor (roi_heads): simplify init functions and _mask_forward

* Fix (cascade_roi_head): fix init bug of cascade roi_head

* Refactor (roi_head): use dict as outputs of _bbox_forward and _mask_forward

* Refactor (test_config): scan valid configs rather than list them all
---
 .gitignore                                    |   1 +
 .../models/cascade_mask_rcnn_r50_fpn.py       | 124 +++--
 configs/_base_/models/cascade_rcnn_r50_fpn.py | 102 ++--
 configs/_base_/models/fast_rcnn_r50_fpn.py    |  36 +-
 .../_base_/models/faster_rcnn_r50_caffe_c4.py |  54 +-
 configs/_base_/models/faster_rcnn_r50_fpn.py  |  36 +-
 .../_base_/models/mask_rcnn_r50_caffe_c4.py   |  72 +--
 configs/_base_/models/mask_rcnn_r50_fpn.py    |  62 +--
 .../mask_rcnn_r50_fpn_carafe_1x_coco.py       |  19 +-
 .../cascade_rcnn_x101_64x4d_fpn_1x_coco.py    |   1 -
 .../faster_rcnn_r50_fpn_1x_cityscapes.py      |  25 +-
 .../mask_rcnn_r50_fpn_1x_cityscapes.py        |  41 +-
 .../dcn/faster_rcnn_r50_fpn_dpool_1x_coco.py  |  23 +-
 .../dcn/faster_rcnn_r50_fpn_mdpool_1x_coco.py |  23 +-
 .../dh_faster_rcnn_r50_fpn_1x_coco.py         |  37 +-
 .../faster_rcnn_r50_caffe_fpn_1x_coco.py      |   3 +-
 .../faster_rcnn_r50_fpn_gn_ws-all_1x_coco.py  |  11 +-
 .../mask_rcnn_r50_fpn_gn_ws-all_2x_coco.py    |  13 +-
 .../gn/mask_rcnn_r50_fpn_gn-all_2x_coco.py    |  11 +-
 ...ask_rcnn_r50_fpn_gn-all_contrib_2x_coco.py |  11 +-
 .../grid_rcnn_r50_fpn_gn-head_2x_coco.py      |  60 ++-
 .../ga_fast_r50_caffe_fpn_1x_coco.py          |   3 +-
 .../ga_faster_r50_caffe_fpn_1x_coco.py        |   2 +-
 .../ga_faster_r50_fpn_1x_coco.py              |   2 +-
 configs/htc/htc_r50_fpn_1x_coco.py            |  31 +-
 .../htc_without_semantic_r50_fpn_1x_coco.py   | 170 +++---
 .../libra_fast_rcnn_r50_fpn_1x_coco.py        |  17 +-
 .../libra_faster_rcnn_r50_fpn_1x_coco.py      |  17 +-
 .../ms_rcnn/ms_rcnn_r50_caffe_fpn_1x_coco.py  |  20 +-
 configs/ms_rcnn/ms_rcnn_r50_fpn_1x_coco.py    |  20 +-
 .../faster_rcnn_r50_fpn_1x_voc0712.py         |   2 +-
 .../retinanet_r50_caffe_fpn_1x_coco.py        |   3 +-
 configs/rpn/rpn_r50_caffe_fpn_1x_coco.py      |   3 +-
 ...ter_rcnn_r50_fpn_gn-all_scratch_6x_coco.py |   9 +-
 ...ask_rcnn_r50_fpn_gn-all_scratch_6x_coco.py |  11 +-
 mmdet/models/__init__.py                      |   1 +
 mmdet/models/detectors/__init__.py            |   5 +-
 mmdet/models/detectors/base.py                |  13 +-
 mmdet/models/detectors/cascade_rcnn.py        | 508 +----------------
 mmdet/models/detectors/double_head_rcnn.py    | 178 ------
 mmdet/models/detectors/fast_rcnn.py           |  12 +-
 mmdet/models/detectors/faster_rcnn.py         |   8 +-
 mmdet/models/detectors/grid_rcnn.py           | 208 +------
 mmdet/models/detectors/htc.py                 | 510 +-----------------
 mmdet/models/detectors/mask_rcnn.py           |  12 +-
 mmdet/models/detectors/mask_scoring_rcnn.py   | 178 +-----
 mmdet/models/detectors/two_stage.py           | 214 ++------
 mmdet/models/roi_heads/__init__.py            |  11 +
 mmdet/models/roi_heads/base_roi_head.py       |  93 ++++
 mmdet/models/roi_heads/cascade_roi_head.py    | 423 +++++++++++++++
 mmdet/models/roi_heads/double_roi_head.py     |  32 ++
 mmdet/models/roi_heads/grid_roi_head.py       | 153 ++++++
 mmdet/models/roi_heads/htc_roi_head.py        | 498 +++++++++++++++++
 .../models/roi_heads/mask_scoring_roi_head.py |  86 +++
 mmdet/models/roi_heads/standard_roi_head.py   | 280 ++++++++++
 mmdet/models/roi_heads/test_mixins.py         | 202 +++++++
 tests/test_config.py                          | 287 +++++-----
 tests/test_forward.py                         |   9 +-
 tests/test_sampler.py                         |   6 +-
 59 files changed, 2592 insertions(+), 2410 deletions(-)
 delete mode 100644 mmdet/models/detectors/double_head_rcnn.py
 create mode 100644 mmdet/models/roi_heads/__init__.py
 create mode 100644 mmdet/models/roi_heads/base_roi_head.py
 create mode 100644 mmdet/models/roi_heads/cascade_roi_head.py
 create mode 100644 mmdet/models/roi_heads/double_roi_head.py
 create mode 100644 mmdet/models/roi_heads/grid_roi_head.py
 create mode 100644 mmdet/models/roi_heads/htc_roi_head.py
 create mode 100644 mmdet/models/roi_heads/mask_scoring_roi_head.py
 create mode 100644 mmdet/models/roi_heads/standard_roi_head.py
 create mode 100644 mmdet/models/roi_heads/test_mixins.py

diff --git a/.gitignore b/.gitignore
index 407baefc..5a80e725 100644
--- a/.gitignore
+++ b/.gitignore
@@ -116,3 +116,4 @@ work_dirs/
 
 # Pytorch
 *.pth
+*.py~
diff --git a/configs/_base_/models/cascade_mask_rcnn_r50_fpn.py b/configs/_base_/models/cascade_mask_rcnn_r50_fpn.py
index a42aca3c..784c97e3 100644
--- a/configs/_base_/models/cascade_mask_rcnn_r50_fpn.py
+++ b/configs/_base_/models/cascade_mask_rcnn_r50_fpn.py
@@ -1,7 +1,6 @@
 # model settings
 model = dict(
     type='CascadeRCNN',
-    num_stages=3,
     pretrained='torchvision://resnet50',
     backbone=dict(
         type='ResNet',
@@ -29,62 +28,74 @@ model = dict(
         loss_cls=dict(
             type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
         loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
-    bbox_roi_extractor=dict(
-        type='SingleRoIExtractor',
-        roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
-        out_channels=256,
-        featmap_strides=[4, 8, 16, 32]),
-    bbox_head=[
-        dict(
-            type='Shared2FCBBoxHead',
-            in_channels=256,
-            fc_out_channels=1024,
-            roi_feat_size=7,
-            num_classes=81,
-            target_means=[0., 0., 0., 0.],
-            target_stds=[0.1, 0.1, 0.2, 0.2],
-            reg_class_agnostic=True,
-            loss_cls=dict(
-                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
-            loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
-        dict(
-            type='Shared2FCBBoxHead',
+    roi_head=dict(
+        type='CascadeRoIHead',
+        num_stages=3,
+        stage_loss_weights=[1, 0.5, 0.25],
+        bbox_roi_extractor=dict(
+            type='SingleRoIExtractor',
+            roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
+            out_channels=256,
+            featmap_strides=[4, 8, 16, 32]),
+        bbox_head=[
+            dict(
+                type='Shared2FCBBoxHead',
+                in_channels=256,
+                fc_out_channels=1024,
+                roi_feat_size=7,
+                num_classes=81,
+                target_means=[0., 0., 0., 0.],
+                target_stds=[0.1, 0.1, 0.2, 0.2],
+                reg_class_agnostic=True,
+                loss_cls=dict(
+                    type='CrossEntropyLoss',
+                    use_sigmoid=False,
+                    loss_weight=1.0),
+                loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
+                               loss_weight=1.0)),
+            dict(
+                type='Shared2FCBBoxHead',
+                in_channels=256,
+                fc_out_channels=1024,
+                roi_feat_size=7,
+                num_classes=81,
+                target_means=[0., 0., 0., 0.],
+                target_stds=[0.05, 0.05, 0.1, 0.1],
+                reg_class_agnostic=True,
+                loss_cls=dict(
+                    type='CrossEntropyLoss',
+                    use_sigmoid=False,
+                    loss_weight=1.0),
+                loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
+                               loss_weight=1.0)),
+            dict(
+                type='Shared2FCBBoxHead',
+                in_channels=256,
+                fc_out_channels=1024,
+                roi_feat_size=7,
+                num_classes=81,
+                target_means=[0., 0., 0., 0.],
+                target_stds=[0.033, 0.033, 0.067, 0.067],
+                reg_class_agnostic=True,
+                loss_cls=dict(
+                    type='CrossEntropyLoss',
+                    use_sigmoid=False,
+                    loss_weight=1.0),
+                loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))
+        ],
+        mask_roi_extractor=dict(
+            type='SingleRoIExtractor',
+            roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
+            out_channels=256,
+            featmap_strides=[4, 8, 16, 32]),
+        mask_head=dict(
+            type='FCNMaskHead',
+            num_convs=4,
             in_channels=256,
-            fc_out_channels=1024,
-            roi_feat_size=7,
+            conv_out_channels=256,
             num_classes=81,
-            target_means=[0., 0., 0., 0.],
-            target_stds=[0.05, 0.05, 0.1, 0.1],
-            reg_class_agnostic=True,
-            loss_cls=dict(
-                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
-            loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
-        dict(
-            type='Shared2FCBBoxHead',
-            in_channels=256,
-            fc_out_channels=1024,
-            roi_feat_size=7,
-            num_classes=81,
-            target_means=[0., 0., 0., 0.],
-            target_stds=[0.033, 0.033, 0.067, 0.067],
-            reg_class_agnostic=True,
-            loss_cls=dict(
-                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
-            loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))
-    ],
-    mask_roi_extractor=dict(
-        type='SingleRoIExtractor',
-        roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
-        out_channels=256,
-        featmap_strides=[4, 8, 16, 32]),
-    mask_head=dict(
-        type='FCNMaskHead',
-        num_convs=4,
-        in_channels=256,
-        conv_out_channels=256,
-        num_classes=81,
-        loss_mask=dict(
-            type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)))
+            loss_mask=dict(
+                type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))))
 # model training and testing settings
 train_cfg = dict(
     rpn=dict(
@@ -159,8 +170,7 @@ train_cfg = dict(
             mask_size=28,
             pos_weight=-1,
             debug=False)
-    ],
-    stage_loss_weights=[1, 0.5, 0.25])
+    ])
 test_cfg = dict(
     rpn=dict(
         nms_across_levels=False,
diff --git a/configs/_base_/models/cascade_rcnn_r50_fpn.py b/configs/_base_/models/cascade_rcnn_r50_fpn.py
index 6ccdb99f..abb9037b 100644
--- a/configs/_base_/models/cascade_rcnn_r50_fpn.py
+++ b/configs/_base_/models/cascade_rcnn_r50_fpn.py
@@ -1,7 +1,6 @@
 # model settings
 model = dict(
     type='CascadeRCNN',
-    num_stages=3,
     pretrained='torchvision://resnet50',
     backbone=dict(
         type='ResNet',
@@ -29,49 +28,61 @@ model = dict(
         loss_cls=dict(
             type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
         loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
-    bbox_roi_extractor=dict(
-        type='SingleRoIExtractor',
-        roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
-        out_channels=256,
-        featmap_strides=[4, 8, 16, 32]),
-    bbox_head=[
-        dict(
-            type='Shared2FCBBoxHead',
-            in_channels=256,
-            fc_out_channels=1024,
-            roi_feat_size=7,
-            num_classes=81,
-            target_means=[0., 0., 0., 0.],
-            target_stds=[0.1, 0.1, 0.2, 0.2],
-            reg_class_agnostic=True,
-            loss_cls=dict(
-                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
-            loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
-        dict(
-            type='Shared2FCBBoxHead',
-            in_channels=256,
-            fc_out_channels=1024,
-            roi_feat_size=7,
-            num_classes=81,
-            target_means=[0., 0., 0., 0.],
-            target_stds=[0.05, 0.05, 0.1, 0.1],
-            reg_class_agnostic=True,
-            loss_cls=dict(
-                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
-            loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
-        dict(
-            type='Shared2FCBBoxHead',
-            in_channels=256,
-            fc_out_channels=1024,
-            roi_feat_size=7,
-            num_classes=81,
-            target_means=[0., 0., 0., 0.],
-            target_stds=[0.033, 0.033, 0.067, 0.067],
-            reg_class_agnostic=True,
-            loss_cls=dict(
-                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
-            loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))
-    ])
+    roi_head=dict(
+        type='CascadeRoIHead',
+        num_stages=3,
+        stage_loss_weights=[1, 0.5, 0.25],
+        bbox_roi_extractor=dict(
+            type='SingleRoIExtractor',
+            roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
+            out_channels=256,
+            featmap_strides=[4, 8, 16, 32]),
+        bbox_head=[
+            dict(
+                type='Shared2FCBBoxHead',
+                in_channels=256,
+                fc_out_channels=1024,
+                roi_feat_size=7,
+                num_classes=81,
+                target_means=[0., 0., 0., 0.],
+                target_stds=[0.1, 0.1, 0.2, 0.2],
+                reg_class_agnostic=True,
+                loss_cls=dict(
+                    type='CrossEntropyLoss',
+                    use_sigmoid=False,
+                    loss_weight=1.0),
+                loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
+                               loss_weight=1.0)),
+            dict(
+                type='Shared2FCBBoxHead',
+                in_channels=256,
+                fc_out_channels=1024,
+                roi_feat_size=7,
+                num_classes=81,
+                target_means=[0., 0., 0., 0.],
+                target_stds=[0.05, 0.05, 0.1, 0.1],
+                reg_class_agnostic=True,
+                loss_cls=dict(
+                    type='CrossEntropyLoss',
+                    use_sigmoid=False,
+                    loss_weight=1.0),
+                loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
+                               loss_weight=1.0)),
+            dict(
+                type='Shared2FCBBoxHead',
+                in_channels=256,
+                fc_out_channels=1024,
+                roi_feat_size=7,
+                num_classes=81,
+                target_means=[0., 0., 0., 0.],
+                target_stds=[0.033, 0.033, 0.067, 0.067],
+                reg_class_agnostic=True,
+                loss_cls=dict(
+                    type='CrossEntropyLoss',
+                    use_sigmoid=False,
+                    loss_weight=1.0),
+                loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))
+        ]))
 # model training and testing settings
 train_cfg = dict(
     rpn=dict(
@@ -143,8 +154,7 @@ train_cfg = dict(
                 add_gt_as_proposals=True),
             pos_weight=-1,
             debug=False)
-    ],
-    stage_loss_weights=[1, 0.5, 0.25])
+    ])
 test_cfg = dict(
     rpn=dict(
         nms_across_levels=False,
diff --git a/configs/_base_/models/fast_rcnn_r50_fpn.py b/configs/_base_/models/fast_rcnn_r50_fpn.py
index f29b256c..63da9fac 100644
--- a/configs/_base_/models/fast_rcnn_r50_fpn.py
+++ b/configs/_base_/models/fast_rcnn_r50_fpn.py
@@ -16,23 +16,25 @@ model = dict(
         in_channels=[256, 512, 1024, 2048],
         out_channels=256,
         num_outs=5),
-    bbox_roi_extractor=dict(
-        type='SingleRoIExtractor',
-        roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
-        out_channels=256,
-        featmap_strides=[4, 8, 16, 32]),
-    bbox_head=dict(
-        type='Shared2FCBBoxHead',
-        in_channels=256,
-        fc_out_channels=1024,
-        roi_feat_size=7,
-        num_classes=81,
-        target_means=[0., 0., 0., 0.],
-        target_stds=[0.1, 0.1, 0.2, 0.2],
-        reg_class_agnostic=False,
-        loss_cls=dict(
-            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
-        loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)))
+    roi_head=dict(
+        type='StandardRoIHead',
+        bbox_roi_extractor=dict(
+            type='SingleRoIExtractor',
+            roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
+            out_channels=256,
+            featmap_strides=[4, 8, 16, 32]),
+        bbox_head=dict(
+            type='Shared2FCBBoxHead',
+            in_channels=256,
+            fc_out_channels=1024,
+            roi_feat_size=7,
+            num_classes=81,
+            target_means=[0., 0., 0., 0.],
+            target_stds=[0.1, 0.1, 0.2, 0.2],
+            reg_class_agnostic=False,
+            loss_cls=dict(
+                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
+            loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))))
 # model training and testing settings
 train_cfg = dict(
     rcnn=dict(
diff --git a/configs/_base_/models/faster_rcnn_r50_caffe_c4.py b/configs/_base_/models/faster_rcnn_r50_caffe_c4.py
index 14ca1ab1..d97c87d5 100644
--- a/configs/_base_/models/faster_rcnn_r50_caffe_c4.py
+++ b/configs/_base_/models/faster_rcnn_r50_caffe_c4.py
@@ -14,15 +14,6 @@ model = dict(
         norm_cfg=norm_cfg,
         norm_eval=True,
         style='caffe'),
-    shared_head=dict(
-        type='ResLayer',
-        depth=50,
-        stage=3,
-        stride=2,
-        dilation=1,
-        style='caffe',
-        norm_cfg=norm_cfg,
-        norm_eval=True),
     rpn_head=dict(
         type='RPNHead',
         in_channels=1024,
@@ -35,23 +26,34 @@ model = dict(
         loss_cls=dict(
             type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
         loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
-    bbox_roi_extractor=dict(
-        type='SingleRoIExtractor',
-        roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
-        out_channels=1024,
-        featmap_strides=[16]),
-    bbox_head=dict(
-        type='BBoxHead',
-        with_avg_pool=True,
-        roi_feat_size=7,
-        in_channels=2048,
-        num_classes=81,
-        target_means=[0., 0., 0., 0.],
-        target_stds=[0.1, 0.1, 0.2, 0.2],
-        reg_class_agnostic=False,
-        loss_cls=dict(
-            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
-        loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)))
+    roi_head=dict(
+        type='StandardRoIHead',
+        shared_head=dict(
+            type='ResLayer',
+            depth=50,
+            stage=3,
+            stride=2,
+            dilation=1,
+            style='caffe',
+            norm_cfg=norm_cfg,
+            norm_eval=True),
+        bbox_roi_extractor=dict(
+            type='SingleRoIExtractor',
+            roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
+            out_channels=1024,
+            featmap_strides=[16]),
+        bbox_head=dict(
+            type='BBoxHead',
+            with_avg_pool=True,
+            roi_feat_size=7,
+            in_channels=2048,
+            num_classes=81,
+            target_means=[0., 0., 0., 0.],
+            target_stds=[0.1, 0.1, 0.2, 0.2],
+            reg_class_agnostic=False,
+            loss_cls=dict(
+                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
+            loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))))
 # model training and testing settings
 train_cfg = dict(
     rpn=dict(
diff --git a/configs/_base_/models/faster_rcnn_r50_fpn.py b/configs/_base_/models/faster_rcnn_r50_fpn.py
index 467e61dc..eb656722 100644
--- a/configs/_base_/models/faster_rcnn_r50_fpn.py
+++ b/configs/_base_/models/faster_rcnn_r50_fpn.py
@@ -27,23 +27,25 @@ model = dict(
         loss_cls=dict(
             type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
         loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
-    bbox_roi_extractor=dict(
-        type='SingleRoIExtractor',
-        roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
-        out_channels=256,
-        featmap_strides=[4, 8, 16, 32]),
-    bbox_head=dict(
-        type='Shared2FCBBoxHead',
-        in_channels=256,
-        fc_out_channels=1024,
-        roi_feat_size=7,
-        num_classes=81,
-        target_means=[0., 0., 0., 0.],
-        target_stds=[0.1, 0.1, 0.2, 0.2],
-        reg_class_agnostic=False,
-        loss_cls=dict(
-            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
-        loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)))
+    roi_head=dict(
+        type='StandardRoIHead',
+        bbox_roi_extractor=dict(
+            type='SingleRoIExtractor',
+            roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
+            out_channels=256,
+            featmap_strides=[4, 8, 16, 32]),
+        bbox_head=dict(
+            type='Shared2FCBBoxHead',
+            in_channels=256,
+            fc_out_channels=1024,
+            roi_feat_size=7,
+            num_classes=81,
+            target_means=[0., 0., 0., 0.],
+            target_stds=[0.1, 0.1, 0.2, 0.2],
+            reg_class_agnostic=False,
+            loss_cls=dict(
+                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
+            loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))))
 # model training and testing settings
 train_cfg = dict(
     rpn=dict(
diff --git a/configs/_base_/models/mask_rcnn_r50_caffe_c4.py b/configs/_base_/models/mask_rcnn_r50_caffe_c4.py
index 3f1c07bb..d9660af2 100644
--- a/configs/_base_/models/mask_rcnn_r50_caffe_c4.py
+++ b/configs/_base_/models/mask_rcnn_r50_caffe_c4.py
@@ -14,15 +14,6 @@ model = dict(
         norm_cfg=norm_cfg,
         norm_eval=True,
         style='caffe'),
-    shared_head=dict(
-        type='ResLayer',
-        depth=50,
-        stage=3,
-        stride=2,
-        dilation=1,
-        style='caffe',
-        norm_cfg=norm_cfg,
-        norm_eval=True),
     rpn_head=dict(
         type='RPNHead',
         in_channels=1024,
@@ -35,32 +26,43 @@ model = dict(
         loss_cls=dict(
             type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
         loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
-    bbox_roi_extractor=dict(
-        type='SingleRoIExtractor',
-        roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
-        out_channels=1024,
-        featmap_strides=[16]),
-    bbox_head=dict(
-        type='BBoxHead',
-        with_avg_pool=True,
-        roi_feat_size=7,
-        in_channels=2048,
-        num_classes=81,
-        target_means=[0., 0., 0., 0.],
-        target_stds=[0.1, 0.1, 0.2, 0.2],
-        reg_class_agnostic=False,
-        loss_cls=dict(
-            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
-        loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
-    mask_roi_extractor=None,
-    mask_head=dict(
-        type='FCNMaskHead',
-        num_convs=0,
-        in_channels=2048,
-        conv_out_channels=256,
-        num_classes=81,
-        loss_mask=dict(
-            type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)))
+    roi_head=dict(
+        type='StandardRoIHead',
+        shared_head=dict(
+            type='ResLayer',
+            depth=50,
+            stage=3,
+            stride=2,
+            dilation=1,
+            style='caffe',
+            norm_cfg=norm_cfg,
+            norm_eval=True),
+        bbox_roi_extractor=dict(
+            type='SingleRoIExtractor',
+            roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
+            out_channels=1024,
+            featmap_strides=[16]),
+        bbox_head=dict(
+            type='BBoxHead',
+            with_avg_pool=True,
+            roi_feat_size=7,
+            in_channels=2048,
+            num_classes=81,
+            target_means=[0., 0., 0., 0.],
+            target_stds=[0.1, 0.1, 0.2, 0.2],
+            reg_class_agnostic=False,
+            loss_cls=dict(
+                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
+            loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
+        mask_roi_extractor=None,
+        mask_head=dict(
+            type='FCNMaskHead',
+            num_convs=0,
+            in_channels=2048,
+            conv_out_channels=256,
+            num_classes=81,
+            loss_mask=dict(
+                type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))))
 # model training and testing settings
 train_cfg = dict(
     rpn=dict(
diff --git a/configs/_base_/models/mask_rcnn_r50_fpn.py b/configs/_base_/models/mask_rcnn_r50_fpn.py
index d0dce444..b47e08ee 100644
--- a/configs/_base_/models/mask_rcnn_r50_fpn.py
+++ b/configs/_base_/models/mask_rcnn_r50_fpn.py
@@ -28,36 +28,38 @@ model = dict(
         loss_cls=dict(
             type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
         loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
-    bbox_roi_extractor=dict(
-        type='SingleRoIExtractor',
-        roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
-        out_channels=256,
-        featmap_strides=[4, 8, 16, 32]),
-    bbox_head=dict(
-        type='Shared2FCBBoxHead',
-        in_channels=256,
-        fc_out_channels=1024,
-        roi_feat_size=7,
-        num_classes=81,
-        target_means=[0., 0., 0., 0.],
-        target_stds=[0.1, 0.1, 0.2, 0.2],
-        reg_class_agnostic=False,
-        loss_cls=dict(
-            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
-        loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
-    mask_roi_extractor=dict(
-        type='SingleRoIExtractor',
-        roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
-        out_channels=256,
-        featmap_strides=[4, 8, 16, 32]),
-    mask_head=dict(
-        type='FCNMaskHead',
-        num_convs=4,
-        in_channels=256,
-        conv_out_channels=256,
-        num_classes=81,
-        loss_mask=dict(
-            type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)))
+    roi_head=dict(
+        type='StandardRoIHead',
+        bbox_roi_extractor=dict(
+            type='SingleRoIExtractor',
+            roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
+            out_channels=256,
+            featmap_strides=[4, 8, 16, 32]),
+        bbox_head=dict(
+            type='Shared2FCBBoxHead',
+            in_channels=256,
+            fc_out_channels=1024,
+            roi_feat_size=7,
+            num_classes=81,
+            target_means=[0., 0., 0., 0.],
+            target_stds=[0.1, 0.1, 0.2, 0.2],
+            reg_class_agnostic=False,
+            loss_cls=dict(
+                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
+            loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
+        mask_roi_extractor=dict(
+            type='SingleRoIExtractor',
+            roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
+            out_channels=256,
+            featmap_strides=[4, 8, 16, 32]),
+        mask_head=dict(
+            type='FCNMaskHead',
+            num_convs=4,
+            in_channels=256,
+            conv_out_channels=256,
+            num_classes=81,
+            loss_mask=dict(
+                type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))))
 # model training and testing settings
 train_cfg = dict(
     rpn=dict(
diff --git a/configs/carafe/mask_rcnn_r50_fpn_carafe_1x_coco.py b/configs/carafe/mask_rcnn_r50_fpn_carafe_1x_coco.py
index abc19ad0..668c0239 100644
--- a/configs/carafe/mask_rcnn_r50_fpn_carafe_1x_coco.py
+++ b/configs/carafe/mask_rcnn_r50_fpn_carafe_1x_coco.py
@@ -17,15 +17,16 @@ model = dict(
             encoder_kernel=3,
             encoder_dilation=1,
             compressed_channels=64)),
-    mask_head=dict(
-        upsample_cfg=dict(
-            type='carafe',
-            scale_factor=2,
-            up_kernel=5,
-            up_group=1,
-            encoder_kernel=3,
-            encoder_dilation=1,
-            compressed_channels=64)))
+    roi_head=dict(
+        mask_head=dict(
+            upsample_cfg=dict(
+                type='carafe',
+                scale_factor=2,
+                up_kernel=5,
+                up_group=1,
+                encoder_kernel=3,
+                encoder_dilation=1,
+                compressed_channels=64))))
 img_norm_cfg = dict(
     mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
 train_pipeline = [
diff --git a/configs/cascade_rcnn/cascade_rcnn_x101_64x4d_fpn_1x_coco.py b/configs/cascade_rcnn/cascade_rcnn_x101_64x4d_fpn_1x_coco.py
index 25439fdc..b249bfa0 100644
--- a/configs/cascade_rcnn/cascade_rcnn_x101_64x4d_fpn_1x_coco.py
+++ b/configs/cascade_rcnn/cascade_rcnn_x101_64x4d_fpn_1x_coco.py
@@ -1,7 +1,6 @@
 _base_ = './cascade_rcnn_r50_fpn_1x_coco.py'
 model = dict(
     type='CascadeRCNN',
-    num_stages=3,
     pretrained='open-mmlab://resnext101_64x4d',
     backbone=dict(
         type='ResNeXt',
diff --git a/configs/cityscapes/faster_rcnn_r50_fpn_1x_cityscapes.py b/configs/cityscapes/faster_rcnn_r50_fpn_1x_cityscapes.py
index fa31cb07..416520a3 100644
--- a/configs/cityscapes/faster_rcnn_r50_fpn_1x_cityscapes.py
+++ b/configs/cityscapes/faster_rcnn_r50_fpn_1x_cityscapes.py
@@ -5,18 +5,19 @@ _base_ = [
 ]
 model = dict(
     pretrained=None,
-    bbox_head=dict(
-        type='Shared2FCBBoxHead',
-        in_channels=256,
-        fc_out_channels=1024,
-        roi_feat_size=7,
-        num_classes=9,
-        target_means=[0., 0., 0., 0.],
-        target_stds=[0.1, 0.1, 0.2, 0.2],
-        reg_class_agnostic=False,
-        loss_cls=dict(
-            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
-        loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)))
+    roi_head=dict(
+        bbox_head=dict(
+            type='Shared2FCBBoxHead',
+            in_channels=256,
+            fc_out_channels=1024,
+            roi_feat_size=7,
+            num_classes=9,
+            target_means=[0., 0., 0., 0.],
+            target_stds=[0.1, 0.1, 0.2, 0.2],
+            reg_class_agnostic=False,
+            loss_cls=dict(
+                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
+            loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))))
 # optimizer
 # lr is set for a batch size of 8
 optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
diff --git a/configs/cityscapes/mask_rcnn_r50_fpn_1x_cityscapes.py b/configs/cityscapes/mask_rcnn_r50_fpn_1x_cityscapes.py
index 37a3303a..6c64ae46 100644
--- a/configs/cityscapes/mask_rcnn_r50_fpn_1x_cityscapes.py
+++ b/configs/cityscapes/mask_rcnn_r50_fpn_1x_cityscapes.py
@@ -4,26 +4,27 @@ _base_ = [
 ]
 model = dict(
     pretrained=None,
-    bbox_head=dict(
-        type='Shared2FCBBoxHead',
-        in_channels=256,
-        fc_out_channels=1024,
-        roi_feat_size=7,
-        num_classes=9,
-        target_means=[0., 0., 0., 0.],
-        target_stds=[0.1, 0.1, 0.2, 0.2],
-        reg_class_agnostic=False,
-        loss_cls=dict(
-            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
-        loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
-    mask_head=dict(
-        type='FCNMaskHead',
-        num_convs=4,
-        in_channels=256,
-        conv_out_channels=256,
-        num_classes=9,
-        loss_mask=dict(
-            type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)))
+    roi_head=dict(
+        bbox_head=dict(
+            type='Shared2FCBBoxHead',
+            in_channels=256,
+            fc_out_channels=1024,
+            roi_feat_size=7,
+            num_classes=9,
+            target_means=[0., 0., 0., 0.],
+            target_stds=[0.1, 0.1, 0.2, 0.2],
+            reg_class_agnostic=False,
+            loss_cls=dict(
+                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
+            loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
+        mask_head=dict(
+            type='FCNMaskHead',
+            num_convs=4,
+            in_channels=256,
+            conv_out_channels=256,
+            num_classes=9,
+            loss_mask=dict(
+                type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))))
 # optimizer
 # lr is set for a batch size of 8
 optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
diff --git a/configs/dcn/faster_rcnn_r50_fpn_dpool_1x_coco.py b/configs/dcn/faster_rcnn_r50_fpn_dpool_1x_coco.py
index 1ba7dfdd..40396b91 100644
--- a/configs/dcn/faster_rcnn_r50_fpn_dpool_1x_coco.py
+++ b/configs/dcn/faster_rcnn_r50_fpn_dpool_1x_coco.py
@@ -1,14 +1,15 @@
 _base_ = '../faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py'
 model = dict(
-    bbox_roi_extractor=dict(
-        type='SingleRoIExtractor',
-        roi_layer=dict(
-            _delete_=True,
-            type='DeformRoIPoolingPack',
-            out_size=7,
+    roi_head=dict(
+        bbox_roi_extractor=dict(
+            type='SingleRoIExtractor',
+            roi_layer=dict(
+                _delete_=True,
+                type='DeformRoIPoolingPack',
+                out_size=7,
+                out_channels=256,
+                no_trans=False,
+                group_size=1,
+                trans_std=0.1),
             out_channels=256,
-            no_trans=False,
-            group_size=1,
-            trans_std=0.1),
-        out_channels=256,
-        featmap_strides=[4, 8, 16, 32]))
+            featmap_strides=[4, 8, 16, 32])))
diff --git a/configs/dcn/faster_rcnn_r50_fpn_mdpool_1x_coco.py b/configs/dcn/faster_rcnn_r50_fpn_mdpool_1x_coco.py
index eb3daa2e..cfeb6d92 100644
--- a/configs/dcn/faster_rcnn_r50_fpn_mdpool_1x_coco.py
+++ b/configs/dcn/faster_rcnn_r50_fpn_mdpool_1x_coco.py
@@ -1,14 +1,15 @@
 _base_ = '../faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py'
 model = dict(
-    bbox_roi_extractor=dict(
-        type='SingleRoIExtractor',
-        roi_layer=dict(
-            _delete_=True,
-            type='ModulatedDeformRoIPoolingPack',
-            out_size=7,
+    roi_head=dict(
+        bbox_roi_extractor=dict(
+            type='SingleRoIExtractor',
+            roi_layer=dict(
+                _delete_=True,
+                type='ModulatedDeformRoIPoolingPack',
+                out_size=7,
+                out_channels=256,
+                no_trans=False,
+                group_size=1,
+                trans_std=0.1),
             out_channels=256,
-            no_trans=False,
-            group_size=1,
-            trans_std=0.1),
-        out_channels=256,
-        featmap_strides=[4, 8, 16, 32]))
+            featmap_strides=[4, 8, 16, 32])))
diff --git a/configs/double_heads/dh_faster_rcnn_r50_fpn_1x_coco.py b/configs/double_heads/dh_faster_rcnn_r50_fpn_1x_coco.py
index 0fc9e93e..defb8fb6 100644
--- a/configs/double_heads/dh_faster_rcnn_r50_fpn_1x_coco.py
+++ b/configs/double_heads/dh_faster_rcnn_r50_fpn_1x_coco.py
@@ -1,20 +1,21 @@
 _base_ = '../faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py'
 model = dict(
-    type='DoubleHeadRCNN',
-    reg_roi_scale_factor=1.3,
-    bbox_head=dict(
-        _delete_=True,
-        type='DoubleConvFCBBoxHead',
-        num_convs=4,
-        num_fcs=2,
-        in_channels=256,
-        conv_out_channels=1024,
-        fc_out_channels=1024,
-        roi_feat_size=7,
-        num_classes=81,
-        target_means=[0., 0., 0., 0.],
-        target_stds=[0.1, 0.1, 0.2, 0.2],
-        reg_class_agnostic=False,
-        loss_cls=dict(
-            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=2.0),
-        loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=2.0)))
+    roi_head=dict(
+        type='DoubleHeadRoIHead',
+        reg_roi_scale_factor=1.3,
+        bbox_head=dict(
+            _delete_=True,
+            type='DoubleConvFCBBoxHead',
+            num_convs=4,
+            num_fcs=2,
+            in_channels=256,
+            conv_out_channels=1024,
+            fc_out_channels=1024,
+            roi_feat_size=7,
+            num_classes=81,
+            target_means=[0., 0., 0., 0.],
+            target_stds=[0.1, 0.1, 0.2, 0.2],
+            reg_class_agnostic=False,
+            loss_cls=dict(
+                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=2.0),
+            loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=2.0))))
diff --git a/configs/faster_rcnn/faster_rcnn_r50_caffe_fpn_1x_coco.py b/configs/faster_rcnn/faster_rcnn_r50_caffe_fpn_1x_coco.py
index 413edcf5..49ba7c50 100644
--- a/configs/faster_rcnn/faster_rcnn_r50_caffe_fpn_1x_coco.py
+++ b/configs/faster_rcnn/faster_rcnn_r50_caffe_fpn_1x_coco.py
@@ -1,7 +1,8 @@
 _base_ = './faster_rcnn_r50_fpn_1x_coco.py'
 model = dict(
     pretrained='open-mmlab://resnet50_caffe',
-    backbone=dict(norm_cfg=dict(requires_grad=False), style='caffe'))
+    backbone=dict(
+        norm_cfg=dict(requires_grad=False), norm_eval=True, style='caffe'))
 # use caffe img_norm
 img_norm_cfg = dict(
     mean=[102.9801, 115.9465, 122.7717], std=[1.0, 1.0, 1.0], to_rgb=False)
diff --git a/configs/gn+ws/faster_rcnn_r50_fpn_gn_ws-all_1x_coco.py b/configs/gn+ws/faster_rcnn_r50_fpn_gn_ws-all_1x_coco.py
index 6180b9e4..497267b6 100644
--- a/configs/gn+ws/faster_rcnn_r50_fpn_gn_ws-all_1x_coco.py
+++ b/configs/gn+ws/faster_rcnn_r50_fpn_gn_ws-all_1x_coco.py
@@ -5,8 +5,9 @@ model = dict(
     pretrained='open-mmlab://jhu/resnet50_gn_ws',
     backbone=dict(conv_cfg=conv_cfg, norm_cfg=norm_cfg),
     neck=dict(conv_cfg=conv_cfg, norm_cfg=norm_cfg),
-    bbox_head=dict(
-        type='Shared4Conv1FCBBoxHead',
-        conv_out_channels=256,
-        conv_cfg=conv_cfg,
-        norm_cfg=norm_cfg))
+    roi_head=dict(
+        bbox_head=dict(
+            type='Shared4Conv1FCBBoxHead',
+            conv_out_channels=256,
+            conv_cfg=conv_cfg,
+            norm_cfg=norm_cfg)))
diff --git a/configs/gn+ws/mask_rcnn_r50_fpn_gn_ws-all_2x_coco.py b/configs/gn+ws/mask_rcnn_r50_fpn_gn_ws-all_2x_coco.py
index 15c755db..2032b932 100644
--- a/configs/gn+ws/mask_rcnn_r50_fpn_gn_ws-all_2x_coco.py
+++ b/configs/gn+ws/mask_rcnn_r50_fpn_gn_ws-all_2x_coco.py
@@ -5,12 +5,13 @@ model = dict(
     pretrained='open-mmlab://jhu/resnet50_gn_ws',
     backbone=dict(conv_cfg=conv_cfg, norm_cfg=norm_cfg),
     neck=dict(conv_cfg=conv_cfg, norm_cfg=norm_cfg),
-    bbox_head=dict(
-        type='Shared4Conv1FCBBoxHead',
-        conv_out_channels=256,
-        conv_cfg=conv_cfg,
-        norm_cfg=norm_cfg),
-    mask_head=dict(conv_cfg=conv_cfg, norm_cfg=norm_cfg))
+    roi_head=dict(
+        bbox_head=dict(
+            type='Shared4Conv1FCBBoxHead',
+            conv_out_channels=256,
+            conv_cfg=conv_cfg,
+            norm_cfg=norm_cfg),
+        mask_head=dict(conv_cfg=conv_cfg, norm_cfg=norm_cfg)))
 # learning policy
 lr_config = dict(step=[16, 22])
 total_epochs = 24
diff --git a/configs/gn/mask_rcnn_r50_fpn_gn-all_2x_coco.py b/configs/gn/mask_rcnn_r50_fpn_gn-all_2x_coco.py
index 71ec6734..66ea47d7 100644
--- a/configs/gn/mask_rcnn_r50_fpn_gn-all_2x_coco.py
+++ b/configs/gn/mask_rcnn_r50_fpn_gn-all_2x_coco.py
@@ -4,11 +4,12 @@ model = dict(
     pretrained='open-mmlab://detectron/resnet50_gn',
     backbone=dict(norm_cfg=norm_cfg),
     neck=dict(norm_cfg=norm_cfg),
-    bbox_head=dict(
-        type='Shared4Conv1FCBBoxHead',
-        conv_out_channels=256,
-        norm_cfg=norm_cfg),
-    mask_head=dict(norm_cfg=norm_cfg))
+    roi_head=dict(
+        bbox_head=dict(
+            type='Shared4Conv1FCBBoxHead',
+            conv_out_channels=256,
+            norm_cfg=norm_cfg),
+        mask_head=dict(norm_cfg=norm_cfg)))
 img_norm_cfg = dict(
     mean=[102.9801, 115.9465, 122.7717], std=[1.0, 1.0, 1.0], to_rgb=False)
 train_pipeline = [
diff --git a/configs/gn/mask_rcnn_r50_fpn_gn-all_contrib_2x_coco.py b/configs/gn/mask_rcnn_r50_fpn_gn-all_contrib_2x_coco.py
index 973b1d55..3c690aec 100644
--- a/configs/gn/mask_rcnn_r50_fpn_gn-all_contrib_2x_coco.py
+++ b/configs/gn/mask_rcnn_r50_fpn_gn-all_contrib_2x_coco.py
@@ -4,11 +4,12 @@ model = dict(
     pretrained='open-mmlab://contrib/resnet50_gn',
     backbone=dict(norm_cfg=norm_cfg),
     neck=dict(norm_cfg=norm_cfg),
-    bbox_head=dict(
-        type='Shared4Conv1FCBBoxHead',
-        conv_out_channels=256,
-        norm_cfg=norm_cfg),
-    mask_head=dict(norm_cfg=norm_cfg))
+    roi_head=dict(
+        bbox_head=dict(
+            type='Shared4Conv1FCBBoxHead',
+            conv_out_channels=256,
+            norm_cfg=norm_cfg),
+        mask_head=dict(norm_cfg=norm_cfg)))
 # learning policy
 lr_config = dict(step=[16, 22])
 total_epochs = 24
diff --git a/configs/grid_rcnn/grid_rcnn_r50_fpn_gn-head_2x_coco.py b/configs/grid_rcnn/grid_rcnn_r50_fpn_gn-head_2x_coco.py
index 06e170ac..c42155f7 100644
--- a/configs/grid_rcnn/grid_rcnn_r50_fpn_gn-head_2x_coco.py
+++ b/configs/grid_rcnn/grid_rcnn_r50_fpn_gn-head_2x_coco.py
@@ -31,35 +31,37 @@ model = dict(
         loss_cls=dict(
             type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
         loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
-    bbox_roi_extractor=dict(
-        type='SingleRoIExtractor',
-        roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
-        out_channels=256,
-        featmap_strides=[4, 8, 16, 32]),
-    bbox_head=dict(
-        type='Shared2FCBBoxHead',
-        with_reg=False,
-        in_channels=256,
-        fc_out_channels=1024,
-        roi_feat_size=7,
-        num_classes=81,
-        target_means=[0., 0., 0., 0.],
-        target_stds=[0.1, 0.1, 0.2, 0.2],
-        reg_class_agnostic=False),
-    grid_roi_extractor=dict(
-        type='SingleRoIExtractor',
-        roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
-        out_channels=256,
-        featmap_strides=[4, 8, 16, 32]),
-    grid_head=dict(
-        type='GridHead',
-        grid_points=9,
-        num_convs=8,
-        in_channels=256,
-        point_feat_channels=64,
-        norm_cfg=dict(type='GN', num_groups=36),
-        loss_grid=dict(
-            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=15)))
+    roi_head=dict(
+        type='GridRoIHead',
+        bbox_roi_extractor=dict(
+            type='SingleRoIExtractor',
+            roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
+            out_channels=256,
+            featmap_strides=[4, 8, 16, 32]),
+        bbox_head=dict(
+            type='Shared2FCBBoxHead',
+            with_reg=False,
+            in_channels=256,
+            fc_out_channels=1024,
+            roi_feat_size=7,
+            num_classes=81,
+            target_means=[0., 0., 0., 0.],
+            target_stds=[0.1, 0.1, 0.2, 0.2],
+            reg_class_agnostic=False),
+        grid_roi_extractor=dict(
+            type='SingleRoIExtractor',
+            roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
+            out_channels=256,
+            featmap_strides=[4, 8, 16, 32]),
+        grid_head=dict(
+            type='GridHead',
+            grid_points=9,
+            num_convs=8,
+            in_channels=256,
+            point_feat_channels=64,
+            norm_cfg=dict(type='GN', num_groups=36),
+            loss_grid=dict(
+                type='CrossEntropyLoss', use_sigmoid=True, loss_weight=15))))
 # model training and testing settings
 train_cfg = dict(
     rpn=dict(
diff --git a/configs/guided_anchoring/ga_fast_r50_caffe_fpn_1x_coco.py b/configs/guided_anchoring/ga_fast_r50_caffe_fpn_1x_coco.py
index 501ba078..2c63e47e 100644
--- a/configs/guided_anchoring/ga_fast_r50_caffe_fpn_1x_coco.py
+++ b/configs/guided_anchoring/ga_fast_r50_caffe_fpn_1x_coco.py
@@ -8,8 +8,9 @@ model = dict(
         out_indices=(0, 1, 2, 3),
         frozen_stages=1,
         norm_cfg=dict(type='BN', requires_grad=False),
+        norm_eval=True,
         style='caffe'),
-    bbox_head=dict(target_stds=[0.05, 0.05, 0.1, 0.1]))
+    roi_head=dict(bbox_head=dict(target_stds=[0.05, 0.05, 0.1, 0.1])))
 # model training and testing settings
 train_cfg = dict(
     rcnn=dict(
diff --git a/configs/guided_anchoring/ga_faster_r50_caffe_fpn_1x_coco.py b/configs/guided_anchoring/ga_faster_r50_caffe_fpn_1x_coco.py
index a2419272..ae8b863e 100644
--- a/configs/guided_anchoring/ga_faster_r50_caffe_fpn_1x_coco.py
+++ b/configs/guided_anchoring/ga_faster_r50_caffe_fpn_1x_coco.py
@@ -25,7 +25,7 @@ model = dict(
         loss_cls=dict(
             type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
         loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
-    bbox_head=dict(target_stds=[0.05, 0.05, 0.1, 0.1]))
+    roi_head=dict(bbox_head=dict(target_stds=[0.05, 0.05, 0.1, 0.1])))
 # model training and testing settings
 train_cfg = dict(
     rpn=dict(
diff --git a/configs/guided_anchoring/ga_faster_r50_fpn_1x_coco.py b/configs/guided_anchoring/ga_faster_r50_fpn_1x_coco.py
index dc5dd603..41b46a58 100644
--- a/configs/guided_anchoring/ga_faster_r50_fpn_1x_coco.py
+++ b/configs/guided_anchoring/ga_faster_r50_fpn_1x_coco.py
@@ -25,7 +25,7 @@ model = dict(
         loss_cls=dict(
             type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
         loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
-    bbox_head=dict(target_stds=[0.05, 0.05, 0.1, 0.1]))
+    roi_head=dict(bbox_head=dict(target_stds=[0.05, 0.05, 0.1, 0.1])))
 # model training and testing settings
 train_cfg = dict(
     rpn=dict(
diff --git a/configs/htc/htc_r50_fpn_1x_coco.py b/configs/htc/htc_r50_fpn_1x_coco.py
index d210dfba..9fef1e6e 100644
--- a/configs/htc/htc_r50_fpn_1x_coco.py
+++ b/configs/htc/htc_r50_fpn_1x_coco.py
@@ -1,20 +1,21 @@
 _base_ = './htc_without_semantic_r50_fpn_1x_coco.py'
 model = dict(
-    semantic_roi_extractor=dict(
-        type='SingleRoIExtractor',
-        roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
-        out_channels=256,
-        featmap_strides=[8]),
-    semantic_head=dict(
-        type='FusedSemanticHead',
-        num_ins=5,
-        fusion_level=1,
-        num_convs=4,
-        in_channels=256,
-        conv_out_channels=256,
-        num_classes=183,
-        ignore_label=255,
-        loss_weight=0.2))
+    roi_head=dict(
+        semantic_roi_extractor=dict(
+            type='SingleRoIExtractor',
+            roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
+            out_channels=256,
+            featmap_strides=[8]),
+        semantic_head=dict(
+            type='FusedSemanticHead',
+            num_ins=5,
+            fusion_level=1,
+            num_convs=4,
+            in_channels=256,
+            conv_out_channels=256,
+            num_classes=183,
+            ignore_label=255,
+            loss_weight=0.2)))
 data_root = 'data/coco/'
 img_norm_cfg = dict(
     mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
diff --git a/configs/htc/htc_without_semantic_r50_fpn_1x_coco.py b/configs/htc/htc_without_semantic_r50_fpn_1x_coco.py
index dd467c86..a97a10b0 100644
--- a/configs/htc/htc_without_semantic_r50_fpn_1x_coco.py
+++ b/configs/htc/htc_without_semantic_r50_fpn_1x_coco.py
@@ -5,10 +5,7 @@ _base_ = [
 # model settings
 model = dict(
     type='HybridTaskCascade',
-    num_stages=3,
     pretrained='torchvision://resnet50',
-    interleaved=True,
-    mask_info_flow=True,
     backbone=dict(
         type='ResNet',
         depth=50,
@@ -35,81 +32,95 @@ model = dict(
         loss_cls=dict(
             type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
         loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
-    bbox_roi_extractor=dict(
-        type='SingleRoIExtractor',
-        roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
-        out_channels=256,
-        featmap_strides=[4, 8, 16, 32]),
-    bbox_head=[
-        dict(
-            type='Shared2FCBBoxHead',
-            in_channels=256,
-            fc_out_channels=1024,
-            roi_feat_size=7,
-            num_classes=81,
-            target_means=[0., 0., 0., 0.],
-            target_stds=[0.1, 0.1, 0.2, 0.2],
-            reg_class_agnostic=True,
-            loss_cls=dict(
-                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
-            loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
-        dict(
-            type='Shared2FCBBoxHead',
-            in_channels=256,
-            fc_out_channels=1024,
-            roi_feat_size=7,
-            num_classes=81,
-            target_means=[0., 0., 0., 0.],
-            target_stds=[0.05, 0.05, 0.1, 0.1],
-            reg_class_agnostic=True,
-            loss_cls=dict(
-                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
-            loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
-        dict(
-            type='Shared2FCBBoxHead',
-            in_channels=256,
-            fc_out_channels=1024,
-            roi_feat_size=7,
-            num_classes=81,
-            target_means=[0., 0., 0., 0.],
-            target_stds=[0.033, 0.033, 0.067, 0.067],
-            reg_class_agnostic=True,
-            loss_cls=dict(
-                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
-            loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))
-    ],
-    mask_roi_extractor=dict(
-        type='SingleRoIExtractor',
-        roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
-        out_channels=256,
-        featmap_strides=[4, 8, 16, 32]),
-    mask_head=[
-        dict(
-            type='HTCMaskHead',
-            with_conv_res=False,
-            num_convs=4,
-            in_channels=256,
-            conv_out_channels=256,
-            num_classes=81,
-            loss_mask=dict(
-                type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)),
-        dict(
-            type='HTCMaskHead',
-            num_convs=4,
-            in_channels=256,
-            conv_out_channels=256,
-            num_classes=81,
-            loss_mask=dict(
-                type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)),
-        dict(
-            type='HTCMaskHead',
-            num_convs=4,
-            in_channels=256,
-            conv_out_channels=256,
-            num_classes=81,
-            loss_mask=dict(
-                type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))
-    ])
+    roi_head=dict(
+        type='HybridTaskCascadeRoIHead',
+        interleaved=True,
+        mask_info_flow=True,
+        num_stages=3,
+        stage_loss_weights=[1, 0.5, 0.25],
+        bbox_roi_extractor=dict(
+            type='SingleRoIExtractor',
+            roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
+            out_channels=256,
+            featmap_strides=[4, 8, 16, 32]),
+        bbox_head=[
+            dict(
+                type='Shared2FCBBoxHead',
+                in_channels=256,
+                fc_out_channels=1024,
+                roi_feat_size=7,
+                num_classes=81,
+                target_means=[0., 0., 0., 0.],
+                target_stds=[0.1, 0.1, 0.2, 0.2],
+                reg_class_agnostic=True,
+                loss_cls=dict(
+                    type='CrossEntropyLoss',
+                    use_sigmoid=False,
+                    loss_weight=1.0),
+                loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
+                               loss_weight=1.0)),
+            dict(
+                type='Shared2FCBBoxHead',
+                in_channels=256,
+                fc_out_channels=1024,
+                roi_feat_size=7,
+                num_classes=81,
+                target_means=[0., 0., 0., 0.],
+                target_stds=[0.05, 0.05, 0.1, 0.1],
+                reg_class_agnostic=True,
+                loss_cls=dict(
+                    type='CrossEntropyLoss',
+                    use_sigmoid=False,
+                    loss_weight=1.0),
+                loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
+                               loss_weight=1.0)),
+            dict(
+                type='Shared2FCBBoxHead',
+                in_channels=256,
+                fc_out_channels=1024,
+                roi_feat_size=7,
+                num_classes=81,
+                target_means=[0., 0., 0., 0.],
+                target_stds=[0.033, 0.033, 0.067, 0.067],
+                reg_class_agnostic=True,
+                loss_cls=dict(
+                    type='CrossEntropyLoss',
+                    use_sigmoid=False,
+                    loss_weight=1.0),
+                loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))
+        ],
+        mask_roi_extractor=dict(
+            type='SingleRoIExtractor',
+            roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
+            out_channels=256,
+            featmap_strides=[4, 8, 16, 32]),
+        mask_head=[
+            dict(
+                type='HTCMaskHead',
+                with_conv_res=False,
+                num_convs=4,
+                in_channels=256,
+                conv_out_channels=256,
+                num_classes=81,
+                loss_mask=dict(
+                    type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)),
+            dict(
+                type='HTCMaskHead',
+                num_convs=4,
+                in_channels=256,
+                conv_out_channels=256,
+                num_classes=81,
+                loss_mask=dict(
+                    type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)),
+            dict(
+                type='HTCMaskHead',
+                num_convs=4,
+                in_channels=256,
+                conv_out_channels=256,
+                num_classes=81,
+                loss_mask=dict(
+                    type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))
+        ]))
 # model training and testing settings
 train_cfg = dict(
     rpn=dict(
@@ -184,8 +195,7 @@ train_cfg = dict(
             mask_size=28,
             pos_weight=-1,
             debug=False)
-    ],
-    stage_loss_weights=[1, 0.5, 0.25])
+    ])
 test_cfg = dict(
     rpn=dict(
         nms_across_levels=False,
diff --git a/configs/libra_rcnn/libra_fast_rcnn_r50_fpn_1x_coco.py b/configs/libra_rcnn/libra_fast_rcnn_r50_fpn_1x_coco.py
index ef4d555c..b416c8d0 100644
--- a/configs/libra_rcnn/libra_fast_rcnn_r50_fpn_1x_coco.py
+++ b/configs/libra_rcnn/libra_fast_rcnn_r50_fpn_1x_coco.py
@@ -14,14 +14,15 @@ model = dict(
             refine_level=2,
             refine_type='non_local')
     ],
-    bbox_head=dict(
-        loss_bbox=dict(
-            _delete_=True,
-            type='BalancedL1Loss',
-            alpha=0.5,
-            gamma=1.5,
-            beta=1.0,
-            loss_weight=1.0)))
+    roi_head=dict(
+        bbox_head=dict(
+            loss_bbox=dict(
+                _delete_=True,
+                type='BalancedL1Loss',
+                alpha=0.5,
+                gamma=1.5,
+                beta=1.0,
+                loss_weight=1.0))))
 # model training and testing settings
 train_cfg = dict(
     rcnn=dict(
diff --git a/configs/libra_rcnn/libra_faster_rcnn_r50_fpn_1x_coco.py b/configs/libra_rcnn/libra_faster_rcnn_r50_fpn_1x_coco.py
index 2f19b267..9e9b6172 100644
--- a/configs/libra_rcnn/libra_faster_rcnn_r50_fpn_1x_coco.py
+++ b/configs/libra_rcnn/libra_faster_rcnn_r50_fpn_1x_coco.py
@@ -14,14 +14,15 @@ model = dict(
             refine_level=2,
             refine_type='non_local')
     ],
-    bbox_head=dict(
-        loss_bbox=dict(
-            _delete_=True,
-            type='BalancedL1Loss',
-            alpha=0.5,
-            gamma=1.5,
-            beta=1.0,
-            loss_weight=1.0)))
+    roi_head=dict(
+        bbox_head=dict(
+            loss_bbox=dict(
+                _delete_=True,
+                type='BalancedL1Loss',
+                alpha=0.5,
+                gamma=1.5,
+                beta=1.0,
+                loss_weight=1.0))))
 # model training and testing settings
 train_cfg = dict(
     rpn=dict(sampler=dict(neg_pos_ub=5), allowed_border=-1),
diff --git a/configs/ms_rcnn/ms_rcnn_r50_caffe_fpn_1x_coco.py b/configs/ms_rcnn/ms_rcnn_r50_caffe_fpn_1x_coco.py
index c5d6974f..17551c53 100644
--- a/configs/ms_rcnn/ms_rcnn_r50_caffe_fpn_1x_coco.py
+++ b/configs/ms_rcnn/ms_rcnn_r50_caffe_fpn_1x_coco.py
@@ -1,14 +1,16 @@
 _base_ = '../mask_rcnn/mask_rcnn_r50_caffe_fpn_1x_coco.py'
 model = dict(
     type='MaskScoringRCNN',
-    mask_iou_head=dict(
-        type='MaskIoUHead',
-        num_convs=4,
-        num_fcs=2,
-        roi_feat_size=14,
-        in_channels=256,
-        conv_out_channels=256,
-        fc_out_channels=1024,
-        num_classes=81))
+    roi_head=dict(
+        type='MaskScoringRoIHead',
+        mask_iou_head=dict(
+            type='MaskIoUHead',
+            num_convs=4,
+            num_fcs=2,
+            roi_feat_size=14,
+            in_channels=256,
+            conv_out_channels=256,
+            fc_out_channels=1024,
+            num_classes=81)))
 # model training and testing settings
 train_cfg = dict(rcnn=dict(mask_thr_binary=0.5))
diff --git a/configs/ms_rcnn/ms_rcnn_r50_fpn_1x_coco.py b/configs/ms_rcnn/ms_rcnn_r50_fpn_1x_coco.py
index f5d764f8..d218a8f7 100644
--- a/configs/ms_rcnn/ms_rcnn_r50_fpn_1x_coco.py
+++ b/configs/ms_rcnn/ms_rcnn_r50_fpn_1x_coco.py
@@ -1,14 +1,16 @@
 _base_ = '../mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py'
 model = dict(
     type='MaskScoringRCNN',
-    mask_iou_head=dict(
-        type='MaskIoUHead',
-        num_convs=4,
-        num_fcs=2,
-        roi_feat_size=14,
-        in_channels=256,
-        conv_out_channels=256,
-        fc_out_channels=1024,
-        num_classes=81))
+    roi_head=dict(
+        type='MaskScoringRoIHead',
+        mask_iou_head=dict(
+            type='MaskIoUHead',
+            num_convs=4,
+            num_fcs=2,
+            roi_feat_size=14,
+            in_channels=256,
+            conv_out_channels=256,
+            fc_out_channels=1024,
+            num_classes=81)))
 # model training and testing settings
 train_cfg = dict(rcnn=dict(mask_thr_binary=0.5))
diff --git a/configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc0712.py b/configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc0712.py
index 148ae7a1..89b749bd 100644
--- a/configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc0712.py
+++ b/configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc0712.py
@@ -2,7 +2,7 @@ _base_ = [
     '../_base_/models/faster_rcnn_r50_fpn.py', '../_base_/datasets/voc0712.py',
     '../_base_/default_runtime.py'
 ]
-model = dict(bbox_head=dict(num_classes=21))
+model = dict(roi_head=dict(bbox_head=dict(num_classes=21)))
 # optimizer
 optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
 optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
diff --git a/configs/retinanet/retinanet_r50_caffe_fpn_1x_coco.py b/configs/retinanet/retinanet_r50_caffe_fpn_1x_coco.py
index fe29b26b..64728063 100644
--- a/configs/retinanet/retinanet_r50_caffe_fpn_1x_coco.py
+++ b/configs/retinanet/retinanet_r50_caffe_fpn_1x_coco.py
@@ -1,7 +1,8 @@
 _base_ = './retinanet_r50_fpn_1x_coco.py'
 model = dict(
     pretrained='open-mmlab://resnet50_caffe',
-    backbone=dict(norm_cfg=dict(requires_grad=False), style='caffe'))
+    backbone=dict(
+        norm_cfg=dict(requires_grad=False), norm_eval=True, style='caffe'))
 # use caffe img_norm
 img_norm_cfg = dict(
     mean=[102.9801, 115.9465, 122.7717], std=[1.0, 1.0, 1.0], to_rgb=False)
diff --git a/configs/rpn/rpn_r50_caffe_fpn_1x_coco.py b/configs/rpn/rpn_r50_caffe_fpn_1x_coco.py
index 1ed57fb2..a4e645ef 100644
--- a/configs/rpn/rpn_r50_caffe_fpn_1x_coco.py
+++ b/configs/rpn/rpn_r50_caffe_fpn_1x_coco.py
@@ -1,7 +1,8 @@
 _base_ = './rpn_r50_fpn_1x_coco.py'
 model = dict(
     pretrained='open-mmlab://resnet50_caffe',
-    backbone=dict(norm_cfg=dict(requires_grad=False), style='caffe'))
+    backbone=dict(
+        norm_cfg=dict(requires_grad=False), norm_eval=True, style='caffe'))
 # use caffe img_norm
 img_norm_cfg = dict(
     mean=[102.9801, 115.9465, 122.7717], std=[1.0, 1.0, 1.0], to_rgb=False)
diff --git a/configs/scratch/faster_rcnn_r50_fpn_gn-all_scratch_6x_coco.py b/configs/scratch/faster_rcnn_r50_fpn_gn-all_scratch_6x_coco.py
index b72b0dff..d91fca06 100644
--- a/configs/scratch/faster_rcnn_r50_fpn_gn-all_scratch_6x_coco.py
+++ b/configs/scratch/faster_rcnn_r50_fpn_gn-all_scratch_6x_coco.py
@@ -9,10 +9,11 @@ model = dict(
     backbone=dict(
         frozen_stages=-1, zero_init_residual=False, norm_cfg=norm_cfg),
     neck=dict(norm_cfg=norm_cfg),
-    bbox_head=dict(
-        type='Shared4Conv1FCBBoxHead',
-        conv_out_channels=256,
-        norm_cfg=norm_cfg))
+    roi_head=dict(
+        bbox_head=dict(
+            type='Shared4Conv1FCBBoxHead',
+            conv_out_channels=256,
+            norm_cfg=norm_cfg)))
 # optimizer
 optimizer = dict(paramwise_options=dict(norm_decay_mult=0))
 optimizer_config = dict(_delete_=True, grad_clip=None)
diff --git a/configs/scratch/mask_rcnn_r50_fpn_gn-all_scratch_6x_coco.py b/configs/scratch/mask_rcnn_r50_fpn_gn-all_scratch_6x_coco.py
index 5cae32ee..03808485 100644
--- a/configs/scratch/mask_rcnn_r50_fpn_gn-all_scratch_6x_coco.py
+++ b/configs/scratch/mask_rcnn_r50_fpn_gn-all_scratch_6x_coco.py
@@ -9,11 +9,12 @@ model = dict(
     backbone=dict(
         frozen_stages=-1, zero_init_residual=False, norm_cfg=norm_cfg),
     neck=dict(norm_cfg=norm_cfg),
-    bbox_head=dict(
-        type='Shared4Conv1FCBBoxHead',
-        conv_out_channels=256,
-        norm_cfg=norm_cfg),
-    mask_head=dict(norm_cfg=norm_cfg))
+    roi_head=dict(
+        bbox_head=dict(
+            type='Shared4Conv1FCBBoxHead',
+            conv_out_channels=256,
+            norm_cfg=norm_cfg),
+        mask_head=dict(norm_cfg=norm_cfg)))
 # optimizer
 optimizer = dict(paramwise_options=dict(norm_decay_mult=0))
 optimizer_config = dict(_delete_=True, grad_clip=None)
diff --git a/mmdet/models/__init__.py b/mmdet/models/__init__.py
index 35f0a09e..f25c3195 100644
--- a/mmdet/models/__init__.py
+++ b/mmdet/models/__init__.py
@@ -10,6 +10,7 @@ from .necks import *  # noqa: F401,F403
 from .registry import (BACKBONES, DETECTORS, HEADS, LOSSES, NECKS,
                        ROI_EXTRACTORS, SHARED_HEADS)
 from .roi_extractors import *  # noqa: F401,F403
+from .roi_heads import *  # noqa: F401,F403
 from .shared_heads import *  # noqa: F401,F403
 
 __all__ = [
diff --git a/mmdet/models/detectors/__init__.py b/mmdet/models/detectors/__init__.py
index 70a53d87..852c9d60 100644
--- a/mmdet/models/detectors/__init__.py
+++ b/mmdet/models/detectors/__init__.py
@@ -1,7 +1,6 @@
 from .atss import ATSS
 from .base import BaseDetector
 from .cascade_rcnn import CascadeRCNN
-from .double_head_rcnn import DoubleHeadRCNN
 from .fast_rcnn import FastRCNN
 from .faster_rcnn import FasterRCNN
 from .fcos import FCOS
@@ -19,6 +18,6 @@ from .two_stage import TwoStageDetector
 __all__ = [
     'ATSS', 'BaseDetector', 'SingleStageDetector', 'TwoStageDetector', 'RPN',
     'FastRCNN', 'FasterRCNN', 'MaskRCNN', 'CascadeRCNN', 'HybridTaskCascade',
-    'DoubleHeadRCNN', 'RetinaNet', 'FCOS', 'GridRCNN', 'MaskScoringRCNN',
-    'RepPointsDetector', 'FOVEA'
+    'RetinaNet', 'FCOS', 'GridRCNN', 'MaskScoringRCNN', 'RepPointsDetector',
+    'FOVEA'
 ]
diff --git a/mmdet/models/detectors/base.py b/mmdet/models/detectors/base.py
index 0a713074..0c9bb746 100644
--- a/mmdet/models/detectors/base.py
+++ b/mmdet/models/detectors/base.py
@@ -20,17 +20,24 @@ class BaseDetector(nn.Module, metaclass=ABCMeta):
     def with_neck(self):
         return hasattr(self, 'neck') and self.neck is not None
 
+    # TODO: these properties need to be carefully handled
+    # for both single stage & two stage detectors
     @property
     def with_shared_head(self):
-        return hasattr(self, 'shared_head') and self.shared_head is not None
+        return hasattr(self.roi_head,
+                       'shared_head') and self.roi_head.shared_head is not None
 
     @property
     def with_bbox(self):
-        return hasattr(self, 'bbox_head') and self.bbox_head is not None
+        return ((hasattr(self.roi_head, 'bbox_head')
+                 and self.roi_head.bbox_head is not None)
+                or (hasattr(self, 'bbox_head') and self.bbox_head is not None))
 
     @property
     def with_mask(self):
-        return hasattr(self, 'mask_head') and self.mask_head is not None
+        return ((hasattr(self.roi_head, 'mask_head')
+                 and self.roi_head.mask_head is not None)
+                or (hasattr(self, 'mask_head') and self.mask_head is not None))
 
     @abstractmethod
     def extract_feat(self, imgs):
diff --git a/mmdet/models/detectors/cascade_rcnn.py b/mmdet/models/detectors/cascade_rcnn.py
index dc71df1a..b5d12f31 100644
--- a/mmdet/models/detectors/cascade_rcnn.py
+++ b/mmdet/models/detectors/cascade_rcnn.py
@@ -1,512 +1,26 @@
-from __future__ import division
-
-import torch
-import torch.nn as nn
-
-from mmdet.core import (bbox2result, bbox2roi, bbox_mapping, build_assigner,
-                        build_sampler, merge_aug_bboxes, merge_aug_masks,
-                        multiclass_nms)
-from .. import builder
 from ..registry import DETECTORS
-from .base import BaseDetector
-from .test_mixins import RPNTestMixin
+from .two_stage import TwoStageDetector
 
 
 @DETECTORS.register_module
-class CascadeRCNN(BaseDetector, RPNTestMixin):
+class CascadeRCNN(TwoStageDetector):
 
     def __init__(self,
-                 num_stages,
                  backbone,
                  neck=None,
-                 shared_head=None,
                  rpn_head=None,
-                 bbox_roi_extractor=None,
-                 bbox_head=None,
-                 mask_roi_extractor=None,
-                 mask_head=None,
+                 roi_head=None,
                  train_cfg=None,
                  test_cfg=None,
                  pretrained=None):
-        assert bbox_roi_extractor is not None
-        assert bbox_head is not None
-        super(CascadeRCNN, self).__init__()
-
-        self.num_stages = num_stages
-        self.backbone = builder.build_backbone(backbone)
-
-        if neck is not None:
-            self.neck = builder.build_neck(neck)
-
-        if rpn_head is not None:
-            self.rpn_head = builder.build_head(rpn_head)
-
-        if shared_head is not None:
-            self.shared_head = builder.build_shared_head(shared_head)
-
-        if bbox_head is not None:
-            self.bbox_roi_extractor = nn.ModuleList()
-            self.bbox_head = nn.ModuleList()
-            if not isinstance(bbox_roi_extractor, list):
-                bbox_roi_extractor = [
-                    bbox_roi_extractor for _ in range(num_stages)
-                ]
-            if not isinstance(bbox_head, list):
-                bbox_head = [bbox_head for _ in range(num_stages)]
-            assert len(bbox_roi_extractor) == len(bbox_head) == self.num_stages
-            for roi_extractor, head in zip(bbox_roi_extractor, bbox_head):
-                self.bbox_roi_extractor.append(
-                    builder.build_roi_extractor(roi_extractor))
-                self.bbox_head.append(builder.build_head(head))
-
-        if mask_head is not None:
-            self.mask_head = nn.ModuleList()
-            if not isinstance(mask_head, list):
-                mask_head = [mask_head for _ in range(num_stages)]
-            assert len(mask_head) == self.num_stages
-            for head in mask_head:
-                self.mask_head.append(builder.build_head(head))
-            if mask_roi_extractor is not None:
-                self.share_roi_extractor = False
-                self.mask_roi_extractor = nn.ModuleList()
-                if not isinstance(mask_roi_extractor, list):
-                    mask_roi_extractor = [
-                        mask_roi_extractor for _ in range(num_stages)
-                    ]
-                assert len(mask_roi_extractor) == self.num_stages
-                for roi_extractor in mask_roi_extractor:
-                    self.mask_roi_extractor.append(
-                        builder.build_roi_extractor(roi_extractor))
-            else:
-                self.share_roi_extractor = True
-                self.mask_roi_extractor = self.bbox_roi_extractor
-
-        self.train_cfg = train_cfg
-        self.test_cfg = test_cfg
-
-        self.init_weights(pretrained=pretrained)
-
-    @property
-    def with_rpn(self):
-        return hasattr(self, 'rpn_head') and self.rpn_head is not None
-
-    def init_weights(self, pretrained=None):
-        super(CascadeRCNN, self).init_weights(pretrained)
-        self.backbone.init_weights(pretrained=pretrained)
-        if self.with_neck:
-            if isinstance(self.neck, nn.Sequential):
-                for m in self.neck:
-                    m.init_weights()
-            else:
-                self.neck.init_weights()
-        if self.with_rpn:
-            self.rpn_head.init_weights()
-        if self.with_shared_head:
-            self.shared_head.init_weights(pretrained=pretrained)
-        for i in range(self.num_stages):
-            if self.with_bbox:
-                self.bbox_roi_extractor[i].init_weights()
-                self.bbox_head[i].init_weights()
-            if self.with_mask:
-                if not self.share_roi_extractor:
-                    self.mask_roi_extractor[i].init_weights()
-                self.mask_head[i].init_weights()
-
-    def extract_feat(self, img):
-        x = self.backbone(img)
-        if self.with_neck:
-            x = self.neck(x)
-        return x
-
-    def forward_dummy(self, img):
-        outs = ()
-        # backbone
-        x = self.extract_feat(img)
-        # rpn
-        if self.with_rpn:
-            rpn_outs = self.rpn_head(x)
-            outs = outs + (rpn_outs, )
-        proposals = torch.randn(1000, 4).to(device=img.device)
-        # bbox heads
-        rois = bbox2roi([proposals])
-        if self.with_bbox:
-            for i in range(self.num_stages):
-                bbox_feats = self.bbox_roi_extractor[i](
-                    x[:self.bbox_roi_extractor[i].num_inputs], rois)
-                if self.with_shared_head:
-                    bbox_feats = self.shared_head(bbox_feats)
-                cls_score, bbox_pred = self.bbox_head[i](bbox_feats)
-                outs = outs + (cls_score, bbox_pred)
-        # mask heads
-        if self.with_mask:
-            mask_rois = rois[:100]
-            for i in range(self.num_stages):
-                mask_feats = self.mask_roi_extractor[i](
-                    x[:self.mask_roi_extractor[i].num_inputs], mask_rois)
-                if self.with_shared_head:
-                    mask_feats = self.shared_head(mask_feats)
-                mask_pred = self.mask_head[i](mask_feats)
-                outs = outs + (mask_pred, )
-        return outs
-
-    def forward_train(self,
-                      img,
-                      img_metas,
-                      gt_bboxes,
-                      gt_labels,
-                      gt_bboxes_ignore=None,
-                      gt_masks=None,
-                      proposals=None):
-        """
-        Args:
-            img (Tensor): of shape (N, C, H, W) encoding input images.
-                Typically these should be mean centered and std scaled.
-
-            img_metas (list[dict]): list of image info dict where each dict
-                has: 'img_shape', 'scale_factor', 'flip', and my also contain
-                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
-                For details on the values of these keys see
-                `mmdet/datasets/pipelines/formatting.py:Collect`.
-
-            gt_bboxes (list[Tensor]): each item are the truth boxes for each
-                image in [tl_x, tl_y, br_x, br_y] format.
-
-            gt_labels (list[Tensor]): class indices corresponding to each box
-
-            gt_bboxes_ignore (None | list[Tensor]): specify which bounding
-                boxes can be ignored when computing the loss.
-
-            gt_masks (None | Tensor) : true segmentation masks for each box
-                used if the architecture supports a segmentation task.
-
-            proposals : override rpn proposals with custom proposals. Use when
-                `with_rpn` is False.
-
-        Returns:
-            dict[str, Tensor]: a dictionary of loss components
-        """
-        x = self.extract_feat(img)
-
-        losses = dict()
-
-        if self.with_rpn:
-            rpn_outs = self.rpn_head(x)
-            rpn_loss_inputs = rpn_outs + (gt_bboxes, img_metas,
-                                          self.train_cfg.rpn)
-            rpn_losses = self.rpn_head.loss(
-                *rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
-            losses.update(rpn_losses)
-
-            proposal_cfg = self.train_cfg.get('rpn_proposal',
-                                              self.test_cfg.rpn)
-            proposal_inputs = rpn_outs + (img_metas, proposal_cfg)
-            proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
-        else:
-            proposal_list = proposals
-
-        for i in range(self.num_stages):
-            self.current_stage = i
-            rcnn_train_cfg = self.train_cfg.rcnn[i]
-            lw = self.train_cfg.stage_loss_weights[i]
-
-            # assign gts and sample proposals
-            sampling_results = []
-            if self.with_bbox or self.with_mask:
-                bbox_assigner = build_assigner(rcnn_train_cfg.assigner)
-                bbox_sampler = build_sampler(
-                    rcnn_train_cfg.sampler, context=self)
-                num_imgs = img.size(0)
-                if gt_bboxes_ignore is None:
-                    gt_bboxes_ignore = [None for _ in range(num_imgs)]
-
-                for j in range(num_imgs):
-                    assign_result = bbox_assigner.assign(
-                        proposal_list[j], gt_bboxes[j], gt_bboxes_ignore[j],
-                        gt_labels[j])
-                    sampling_result = bbox_sampler.sample(
-                        assign_result,
-                        proposal_list[j],
-                        gt_bboxes[j],
-                        gt_labels[j],
-                        feats=[lvl_feat[j][None] for lvl_feat in x])
-                    sampling_results.append(sampling_result)
-
-            # bbox head forward and loss
-            bbox_roi_extractor = self.bbox_roi_extractor[i]
-            bbox_head = self.bbox_head[i]
-
-            rois = bbox2roi([res.bboxes for res in sampling_results])
-
-            if len(rois) == 0:
-                # If there are no predicted and/or truth boxes, then we cannot
-                # compute head / mask losses
-                continue
-
-            bbox_feats = bbox_roi_extractor(x[:bbox_roi_extractor.num_inputs],
-                                            rois)
-            if self.with_shared_head:
-                bbox_feats = self.shared_head(bbox_feats)
-            cls_score, bbox_pred = bbox_head(bbox_feats)
-
-            bbox_targets = bbox_head.get_target(sampling_results, gt_bboxes,
-                                                gt_labels, rcnn_train_cfg)
-            loss_bbox = bbox_head.loss(cls_score, bbox_pred, *bbox_targets)
-            for name, value in loss_bbox.items():
-                losses['s{}.{}'.format(i, name)] = (
-                    value * lw if 'loss' in name else value)
-
-            # mask head forward and loss
-            if self.with_mask:
-                if not self.share_roi_extractor:
-                    mask_roi_extractor = self.mask_roi_extractor[i]
-                    pos_rois = bbox2roi(
-                        [res.pos_bboxes for res in sampling_results])
-                    mask_feats = mask_roi_extractor(
-                        x[:mask_roi_extractor.num_inputs], pos_rois)
-                    if self.with_shared_head:
-                        mask_feats = self.shared_head(mask_feats)
-                else:
-                    # reuse positive bbox feats
-                    pos_inds = []
-                    device = bbox_feats.device
-                    for res in sampling_results:
-                        pos_inds.append(
-                            torch.ones(
-                                res.pos_bboxes.shape[0],
-                                device=device,
-                                dtype=torch.uint8))
-                        pos_inds.append(
-                            torch.zeros(
-                                res.neg_bboxes.shape[0],
-                                device=device,
-                                dtype=torch.uint8))
-                    pos_inds = torch.cat(pos_inds)
-                    mask_feats = bbox_feats[pos_inds.type(torch.bool)]
-                mask_head = self.mask_head[i]
-                mask_pred = mask_head(mask_feats)
-                mask_targets = mask_head.get_target(sampling_results, gt_masks,
-                                                    rcnn_train_cfg)
-                pos_labels = torch.cat(
-                    [res.pos_gt_labels for res in sampling_results])
-                loss_mask = mask_head.loss(mask_pred, mask_targets, pos_labels)
-                for name, value in loss_mask.items():
-                    losses['s{}.{}'.format(i, name)] = (
-                        value * lw if 'loss' in name else value)
-
-            # refine bboxes
-            if i < self.num_stages - 1:
-                pos_is_gts = [res.pos_is_gt for res in sampling_results]
-                roi_labels = bbox_targets[0]  # bbox_targets is a tuple
-                with torch.no_grad():
-                    proposal_list = bbox_head.refine_bboxes(
-                        rois, roi_labels, bbox_pred, pos_is_gts, img_metas)
-
-        return losses
-
-    def simple_test(self, img, img_metas, proposals=None, rescale=False):
-        """Run inference on a single image.
-
-        Args:
-            img (Tensor): must be in shape (N, C, H, W)
-            img_metas (list[dict]): a list with one dictionary element.
-                See `mmdet/datasets/pipelines/formatting.py:Collect` for
-                details of meta dicts.
-            proposals : if specified overrides rpn proposals
-            rescale (bool): if True returns boxes in original image space
-
-        Returns:
-            dict: results
-        """
-        x = self.extract_feat(img)
-
-        proposal_list = self.simple_test_rpn(
-            x, img_metas,
-            self.test_cfg.rpn) if proposals is None else proposals
-
-        img_shape = img_metas[0]['img_shape']
-        ori_shape = img_metas[0]['ori_shape']
-        scale_factor = img_metas[0]['scale_factor']
-
-        # "ms" in variable names means multi-stage
-        ms_bbox_result = {}
-        ms_segm_result = {}
-        ms_scores = []
-        rcnn_test_cfg = self.test_cfg.rcnn
-
-        rois = bbox2roi(proposal_list)
-        for i in range(self.num_stages):
-            bbox_roi_extractor = self.bbox_roi_extractor[i]
-            bbox_head = self.bbox_head[i]
-
-            bbox_feats = bbox_roi_extractor(
-                x[:len(bbox_roi_extractor.featmap_strides)], rois)
-            if self.with_shared_head:
-                bbox_feats = self.shared_head(bbox_feats)
-
-            cls_score, bbox_pred = bbox_head(bbox_feats)
-            ms_scores.append(cls_score)
-
-            if i < self.num_stages - 1:
-                bbox_label = cls_score.argmax(dim=1)
-                rois = bbox_head.regress_by_class(rois, bbox_label, bbox_pred,
-                                                  img_metas[0])
-
-        cls_score = sum(ms_scores) / self.num_stages
-        det_bboxes, det_labels = self.bbox_head[-1].get_det_bboxes(
-            rois,
-            cls_score,
-            bbox_pred,
-            img_shape,
-            scale_factor,
-            rescale=rescale,
-            cfg=rcnn_test_cfg)
-        bbox_result = bbox2result(det_bboxes, det_labels,
-                                  self.bbox_head[-1].num_classes)
-        ms_bbox_result['ensemble'] = bbox_result
-
-        if self.with_mask:
-            if det_bboxes.shape[0] == 0:
-                mask_classes = self.mask_head[-1].num_classes - 1
-                segm_result = [[] for _ in range(mask_classes)]
-            else:
-                if isinstance(scale_factor, float):  # aspect ratio fixed
-                    _bboxes = (
-                        det_bboxes[:, :4] *
-                        scale_factor if rescale else det_bboxes)
-                else:
-                    _bboxes = (
-                        det_bboxes[:, :4] * det_bboxes.new_tensor(scale_factor)
-                        if rescale else det_bboxes)
-
-                mask_rois = bbox2roi([_bboxes])
-                aug_masks = []
-                for i in range(self.num_stages):
-                    mask_roi_extractor = self.mask_roi_extractor[i]
-                    mask_feats = mask_roi_extractor(
-                        x[:len(mask_roi_extractor.featmap_strides)], mask_rois)
-                    if self.with_shared_head:
-                        mask_feats = self.shared_head(mask_feats)
-                    mask_pred = self.mask_head[i](mask_feats)
-                    aug_masks.append(mask_pred.sigmoid().cpu().numpy())
-                merged_masks = merge_aug_masks(aug_masks,
-                                               [img_metas] * self.num_stages,
-                                               self.test_cfg.rcnn)
-                segm_result = self.mask_head[-1].get_seg_masks(
-                    merged_masks, _bboxes, det_labels, rcnn_test_cfg,
-                    ori_shape, scale_factor, rescale)
-            ms_segm_result['ensemble'] = segm_result
-
-        if self.with_mask:
-            results = (ms_bbox_result['ensemble'], ms_segm_result['ensemble'])
-        else:
-            results = ms_bbox_result['ensemble']
-
-        return results
-
-    def aug_test(self, imgs, img_metas, proposals=None, rescale=False):
-        """Test with augmentations.
-
-        If rescale is False, then returned bboxes and masks will fit the scale
-        of imgs[0].
-        """
-        # recompute feats to save memory
-        proposal_list = self.aug_test_rpn(
-            self.extract_feats(imgs), img_metas, self.test_cfg.rpn)
-
-        rcnn_test_cfg = self.test_cfg.rcnn
-        aug_bboxes = []
-        aug_scores = []
-        for x, img_meta in zip(self.extract_feats(imgs), img_metas):
-            # only one image in the batch
-            img_shape = img_meta[0]['img_shape']
-            scale_factor = img_meta[0]['scale_factor']
-            flip = img_meta[0]['flip']
-
-            proposals = bbox_mapping(proposal_list[0][:, :4], img_shape,
-                                     scale_factor, flip)
-            # "ms" in variable names means multi-stage
-            ms_scores = []
-
-            rois = bbox2roi([proposals])
-            for i in range(self.num_stages):
-                bbox_roi_extractor = self.bbox_roi_extractor[i]
-                bbox_head = self.bbox_head[i]
-
-                bbox_feats = bbox_roi_extractor(
-                    x[:len(bbox_roi_extractor.featmap_strides)], rois)
-                if self.with_shared_head:
-                    bbox_feats = self.shared_head(bbox_feats)
-
-                cls_score, bbox_pred = bbox_head(bbox_feats)
-                ms_scores.append(cls_score)
-
-                if i < self.num_stages - 1:
-                    bbox_label = cls_score.argmax(dim=1)
-                    rois = bbox_head.regress_by_class(rois, bbox_label,
-                                                      bbox_pred, img_meta[0])
-
-            cls_score = sum(ms_scores) / float(len(ms_scores))
-            bboxes, scores = self.bbox_head[-1].get_det_bboxes(
-                rois,
-                cls_score,
-                bbox_pred,
-                img_shape,
-                scale_factor,
-                rescale=False,
-                cfg=None)
-            aug_bboxes.append(bboxes)
-            aug_scores.append(scores)
-
-        # after merging, bboxes will be rescaled to the original image size
-        merged_bboxes, merged_scores = merge_aug_bboxes(
-            aug_bboxes, aug_scores, img_metas, rcnn_test_cfg)
-        det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores,
-                                                rcnn_test_cfg.score_thr,
-                                                rcnn_test_cfg.nms,
-                                                rcnn_test_cfg.max_per_img)
-
-        bbox_result = bbox2result(det_bboxes, det_labels,
-                                  self.bbox_head[-1].num_classes)
-
-        if self.with_mask:
-            if det_bboxes.shape[0] == 0:
-                segm_result = [[]
-                               for _ in range(self.mask_head[-1].num_classes -
-                                              1)]
-            else:
-                aug_masks = []
-                aug_img_metas = []
-                for x, img_meta in zip(self.extract_feats(imgs), img_metas):
-                    img_shape = img_meta[0]['img_shape']
-                    scale_factor = img_meta[0]['scale_factor']
-                    flip = img_meta[0]['flip']
-                    _bboxes = bbox_mapping(det_bboxes[:, :4], img_shape,
-                                           scale_factor, flip)
-                    mask_rois = bbox2roi([_bboxes])
-                    for i in range(self.num_stages):
-                        mask_feats = self.mask_roi_extractor[i](
-                            x[:len(self.mask_roi_extractor[i].featmap_strides
-                                   )], mask_rois)
-                        if self.with_shared_head:
-                            mask_feats = self.shared_head(mask_feats)
-                        mask_pred = self.mask_head[i](mask_feats)
-                        aug_masks.append(mask_pred.sigmoid().cpu().numpy())
-                        aug_img_metas.append(img_meta)
-                merged_masks = merge_aug_masks(aug_masks, aug_img_metas,
-                                               self.test_cfg.rcnn)
-
-                ori_shape = img_metas[0][0]['ori_shape']
-                segm_result = self.mask_head[-1].get_seg_masks(
-                    merged_masks,
-                    det_bboxes,
-                    det_labels,
-                    rcnn_test_cfg,
-                    ori_shape,
-                    scale_factor=1.0,
-                    rescale=False)
-            return bbox_result, segm_result
-        else:
-            return bbox_result
+        super(CascadeRCNN, self).__init__(
+            backbone=backbone,
+            neck=neck,
+            rpn_head=rpn_head,
+            roi_head=roi_head,
+            train_cfg=train_cfg,
+            test_cfg=test_cfg,
+            pretrained=pretrained)
 
     def show_result(self, data, result, **kwargs):
         if self.with_mask:
diff --git a/mmdet/models/detectors/double_head_rcnn.py b/mmdet/models/detectors/double_head_rcnn.py
deleted file mode 100644
index 15e04ad5..00000000
--- a/mmdet/models/detectors/double_head_rcnn.py
+++ /dev/null
@@ -1,178 +0,0 @@
-import torch
-
-from mmdet.core import bbox2roi, build_assigner, build_sampler
-from ..registry import DETECTORS
-from .two_stage import TwoStageDetector
-
-
-@DETECTORS.register_module
-class DoubleHeadRCNN(TwoStageDetector):
-
-    def __init__(self, reg_roi_scale_factor, **kwargs):
-        super().__init__(**kwargs)
-        self.reg_roi_scale_factor = reg_roi_scale_factor
-
-    def forward_dummy(self, img):
-        outs = ()
-        # backbone
-        x = self.extract_feat(img)
-        # rpn
-        if self.with_rpn:
-            rpn_outs = self.rpn_head(x)
-            outs = outs + (rpn_outs, )
-        proposals = torch.randn(1000, 4).to(device=img.device)
-        # bbox head
-        rois = bbox2roi([proposals])
-        bbox_cls_feats = self.bbox_roi_extractor(
-            x[:self.bbox_roi_extractor.num_inputs], rois)
-        bbox_reg_feats = self.bbox_roi_extractor(
-            x[:self.bbox_roi_extractor.num_inputs],
-            rois,
-            roi_scale_factor=self.reg_roi_scale_factor)
-        if self.with_shared_head:
-            bbox_cls_feats = self.shared_head(bbox_cls_feats)
-            bbox_reg_feats = self.shared_head(bbox_reg_feats)
-        cls_score, bbox_pred = self.bbox_head(bbox_cls_feats, bbox_reg_feats)
-        outs += (cls_score, bbox_pred)
-        return outs
-
-    def forward_train(self,
-                      img,
-                      img_metas,
-                      gt_bboxes,
-                      gt_labels,
-                      gt_bboxes_ignore=None,
-                      gt_masks=None,
-                      proposals=None):
-        x = self.extract_feat(img)
-
-        losses = dict()
-
-        # RPN forward and loss
-        if self.with_rpn:
-            rpn_outs = self.rpn_head(x)
-            rpn_loss_inputs = rpn_outs + (gt_bboxes, img_metas,
-                                          self.train_cfg.rpn)
-            rpn_losses = self.rpn_head.loss(
-                *rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
-            losses.update(rpn_losses)
-
-            proposal_cfg = self.train_cfg.get('rpn_proposal',
-                                              self.test_cfg.rpn)
-            proposal_inputs = rpn_outs + (img_metas, proposal_cfg)
-            proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
-        else:
-            proposal_list = proposals
-
-        # assign gts and sample proposals
-        if self.with_bbox or self.with_mask:
-            bbox_assigner = build_assigner(self.train_cfg.rcnn.assigner)
-            bbox_sampler = build_sampler(
-                self.train_cfg.rcnn.sampler, context=self)
-            num_imgs = img.size(0)
-            if gt_bboxes_ignore is None:
-                gt_bboxes_ignore = [None for _ in range(num_imgs)]
-            sampling_results = []
-            for i in range(num_imgs):
-                assign_result = bbox_assigner.assign(proposal_list[i],
-                                                     gt_bboxes[i],
-                                                     gt_bboxes_ignore[i],
-                                                     gt_labels[i])
-                sampling_result = bbox_sampler.sample(
-                    assign_result,
-                    proposal_list[i],
-                    gt_bboxes[i],
-                    gt_labels[i],
-                    feats=[lvl_feat[i][None] for lvl_feat in x])
-                sampling_results.append(sampling_result)
-
-        # bbox head forward and loss
-        if self.with_bbox:
-            rois = bbox2roi([res.bboxes for res in sampling_results])
-            # TODO: a more flexible way to decide which feature maps to use
-            bbox_cls_feats = self.bbox_roi_extractor(
-                x[:self.bbox_roi_extractor.num_inputs], rois)
-            bbox_reg_feats = self.bbox_roi_extractor(
-                x[:self.bbox_roi_extractor.num_inputs],
-                rois,
-                roi_scale_factor=self.reg_roi_scale_factor)
-            if self.with_shared_head:
-                bbox_cls_feats = self.shared_head(bbox_cls_feats)
-                bbox_reg_feats = self.shared_head(bbox_reg_feats)
-            cls_score, bbox_pred = self.bbox_head(bbox_cls_feats,
-                                                  bbox_reg_feats)
-
-            bbox_targets = self.bbox_head.get_target(sampling_results,
-                                                     gt_bboxes, gt_labels,
-                                                     self.train_cfg.rcnn)
-            loss_bbox = self.bbox_head.loss(cls_score, bbox_pred,
-                                            *bbox_targets)
-            losses.update(loss_bbox)
-
-        # mask head forward and loss
-        if self.with_mask:
-            if not self.share_roi_extractor:
-                pos_rois = bbox2roi(
-                    [res.pos_bboxes for res in sampling_results])
-                mask_feats = self.mask_roi_extractor(
-                    x[:self.mask_roi_extractor.num_inputs], pos_rois)
-                if self.with_shared_head:
-                    mask_feats = self.shared_head(mask_feats)
-            else:
-                pos_inds = []
-                device = bbox_cls_feats.device
-                for res in sampling_results:
-                    pos_inds.append(
-                        torch.ones(
-                            res.pos_bboxes.shape[0],
-                            device=device,
-                            dtype=torch.uint8))
-                    pos_inds.append(
-                        torch.zeros(
-                            res.neg_bboxes.shape[0],
-                            device=device,
-                            dtype=torch.uint8))
-                pos_inds = torch.cat(pos_inds)
-                mask_feats = bbox_cls_feats[pos_inds]
-            mask_pred = self.mask_head(mask_feats)
-
-            mask_targets = self.mask_head.get_target(sampling_results,
-                                                     gt_masks,
-                                                     self.train_cfg.rcnn)
-            pos_labels = torch.cat(
-                [res.pos_gt_labels for res in sampling_results])
-            loss_mask = self.mask_head.loss(mask_pred, mask_targets,
-                                            pos_labels)
-            losses.update(loss_mask)
-
-        return losses
-
-    def simple_test_bboxes(self,
-                           x,
-                           img_metas,
-                           proposals,
-                           rcnn_test_cfg,
-                           rescale=False):
-        """Test only det bboxes without augmentation."""
-        rois = bbox2roi(proposals)
-        bbox_cls_feats = self.bbox_roi_extractor(
-            x[:self.bbox_roi_extractor.num_inputs], rois)
-        bbox_reg_feats = self.bbox_roi_extractor(
-            x[:self.bbox_roi_extractor.num_inputs],
-            rois,
-            roi_scale_factor=self.reg_roi_scale_factor)
-        if self.with_shared_head:
-            bbox_cls_feats = self.shared_head(bbox_cls_feats)
-            bbox_reg_feats = self.shared_head(bbox_reg_feats)
-        cls_score, bbox_pred = self.bbox_head(bbox_cls_feats, bbox_reg_feats)
-        img_shape = img_metas[0]['img_shape']
-        scale_factor = img_metas[0]['scale_factor']
-        det_bboxes, det_labels = self.bbox_head.get_det_bboxes(
-            rois,
-            cls_score,
-            bbox_pred,
-            img_shape,
-            scale_factor,
-            rescale=rescale,
-            cfg=rcnn_test_cfg)
-        return det_bboxes, det_labels
diff --git a/mmdet/models/detectors/fast_rcnn.py b/mmdet/models/detectors/fast_rcnn.py
index 2145c240..dfe563c1 100644
--- a/mmdet/models/detectors/fast_rcnn.py
+++ b/mmdet/models/detectors/fast_rcnn.py
@@ -7,25 +7,17 @@ class FastRCNN(TwoStageDetector):
 
     def __init__(self,
                  backbone,
-                 bbox_roi_extractor,
-                 bbox_head,
+                 roi_head,
                  train_cfg,
                  test_cfg,
                  neck=None,
-                 shared_head=None,
-                 mask_roi_extractor=None,
-                 mask_head=None,
                  pretrained=None):
         super(FastRCNN, self).__init__(
             backbone=backbone,
             neck=neck,
-            shared_head=shared_head,
-            bbox_roi_extractor=bbox_roi_extractor,
-            bbox_head=bbox_head,
+            roi_head=roi_head,
             train_cfg=train_cfg,
             test_cfg=test_cfg,
-            mask_roi_extractor=mask_roi_extractor,
-            mask_head=mask_head,
             pretrained=pretrained)
 
     def forward_test(self, imgs, img_metas, proposals, **kwargs):
diff --git a/mmdet/models/detectors/faster_rcnn.py b/mmdet/models/detectors/faster_rcnn.py
index 969cd7cc..d4e73421 100644
--- a/mmdet/models/detectors/faster_rcnn.py
+++ b/mmdet/models/detectors/faster_rcnn.py
@@ -8,20 +8,16 @@ class FasterRCNN(TwoStageDetector):
     def __init__(self,
                  backbone,
                  rpn_head,
-                 bbox_roi_extractor,
-                 bbox_head,
+                 roi_head,
                  train_cfg,
                  test_cfg,
                  neck=None,
-                 shared_head=None,
                  pretrained=None):
         super(FasterRCNN, self).__init__(
             backbone=backbone,
             neck=neck,
-            shared_head=shared_head,
             rpn_head=rpn_head,
-            bbox_roi_extractor=bbox_roi_extractor,
-            bbox_head=bbox_head,
+            roi_head=roi_head,
             train_cfg=train_cfg,
             test_cfg=test_cfg,
             pretrained=pretrained)
diff --git a/mmdet/models/detectors/grid_rcnn.py b/mmdet/models/detectors/grid_rcnn.py
index 20d9f8d6..ed9c7ad5 100644
--- a/mmdet/models/detectors/grid_rcnn.py
+++ b/mmdet/models/detectors/grid_rcnn.py
@@ -1,7 +1,3 @@
-import torch
-
-from mmdet.core import bbox2result, bbox2roi, build_assigner, build_sampler
-from .. import builder
 from ..registry import DETECTORS
 from .two_stage import TwoStageDetector
 
@@ -18,216 +14,16 @@ class GridRCNN(TwoStageDetector):
     def __init__(self,
                  backbone,
                  rpn_head,
-                 bbox_roi_extractor,
-                 bbox_head,
-                 grid_roi_extractor,
-                 grid_head,
+                 roi_head,
                  train_cfg,
                  test_cfg,
                  neck=None,
-                 shared_head=None,
                  pretrained=None):
-        assert grid_head is not None
         super(GridRCNN, self).__init__(
             backbone=backbone,
             neck=neck,
-            shared_head=shared_head,
             rpn_head=rpn_head,
-            bbox_roi_extractor=bbox_roi_extractor,
-            bbox_head=bbox_head,
+            roi_head=roi_head,
             train_cfg=train_cfg,
             test_cfg=test_cfg,
             pretrained=pretrained)
-
-        if grid_roi_extractor is not None:
-            self.grid_roi_extractor = builder.build_roi_extractor(
-                grid_roi_extractor)
-            self.share_roi_extractor = False
-        else:
-            self.share_roi_extractor = True
-            self.grid_roi_extractor = self.bbox_roi_extractor
-        self.grid_head = builder.build_head(grid_head)
-
-        self.init_extra_weights()
-
-    def init_extra_weights(self):
-        self.grid_head.init_weights()
-        if not self.share_roi_extractor:
-            self.grid_roi_extractor.init_weights()
-
-    def _random_jitter(self, sampling_results, img_metas, amplitude=0.15):
-        """Ramdom jitter positive proposals for training."""
-        for sampling_result, img_meta in zip(sampling_results, img_metas):
-            bboxes = sampling_result.pos_bboxes
-            random_offsets = bboxes.new_empty(bboxes.shape[0], 4).uniform_(
-                -amplitude, amplitude)
-            # before jittering
-            cxcy = (bboxes[:, 2:4] + bboxes[:, :2]) / 2
-            wh = (bboxes[:, 2:4] - bboxes[:, :2]).abs()
-            # after jittering
-            new_cxcy = cxcy + wh * random_offsets[:, :2]
-            new_wh = wh * (1 + random_offsets[:, 2:])
-            # xywh to xyxy
-            new_x1y1 = (new_cxcy - new_wh / 2)
-            new_x2y2 = (new_cxcy + new_wh / 2)
-            new_bboxes = torch.cat([new_x1y1, new_x2y2], dim=1)
-            # clip bboxes
-            max_shape = img_meta['img_shape']
-            if max_shape is not None:
-                new_bboxes[:, 0::2].clamp_(min=0, max=max_shape[1] - 1)
-                new_bboxes[:, 1::2].clamp_(min=0, max=max_shape[0] - 1)
-
-            sampling_result.pos_bboxes = new_bboxes
-        return sampling_results
-
-    def forward_dummy(self, img):
-        outs = ()
-        # backbone
-        x = self.extract_feat(img)
-        # rpn
-        if self.with_rpn:
-            rpn_outs = self.rpn_head(x)
-            outs = outs + (rpn_outs, )
-        proposals = torch.randn(1000, 4).to(device=img.device)
-        # bbox head
-        rois = bbox2roi([proposals])
-        bbox_feats = self.bbox_roi_extractor(
-            x[:self.bbox_roi_extractor.num_inputs], rois)
-        if self.with_shared_head:
-            bbox_feats = self.shared_head(bbox_feats)
-        cls_score, bbox_pred = self.bbox_head(bbox_feats)
-        # grid head
-        grid_rois = rois[:100]
-        grid_feats = self.grid_roi_extractor(
-            x[:self.grid_roi_extractor.num_inputs], grid_rois)
-        if self.with_shared_head:
-            grid_feats = self.shared_head(grid_feats)
-        grid_pred = self.grid_head(grid_feats)
-        return rpn_outs, cls_score, bbox_pred, grid_pred
-
-    def forward_train(self,
-                      img,
-                      img_metas,
-                      gt_bboxes,
-                      gt_labels,
-                      gt_bboxes_ignore=None,
-                      gt_masks=None,
-                      proposals=None):
-        x = self.extract_feat(img)
-
-        losses = dict()
-
-        # RPN forward and loss
-        if self.with_rpn:
-            rpn_outs = self.rpn_head(x)
-            rpn_loss_inputs = rpn_outs + (gt_bboxes, img_metas,
-                                          self.train_cfg.rpn)
-            rpn_losses = self.rpn_head.loss(
-                *rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
-            losses.update(rpn_losses)
-
-            proposal_cfg = self.train_cfg.get('rpn_proposal',
-                                              self.test_cfg.rpn)
-            proposal_inputs = rpn_outs + (img_metas, proposal_cfg)
-            proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
-        else:
-            proposal_list = proposals
-
-        if self.with_bbox:
-            # assign gts and sample proposals
-            bbox_assigner = build_assigner(self.train_cfg.rcnn.assigner)
-            bbox_sampler = build_sampler(
-                self.train_cfg.rcnn.sampler, context=self)
-            num_imgs = img.size(0)
-            if gt_bboxes_ignore is None:
-                gt_bboxes_ignore = [None for _ in range(num_imgs)]
-            sampling_results = []
-            for i in range(num_imgs):
-                assign_result = bbox_assigner.assign(proposal_list[i],
-                                                     gt_bboxes[i],
-                                                     gt_bboxes_ignore[i],
-                                                     gt_labels[i])
-                sampling_result = bbox_sampler.sample(
-                    assign_result,
-                    proposal_list[i],
-                    gt_bboxes[i],
-                    gt_labels[i],
-                    feats=[lvl_feat[i][None] for lvl_feat in x])
-                sampling_results.append(sampling_result)
-
-            # bbox head forward and loss
-            rois = bbox2roi([res.bboxes for res in sampling_results])
-            # TODO: a more flexible way to decide which feature maps to use
-            bbox_feats = self.bbox_roi_extractor(
-                x[:self.bbox_roi_extractor.num_inputs], rois)
-            if self.with_shared_head:
-                bbox_feats = self.shared_head(bbox_feats)
-            cls_score, bbox_pred = self.bbox_head(bbox_feats)
-
-            bbox_targets = self.bbox_head.get_target(sampling_results,
-                                                     gt_bboxes, gt_labels,
-                                                     self.train_cfg.rcnn)
-            loss_bbox = self.bbox_head.loss(cls_score, bbox_pred,
-                                            *bbox_targets)
-            losses.update(loss_bbox)
-
-            # Grid head forward and loss
-            sampling_results = self._random_jitter(sampling_results, img_metas)
-            pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results])
-            grid_feats = self.grid_roi_extractor(
-                x[:self.grid_roi_extractor.num_inputs], pos_rois)
-            if self.with_shared_head:
-                grid_feats = self.shared_head(grid_feats)
-            # Accelerate training
-            max_sample_num_grid = self.train_cfg.rcnn.get('max_num_grid', 192)
-            sample_idx = torch.randperm(
-                grid_feats.shape[0])[:min(grid_feats.
-                                          shape[0], max_sample_num_grid)]
-            grid_feats = grid_feats[sample_idx]
-
-            grid_pred = self.grid_head(grid_feats)
-
-            grid_targets = self.grid_head.get_target(sampling_results,
-                                                     self.train_cfg.rcnn)
-            grid_targets = grid_targets[sample_idx]
-
-            loss_grid = self.grid_head.loss(grid_pred, grid_targets)
-            losses.update(loss_grid)
-
-        return losses
-
-    def simple_test(self, img, img_metas, proposals=None, rescale=False):
-        """Test without augmentation."""
-        assert self.with_bbox, 'Bbox head must be implemented.'
-
-        x = self.extract_feat(img)
-
-        proposal_list = self.simple_test_rpn(
-            x, img_metas,
-            self.test_cfg.rpn) if proposals is None else proposals
-
-        det_bboxes, det_labels = self.simple_test_bboxes(
-            x, img_metas, proposal_list, self.test_cfg.rcnn, rescale=False)
-
-        # pack rois into bboxes
-        grid_rois = bbox2roi([det_bboxes[:, :4]])
-        grid_feats = self.grid_roi_extractor(
-            x[:len(self.grid_roi_extractor.featmap_strides)], grid_rois)
-        if grid_rois.shape[0] != 0:
-            self.grid_head.test_mode = True
-            grid_pred = self.grid_head(grid_feats)
-            det_bboxes = self.grid_head.get_bboxes(det_bboxes,
-                                                   grid_pred['fused'],
-                                                   img_metas)
-            if rescale:
-                scale_factor = img_metas[0]['scale_factor']
-                if not isinstance(scale_factor, (float, torch.Tensor)):
-                    scale_factor = det_bboxes.new_tensor(scale_factor)
-                det_bboxes[:, :4] /= scale_factor
-        else:
-            det_bboxes = torch.Tensor([])
-
-        bbox_results = bbox2result(det_bboxes, det_labels,
-                                   self.bbox_head.num_classes)
-
-        return bbox_results
diff --git a/mmdet/models/detectors/htc.py b/mmdet/models/detectors/htc.py
index 68c64c4e..34fdd0dc 100644
--- a/mmdet/models/detectors/htc.py
+++ b/mmdet/models/detectors/htc.py
@@ -1,10 +1,3 @@
-import torch
-import torch.nn.functional as F
-
-from mmdet.core import (bbox2result, bbox2roi, bbox_mapping, build_assigner,
-                        build_sampler, merge_aug_bboxes, merge_aug_masks,
-                        multiclass_nms)
-from .. import builder
 from ..registry import DETECTORS
 from .cascade_rcnn import CascadeRCNN
 
@@ -12,506 +5,9 @@ from .cascade_rcnn import CascadeRCNN
 @DETECTORS.register_module
 class HybridTaskCascade(CascadeRCNN):
 
-    def __init__(self,
-                 num_stages,
-                 backbone,
-                 semantic_roi_extractor=None,
-                 semantic_head=None,
-                 semantic_fusion=('bbox', 'mask'),
-                 interleaved=True,
-                 mask_info_flow=True,
-                 **kwargs):
-        super(HybridTaskCascade, self).__init__(num_stages, backbone, **kwargs)
-        assert self.with_bbox and self.with_mask
-        assert not self.with_shared_head  # shared head not supported
-        if semantic_head is not None:
-            self.semantic_roi_extractor = builder.build_roi_extractor(
-                semantic_roi_extractor)
-            self.semantic_head = builder.build_head(semantic_head)
-
-        self.semantic_fusion = semantic_fusion
-        self.interleaved = interleaved
-        self.mask_info_flow = mask_info_flow
+    def __init__(self, **kwargs):
+        super(HybridTaskCascade, self).__init__(**kwargs)
 
     @property
     def with_semantic(self):
-        if hasattr(self, 'semantic_head') and self.semantic_head is not None:
-            return True
-        else:
-            return False
-
-    def _bbox_forward_train(self,
-                            stage,
-                            x,
-                            sampling_results,
-                            gt_bboxes,
-                            gt_labels,
-                            rcnn_train_cfg,
-                            semantic_feat=None):
-        rois = bbox2roi([res.bboxes for res in sampling_results])
-        bbox_roi_extractor = self.bbox_roi_extractor[stage]
-        bbox_head = self.bbox_head[stage]
-        bbox_feats = bbox_roi_extractor(x[:bbox_roi_extractor.num_inputs],
-                                        rois)
-        # semantic feature fusion
-        # element-wise sum for original features and pooled semantic features
-        if self.with_semantic and 'bbox' in self.semantic_fusion:
-            bbox_semantic_feat = self.semantic_roi_extractor([semantic_feat],
-                                                             rois)
-            if bbox_semantic_feat.shape[-2:] != bbox_feats.shape[-2:]:
-                bbox_semantic_feat = F.adaptive_avg_pool2d(
-                    bbox_semantic_feat, bbox_feats.shape[-2:])
-            bbox_feats += bbox_semantic_feat
-
-        cls_score, bbox_pred = bbox_head(bbox_feats)
-
-        bbox_targets = bbox_head.get_target(sampling_results, gt_bboxes,
-                                            gt_labels, rcnn_train_cfg)
-        loss_bbox = bbox_head.loss(cls_score, bbox_pred, *bbox_targets)
-        return loss_bbox, rois, bbox_targets, bbox_pred
-
-    def _mask_forward_train(self,
-                            stage,
-                            x,
-                            sampling_results,
-                            gt_masks,
-                            rcnn_train_cfg,
-                            semantic_feat=None):
-        mask_roi_extractor = self.mask_roi_extractor[stage]
-        mask_head = self.mask_head[stage]
-        pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results])
-        mask_feats = mask_roi_extractor(x[:mask_roi_extractor.num_inputs],
-                                        pos_rois)
-
-        # semantic feature fusion
-        # element-wise sum for original features and pooled semantic features
-        if self.with_semantic and 'mask' in self.semantic_fusion:
-            mask_semantic_feat = self.semantic_roi_extractor([semantic_feat],
-                                                             pos_rois)
-            if mask_semantic_feat.shape[-2:] != mask_feats.shape[-2:]:
-                mask_semantic_feat = F.adaptive_avg_pool2d(
-                    mask_semantic_feat, mask_feats.shape[-2:])
-            mask_feats += mask_semantic_feat
-
-        # mask information flow
-        # forward all previous mask heads to obtain last_feat, and fuse it
-        # with the normal mask feature
-        if self.mask_info_flow:
-            last_feat = None
-            for i in range(stage):
-                last_feat = self.mask_head[i](
-                    mask_feats, last_feat, return_logits=False)
-            mask_pred = mask_head(mask_feats, last_feat, return_feat=False)
-        else:
-            mask_pred = mask_head(mask_feats, return_feat=False)
-
-        mask_targets = mask_head.get_target(sampling_results, gt_masks,
-                                            rcnn_train_cfg)
-        pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results])
-        loss_mask = mask_head.loss(mask_pred, mask_targets, pos_labels)
-        return loss_mask
-
-    def _bbox_forward_test(self, stage, x, rois, semantic_feat=None):
-        bbox_roi_extractor = self.bbox_roi_extractor[stage]
-        bbox_head = self.bbox_head[stage]
-        bbox_feats = bbox_roi_extractor(
-            x[:len(bbox_roi_extractor.featmap_strides)], rois)
-        if self.with_semantic and 'bbox' in self.semantic_fusion:
-            bbox_semantic_feat = self.semantic_roi_extractor([semantic_feat],
-                                                             rois)
-            if bbox_semantic_feat.shape[-2:] != bbox_feats.shape[-2:]:
-                bbox_semantic_feat = F.adaptive_avg_pool2d(
-                    bbox_semantic_feat, bbox_feats.shape[-2:])
-            bbox_feats += bbox_semantic_feat
-        cls_score, bbox_pred = bbox_head(bbox_feats)
-        return cls_score, bbox_pred
-
-    def _mask_forward_test(self, stage, x, bboxes, semantic_feat=None):
-        mask_roi_extractor = self.mask_roi_extractor[stage]
-        mask_head = self.mask_head[stage]
-        mask_rois = bbox2roi([bboxes])
-        mask_feats = mask_roi_extractor(
-            x[:len(mask_roi_extractor.featmap_strides)], mask_rois)
-        if self.with_semantic and 'mask' in self.semantic_fusion:
-            mask_semantic_feat = self.semantic_roi_extractor([semantic_feat],
-                                                             mask_rois)
-            if mask_semantic_feat.shape[-2:] != mask_feats.shape[-2:]:
-                mask_semantic_feat = F.adaptive_avg_pool2d(
-                    mask_semantic_feat, mask_feats.shape[-2:])
-            mask_feats += mask_semantic_feat
-        if self.mask_info_flow:
-            last_feat = None
-            last_pred = None
-            for i in range(stage):
-                mask_pred, last_feat = self.mask_head[i](mask_feats, last_feat)
-                if last_pred is not None:
-                    mask_pred = mask_pred + last_pred
-                last_pred = mask_pred
-            mask_pred = mask_head(mask_feats, last_feat, return_feat=False)
-            if last_pred is not None:
-                mask_pred = mask_pred + last_pred
-        else:
-            mask_pred = mask_head(mask_feats)
-        return mask_pred
-
-    def forward_dummy(self, img):
-        outs = ()
-        # backbone
-        x = self.extract_feat(img)
-        # rpn
-        if self.with_rpn:
-            rpn_outs = self.rpn_head(x)
-            outs = outs + (rpn_outs, )
-        proposals = torch.randn(1000, 4).to(device=img.device)
-        # semantic head
-        if self.with_semantic:
-            _, semantic_feat = self.semantic_head(x)
-        else:
-            semantic_feat = None
-        # bbox heads
-        rois = bbox2roi([proposals])
-        for i in range(self.num_stages):
-            cls_score, bbox_pred = self._bbox_forward_test(
-                i, x, rois, semantic_feat=semantic_feat)
-            outs = outs + (cls_score, bbox_pred)
-        # mask heads
-        if self.with_mask:
-            mask_rois = rois[:100]
-            mask_roi_extractor = self.mask_roi_extractor[-1]
-            mask_feats = mask_roi_extractor(
-                x[:len(mask_roi_extractor.featmap_strides)], mask_rois)
-            if self.with_semantic and 'mask' in self.semantic_fusion:
-                mask_semantic_feat = self.semantic_roi_extractor(
-                    [semantic_feat], mask_rois)
-                mask_feats += mask_semantic_feat
-            last_feat = None
-            for i in range(self.num_stages):
-                mask_head = self.mask_head[i]
-                if self.mask_info_flow:
-                    mask_pred, last_feat = mask_head(mask_feats, last_feat)
-                else:
-                    mask_pred = mask_head(mask_feats)
-                outs = outs + (mask_pred, )
-        return outs
-
-    def forward_train(self,
-                      img,
-                      img_metas,
-                      gt_bboxes,
-                      gt_labels,
-                      gt_bboxes_ignore=None,
-                      gt_masks=None,
-                      gt_semantic_seg=None,
-                      proposals=None):
-        x = self.extract_feat(img)
-
-        losses = dict()
-
-        # RPN part, the same as normal two-stage detectors
-        if self.with_rpn:
-            rpn_outs = self.rpn_head(x)
-            rpn_loss_inputs = rpn_outs + (gt_bboxes, img_metas,
-                                          self.train_cfg.rpn)
-            rpn_losses = self.rpn_head.loss(
-                *rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
-            losses.update(rpn_losses)
-
-            proposal_cfg = self.train_cfg.get('rpn_proposal',
-                                              self.test_cfg.rpn)
-            proposal_inputs = rpn_outs + (img_metas, proposal_cfg)
-            proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
-        else:
-            proposal_list = proposals
-
-        # semantic segmentation part
-        # 2 outputs: segmentation prediction and embedded features
-        if self.with_semantic:
-            semantic_pred, semantic_feat = self.semantic_head(x)
-            loss_seg = self.semantic_head.loss(semantic_pred, gt_semantic_seg)
-            losses['loss_semantic_seg'] = loss_seg
-        else:
-            semantic_feat = None
-
-        for i in range(self.num_stages):
-            self.current_stage = i
-            rcnn_train_cfg = self.train_cfg.rcnn[i]
-            lw = self.train_cfg.stage_loss_weights[i]
-
-            # assign gts and sample proposals
-            sampling_results = []
-            bbox_assigner = build_assigner(rcnn_train_cfg.assigner)
-            bbox_sampler = build_sampler(rcnn_train_cfg.sampler, context=self)
-            num_imgs = img.size(0)
-            if gt_bboxes_ignore is None:
-                gt_bboxes_ignore = [None for _ in range(num_imgs)]
-
-            for j in range(num_imgs):
-                assign_result = bbox_assigner.assign(proposal_list[j],
-                                                     gt_bboxes[j],
-                                                     gt_bboxes_ignore[j],
-                                                     gt_labels[j])
-                sampling_result = bbox_sampler.sample(
-                    assign_result,
-                    proposal_list[j],
-                    gt_bboxes[j],
-                    gt_labels[j],
-                    feats=[lvl_feat[j][None] for lvl_feat in x])
-                sampling_results.append(sampling_result)
-
-            # bbox head forward and loss
-            loss_bbox, rois, bbox_targets, bbox_pred = \
-                self._bbox_forward_train(
-                    i, x, sampling_results, gt_bboxes, gt_labels,
-                    rcnn_train_cfg, semantic_feat)
-            roi_labels = bbox_targets[0]
-
-            for name, value in loss_bbox.items():
-                losses['s{}.{}'.format(i, name)] = (
-                    value * lw if 'loss' in name else value)
-
-            # mask head forward and loss
-            if self.with_mask:
-                # interleaved execution: use regressed bboxes by the box branch
-                # to train the mask branch
-                if self.interleaved:
-                    pos_is_gts = [res.pos_is_gt for res in sampling_results]
-                    with torch.no_grad():
-                        proposal_list = self.bbox_head[i].refine_bboxes(
-                            rois, roi_labels, bbox_pred, pos_is_gts, img_metas)
-                        # re-assign and sample 512 RoIs from 512 RoIs
-                        sampling_results = []
-                        for j in range(num_imgs):
-                            assign_result = bbox_assigner.assign(
-                                proposal_list[j], gt_bboxes[j],
-                                gt_bboxes_ignore[j], gt_labels[j])
-                            sampling_result = bbox_sampler.sample(
-                                assign_result,
-                                proposal_list[j],
-                                gt_bboxes[j],
-                                gt_labels[j],
-                                feats=[lvl_feat[j][None] for lvl_feat in x])
-                            sampling_results.append(sampling_result)
-                loss_mask = self._mask_forward_train(i, x, sampling_results,
-                                                     gt_masks, rcnn_train_cfg,
-                                                     semantic_feat)
-                for name, value in loss_mask.items():
-                    losses['s{}.{}'.format(i, name)] = (
-                        value * lw if 'loss' in name else value)
-
-            # refine bboxes (same as Cascade R-CNN)
-            if i < self.num_stages - 1 and not self.interleaved:
-                pos_is_gts = [res.pos_is_gt for res in sampling_results]
-                with torch.no_grad():
-                    proposal_list = self.bbox_head[i].refine_bboxes(
-                        rois, roi_labels, bbox_pred, pos_is_gts, img_metas)
-
-        return losses
-
-    def simple_test(self, img, img_metas, proposals=None, rescale=False):
-        x = self.extract_feat(img)
-        proposal_list = self.simple_test_rpn(
-            x, img_metas,
-            self.test_cfg.rpn) if proposals is None else proposals
-
-        if self.with_semantic:
-            _, semantic_feat = self.semantic_head(x)
-        else:
-            semantic_feat = None
-
-        img_shape = img_metas[0]['img_shape']
-        ori_shape = img_metas[0]['ori_shape']
-        scale_factor = img_metas[0]['scale_factor']
-
-        # "ms" in variable names means multi-stage
-        ms_bbox_result = {}
-        ms_segm_result = {}
-        ms_scores = []
-        rcnn_test_cfg = self.test_cfg.rcnn
-
-        rois = bbox2roi(proposal_list)
-        for i in range(self.num_stages):
-            bbox_head = self.bbox_head[i]
-            cls_score, bbox_pred = self._bbox_forward_test(
-                i, x, rois, semantic_feat=semantic_feat)
-            ms_scores.append(cls_score)
-
-            if i < self.num_stages - 1:
-                bbox_label = cls_score.argmax(dim=1)
-                rois = bbox_head.regress_by_class(rois, bbox_label, bbox_pred,
-                                                  img_metas[0])
-
-        cls_score = sum(ms_scores) / float(len(ms_scores))
-        det_bboxes, det_labels = self.bbox_head[-1].get_det_bboxes(
-            rois,
-            cls_score,
-            bbox_pred,
-            img_shape,
-            scale_factor,
-            rescale=rescale,
-            cfg=rcnn_test_cfg)
-        bbox_result = bbox2result(det_bboxes, det_labels,
-                                  self.bbox_head[-1].num_classes)
-        ms_bbox_result['ensemble'] = bbox_result
-
-        if self.with_mask:
-            if det_bboxes.shape[0] == 0:
-                mask_classes = self.mask_head[-1].num_classes - 1
-                segm_result = [[] for _ in range(mask_classes)]
-            else:
-                _bboxes = (
-                    det_bboxes[:, :4] * det_bboxes.new_tensor(scale_factor)
-                    if rescale else det_bboxes)
-
-                mask_rois = bbox2roi([_bboxes])
-                aug_masks = []
-                mask_roi_extractor = self.mask_roi_extractor[-1]
-                mask_feats = mask_roi_extractor(
-                    x[:len(mask_roi_extractor.featmap_strides)], mask_rois)
-                if self.with_semantic and 'mask' in self.semantic_fusion:
-                    mask_semantic_feat = self.semantic_roi_extractor(
-                        [semantic_feat], mask_rois)
-                    mask_feats += mask_semantic_feat
-                last_feat = None
-                for i in range(self.num_stages):
-                    mask_head = self.mask_head[i]
-                    if self.mask_info_flow:
-                        mask_pred, last_feat = mask_head(mask_feats, last_feat)
-                    else:
-                        mask_pred = mask_head(mask_feats)
-                    aug_masks.append(mask_pred.sigmoid().cpu().numpy())
-                merged_masks = merge_aug_masks(aug_masks,
-                                               [img_metas] * self.num_stages,
-                                               self.test_cfg.rcnn)
-                segm_result = self.mask_head[-1].get_seg_masks(
-                    merged_masks, _bboxes, det_labels, rcnn_test_cfg,
-                    ori_shape, scale_factor, rescale)
-            ms_segm_result['ensemble'] = segm_result
-
-        if self.with_mask:
-            results = (ms_bbox_result['ensemble'], ms_segm_result['ensemble'])
-        else:
-            results = ms_bbox_result['ensemble']
-
-        return results
-
-    def aug_test(self, imgs, img_metas, proposals=None, rescale=False):
-        """Test with augmentations.
-
-        If rescale is False, then returned bboxes and masks will fit the scale
-        of imgs[0].
-        """
-        if self.with_semantic:
-            semantic_feats = [
-                self.semantic_head(feat)[1]
-                for feat in self.extract_feats(imgs)
-            ]
-        else:
-            semantic_feats = [None] * len(img_metas)
-
-        # recompute feats to save memory
-        proposal_list = self.aug_test_rpn(
-            self.extract_feats(imgs), img_metas, self.test_cfg.rpn)
-
-        rcnn_test_cfg = self.test_cfg.rcnn
-        aug_bboxes = []
-        aug_scores = []
-        for x, img_meta, semantic in zip(
-                self.extract_feats(imgs), img_metas, semantic_feats):
-            # only one image in the batch
-            img_shape = img_meta[0]['img_shape']
-            scale_factor = img_meta[0]['scale_factor']
-            flip = img_meta[0]['flip']
-
-            proposals = bbox_mapping(proposal_list[0][:, :4], img_shape,
-                                     scale_factor, flip)
-            # "ms" in variable names means multi-stage
-            ms_scores = []
-
-            rois = bbox2roi([proposals])
-            for i in range(self.num_stages):
-                bbox_head = self.bbox_head[i]
-                cls_score, bbox_pred = self._bbox_forward_test(
-                    i, x, rois, semantic_feat=semantic)
-                ms_scores.append(cls_score)
-
-                if i < self.num_stages - 1:
-                    bbox_label = cls_score.argmax(dim=1)
-                    rois = bbox_head.regress_by_class(rois, bbox_label,
-                                                      bbox_pred, img_meta[0])
-
-            cls_score = sum(ms_scores) / float(len(ms_scores))
-            bboxes, scores = self.bbox_head[-1].get_det_bboxes(
-                rois,
-                cls_score,
-                bbox_pred,
-                img_shape,
-                scale_factor,
-                rescale=False,
-                cfg=None)
-            aug_bboxes.append(bboxes)
-            aug_scores.append(scores)
-
-        # after merging, bboxes will be rescaled to the original image size
-        merged_bboxes, merged_scores = merge_aug_bboxes(
-            aug_bboxes, aug_scores, img_metas, rcnn_test_cfg)
-        det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores,
-                                                rcnn_test_cfg.score_thr,
-                                                rcnn_test_cfg.nms,
-                                                rcnn_test_cfg.max_per_img)
-
-        bbox_result = bbox2result(det_bboxes, det_labels,
-                                  self.bbox_head[-1].num_classes)
-
-        if self.with_mask:
-            if det_bboxes.shape[0] == 0:
-                segm_result = [[]
-                               for _ in range(self.mask_head[-1].num_classes -
-                                              1)]
-            else:
-                aug_masks = []
-                aug_img_metas = []
-                for x, img_meta, semantic in zip(
-                        self.extract_feats(imgs), img_metas, semantic_feats):
-                    img_shape = img_meta[0]['img_shape']
-                    scale_factor = img_meta[0]['scale_factor']
-                    flip = img_meta[0]['flip']
-                    _bboxes = bbox_mapping(det_bboxes[:, :4], img_shape,
-                                           scale_factor, flip)
-                    mask_rois = bbox2roi([_bboxes])
-                    mask_feats = self.mask_roi_extractor[-1](
-                        x[:len(self.mask_roi_extractor[-1].featmap_strides)],
-                        mask_rois)
-                    if self.with_semantic:
-                        semantic_feat = semantic
-                        mask_semantic_feat = self.semantic_roi_extractor(
-                            [semantic_feat], mask_rois)
-                        if mask_semantic_feat.shape[-2:] != mask_feats.shape[
-                                -2:]:
-                            mask_semantic_feat = F.adaptive_avg_pool2d(
-                                mask_semantic_feat, mask_feats.shape[-2:])
-                        mask_feats += mask_semantic_feat
-                    last_feat = None
-                    for i in range(self.num_stages):
-                        mask_head = self.mask_head[i]
-                        if self.mask_info_flow:
-                            mask_pred, last_feat = mask_head(
-                                mask_feats, last_feat)
-                        else:
-                            mask_pred = mask_head(mask_feats)
-                        aug_masks.append(mask_pred.sigmoid().cpu().numpy())
-                        aug_img_metas.append(img_meta)
-                merged_masks = merge_aug_masks(aug_masks, aug_img_metas,
-                                               self.test_cfg.rcnn)
-
-                ori_shape = img_metas[0][0]['ori_shape']
-                segm_result = self.mask_head[-1].get_seg_masks(
-                    merged_masks,
-                    det_bboxes,
-                    det_labels,
-                    rcnn_test_cfg,
-                    ori_shape,
-                    scale_factor=1.0,
-                    rescale=False)
-            return bbox_result, segm_result
-        else:
-            return bbox_result
+        return self.roi_head.with_semantic
diff --git a/mmdet/models/detectors/mask_rcnn.py b/mmdet/models/detectors/mask_rcnn.py
index becfdad5..718c5553 100644
--- a/mmdet/models/detectors/mask_rcnn.py
+++ b/mmdet/models/detectors/mask_rcnn.py
@@ -8,24 +8,16 @@ class MaskRCNN(TwoStageDetector):
     def __init__(self,
                  backbone,
                  rpn_head,
-                 bbox_roi_extractor,
-                 bbox_head,
-                 mask_roi_extractor,
-                 mask_head,
+                 roi_head,
                  train_cfg,
                  test_cfg,
                  neck=None,
-                 shared_head=None,
                  pretrained=None):
         super(MaskRCNN, self).__init__(
             backbone=backbone,
             neck=neck,
-            shared_head=shared_head,
             rpn_head=rpn_head,
-            bbox_roi_extractor=bbox_roi_extractor,
-            bbox_head=bbox_head,
-            mask_roi_extractor=mask_roi_extractor,
-            mask_head=mask_head,
+            roi_head=roi_head,
             train_cfg=train_cfg,
             test_cfg=test_cfg,
             pretrained=pretrained)
diff --git a/mmdet/models/detectors/mask_scoring_rcnn.py b/mmdet/models/detectors/mask_scoring_rcnn.py
index b29b220f..f7d81125 100644
--- a/mmdet/models/detectors/mask_scoring_rcnn.py
+++ b/mmdet/models/detectors/mask_scoring_rcnn.py
@@ -1,7 +1,3 @@
-import torch
-
-from mmdet.core import bbox2roi, build_assigner, build_sampler
-from .. import builder
 from ..registry import DETECTORS
 from .two_stage import TwoStageDetector
 
@@ -16,186 +12,16 @@ class MaskScoringRCNN(TwoStageDetector):
     def __init__(self,
                  backbone,
                  rpn_head,
-                 bbox_roi_extractor,
-                 bbox_head,
-                 mask_roi_extractor,
-                 mask_head,
+                 roi_head,
                  train_cfg,
                  test_cfg,
                  neck=None,
-                 shared_head=None,
-                 mask_iou_head=None,
                  pretrained=None):
         super(MaskScoringRCNN, self).__init__(
             backbone=backbone,
             neck=neck,
-            shared_head=shared_head,
             rpn_head=rpn_head,
-            bbox_roi_extractor=bbox_roi_extractor,
-            bbox_head=bbox_head,
-            mask_roi_extractor=mask_roi_extractor,
-            mask_head=mask_head,
+            roi_head=roi_head,
             train_cfg=train_cfg,
             test_cfg=test_cfg,
             pretrained=pretrained)
-
-        self.mask_iou_head = builder.build_head(mask_iou_head)
-        self.mask_iou_head.init_weights()
-
-    def forward_dummy(self, img):
-        raise NotImplementedError
-
-    # TODO: refactor forward_train in two stage to reduce code redundancy
-    def forward_train(self,
-                      img,
-                      img_metas,
-                      gt_bboxes,
-                      gt_labels,
-                      gt_bboxes_ignore=None,
-                      gt_masks=None,
-                      proposals=None):
-        x = self.extract_feat(img)
-
-        losses = dict()
-
-        # RPN forward and loss
-        if self.with_rpn:
-            rpn_outs = self.rpn_head(x)
-            rpn_loss_inputs = rpn_outs + (gt_bboxes, img_metas,
-                                          self.train_cfg.rpn)
-            rpn_losses = self.rpn_head.loss(
-                *rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
-            losses.update(rpn_losses)
-
-            proposal_cfg = self.train_cfg.get('rpn_proposal',
-                                              self.test_cfg.rpn)
-            proposal_inputs = rpn_outs + (img_metas, proposal_cfg)
-            proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
-        else:
-            proposal_list = proposals
-
-        # assign gts and sample proposals
-        if self.with_bbox or self.with_mask:
-            bbox_assigner = build_assigner(self.train_cfg.rcnn.assigner)
-            bbox_sampler = build_sampler(
-                self.train_cfg.rcnn.sampler, context=self)
-            num_imgs = img.size(0)
-            if gt_bboxes_ignore is None:
-                gt_bboxes_ignore = [None for _ in range(num_imgs)]
-            sampling_results = []
-            for i in range(num_imgs):
-                assign_result = bbox_assigner.assign(proposal_list[i],
-                                                     gt_bboxes[i],
-                                                     gt_bboxes_ignore[i],
-                                                     gt_labels[i])
-                sampling_result = bbox_sampler.sample(
-                    assign_result,
-                    proposal_list[i],
-                    gt_bboxes[i],
-                    gt_labels[i],
-                    feats=[lvl_feat[i][None] for lvl_feat in x])
-                sampling_results.append(sampling_result)
-
-        # bbox head forward and loss
-        if self.with_bbox:
-            rois = bbox2roi([res.bboxes for res in sampling_results])
-            # TODO: a more flexible way to decide which feature maps to use
-            bbox_feats = self.bbox_roi_extractor(
-                x[:self.bbox_roi_extractor.num_inputs], rois)
-            if self.with_shared_head:
-                bbox_feats = self.shared_head(bbox_feats)
-            cls_score, bbox_pred = self.bbox_head(bbox_feats)
-
-            bbox_targets = self.bbox_head.get_target(sampling_results,
-                                                     gt_bboxes, gt_labels,
-                                                     self.train_cfg.rcnn)
-            loss_bbox = self.bbox_head.loss(cls_score, bbox_pred,
-                                            *bbox_targets)
-            losses.update(loss_bbox)
-
-        # mask head forward and loss
-        if self.with_mask:
-            if not self.share_roi_extractor:
-                pos_rois = bbox2roi(
-                    [res.pos_bboxes for res in sampling_results])
-                mask_feats = self.mask_roi_extractor(
-                    x[:self.mask_roi_extractor.num_inputs], pos_rois)
-                if self.with_shared_head:
-                    mask_feats = self.shared_head(mask_feats)
-            else:
-                pos_inds = []
-                device = bbox_feats.device
-                for res in sampling_results:
-                    pos_inds.append(
-                        torch.ones(
-                            res.pos_bboxes.shape[0],
-                            device=device,
-                            dtype=torch.uint8))
-                    pos_inds.append(
-                        torch.zeros(
-                            res.neg_bboxes.shape[0],
-                            device=device,
-                            dtype=torch.uint8))
-                pos_inds = torch.cat(pos_inds)
-                mask_feats = bbox_feats[pos_inds]
-            mask_pred = self.mask_head(mask_feats)
-
-            mask_targets = self.mask_head.get_target(sampling_results,
-                                                     gt_masks,
-                                                     self.train_cfg.rcnn)
-            pos_labels = torch.cat(
-                [res.pos_gt_labels for res in sampling_results])
-            loss_mask = self.mask_head.loss(mask_pred, mask_targets,
-                                            pos_labels)
-            losses.update(loss_mask)
-
-            # mask iou head forward and loss
-            pos_mask_pred = mask_pred[range(mask_pred.size(0)), pos_labels]
-            mask_iou_pred = self.mask_iou_head(mask_feats, pos_mask_pred)
-            pos_mask_iou_pred = mask_iou_pred[range(mask_iou_pred.size(0)),
-                                              pos_labels]
-            mask_iou_targets = self.mask_iou_head.get_target(
-                sampling_results, gt_masks, pos_mask_pred, mask_targets,
-                self.train_cfg.rcnn)
-            loss_mask_iou = self.mask_iou_head.loss(pos_mask_iou_pred,
-                                                    mask_iou_targets)
-            losses.update(loss_mask_iou)
-        return losses
-
-    def simple_test_mask(self,
-                         x,
-                         img_metas,
-                         det_bboxes,
-                         det_labels,
-                         rescale=False):
-        # image shape of the first image in the batch (only one)
-        ori_shape = img_metas[0]['ori_shape']
-        scale_factor = img_metas[0]['scale_factor']
-
-        if det_bboxes.shape[0] == 0:
-            segm_result = [[] for _ in range(self.mask_head.num_classes - 1)]
-            mask_scores = [[] for _ in range(self.mask_head.num_classes - 1)]
-        else:
-            # if det_bboxes is rescaled to the original image size, we need to
-            # rescale it back to the testing scale to obtain RoIs.
-            _bboxes = (
-                det_bboxes[:, :4] *
-                det_bboxes.new_tensor(scale_factor) if rescale else det_bboxes)
-            mask_rois = bbox2roi([_bboxes])
-            mask_feats = self.mask_roi_extractor(
-                x[:len(self.mask_roi_extractor.featmap_strides)], mask_rois)
-            if self.with_shared_head:
-                mask_feats = self.shared_head(mask_feats)
-            mask_pred = self.mask_head(mask_feats)
-            segm_result = self.mask_head.get_seg_masks(mask_pred, _bboxes,
-                                                       det_labels,
-                                                       self.test_cfg.rcnn,
-                                                       ori_shape, scale_factor,
-                                                       rescale)
-            # get mask scores with mask iou head
-            mask_iou_pred = self.mask_iou_head(
-                mask_feats, mask_pred[range(det_labels.size(0)),
-                                      det_labels + 1])
-            mask_scores = self.mask_iou_head.get_mask_scores(
-                mask_iou_pred, det_bboxes, det_labels)
-        return segm_result, mask_scores
diff --git a/mmdet/models/detectors/two_stage.py b/mmdet/models/detectors/two_stage.py
index fb18330c..79f62c45 100644
--- a/mmdet/models/detectors/two_stage.py
+++ b/mmdet/models/detectors/two_stage.py
@@ -1,16 +1,15 @@
 import torch
 import torch.nn as nn
 
-from mmdet.core import bbox2result, bbox2roi, build_assigner, build_sampler
+# from mmdet.core import bbox2result, bbox2roi, build_assigner, build_sampler
 from .. import builder
 from ..registry import DETECTORS
 from .base import BaseDetector
-from .test_mixins import BBoxTestMixin, MaskTestMixin, RPNTestMixin
+from .test_mixins import RPNTestMixin
 
 
 @DETECTORS.register_module
-class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
-                       MaskTestMixin):
+class TwoStageDetector(BaseDetector, RPNTestMixin):
     """Base class for two-stage detectors.
 
     Two-stage detectors typically consisting of a region proposal network and a
@@ -20,12 +19,8 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
     def __init__(self,
                  backbone,
                  neck=None,
-                 shared_head=None,
                  rpn_head=None,
-                 bbox_roi_extractor=None,
-                 bbox_head=None,
-                 mask_roi_extractor=None,
-                 mask_head=None,
+                 roi_head=None,
                  train_cfg=None,
                  test_cfg=None,
                  pretrained=None):
@@ -35,26 +30,16 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
         if neck is not None:
             self.neck = builder.build_neck(neck)
 
-        if shared_head is not None:
-            self.shared_head = builder.build_shared_head(shared_head)
-
         if rpn_head is not None:
             self.rpn_head = builder.build_head(rpn_head)
 
-        if bbox_head is not None:
-            self.bbox_roi_extractor = builder.build_roi_extractor(
-                bbox_roi_extractor)
-            self.bbox_head = builder.build_head(bbox_head)
-
-        if mask_head is not None:
-            if mask_roi_extractor is not None:
-                self.mask_roi_extractor = builder.build_roi_extractor(
-                    mask_roi_extractor)
-                self.share_roi_extractor = False
-            else:
-                self.share_roi_extractor = True
-                self.mask_roi_extractor = self.bbox_roi_extractor
-            self.mask_head = builder.build_head(mask_head)
+        if roi_head is not None:
+            # update train and test cfg here for now
+            # TODO: refactor assigner & sampler
+            rcnn_train_cfg = train_cfg.rcnn if train_cfg is not None else None
+            roi_head.update(train_cfg=rcnn_train_cfg)
+            roi_head.update(test_cfg=test_cfg.rcnn)
+            self.roi_head = builder.build_head(roi_head)
 
         self.train_cfg = train_cfg
         self.test_cfg = test_cfg
@@ -65,6 +50,10 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
     def with_rpn(self):
         return hasattr(self, 'rpn_head') and self.rpn_head is not None
 
+    @property
+    def with_roi_head(self):
+        return hasattr(self, 'roi_head') and self.roi_head is not None
+
     def init_weights(self, pretrained=None):
         super(TwoStageDetector, self).init_weights(pretrained)
         self.backbone.init_weights(pretrained=pretrained)
@@ -74,17 +63,10 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
                     m.init_weights()
             else:
                 self.neck.init_weights()
-        if self.with_shared_head:
-            self.shared_head.init_weights(pretrained=pretrained)
         if self.with_rpn:
             self.rpn_head.init_weights()
-        if self.with_bbox:
-            self.bbox_roi_extractor.init_weights()
-            self.bbox_head.init_weights()
-        if self.with_mask:
-            self.mask_head.init_weights()
-            if not self.share_roi_extractor:
-                self.mask_roi_extractor.init_weights()
+        if self.with_roi_head:
+            self.roi_head.init_weights(pretrained)
 
     def extract_feat(self, img):
         """Directly extract features from the backbone+neck
@@ -106,25 +88,10 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
         if self.with_rpn:
             rpn_outs = self.rpn_head(x)
             outs = outs + (rpn_outs, )
-        proposals = torch.randn(1000, 4).to(device=img.device)
-        # bbox head
-        rois = bbox2roi([proposals])
-        if self.with_bbox:
-            bbox_feats = self.bbox_roi_extractor(
-                x[:self.bbox_roi_extractor.num_inputs], rois)
-            if self.with_shared_head:
-                bbox_feats = self.shared_head(bbox_feats)
-            cls_score, bbox_pred = self.bbox_head(bbox_feats)
-            outs = outs + (cls_score, bbox_pred)
-        # mask head
-        if self.with_mask:
-            mask_rois = rois[:100]
-            mask_feats = self.mask_roi_extractor(
-                x[:self.mask_roi_extractor.num_inputs], mask_rois)
-            if self.with_shared_head:
-                mask_feats = self.shared_head(mask_feats)
-            mask_pred = self.mask_head(mask_feats)
-            outs = outs + (mask_pred, )
+        proposals = torch.randn(1000, 4).to(img.device)
+        # roi_head
+        roi_outs = self.roi_head.forward_dummy(x, proposals)
+        outs = outs + (roi_outs, )
         return outs
 
     def forward_train(self,
@@ -134,7 +101,8 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
                       gt_labels,
                       gt_bboxes_ignore=None,
                       gt_masks=None,
-                      proposals=None):
+                      proposals=None,
+                      **kwargs):
         """
         Args:
             img (Tensor): of shape (N, C, H, W) encoding input images.
@@ -183,80 +151,11 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
         else:
             proposal_list = proposals
 
-        # assign gts and sample proposals
-        if self.with_bbox or self.with_mask:
-            bbox_assigner = build_assigner(self.train_cfg.rcnn.assigner)
-            bbox_sampler = build_sampler(
-                self.train_cfg.rcnn.sampler, context=self)
-            num_imgs = img.size(0)
-            if gt_bboxes_ignore is None:
-                gt_bboxes_ignore = [None for _ in range(num_imgs)]
-            sampling_results = []
-            for i in range(num_imgs):
-                assign_result = bbox_assigner.assign(proposal_list[i],
-                                                     gt_bboxes[i],
-                                                     gt_bboxes_ignore[i],
-                                                     gt_labels[i])
-                sampling_result = bbox_sampler.sample(
-                    assign_result,
-                    proposal_list[i],
-                    gt_bboxes[i],
-                    gt_labels[i],
-                    feats=[lvl_feat[i][None] for lvl_feat in x])
-                sampling_results.append(sampling_result)
-
-        # bbox head forward and loss
-        if self.with_bbox:
-            rois = bbox2roi([res.bboxes for res in sampling_results])
-            # TODO: a more flexible way to decide which feature maps to use
-            bbox_feats = self.bbox_roi_extractor(
-                x[:self.bbox_roi_extractor.num_inputs], rois)
-            if self.with_shared_head:
-                bbox_feats = self.shared_head(bbox_feats)
-            cls_score, bbox_pred = self.bbox_head(bbox_feats)
-
-            bbox_targets = self.bbox_head.get_target(sampling_results,
-                                                     gt_bboxes, gt_labels,
-                                                     self.train_cfg.rcnn)
-            loss_bbox = self.bbox_head.loss(cls_score, bbox_pred,
-                                            *bbox_targets)
-            losses.update(loss_bbox)
-
-        # mask head forward and loss
-        if self.with_mask:
-            if not self.share_roi_extractor:
-                pos_rois = bbox2roi(
-                    [res.pos_bboxes for res in sampling_results])
-                mask_feats = self.mask_roi_extractor(
-                    x[:self.mask_roi_extractor.num_inputs], pos_rois)
-                if self.with_shared_head:
-                    mask_feats = self.shared_head(mask_feats)
-            else:
-                pos_inds = []
-                device = bbox_feats.device
-                for res in sampling_results:
-                    pos_inds.append(
-                        torch.ones(
-                            res.pos_bboxes.shape[0],
-                            device=device,
-                            dtype=torch.uint8))
-                    pos_inds.append(
-                        torch.zeros(
-                            res.neg_bboxes.shape[0],
-                            device=device,
-                            dtype=torch.uint8))
-                pos_inds = torch.cat(pos_inds)
-                mask_feats = bbox_feats[pos_inds]
-
-            if mask_feats.shape[0] > 0:
-                mask_pred = self.mask_head(mask_feats)
-                mask_targets = self.mask_head.get_target(
-                    sampling_results, gt_masks, self.train_cfg.rcnn)
-                pos_labels = torch.cat(
-                    [res.pos_gt_labels for res in sampling_results])
-                loss_mask = self.mask_head.loss(mask_pred, mask_targets,
-                                                pos_labels)
-                losses.update(loss_mask)
+        roi_losses = self.roi_head.forward_train(x, img_metas, proposal_list,
+                                                 gt_bboxes, gt_labels,
+                                                 gt_bboxes_ignore, gt_masks,
+                                                 **kwargs)
+        losses.update(roi_losses)
 
         return losses
 
@@ -275,22 +174,8 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
         else:
             proposal_list = proposals
 
-        det_bboxes, det_labels = await self.async_test_bboxes(
-            x, img_meta, proposal_list, self.test_cfg.rcnn, rescale=rescale)
-        bbox_results = bbox2result(det_bboxes, det_labels,
-                                   self.bbox_head.num_classes)
-
-        if not self.with_mask:
-            return bbox_results
-        else:
-            segm_results = await self.async_test_mask(
-                x,
-                img_meta,
-                det_bboxes,
-                det_labels,
-                rescale=rescale,
-                mask_test_cfg=self.test_cfg.get('mask'))
-            return bbox_results, segm_results
+        return await self.roi_head.async_simple_test(
+            x, proposal_list, img_meta, rescale=rescale)
 
     def simple_test(self, img, img_metas, proposals=None, rescale=False):
         """Test without augmentation."""
@@ -304,17 +189,8 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
         else:
             proposal_list = proposals
 
-        det_bboxes, det_labels = self.simple_test_bboxes(
-            x, img_metas, proposal_list, self.test_cfg.rcnn, rescale=rescale)
-        bbox_results = bbox2result(det_bboxes, det_labels,
-                                   self.bbox_head.num_classes)
-
-        if not self.with_mask:
-            return bbox_results
-        else:
-            segm_results = self.simple_test_mask(
-                x, img_metas, det_bboxes, det_labels, rescale=rescale)
-            return bbox_results, segm_results
+        return self.roi_head.simple_test(
+            x, proposal_list, img_metas, rescale=rescale)
 
     def aug_test(self, imgs, img_metas, rescale=False):
         """Test with augmentations.
@@ -323,25 +199,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
         of imgs[0].
         """
         # recompute feats to save memory
-        proposal_list = self.aug_test_rpn(
-            self.extract_feats(imgs), img_metas, self.test_cfg.rpn)
-        det_bboxes, det_labels = self.aug_test_bboxes(
-            self.extract_feats(imgs), img_metas, proposal_list,
-            self.test_cfg.rcnn)
-
-        if rescale:
-            _det_bboxes = det_bboxes
-        else:
-            _det_bboxes = det_bboxes.clone()
-            _det_bboxes[:, :4] *= det_bboxes.new_tensor(
-                img_metas[0][0]['scale_factor'])
-        bbox_results = bbox2result(_det_bboxes, det_labels,
-                                   self.bbox_head.num_classes)
-
-        # det_bboxes always keep the original scale
-        if self.with_mask:
-            segm_results = self.aug_test_mask(
-                self.extract_feats(imgs), img_metas, det_bboxes, det_labels)
-            return bbox_results, segm_results
-        else:
-            return bbox_results
+        x = self.extract_feats(imgs)
+        proposal_list = self.aug_test_rpn(x, img_metas, self.test_cfg.rpn)
+        return self.roi_head.aug_test(
+            x, proposal_list, img_metas, rescale=rescale)
diff --git a/mmdet/models/roi_heads/__init__.py b/mmdet/models/roi_heads/__init__.py
new file mode 100644
index 00000000..3c59239a
--- /dev/null
+++ b/mmdet/models/roi_heads/__init__.py
@@ -0,0 +1,11 @@
+from .base_roi_head import BaseRoIHead
+from .cascade_roi_head import CascadeRoIHead
+from .double_roi_head import DoubleHeadRoIHead
+from .grid_roi_head import GridRoIHead
+from .htc_roi_head import HybridTaskCascadeRoIHead
+from .mask_scoring_roi_head import MaskScoringRoIHead
+
+__all__ = [
+    'BaseRoIHead', 'CascadeRoIHead', 'DoubleHeadRoIHead', 'MaskScoringRoIHead',
+    'HybridTaskCascadeRoIHead', 'GridRoIHead'
+]
diff --git a/mmdet/models/roi_heads/base_roi_head.py b/mmdet/models/roi_heads/base_roi_head.py
new file mode 100644
index 00000000..6448c6b2
--- /dev/null
+++ b/mmdet/models/roi_heads/base_roi_head.py
@@ -0,0 +1,93 @@
+from abc import ABCMeta, abstractmethod
+
+import torch.nn as nn
+
+from .. import builder
+
+
+class BaseRoIHead(nn.Module, metaclass=ABCMeta):
+    """Base class for RoIHeads"""
+
+    def __init__(self,
+                 bbox_roi_extractor=None,
+                 bbox_head=None,
+                 mask_roi_extractor=None,
+                 mask_head=None,
+                 shared_head=None,
+                 train_cfg=None,
+                 test_cfg=None):
+        super(BaseRoIHead, self).__init__()
+        self.train_cfg = train_cfg
+        self.test_cfg = test_cfg
+        if shared_head is not None:
+            self.shared_head = builder.build_shared_head(shared_head)
+
+        if bbox_head is not None:
+            self.init_bbox_head(bbox_roi_extractor, bbox_head)
+
+        if mask_head is not None:
+            self.init_mask_head(mask_roi_extractor, mask_head)
+
+        self.init_assigner_sampler()
+
+    @property
+    def with_bbox(self):
+        return hasattr(self, 'bbox_head') and self.bbox_head is not None
+
+    @property
+    def with_mask(self):
+        return hasattr(self, 'mask_head') and self.mask_head is not None
+
+    @property
+    def with_shared_head(self):
+        return hasattr(self, 'shared_head') and self.shared_head is not None
+
+    @abstractmethod
+    def init_weights(self, pretrained):
+        pass
+
+    @abstractmethod
+    def init_bbox_head(self):
+        pass
+
+    @abstractmethod
+    def init_mask_head(self):
+        pass
+
+    @abstractmethod
+    def init_assigner_sampler(self):
+        pass
+
+    @abstractmethod
+    def forward_train(self,
+                      x,
+                      img_meta,
+                      proposal_list,
+                      gt_bboxes,
+                      gt_labels,
+                      gt_bboxes_ignore=None,
+                      gt_masks=None,
+                      **kwargs):
+        """Forward function during training"""
+        pass
+
+    async def async_simple_test(self, x, img_meta, **kwargs):
+        raise NotImplementedError
+
+    def simple_test(self,
+                    x,
+                    proposal_list,
+                    img_meta,
+                    proposals=None,
+                    rescale=False,
+                    **kwargs):
+        """Test without augmentation."""
+        pass
+
+    def aug_test(self, x, proposal_list, img_metas, rescale=False, **kwargs):
+        """Test with augmentations.
+
+        If rescale is False, then returned bboxes and masks will fit the scale
+        of imgs[0].
+        """
+        pass
diff --git a/mmdet/models/roi_heads/cascade_roi_head.py b/mmdet/models/roi_heads/cascade_roi_head.py
new file mode 100644
index 00000000..168c46e7
--- /dev/null
+++ b/mmdet/models/roi_heads/cascade_roi_head.py
@@ -0,0 +1,423 @@
+import torch
+import torch.nn as nn
+
+from mmdet.core import (bbox2result, bbox2roi, bbox_mapping, build_assigner,
+                        build_sampler, merge_aug_bboxes, merge_aug_masks,
+                        multiclass_nms)
+from .. import builder
+from ..registry import HEADS
+from .base_roi_head import BaseRoIHead
+from .test_mixins import BBoxTestMixin, MaskTestMixin
+
+
+@HEADS.register_module
+class CascadeRoIHead(BaseRoIHead, BBoxTestMixin, MaskTestMixin):
+    """Cascade roi head including one bbox head and one mask head.
+
+    https://arxiv.org/abs/1712.00726
+    """
+
+    def __init__(self,
+                 num_stages,
+                 stage_loss_weights,
+                 bbox_roi_extractor=None,
+                 bbox_head=None,
+                 mask_roi_extractor=None,
+                 mask_head=None,
+                 shared_head=None,
+                 train_cfg=None,
+                 test_cfg=None):
+        assert bbox_roi_extractor is not None
+        assert bbox_head is not None
+        assert shared_head is None, \
+            'Shared head is not supported in Cascade RCNN anymore'
+        self.num_stages = num_stages
+        self.stage_loss_weights = stage_loss_weights
+        super(CascadeRoIHead, self).__init__(
+            bbox_roi_extractor=bbox_roi_extractor,
+            bbox_head=bbox_head,
+            mask_roi_extractor=mask_roi_extractor,
+            mask_head=mask_head,
+            shared_head=shared_head,
+            train_cfg=train_cfg,
+            test_cfg=test_cfg)
+
+    def init_bbox_head(self, bbox_roi_extractor, bbox_head):
+        self.bbox_roi_extractor = nn.ModuleList()
+        self.bbox_head = nn.ModuleList()
+        if not isinstance(bbox_roi_extractor, list):
+            bbox_roi_extractor = [
+                bbox_roi_extractor for _ in range(self.num_stages)
+            ]
+        if not isinstance(bbox_head, list):
+            bbox_head = [bbox_head for _ in range(self.num_stages)]
+        assert len(bbox_roi_extractor) == len(bbox_head) == self.num_stages
+        for roi_extractor, head in zip(bbox_roi_extractor, bbox_head):
+            self.bbox_roi_extractor.append(
+                builder.build_roi_extractor(roi_extractor))
+            self.bbox_head.append(builder.build_head(head))
+
+    def init_mask_head(self, mask_roi_extractor, mask_head):
+        self.mask_head = nn.ModuleList()
+        if not isinstance(mask_head, list):
+            mask_head = [mask_head for _ in range(self.num_stages)]
+        assert len(mask_head) == self.num_stages
+        for head in mask_head:
+            self.mask_head.append(builder.build_head(head))
+        if mask_roi_extractor is not None:
+            self.share_roi_extractor = False
+            self.mask_roi_extractor = nn.ModuleList()
+            if not isinstance(mask_roi_extractor, list):
+                mask_roi_extractor = [
+                    mask_roi_extractor for _ in range(self.num_stages)
+                ]
+            assert len(mask_roi_extractor) == self.num_stages
+            for roi_extractor in mask_roi_extractor:
+                self.mask_roi_extractor.append(
+                    builder.build_roi_extractor(roi_extractor))
+        else:
+            self.share_roi_extractor = True
+            self.mask_roi_extractor = self.bbox_roi_extractor
+
+    def init_assigner_sampler(self):
+        # build assigner and smapler for each stage
+        self.bbox_assigner = []
+        self.bbox_sampler = []
+        if self.train_cfg is not None:
+            for rcnn_train_cfg in self.train_cfg:
+                self.bbox_assigner.append(
+                    build_assigner(rcnn_train_cfg.assigner))
+                self.bbox_sampler.append(build_sampler(rcnn_train_cfg.sampler))
+
+    def init_weights(self, pretrained):
+        if self.with_shared_head:
+            self.shared_head.init_weights(pretrained=pretrained)
+        for i in range(self.num_stages):
+            if self.with_bbox:
+                self.bbox_roi_extractor[i].init_weights()
+                self.bbox_head[i].init_weights()
+            if self.with_mask:
+                if not self.share_roi_extractor:
+                    self.mask_roi_extractor[i].init_weights()
+                self.mask_head[i].init_weights()
+
+    def forward_dummy(self, x, proposals):
+        # bbox head
+        outs = ()
+        rois = bbox2roi([proposals])
+        if self.with_bbox:
+            for i in range(self.num_stages):
+                bbox_results = self._bbox_forward(i, x, rois)
+                outs = outs + (bbox_results['cls_score'],
+                               bbox_results['bbox_pred'])
+        # mask heads
+        if self.with_mask:
+            mask_rois = rois[:100]
+            for i in range(self.num_stages):
+                mask_results = self._mask_forward(i, x, mask_rois)
+                outs = outs + (mask_results['mask_pred'], )
+        return outs
+
+    def _bbox_forward(self, stage, x, rois):
+        bbox_roi_extractor = self.bbox_roi_extractor[stage]
+        bbox_head = self.bbox_head[stage]
+        bbox_feats = bbox_roi_extractor(x[:bbox_roi_extractor.num_inputs],
+                                        rois)
+        # do not support caffe_c4 model anymore
+        cls_score, bbox_pred = bbox_head(bbox_feats)
+
+        bbox_results = dict(
+            cls_score=cls_score, bbox_pred=bbox_pred, bbox_feats=bbox_feats)
+        return bbox_results
+
+    def _bbox_forward_train(self, stage, x, sampling_results, gt_bboxes,
+                            gt_labels, rcnn_train_cfg):
+        rois = bbox2roi([res.bboxes for res in sampling_results])
+        bbox_results = self._bbox_forward(stage, x, rois)
+        bbox_targets = self.bbox_head[stage].get_target(
+            sampling_results, gt_bboxes, gt_labels, rcnn_train_cfg)
+        loss_bbox = self.bbox_head[stage].loss(bbox_results['cls_score'],
+                                               bbox_results['bbox_pred'],
+                                               *bbox_targets)
+
+        bbox_results.update(
+            loss_bbox=loss_bbox, rois=rois, bbox_targets=bbox_targets)
+        return bbox_results
+
+    def _mask_forward(self, stage, x, rois):
+        mask_roi_extractor = self.mask_roi_extractor[stage]
+        mask_head = self.mask_head[stage]
+        mask_feats = mask_roi_extractor(x[:mask_roi_extractor.num_inputs],
+                                        rois)
+        # do not support caffe_c4 model anymore
+        mask_pred = mask_head(mask_feats)
+
+        mask_results = dict(mask_pred=mask_pred)
+        return mask_results
+
+    def _mask_forward_train(self,
+                            stage,
+                            x,
+                            sampling_results,
+                            gt_masks,
+                            rcnn_train_cfg,
+                            bbox_feats=None):
+        pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results])
+        mask_results = self._mask_forward(stage, x, pos_rois)
+
+        mask_targets = self.mask_head[stage].get_target(
+            sampling_results, gt_masks, rcnn_train_cfg)
+        pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results])
+        loss_mask = self.mask_head[stage].loss(mask_results['mask_pred'],
+                                               mask_targets, pos_labels)
+
+        mask_results.update(loss_mask=loss_mask)
+        return mask_results
+
+    def forward_train(self,
+                      x,
+                      img_metas,
+                      proposal_list,
+                      gt_bboxes,
+                      gt_labels,
+                      gt_bboxes_ignore=None,
+                      gt_masks=None):
+        """
+        Args:
+            x (list[Tensor]): list of multi-level img features.
+
+            img_metas (list[dict]): list of image info dict where each dict
+                has: 'img_shape', 'scale_factor', 'flip', and may also contain
+                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+                For details on the values of these keys see
+                `mmdet/datasets/pipelines/formatting.py:Collect`.
+
+            proposals (list[Tensors]): list of region proposals.
+
+            gt_bboxes (list[Tensor]): each item are the truth boxes for each
+                image in [tl_x, tl_y, br_x, br_y] format.
+
+            gt_labels (list[Tensor]): class indices corresponding to each box
+
+            gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+                boxes can be ignored when computing the loss.
+
+            gt_masks (None | Tensor) : true segmentation masks for each box
+                used if the architecture supports a segmentation task.
+
+        Returns:
+            dict[str, Tensor]: a dictionary of loss components
+        """
+        losses = dict()
+        for i in range(self.num_stages):
+            self.current_stage = i
+            rcnn_train_cfg = self.train_cfg[i]
+            lw = self.stage_loss_weights[i]
+
+            # assign gts and sample proposals
+            sampling_results = []
+            if self.with_bbox or self.with_mask:
+                bbox_assigner = self.bbox_assigner[i]
+                bbox_sampler = self.bbox_sampler[i]
+                num_imgs = len(img_metas)
+                if gt_bboxes_ignore is None:
+                    gt_bboxes_ignore = [None for _ in range(num_imgs)]
+
+                for j in range(num_imgs):
+                    assign_result = bbox_assigner.assign(
+                        proposal_list[j], gt_bboxes[j], gt_bboxes_ignore[j],
+                        gt_labels[j])
+                    sampling_result = bbox_sampler.sample(
+                        assign_result,
+                        proposal_list[j],
+                        gt_bboxes[j],
+                        gt_labels[j],
+                        feats=[lvl_feat[j][None] for lvl_feat in x])
+                    sampling_results.append(sampling_result)
+
+            # bbox head forward and loss
+            bbox_results = self._bbox_forward_train(i, x, sampling_results,
+                                                    gt_bboxes, gt_labels,
+                                                    rcnn_train_cfg)
+
+            for name, value in bbox_results['loss_bbox'].items():
+                losses['s{}.{}'.format(i, name)] = (
+                    value * lw if 'loss' in name else value)
+
+            # mask head forward and loss
+            if self.with_mask:
+                mask_results = self._mask_forward_train(
+                    i, x, sampling_results, gt_masks, rcnn_train_cfg,
+                    bbox_results['bbox_feats'])
+
+                for name, value in mask_results['loss_mask'].items():
+                    losses['s{}.{}'.format(i, name)] = (
+                        value * lw if 'loss' in name else value)
+
+            # refine bboxes
+            if i < self.num_stages - 1:
+                pos_is_gts = [res.pos_is_gt for res in sampling_results]
+                # bbox_targets is a tuple
+                roi_labels = bbox_results['bbox_targets'][0]
+                with torch.no_grad():
+                    proposal_list = self.bbox_head[i].refine_bboxes(
+                        bbox_results['rois'], roi_labels,
+                        bbox_results['bbox_pred'], pos_is_gts, img_metas)
+
+        return losses
+
+    def simple_test(self, x, proposal_list, img_metas, rescale=False):
+        """Test without augmentation."""
+        assert self.with_bbox, 'Bbox head must be implemented.'
+        img_shape = img_metas[0]['img_shape']
+        ori_shape = img_metas[0]['ori_shape']
+        scale_factor = img_metas[0]['scale_factor']
+
+        # "ms" in variable names means multi-stage
+        ms_bbox_result = {}
+        ms_segm_result = {}
+        ms_scores = []
+        rcnn_test_cfg = self.test_cfg
+
+        rois = bbox2roi(proposal_list)
+        for i in range(self.num_stages):
+            bbox_results = self._bbox_forward(i, x, rois)
+            ms_scores.append(bbox_results['cls_score'])
+
+            if i < self.num_stages - 1:
+                bbox_label = bbox_results['cls_score'].argmax(dim=1)
+                rois = self.bbox_head[i].regress_by_class(
+                    rois, bbox_label, bbox_results['bbox_pred'], img_metas[0])
+
+        cls_score = sum(ms_scores) / self.num_stages
+        det_bboxes, det_labels = self.bbox_head[-1].get_det_bboxes(
+            rois,
+            cls_score,
+            bbox_results['bbox_pred'],
+            img_shape,
+            scale_factor,
+            rescale=rescale,
+            cfg=rcnn_test_cfg)
+        bbox_result = bbox2result(det_bboxes, det_labels,
+                                  self.bbox_head[-1].num_classes)
+        ms_bbox_result['ensemble'] = bbox_result
+
+        if self.with_mask:
+            if det_bboxes.shape[0] == 0:
+                mask_classes = self.mask_head[-1].num_classes - 1
+                segm_result = [[] for _ in range(mask_classes)]
+            else:
+                _bboxes = (
+                    det_bboxes[:, :4] * det_bboxes.new_tensor(scale_factor)
+                    if rescale else det_bboxes)
+
+                mask_rois = bbox2roi([_bboxes])
+                aug_masks = []
+                for i in range(self.num_stages):
+                    mask_results = self._mask_forward(i, x, mask_rois)
+                    aug_masks.append(
+                        mask_results['mask_pred'].sigmoid().cpu().numpy())
+                merged_masks = merge_aug_masks(aug_masks,
+                                               [img_metas] * self.num_stages,
+                                               self.test_cfg)
+                segm_result = self.mask_head[-1].get_seg_masks(
+                    merged_masks, _bboxes, det_labels, rcnn_test_cfg,
+                    ori_shape, scale_factor, rescale)
+            ms_segm_result['ensemble'] = segm_result
+
+        if self.with_mask:
+            results = (ms_bbox_result['ensemble'], ms_segm_result['ensemble'])
+        else:
+            results = ms_bbox_result['ensemble']
+
+        return results
+
+    def aug_test(self, features, proposal_list, img_metas, rescale=False):
+        """Test with augmentations.
+
+        If rescale is False, then returned bboxes and masks will fit the scale
+        of imgs[0].
+        """
+        rcnn_test_cfg = self.test_cfg
+        aug_bboxes = []
+        aug_scores = []
+        for x, img_meta in zip(features, img_metas):
+            # only one image in the batch
+            img_shape = img_meta[0]['img_shape']
+            scale_factor = img_meta[0]['scale_factor']
+            flip = img_meta[0]['flip']
+
+            proposals = bbox_mapping(proposal_list[0][:, :4], img_shape,
+                                     scale_factor, flip)
+            # "ms" in variable names means multi-stage
+            ms_scores = []
+
+            rois = bbox2roi([proposals])
+            for i in range(self.num_stages):
+                bbox_results = self._bbox_forward(i, x, rois)
+                ms_scores.append(bbox_results['cls_score'])
+
+                if i < self.num_stages - 1:
+                    bbox_label = bbox_results['cls_score'].argmax(dim=1)
+                    rois = self.bbox_head[i].regress_by_class(
+                        rois, bbox_label, bbox_results['bbox_pred'],
+                        img_meta[0])
+
+            cls_score = sum(ms_scores) / float(len(ms_scores))
+            bboxes, scores = self.bbox_head[-1].get_det_bboxes(
+                rois,
+                cls_score,
+                bbox_results['bbox_pred'],
+                img_shape,
+                scale_factor,
+                rescale=False,
+                cfg=None)
+            aug_bboxes.append(bboxes)
+            aug_scores.append(scores)
+
+        # after merging, bboxes will be rescaled to the original image size
+        merged_bboxes, merged_scores = merge_aug_bboxes(
+            aug_bboxes, aug_scores, img_metas, rcnn_test_cfg)
+        det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores,
+                                                rcnn_test_cfg.score_thr,
+                                                rcnn_test_cfg.nms,
+                                                rcnn_test_cfg.max_per_img)
+
+        bbox_result = bbox2result(det_bboxes, det_labels,
+                                  self.bbox_head[-1].num_classes)
+
+        if self.with_mask:
+            if det_bboxes.shape[0] == 0:
+                segm_result = [[]
+                               for _ in range(self.mask_head[-1].num_classes -
+                                              1)]
+            else:
+                aug_masks = []
+                aug_img_metas = []
+                for x, img_meta in zip(features, img_metas):
+                    img_shape = img_meta[0]['img_shape']
+                    scale_factor = img_meta[0]['scale_factor']
+                    flip = img_meta[0]['flip']
+                    _bboxes = bbox_mapping(det_bboxes[:, :4], img_shape,
+                                           scale_factor, flip)
+                    mask_rois = bbox2roi([_bboxes])
+                    for i in range(self.num_stages):
+                        mask_results = self._mask_forward(i, x, mask_rois)
+                        aug_masks.append(
+                            mask_results['mask_pred'].sigmoid().cpu().numpy())
+                        aug_img_metas.append(img_meta)
+                merged_masks = merge_aug_masks(aug_masks, aug_img_metas,
+                                               self.test_cfg)
+
+                ori_shape = img_metas[0][0]['ori_shape']
+                segm_result = self.mask_head[-1].get_seg_masks(
+                    merged_masks,
+                    det_bboxes,
+                    det_labels,
+                    rcnn_test_cfg,
+                    ori_shape,
+                    scale_factor=1.0,
+                    rescale=False)
+            return bbox_result, segm_result
+        else:
+            return bbox_result
diff --git a/mmdet/models/roi_heads/double_roi_head.py b/mmdet/models/roi_heads/double_roi_head.py
new file mode 100644
index 00000000..606b99aa
--- /dev/null
+++ b/mmdet/models/roi_heads/double_roi_head.py
@@ -0,0 +1,32 @@
+from ..registry import HEADS
+from .standard_roi_head import StandardRoIHead
+
+
+@HEADS.register_module
+class DoubleHeadRoIHead(StandardRoIHead):
+    """RoI head for Double Head RCNN
+
+    https://arxiv.org/abs/1904.06493
+    """
+
+    def __init__(self, reg_roi_scale_factor, **kwargs):
+        super(DoubleHeadRoIHead, self).__init__(**kwargs)
+        self.reg_roi_scale_factor = reg_roi_scale_factor
+
+    def _bbox_forward(self, x, rois):
+        bbox_cls_feats = self.bbox_roi_extractor(
+            x[:self.bbox_roi_extractor.num_inputs], rois)
+        bbox_reg_feats = self.bbox_roi_extractor(
+            x[:self.bbox_roi_extractor.num_inputs],
+            rois,
+            roi_scale_factor=self.reg_roi_scale_factor)
+        if self.with_shared_head:
+            bbox_cls_feats = self.shared_head(bbox_cls_feats)
+            bbox_reg_feats = self.shared_head(bbox_reg_feats)
+        cls_score, bbox_pred = self.bbox_head(bbox_cls_feats, bbox_reg_feats)
+
+        bbox_results = dict(
+            cls_score=cls_score,
+            bbox_pred=bbox_pred,
+            bbox_feats=bbox_cls_feats)
+        return bbox_results
diff --git a/mmdet/models/roi_heads/grid_roi_head.py b/mmdet/models/roi_heads/grid_roi_head.py
new file mode 100644
index 00000000..af35ef13
--- /dev/null
+++ b/mmdet/models/roi_heads/grid_roi_head.py
@@ -0,0 +1,153 @@
+import torch
+
+from mmdet.core import bbox2result, bbox2roi
+from .. import builder
+from ..registry import HEADS
+from .standard_roi_head import StandardRoIHead
+
+
+@HEADS.register_module
+class GridRoIHead(StandardRoIHead):
+    """Grid roi head for Grid R-CNN.
+
+    https://arxiv.org/abs/1811.12030
+    """
+
+    def __init__(self, grid_roi_extractor, grid_head, **kwargs):
+        assert grid_head is not None
+        super(GridRoIHead, self).__init__(**kwargs)
+        if grid_roi_extractor is not None:
+            self.grid_roi_extractor = builder.build_roi_extractor(
+                grid_roi_extractor)
+            self.share_roi_extractor = False
+        else:
+            self.share_roi_extractor = True
+            self.grid_roi_extractor = self.bbox_roi_extractor
+        self.grid_head = builder.build_head(grid_head)
+
+    def init_weights(self, pretrained):
+        super(GridRoIHead, self).init_weights(pretrained)
+        self.grid_head.init_weights()
+        if not self.share_roi_extractor:
+            self.grid_roi_extractor.init_weights()
+
+    def _random_jitter(self, sampling_results, img_metas, amplitude=0.15):
+        """Ramdom jitter positive proposals for training."""
+        for sampling_result, img_meta in zip(sampling_results, img_metas):
+            bboxes = sampling_result.pos_bboxes
+            random_offsets = bboxes.new_empty(bboxes.shape[0], 4).uniform_(
+                -amplitude, amplitude)
+            # before jittering
+            cxcy = (bboxes[:, 2:4] + bboxes[:, :2]) / 2
+            wh = (bboxes[:, 2:4] - bboxes[:, :2]).abs()
+            # after jittering
+            new_cxcy = cxcy + wh * random_offsets[:, :2]
+            new_wh = wh * (1 + random_offsets[:, 2:])
+            # xywh to xyxy
+            new_x1y1 = (new_cxcy - new_wh / 2)
+            new_x2y2 = (new_cxcy + new_wh / 2)
+            new_bboxes = torch.cat([new_x1y1, new_x2y2], dim=1)
+            # clip bboxes
+            max_shape = img_meta['img_shape']
+            if max_shape is not None:
+                new_bboxes[:, 0::2].clamp_(min=0, max=max_shape[1] - 1)
+                new_bboxes[:, 1::2].clamp_(min=0, max=max_shape[0] - 1)
+
+            sampling_result.pos_bboxes = new_bboxes
+        return sampling_results
+
+    def forward_dummy(self, x, proposals):
+        # bbox head
+        outs = ()
+        rois = bbox2roi([proposals])
+        if self.with_bbox:
+            bbox_results = self._bbox_forward(x, rois)
+            outs = outs + (bbox_results['cls_score'],
+                           bbox_results['bbox_pred'])
+
+        # grid head
+        grid_rois = rois[:100]
+        grid_feats = self.grid_roi_extractor(
+            x[:self.grid_roi_extractor.num_inputs], grid_rois)
+        if self.with_shared_head:
+            grid_feats = self.shared_head(grid_feats)
+        grid_pred = self.grid_head(grid_feats)
+        outs = outs + (grid_pred, )
+
+        # mask head
+        if self.with_mask:
+            mask_rois = rois[:100]
+            mask_results = self._mask_forward(x, mask_rois)
+            outs = outs + (mask_results['mask_pred'], )
+        return outs
+
+    def _bbox_forward_train(self, x, sampling_results, gt_bboxes, gt_labels,
+                            img_metas):
+        bbox_results = super(GridRoIHead,
+                             self)._bbox_forward_train(x, sampling_results,
+                                                       gt_bboxes, gt_labels,
+                                                       img_metas)
+
+        # Grid head forward and loss
+        sampling_results = self._random_jitter(sampling_results, img_metas)
+        pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results])
+        grid_feats = self.grid_roi_extractor(
+            x[:self.grid_roi_extractor.num_inputs], pos_rois)
+        if self.with_shared_head:
+            grid_feats = self.shared_head(grid_feats)
+        # Accelerate training
+        max_sample_num_grid = self.train_cfg.get('max_num_grid', 192)
+        sample_idx = torch.randperm(
+            grid_feats.shape[0])[:min(grid_feats.shape[0], max_sample_num_grid
+                                      )]
+        grid_feats = grid_feats[sample_idx]
+
+        grid_pred = self.grid_head(grid_feats)
+
+        grid_targets = self.grid_head.get_target(sampling_results,
+                                                 self.train_cfg)
+        grid_targets = grid_targets[sample_idx]
+
+        loss_grid = self.grid_head.loss(grid_pred, grid_targets)
+
+        bbox_results['loss_bbox'].update(loss_grid)
+        return bbox_results
+
+    def simple_test(self,
+                    x,
+                    proposal_list,
+                    img_metas,
+                    proposals=None,
+                    rescale=False):
+        """Test without augmentation."""
+        assert self.with_bbox, 'Bbox head must be implemented.'
+
+        det_bboxes, det_labels = self.simple_test_bboxes(
+            x, img_metas, proposal_list, self.test_cfg, rescale=False)
+        # pack rois into bboxes
+        grid_rois = bbox2roi([det_bboxes[:, :4]])
+        grid_feats = self.grid_roi_extractor(
+            x[:len(self.grid_roi_extractor.featmap_strides)], grid_rois)
+        if grid_rois.shape[0] != 0:
+            self.grid_head.test_mode = True
+            grid_pred = self.grid_head(grid_feats)
+            det_bboxes = self.grid_head.get_bboxes(det_bboxes,
+                                                   grid_pred['fused'],
+                                                   img_metas)
+            if rescale:
+                scale_factor = img_metas[0]['scale_factor']
+                if not isinstance(scale_factor, (float, torch.Tensor)):
+                    scale_factor = det_bboxes.new_tensor(scale_factor)
+                det_bboxes[:, :4] /= scale_factor
+        else:
+            det_bboxes = torch.Tensor([])
+
+        bbox_results = bbox2result(det_bboxes, det_labels,
+                                   self.bbox_head.num_classes)
+
+        if not self.with_mask:
+            return bbox_results
+        else:
+            segm_results = self.simple_test_mask(
+                x, img_metas, det_bboxes, det_labels, rescale=rescale)
+            return bbox_results, segm_results
diff --git a/mmdet/models/roi_heads/htc_roi_head.py b/mmdet/models/roi_heads/htc_roi_head.py
new file mode 100644
index 00000000..8191feab
--- /dev/null
+++ b/mmdet/models/roi_heads/htc_roi_head.py
@@ -0,0 +1,498 @@
+import torch
+import torch.nn.functional as F
+
+from mmdet.core import (bbox2result, bbox2roi, bbox_mapping, merge_aug_bboxes,
+                        merge_aug_masks, multiclass_nms)
+from .. import builder
+from ..registry import HEADS
+from .cascade_roi_head import CascadeRoIHead
+
+
+@HEADS.register_module
+class HybridTaskCascadeRoIHead(CascadeRoIHead):
+    """Hybrid task cascade roi head including one bbox head and one mask head.
+
+    https://arxiv.org/abs/1901.07518
+    """
+
+    def __init__(self,
+                 num_stages,
+                 stage_loss_weights,
+                 semantic_roi_extractor=None,
+                 semantic_head=None,
+                 semantic_fusion=('bbox', 'mask'),
+                 interleaved=True,
+                 mask_info_flow=True,
+                 **kwargs):
+        super(HybridTaskCascadeRoIHead,
+              self).__init__(num_stages, stage_loss_weights, **kwargs)
+        assert self.with_bbox and self.with_mask
+        assert not self.with_shared_head  # shared head is not supported
+
+        if semantic_head is not None:
+            self.semantic_roi_extractor = builder.build_roi_extractor(
+                semantic_roi_extractor)
+            self.semantic_head = builder.build_head(semantic_head)
+
+        self.semantic_fusion = semantic_fusion
+        self.interleaved = interleaved
+        self.mask_info_flow = mask_info_flow
+
+    def init_weights(self, pretrained):
+        super(HybridTaskCascadeRoIHead, self).init_weights(pretrained)
+        if self.with_semantic:
+            self.semantic_head.init_weights()
+
+    @property
+    def with_semantic(self):
+        if hasattr(self, 'semantic_head') and self.semantic_head is not None:
+            return True
+        else:
+            return False
+
+    def forward_dummy(self, x, proposals):
+        outs = ()
+        # semantic head
+        if self.with_semantic:
+            _, semantic_feat = self.semantic_head(x)
+        else:
+            semantic_feat = None
+        # bbox heads
+        rois = bbox2roi([proposals])
+        for i in range(self.num_stages):
+            bbox_results = self._bbox_forward(
+                i, x, rois, semantic_feat=semantic_feat)
+            outs = outs + (bbox_results['cls_score'],
+                           bbox_results['bbox_pred'])
+        # mask heads
+        if self.with_mask:
+            mask_rois = rois[:100]
+            mask_roi_extractor = self.mask_roi_extractor[-1]
+            mask_feats = mask_roi_extractor(
+                x[:len(mask_roi_extractor.featmap_strides)], mask_rois)
+            if self.with_semantic and 'mask' in self.semantic_fusion:
+                mask_semantic_feat = self.semantic_roi_extractor(
+                    [semantic_feat], mask_rois)
+                mask_feats += mask_semantic_feat
+            last_feat = None
+            for i in range(self.num_stages):
+                mask_head = self.mask_head[i]
+                if self.mask_info_flow:
+                    mask_pred, last_feat = mask_head(mask_feats, last_feat)
+                else:
+                    mask_pred = mask_head(mask_feats)
+                outs = outs + (mask_pred, )
+        return outs
+
+    def _bbox_forward_train(self,
+                            stage,
+                            x,
+                            sampling_results,
+                            gt_bboxes,
+                            gt_labels,
+                            rcnn_train_cfg,
+                            semantic_feat=None):
+        bbox_head = self.bbox_head[stage]
+        rois = bbox2roi([res.bboxes for res in sampling_results])
+        bbox_results = self._bbox_forward(
+            stage, x, rois, semantic_feat=semantic_feat)
+
+        bbox_targets = bbox_head.get_target(sampling_results, gt_bboxes,
+                                            gt_labels, rcnn_train_cfg)
+        loss_bbox = bbox_head.loss(bbox_results['cls_score'],
+                                   bbox_results['bbox_pred'], *bbox_targets)
+
+        bbox_results.update(
+            loss_bbox=loss_bbox,
+            rois=rois,
+            bbox_targets=bbox_targets,
+        )
+        return bbox_results
+
+    def _mask_forward_train(self,
+                            stage,
+                            x,
+                            sampling_results,
+                            gt_masks,
+                            rcnn_train_cfg,
+                            semantic_feat=None):
+        mask_roi_extractor = self.mask_roi_extractor[stage]
+        mask_head = self.mask_head[stage]
+        pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results])
+        mask_feats = mask_roi_extractor(x[:mask_roi_extractor.num_inputs],
+                                        pos_rois)
+
+        # semantic feature fusion
+        # element-wise sum for original features and pooled semantic features
+        if self.with_semantic and 'mask' in self.semantic_fusion:
+            mask_semantic_feat = self.semantic_roi_extractor([semantic_feat],
+                                                             pos_rois)
+            if mask_semantic_feat.shape[-2:] != mask_feats.shape[-2:]:
+                mask_semantic_feat = F.adaptive_avg_pool2d(
+                    mask_semantic_feat, mask_feats.shape[-2:])
+            mask_feats += mask_semantic_feat
+
+        # mask information flow
+        # forward all previous mask heads to obtain last_feat, and fuse it
+        # with the normal mask feature
+        if self.mask_info_flow:
+            last_feat = None
+            for i in range(stage):
+                last_feat = self.mask_head[i](
+                    mask_feats, last_feat, return_logits=False)
+            mask_pred = mask_head(mask_feats, last_feat, return_feat=False)
+        else:
+            mask_pred = mask_head(mask_feats, return_feat=False)
+
+        mask_targets = mask_head.get_target(sampling_results, gt_masks,
+                                            rcnn_train_cfg)
+        pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results])
+        loss_mask = mask_head.loss(mask_pred, mask_targets, pos_labels)
+
+        mask_results = dict(loss_mask=loss_mask)
+        return mask_results
+
+    def _bbox_forward(self, stage, x, rois, semantic_feat=None):
+        bbox_roi_extractor = self.bbox_roi_extractor[stage]
+        bbox_head = self.bbox_head[stage]
+        bbox_feats = bbox_roi_extractor(
+            x[:len(bbox_roi_extractor.featmap_strides)], rois)
+        if self.with_semantic and 'bbox' in self.semantic_fusion:
+            bbox_semantic_feat = self.semantic_roi_extractor([semantic_feat],
+                                                             rois)
+            if bbox_semantic_feat.shape[-2:] != bbox_feats.shape[-2:]:
+                bbox_semantic_feat = F.adaptive_avg_pool2d(
+                    bbox_semantic_feat, bbox_feats.shape[-2:])
+            bbox_feats += bbox_semantic_feat
+        cls_score, bbox_pred = bbox_head(bbox_feats)
+
+        bbox_results = dict(cls_score=cls_score, bbox_pred=bbox_pred)
+        return bbox_results
+
+    def _mask_forward_test(self, stage, x, bboxes, semantic_feat=None):
+        mask_roi_extractor = self.mask_roi_extractor[stage]
+        mask_head = self.mask_head[stage]
+        mask_rois = bbox2roi([bboxes])
+        mask_feats = mask_roi_extractor(
+            x[:len(mask_roi_extractor.featmap_strides)], mask_rois)
+        if self.with_semantic and 'mask' in self.semantic_fusion:
+            mask_semantic_feat = self.semantic_roi_extractor([semantic_feat],
+                                                             mask_rois)
+            if mask_semantic_feat.shape[-2:] != mask_feats.shape[-2:]:
+                mask_semantic_feat = F.adaptive_avg_pool2d(
+                    mask_semantic_feat, mask_feats.shape[-2:])
+            mask_feats += mask_semantic_feat
+        if self.mask_info_flow:
+            last_feat = None
+            last_pred = None
+            for i in range(stage):
+                mask_pred, last_feat = self.mask_head[i](mask_feats, last_feat)
+                if last_pred is not None:
+                    mask_pred = mask_pred + last_pred
+                last_pred = mask_pred
+            mask_pred = mask_head(mask_feats, last_feat, return_feat=False)
+            if last_pred is not None:
+                mask_pred = mask_pred + last_pred
+        else:
+            mask_pred = mask_head(mask_feats)
+        return mask_pred
+
+    def forward_train(self,
+                      x,
+                      img_metas,
+                      proposal_list,
+                      gt_bboxes,
+                      gt_labels,
+                      gt_bboxes_ignore=None,
+                      gt_masks=None,
+                      gt_semantic_seg=None):
+        # semantic segmentation part
+        # 2 outputs: segmentation prediction and embedded features
+        losses = dict()
+        if self.with_semantic:
+            semantic_pred, semantic_feat = self.semantic_head(x)
+            loss_seg = self.semantic_head.loss(semantic_pred, gt_semantic_seg)
+            losses['loss_semantic_seg'] = loss_seg
+        else:
+            semantic_feat = None
+
+        for i in range(self.num_stages):
+            self.current_stage = i
+            rcnn_train_cfg = self.train_cfg[i]
+            lw = self.stage_loss_weights[i]
+
+            # assign gts and sample proposals
+            sampling_results = []
+            bbox_assigner = self.bbox_assigner[i]
+            bbox_sampler = self.bbox_sampler[i]
+            num_imgs = len(img_metas)
+            if gt_bboxes_ignore is None:
+                gt_bboxes_ignore = [None for _ in range(num_imgs)]
+
+            for j in range(num_imgs):
+                assign_result = bbox_assigner.assign(proposal_list[j],
+                                                     gt_bboxes[j],
+                                                     gt_bboxes_ignore[j],
+                                                     gt_labels[j])
+                sampling_result = bbox_sampler.sample(
+                    assign_result,
+                    proposal_list[j],
+                    gt_bboxes[j],
+                    gt_labels[j],
+                    feats=[lvl_feat[j][None] for lvl_feat in x])
+                sampling_results.append(sampling_result)
+
+            # bbox head forward and loss
+            bbox_results = \
+                self._bbox_forward_train(
+                    i, x, sampling_results, gt_bboxes, gt_labels,
+                    rcnn_train_cfg, semantic_feat)
+            roi_labels = bbox_results['bbox_targets'][0]
+
+            for name, value in bbox_results['loss_bbox'].items():
+                losses['s{}.{}'.format(i, name)] = (
+                    value * lw if 'loss' in name else value)
+
+            # mask head forward and loss
+            if self.with_mask:
+                # interleaved execution: use regressed bboxes by the box branch
+                # to train the mask branch
+                if self.interleaved:
+                    pos_is_gts = [res.pos_is_gt for res in sampling_results]
+                    with torch.no_grad():
+                        proposal_list = self.bbox_head[i].refine_bboxes(
+                            bbox_results['rois'], roi_labels,
+                            bbox_results['bbox_pred'], pos_is_gts, img_metas)
+                        # re-assign and sample 512 RoIs from 512 RoIs
+                        sampling_results = []
+                        for j in range(num_imgs):
+                            assign_result = bbox_assigner.assign(
+                                proposal_list[j], gt_bboxes[j],
+                                gt_bboxes_ignore[j], gt_labels[j])
+                            sampling_result = bbox_sampler.sample(
+                                assign_result,
+                                proposal_list[j],
+                                gt_bboxes[j],
+                                gt_labels[j],
+                                feats=[lvl_feat[j][None] for lvl_feat in x])
+                            sampling_results.append(sampling_result)
+                mask_results = self._mask_forward_train(
+                    i, x, sampling_results, gt_masks, rcnn_train_cfg,
+                    semantic_feat)
+                for name, value in mask_results['loss_mask'].items():
+                    losses['s{}.{}'.format(i, name)] = (
+                        value * lw if 'loss' in name else value)
+
+            # refine bboxes (same as Cascade R-CNN)
+            if i < self.num_stages - 1 and not self.interleaved:
+                pos_is_gts = [res.pos_is_gt for res in sampling_results]
+                with torch.no_grad():
+                    proposal_list = self.bbox_head[i].refine_bboxes(
+                        bbox_results['rois'], roi_labels,
+                        bbox_results['bbox_pred'], pos_is_gts, img_metas)
+
+        return losses
+
+    def simple_test(self, x, proposal_list, img_metas, rescale=False):
+        if self.with_semantic:
+            _, semantic_feat = self.semantic_head(x)
+        else:
+            semantic_feat = None
+
+        img_shape = img_metas[0]['img_shape']
+        ori_shape = img_metas[0]['ori_shape']
+        scale_factor = img_metas[0]['scale_factor']
+
+        # "ms" in variable names means multi-stage
+        ms_bbox_result = {}
+        ms_segm_result = {}
+        ms_scores = []
+        rcnn_test_cfg = self.test_cfg
+
+        rois = bbox2roi(proposal_list)
+        for i in range(self.num_stages):
+            bbox_head = self.bbox_head[i]
+            bbox_results = self._bbox_forward(
+                i, x, rois, semantic_feat=semantic_feat)
+            ms_scores.append(bbox_results['cls_score'])
+
+            if i < self.num_stages - 1:
+                bbox_label = bbox_results['cls_score'].argmax(dim=1)
+                rois = bbox_head.regress_by_class(rois, bbox_label,
+                                                  bbox_results['bbox_pred'],
+                                                  img_metas[0])
+
+        cls_score = sum(ms_scores) / float(len(ms_scores))
+        det_bboxes, det_labels = self.bbox_head[-1].get_det_bboxes(
+            rois,
+            cls_score,
+            bbox_results['bbox_pred'],
+            img_shape,
+            scale_factor,
+            rescale=rescale,
+            cfg=rcnn_test_cfg)
+        bbox_result = bbox2result(det_bboxes, det_labels,
+                                  self.bbox_head[-1].num_classes)
+        ms_bbox_result['ensemble'] = bbox_result
+
+        if self.with_mask:
+            if det_bboxes.shape[0] == 0:
+                mask_classes = self.mask_head[-1].num_classes - 1
+                segm_result = [[] for _ in range(mask_classes)]
+            else:
+                _bboxes = (
+                    det_bboxes[:, :4] * det_bboxes.new_tensor(scale_factor)
+                    if rescale else det_bboxes)
+
+                mask_rois = bbox2roi([_bboxes])
+                aug_masks = []
+                mask_roi_extractor = self.mask_roi_extractor[-1]
+                mask_feats = mask_roi_extractor(
+                    x[:len(mask_roi_extractor.featmap_strides)], mask_rois)
+                if self.with_semantic and 'mask' in self.semantic_fusion:
+                    mask_semantic_feat = self.semantic_roi_extractor(
+                        [semantic_feat], mask_rois)
+                    mask_feats += mask_semantic_feat
+                last_feat = None
+                for i in range(self.num_stages):
+                    mask_head = self.mask_head[i]
+                    if self.mask_info_flow:
+                        mask_pred, last_feat = mask_head(mask_feats, last_feat)
+                    else:
+                        mask_pred = mask_head(mask_feats)
+                    aug_masks.append(mask_pred.sigmoid().cpu().numpy())
+                merged_masks = merge_aug_masks(aug_masks,
+                                               [img_metas] * self.num_stages,
+                                               self.test_cfg)
+                segm_result = self.mask_head[-1].get_seg_masks(
+                    merged_masks, _bboxes, det_labels, rcnn_test_cfg,
+                    ori_shape, scale_factor, rescale)
+            ms_segm_result['ensemble'] = segm_result
+
+        if self.with_mask:
+            results = (ms_bbox_result['ensemble'], ms_segm_result['ensemble'])
+        else:
+            results = ms_bbox_result['ensemble']
+
+        return results
+
+    def aug_test(self, img_feats, img_metas, proposals=None, rescale=False):
+        """Test with augmentations.
+
+        If rescale is False, then returned bboxes and masks will fit the scale
+        of imgs[0].
+        """
+        if self.with_semantic:
+            semantic_feats = [
+                self.semantic_head(feat)[1] for feat in img_feats
+            ]
+        else:
+            semantic_feats = [None] * len(img_metas)
+
+        # recompute feats to save memory
+        proposal_list = self.aug_test_rpn(img_feats, img_metas,
+                                          self.test_cfg.rpn)
+
+        rcnn_test_cfg = self.test_cfg
+        aug_bboxes = []
+        aug_scores = []
+        for x, img_meta, semantic in zip(img_feats, img_metas, semantic_feats):
+            # only one image in the batch
+            img_shape = img_meta[0]['img_shape']
+            scale_factor = img_meta[0]['scale_factor']
+            flip = img_meta[0]['flip']
+
+            proposals = bbox_mapping(proposal_list[0][:, :4], img_shape,
+                                     scale_factor, flip)
+            # "ms" in variable names means multi-stage
+            ms_scores = []
+
+            rois = bbox2roi([proposals])
+            for i in range(self.num_stages):
+                bbox_head = self.bbox_head[i]
+                bbox_results = self._bbox_forward(
+                    i, x, rois, semantic_feat=semantic)
+                ms_scores.append(bbox_results['cls_score'])
+
+                if i < self.num_stages - 1:
+                    bbox_label = bbox_results['cls_score'].argmax(dim=1)
+                    rois = bbox_head.regress_by_class(
+                        rois, bbox_label, bbox_results['bbox_pred'],
+                        img_meta[0])
+
+            cls_score = sum(ms_scores) / float(len(ms_scores))
+            bboxes, scores = self.bbox_head[-1].get_det_bboxes(
+                rois,
+                cls_score,
+                bbox_results['bbox_pred'],
+                img_shape,
+                scale_factor,
+                rescale=False,
+                cfg=None)
+            aug_bboxes.append(bboxes)
+            aug_scores.append(scores)
+
+        # after merging, bboxes will be rescaled to the original image size
+        merged_bboxes, merged_scores = merge_aug_bboxes(
+            aug_bboxes, aug_scores, img_metas, rcnn_test_cfg)
+        det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores,
+                                                rcnn_test_cfg.score_thr,
+                                                rcnn_test_cfg.nms,
+                                                rcnn_test_cfg.max_per_img)
+
+        bbox_result = bbox2result(det_bboxes, det_labels,
+                                  self.bbox_head[-1].num_classes)
+
+        if self.with_mask:
+            if det_bboxes.shape[0] == 0:
+                segm_result = [[]
+                               for _ in range(self.mask_head[-1].num_classes -
+                                              1)]
+            else:
+                aug_masks = []
+                aug_img_metas = []
+                for x, img_meta, semantic in zip(img_feats, img_metas,
+                                                 semantic_feats):
+                    img_shape = img_meta[0]['img_shape']
+                    scale_factor = img_meta[0]['scale_factor']
+                    flip = img_meta[0]['flip']
+                    _bboxes = bbox_mapping(det_bboxes[:, :4], img_shape,
+                                           scale_factor, flip)
+                    mask_rois = bbox2roi([_bboxes])
+                    mask_feats = self.mask_roi_extractor[-1](
+                        x[:len(self.mask_roi_extractor[-1].featmap_strides)],
+                        mask_rois)
+                    if self.with_semantic:
+                        semantic_feat = semantic
+                        mask_semantic_feat = self.semantic_roi_extractor(
+                            [semantic_feat], mask_rois)
+                        if mask_semantic_feat.shape[-2:] != mask_feats.shape[
+                                -2:]:
+                            mask_semantic_feat = F.adaptive_avg_pool2d(
+                                mask_semantic_feat, mask_feats.shape[-2:])
+                        mask_feats += mask_semantic_feat
+                    last_feat = None
+                    for i in range(self.num_stages):
+                        mask_head = self.mask_head[i]
+                        if self.mask_info_flow:
+                            mask_pred, last_feat = mask_head(
+                                mask_feats, last_feat)
+                        else:
+                            mask_pred = mask_head(mask_feats)
+                        aug_masks.append(mask_pred.sigmoid().cpu().numpy())
+                        aug_img_metas.append(img_meta)
+                merged_masks = merge_aug_masks(aug_masks, aug_img_metas,
+                                               self.test_cfg)
+
+                ori_shape = img_metas[0][0]['ori_shape']
+                segm_result = self.mask_head[-1].get_seg_masks(
+                    merged_masks,
+                    det_bboxes,
+                    det_labels,
+                    rcnn_test_cfg,
+                    ori_shape,
+                    scale_factor=1.0,
+                    rescale=False)
+            return bbox_result, segm_result
+        else:
+            return bbox_result
diff --git a/mmdet/models/roi_heads/mask_scoring_roi_head.py b/mmdet/models/roi_heads/mask_scoring_roi_head.py
new file mode 100644
index 00000000..7fbd0616
--- /dev/null
+++ b/mmdet/models/roi_heads/mask_scoring_roi_head.py
@@ -0,0 +1,86 @@
+import torch
+
+from mmdet.core import bbox2roi
+from .. import builder
+from ..registry import HEADS
+from .standard_roi_head import StandardRoIHead
+
+
+@HEADS.register_module
+class MaskScoringRoIHead(StandardRoIHead):
+    """Mask Scoring RoIHead for Mask Scoring RCNN.
+
+    https://arxiv.org/abs/1903.00241
+    """
+
+    def __init__(self, mask_iou_head, **kwargs):
+        assert mask_iou_head is not None
+        super(MaskScoringRoIHead, self).__init__(**kwargs)
+        self.mask_iou_head = builder.build_head(mask_iou_head)
+
+    def init_weights(self, pretrained):
+        super(MaskScoringRoIHead, self).init_weights(pretrained)
+        self.mask_iou_head.init_weights()
+
+    def _mask_forward_train(self, x, sampling_results, bbox_feats, gt_masks,
+                            img_metas):
+        # in ms_rcnn, c4 model is not supported anymore
+        # pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results])
+        # mask_pred, mask_feats = self._mask_forward(x, pos_rois)
+        # mask_targets = self.mask_head.get_target(sampling_results, gt_masks,
+        #                                          self.train_cfg)
+        pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results])
+        # loss_mask = self.mask_head.loss(mask_pred, mask_targets, pos_labels)
+        mask_results = super(MaskScoringRoIHead,
+                             self)._mask_forward_train(x, sampling_results,
+                                                       bbox_feats, gt_masks,
+                                                       img_metas)
+
+        # mask iou head forward and loss
+        pos_mask_pred = mask_results['mask_pred'][
+            range(mask_results['mask_pred'].size(0)), pos_labels]
+        mask_iou_pred = self.mask_iou_head(mask_results['mask_feats'],
+                                           pos_mask_pred)
+        pos_mask_iou_pred = mask_iou_pred[range(mask_iou_pred.size(0)),
+                                          pos_labels]
+
+        mask_iou_targets = self.mask_iou_head.get_target(
+            sampling_results, gt_masks, pos_mask_pred,
+            mask_results['mask_targets'], self.train_cfg)
+        loss_mask_iou = self.mask_iou_head.loss(pos_mask_iou_pred,
+                                                mask_iou_targets)
+        mask_results['loss_mask'].update(loss_mask_iou)
+        return mask_results
+
+    def simple_test_mask(self,
+                         x,
+                         img_metas,
+                         det_bboxes,
+                         det_labels,
+                         rescale=False):
+        # image shape of the first image in the batch (only one)
+        ori_shape = img_metas[0]['ori_shape']
+        scale_factor = img_metas[0]['scale_factor']
+
+        if det_bboxes.shape[0] == 0:
+            segm_result = [[] for _ in range(self.mask_head.num_classes - 1)]
+            mask_scores = [[] for _ in range(self.mask_head.num_classes - 1)]
+        else:
+            # if det_bboxes is rescaled to the original image size, we need to
+            # rescale it back to the testing scale to obtain RoIs.
+            _bboxes = (
+                det_bboxes[:, :4] *
+                det_bboxes.new_tensor(scale_factor) if rescale else det_bboxes)
+            mask_rois = bbox2roi([_bboxes])
+            mask_results = self._mask_forward(x, mask_rois)
+            segm_result = self.mask_head.get_seg_masks(
+                mask_results['mask_pred'], _bboxes, det_labels, self.test_cfg,
+                ori_shape, scale_factor, rescale)
+            # get mask scores with mask iou head
+            mask_iou_pred = self.mask_iou_head(
+                mask_results['mask_feats'],
+                mask_results['mask_pred'][range(det_labels.size(0)),
+                                          det_labels + 1])
+            mask_scores = self.mask_iou_head.get_mask_scores(
+                mask_iou_pred, det_bboxes, det_labels)
+        return segm_result, mask_scores
diff --git a/mmdet/models/roi_heads/standard_roi_head.py b/mmdet/models/roi_heads/standard_roi_head.py
new file mode 100644
index 00000000..69071a26
--- /dev/null
+++ b/mmdet/models/roi_heads/standard_roi_head.py
@@ -0,0 +1,280 @@
+import torch
+
+from mmdet.core import bbox2result, bbox2roi, build_assigner, build_sampler
+from .. import builder
+from ..registry import HEADS
+from .base_roi_head import BaseRoIHead
+from .test_mixins import BBoxTestMixin, MaskTestMixin
+
+
+@HEADS.register_module
+class StandardRoIHead(BaseRoIHead, BBoxTestMixin, MaskTestMixin):
+    """Simplest base roi head including one bbox head and one mask head.
+    """
+
+    def init_assigner_sampler(self):
+        self.bbox_assigner = None
+        self.bbox_sampler = None
+        if self.train_cfg:
+            self.bbox_assigner = build_assigner(self.train_cfg.assigner)
+            self.bbox_sampler = build_sampler(
+                self.train_cfg.sampler, context=self)
+
+    def init_bbox_head(self, bbox_roi_extractor, bbox_head):
+        self.bbox_roi_extractor = builder.build_roi_extractor(
+            bbox_roi_extractor)
+        self.bbox_head = builder.build_head(bbox_head)
+
+    def init_mask_head(self, mask_roi_extractor, mask_head):
+        if mask_roi_extractor is not None:
+            self.mask_roi_extractor = builder.build_roi_extractor(
+                mask_roi_extractor)
+            self.share_roi_extractor = False
+        else:
+            self.share_roi_extractor = True
+            self.mask_roi_extractor = self.bbox_roi_extractor
+        self.mask_head = builder.build_head(mask_head)
+
+    def init_weights(self, pretrained):
+        if self.with_shared_head:
+            self.shared_head.init_weights(pretrained=pretrained)
+        if self.with_bbox:
+            self.bbox_roi_extractor.init_weights()
+            self.bbox_head.init_weights()
+        if self.with_mask:
+            self.mask_head.init_weights()
+            if not self.share_roi_extractor:
+                self.mask_roi_extractor.init_weights()
+
+    def forward_dummy(self, x, proposals):
+        # bbox head
+        outs = ()
+        rois = bbox2roi([proposals])
+        if self.with_bbox:
+            bbox_results = self._bbox_forward(x, rois)
+            outs = outs + (bbox_results['cls_score'],
+                           bbox_results['bbox_pred'])
+        # mask head
+        if self.with_mask:
+            mask_rois = rois[:100]
+            mask_results = self._mask_forward(x, mask_rois)
+            outs = outs + (mask_results['mask_pred'], )
+        return outs
+
+    def forward_train(self,
+                      x,
+                      img_metas,
+                      proposal_list,
+                      gt_bboxes,
+                      gt_labels,
+                      gt_bboxes_ignore=None,
+                      gt_masks=None):
+        """
+        Args:
+            x (list[Tensor]): list of multi-level img features.
+
+            img_metas (list[dict]): list of image info dict where each dict
+                has: 'img_shape', 'scale_factor', 'flip', and may also contain
+                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+                For details on the values of these keys see
+                `mmdet/datasets/pipelines/formatting.py:Collect`.
+
+            proposals (list[Tensors]): list of region proposals.
+
+            gt_bboxes (list[Tensor]): each item are the truth boxes for each
+                image in [tl_x, tl_y, br_x, br_y] format.
+
+            gt_labels (list[Tensor]): class indices corresponding to each box
+
+            gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+                boxes can be ignored when computing the loss.
+
+            gt_masks (None | Tensor) : true segmentation masks for each box
+                used if the architecture supports a segmentation task.
+
+        Returns:
+            dict[str, Tensor]: a dictionary of loss components
+        """
+        # assign gts and sample proposals
+        if self.with_bbox or self.with_mask:
+            num_imgs = len(img_metas)
+            if gt_bboxes_ignore is None:
+                gt_bboxes_ignore = [None for _ in range(num_imgs)]
+            sampling_results = []
+            for i in range(num_imgs):
+                assign_result = self.bbox_assigner.assign(
+                    proposal_list[i], gt_bboxes[i], gt_bboxes_ignore[i],
+                    gt_labels[i])
+                sampling_result = self.bbox_sampler.sample(
+                    assign_result,
+                    proposal_list[i],
+                    gt_bboxes[i],
+                    gt_labels[i],
+                    feats=[lvl_feat[i][None] for lvl_feat in x])
+                sampling_results.append(sampling_result)
+
+        losses = dict()
+        # bbox head forward and loss
+        if self.with_bbox:
+            bbox_results = self._bbox_forward_train(x, sampling_results,
+                                                    gt_bboxes, gt_labels,
+                                                    img_metas)
+            losses.update(bbox_results['loss_bbox'])
+
+        # mask head forward and loss
+        if self.with_mask:
+            mask_results = self._mask_forward_train(x, sampling_results,
+                                                    bbox_results['bbox_feats'],
+                                                    gt_masks, img_metas)
+            if mask_results['loss_mask'] is not None:
+                losses.update(mask_results['loss_mask'])
+
+        return losses
+
+    def _bbox_forward(self, x, rois):
+        # TODO: a more flexible way to decide which feature maps to use
+        bbox_feats = self.bbox_roi_extractor(
+            x[:self.bbox_roi_extractor.num_inputs], rois)
+        if self.with_shared_head:
+            bbox_feats = self.shared_head(bbox_feats)
+        cls_score, bbox_pred = self.bbox_head(bbox_feats)
+
+        bbox_results = dict(
+            cls_score=cls_score, bbox_pred=bbox_pred, bbox_feats=bbox_feats)
+        return bbox_results
+
+    def _bbox_forward_train(self, x, sampling_results, gt_bboxes, gt_labels,
+                            img_metas):
+        rois = bbox2roi([res.bboxes for res in sampling_results])
+        bbox_results = self._bbox_forward(x, rois)
+
+        bbox_targets = self.bbox_head.get_target(sampling_results, gt_bboxes,
+                                                 gt_labels, self.train_cfg)
+        loss_bbox = self.bbox_head.loss(bbox_results['cls_score'],
+                                        bbox_results['bbox_pred'],
+                                        *bbox_targets)
+
+        bbox_results.update(loss_bbox=loss_bbox)
+        return bbox_results
+
+    def _mask_forward_train(self, x, sampling_results, bbox_feats, gt_masks,
+                            img_metas):
+        if not self.share_roi_extractor:
+            pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results])
+            mask_results = self._mask_forward(x, pos_rois)
+        else:
+            pos_inds = []
+            device = bbox_feats.device
+            for res in sampling_results:
+                pos_inds.append(
+                    torch.ones(
+                        res.pos_bboxes.shape[0],
+                        device=device,
+                        dtype=torch.uint8))
+                pos_inds.append(
+                    torch.zeros(
+                        res.neg_bboxes.shape[0],
+                        device=device,
+                        dtype=torch.uint8))
+            pos_inds = torch.cat(pos_inds)
+            mask_results = self._mask_forward(
+                x, pos_inds=pos_inds, bbox_feats=bbox_feats)
+
+        mask_targets = self.mask_head.get_target(sampling_results, gt_masks,
+                                                 self.train_cfg)
+        pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results])
+        loss_mask = self.mask_head.loss(mask_results['mask_pred'],
+                                        mask_targets, pos_labels)
+
+        mask_results.update(loss_mask=loss_mask, mask_targets=mask_targets)
+        return mask_results
+
+    def _mask_forward(self, x, rois=None, pos_inds=None, bbox_feats=None):
+        assert ((rois is not None) ^
+                (pos_inds is not None and bbox_feats is not None))
+        if rois is not None:
+            mask_feats = self.mask_roi_extractor(
+                x[:self.mask_roi_extractor.num_inputs], rois)
+            if self.with_shared_head:
+                mask_feats = self.shared_head(mask_feats)
+        else:
+            assert bbox_feats is not None
+            mask_feats = bbox_feats[pos_inds]
+
+        mask_pred = self.mask_head(mask_feats)
+        mask_results = dict(mask_pred=mask_pred, mask_feats=mask_feats)
+        return mask_results
+
+    async def async_simple_test(self,
+                                x,
+                                proposal_list,
+                                img_metas,
+                                proposals=None,
+                                rescale=False):
+        """Async test without augmentation."""
+        assert self.with_bbox, 'Bbox head must be implemented.'
+
+        det_bboxes, det_labels = await self.async_test_bboxes(
+            x, img_metas, proposal_list, self.test_cfg, rescale=rescale)
+        bbox_results = bbox2result(det_bboxes, det_labels,
+                                   self.bbox_head.num_classes)
+        if not self.with_mask:
+            return bbox_results
+        else:
+            segm_results = await self.async_test_mask(
+                x,
+                img_metas,
+                det_bboxes,
+                det_labels,
+                rescale=rescale,
+                mask_test_cfg=self.test_cfg.get('mask'))
+            return bbox_results, segm_results
+
+    def simple_test(self,
+                    x,
+                    proposal_list,
+                    img_metas,
+                    proposals=None,
+                    rescale=False):
+        """Test without augmentation."""
+        assert self.with_bbox, 'Bbox head must be implemented.'
+
+        det_bboxes, det_labels = self.simple_test_bboxes(
+            x, img_metas, proposal_list, self.test_cfg, rescale=rescale)
+        bbox_results = bbox2result(det_bboxes, det_labels,
+                                   self.bbox_head.num_classes)
+
+        if not self.with_mask:
+            return bbox_results
+        else:
+            segm_results = self.simple_test_mask(
+                x, img_metas, det_bboxes, det_labels, rescale=rescale)
+            return bbox_results, segm_results
+
+    def aug_test(self, x, proposal_list, img_metas, rescale=False):
+        """Test with augmentations.
+
+        If rescale is False, then returned bboxes and masks will fit the scale
+        of imgs[0].
+        """
+        # recompute feats to save memory
+        det_bboxes, det_labels = self.aug_test_bboxes(x, img_metas,
+                                                      proposal_list,
+                                                      self.test_cfg)
+
+        if rescale:
+            _det_bboxes = det_bboxes
+        else:
+            _det_bboxes = det_bboxes.clone()
+            _det_bboxes[:, :4] *= det_bboxes.new_tensor(
+                img_metas[0][0]['scale_factor'])
+        bbox_results = bbox2result(_det_bboxes, det_labels,
+                                   self.bbox_head.num_classes)
+
+        # det_bboxes always keep the original scale
+        if self.with_mask:
+            segm_results = self.aug_test_mask(x, img_metas, det_bboxes,
+                                              det_labels)
+            return bbox_results, segm_results
+        else:
+            return bbox_results
diff --git a/mmdet/models/roi_heads/test_mixins.py b/mmdet/models/roi_heads/test_mixins.py
new file mode 100644
index 00000000..4cc704a4
--- /dev/null
+++ b/mmdet/models/roi_heads/test_mixins.py
@@ -0,0 +1,202 @@
+import logging
+import sys
+
+import torch
+
+from mmdet.core import (bbox2roi, bbox_mapping, merge_aug_bboxes,
+                        merge_aug_masks, multiclass_nms)
+
+logger = logging.getLogger(__name__)
+
+if sys.version_info >= (3, 7):
+    from mmdet.utils.contextmanagers import completed
+
+
+class BBoxTestMixin(object):
+
+    if sys.version_info >= (3, 7):
+
+        async def async_test_bboxes(self,
+                                    x,
+                                    img_metas,
+                                    proposals,
+                                    rcnn_test_cfg,
+                                    rescale=False,
+                                    bbox_semaphore=None,
+                                    global_lock=None):
+            """Async test only det bboxes without augmentation."""
+            rois = bbox2roi(proposals)
+            roi_feats = self.bbox_roi_extractor(
+                x[:len(self.bbox_roi_extractor.featmap_strides)], rois)
+            if self.with_shared_head:
+                roi_feats = self.shared_head(roi_feats)
+            sleep_interval = rcnn_test_cfg.get('async_sleep_interval', 0.017)
+
+            async with completed(
+                    __name__, 'bbox_head_forward',
+                    sleep_interval=sleep_interval):
+                cls_score, bbox_pred = self.bbox_head(roi_feats)
+
+            img_shape = img_metas[0]['img_shape']
+            scale_factor = img_metas[0]['scale_factor']
+            det_bboxes, det_labels = self.bbox_head.get_det_bboxes(
+                rois,
+                cls_score,
+                bbox_pred,
+                img_shape,
+                scale_factor,
+                rescale=rescale,
+                cfg=rcnn_test_cfg)
+            return det_bboxes, det_labels
+
+    def simple_test_bboxes(self,
+                           x,
+                           img_metas,
+                           proposals,
+                           rcnn_test_cfg,
+                           rescale=False):
+        """Test only det bboxes without augmentation."""
+        rois = bbox2roi(proposals)
+        bbox_results = self._bbox_forward(x, rois)
+        img_shape = img_metas[0]['img_shape']
+        scale_factor = img_metas[0]['scale_factor']
+        det_bboxes, det_labels = self.bbox_head.get_det_bboxes(
+            rois,
+            bbox_results['cls_score'],
+            bbox_results['bbox_pred'],
+            img_shape,
+            scale_factor,
+            rescale=rescale,
+            cfg=rcnn_test_cfg)
+        return det_bboxes, det_labels
+
+    def aug_test_bboxes(self, feats, img_metas, proposal_list, rcnn_test_cfg):
+        aug_bboxes = []
+        aug_scores = []
+        for x, img_meta in zip(feats, img_metas):
+            # only one image in the batch
+            img_shape = img_meta[0]['img_shape']
+            scale_factor = img_meta[0]['scale_factor']
+            flip = img_meta[0]['flip']
+            # TODO more flexible
+            proposals = bbox_mapping(proposal_list[0][:, :4], img_shape,
+                                     scale_factor, flip)
+            rois = bbox2roi([proposals])
+            # recompute feature maps to save GPU memory
+            bbox_results = self._bbox_forward(x, rois)
+            bboxes, scores = self.bbox_head.get_det_bboxes(
+                rois,
+                bbox_results['cls_score'],
+                bbox_results['bbox_pred'],
+                img_shape,
+                scale_factor,
+                rescale=False,
+                cfg=None)
+            aug_bboxes.append(bboxes)
+            aug_scores.append(scores)
+        # after merging, bboxes will be rescaled to the original image size
+        merged_bboxes, merged_scores = merge_aug_bboxes(
+            aug_bboxes, aug_scores, img_metas, rcnn_test_cfg)
+        det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores,
+                                                rcnn_test_cfg.score_thr,
+                                                rcnn_test_cfg.nms,
+                                                rcnn_test_cfg.max_per_img)
+        return det_bboxes, det_labels
+
+
+class MaskTestMixin(object):
+
+    if sys.version_info >= (3, 7):
+
+        async def async_test_mask(self,
+                                  x,
+                                  img_metas,
+                                  det_bboxes,
+                                  det_labels,
+                                  rescale=False,
+                                  mask_test_cfg=None):
+            # image shape of the first image in the batch (only one)
+            ori_shape = img_metas[0]['ori_shape']
+            scale_factor = img_metas[0]['scale_factor']
+            if det_bboxes.shape[0] == 0:
+                segm_result = [[]
+                               for _ in range(self.mask_head.num_classes - 1)]
+            else:
+                _bboxes = (
+                    det_bboxes[:, :4] *
+                    scale_factor if rescale else det_bboxes)
+                mask_rois = bbox2roi([_bboxes])
+                mask_feats = self.mask_roi_extractor(
+                    x[:len(self.mask_roi_extractor.featmap_strides)],
+                    mask_rois)
+
+                if self.with_shared_head:
+                    mask_feats = self.shared_head(mask_feats)
+                if mask_test_cfg and mask_test_cfg.get('async_sleep_interval'):
+                    sleep_interval = mask_test_cfg['async_sleep_interval']
+                else:
+                    sleep_interval = 0.035
+                async with completed(
+                        __name__,
+                        'mask_head_forward',
+                        sleep_interval=sleep_interval):
+                    mask_pred = self.mask_head(mask_feats)
+                segm_result = self.mask_head.get_seg_masks(
+                    mask_pred, _bboxes, det_labels, self.test_cfg, ori_shape,
+                    scale_factor, rescale)
+            return segm_result
+
+    def simple_test_mask(self,
+                         x,
+                         img_metas,
+                         det_bboxes,
+                         det_labels,
+                         rescale=False):
+        # image shape of the first image in the batch (only one)
+        ori_shape = img_metas[0]['ori_shape']
+        scale_factor = img_metas[0]['scale_factor']
+        if det_bboxes.shape[0] == 0:
+            segm_result = [[] for _ in range(self.mask_head.num_classes - 1)]
+        else:
+            # if det_bboxes is rescaled to the original image size, we need to
+            # rescale it back to the testing scale to obtain RoIs.
+            if rescale and not isinstance(scale_factor, float):
+                scale_factor = torch.from_numpy(scale_factor).to(
+                    det_bboxes.device)
+            _bboxes = (
+                det_bboxes[:, :4] * scale_factor if rescale else det_bboxes)
+            mask_rois = bbox2roi([_bboxes])
+            mask_results = self._mask_forward(x, mask_rois)
+            segm_result = self.mask_head.get_seg_masks(
+                mask_results['mask_pred'], _bboxes, det_labels, self.test_cfg,
+                ori_shape, scale_factor, rescale)
+        return segm_result
+
+    def aug_test_mask(self, feats, img_metas, det_bboxes, det_labels):
+        if det_bboxes.shape[0] == 0:
+            segm_result = [[] for _ in range(self.mask_head.num_classes - 1)]
+        else:
+            aug_masks = []
+            for x, img_meta in zip(feats, img_metas):
+                img_shape = img_meta[0]['img_shape']
+                scale_factor = img_meta[0]['scale_factor']
+                flip = img_meta[0]['flip']
+                _bboxes = bbox_mapping(det_bboxes[:, :4], img_shape,
+                                       scale_factor, flip)
+                mask_rois = bbox2roi([_bboxes])
+                mask_results = self._mask_forward(x, mask_rois)
+                # convert to numpy array to save memory
+                aug_masks.append(
+                    mask_results['mask_pred'].sigmoid().cpu().numpy())
+            merged_masks = merge_aug_masks(aug_masks, img_metas, self.test_cfg)
+
+            ori_shape = img_metas[0][0]['ori_shape']
+            segm_result = self.mask_head.get_seg_masks(
+                merged_masks,
+                det_bboxes,
+                det_labels,
+                self.test_cfg,
+                ori_shape,
+                scale_factor=1.0,
+                rescale=False)
+        return segm_result
diff --git a/tests/test_config.py b/tests/test_config.py
index 968b0ed3..7e92dca0 100644
--- a/tests/test_config.py
+++ b/tests/test_config.py
@@ -1,4 +1,4 @@
-from os.path import dirname, exists, join
+from os.path import dirname, exists, join, relpath
 
 
 def _get_config_directory():
@@ -26,131 +26,10 @@ def test_config_build_detector():
     config_dpath = _get_config_directory()
     print('Found config_dpath = {!r}'.format(config_dpath))
 
-    # import glob
-    # config_fpaths = list(glob.glob(join(config_dpath, '**', '*.py')))
-    # config_names = [relpath(p, config_dpath) for p in config_fpaths]
-
-    # Only tests a representative subset of configurations
-
-    config_names = [
-        # 'dcn/faster_rcnn_dconv_c3-c5_r50_fpn_1x_coco.py',
-        # 'dcn/cascade_mask_rcnn_dconv_c3-c5_r50_fpn_1x_coco.py',
-        # 'dcn/faster_rcnn_dpool_r50_fpn_1x_coco.py',
-        'dcn/mask_rcnn_r50_fpn_dconv_c3-c5_1x_coco.py',
-        # 'dcn/faster_rcnn_dconv_c3-c5_x101_32x4d_fpn_1x_coco.py',
-        # 'dcn/cascade_rcnn_dconv_c3-c5_r50_fpn_1x_coco.py',
-        # 'dcn/faster_rcnn_mdpool_r50_fpn_1x_coco.py',
-        # 'dcn/faster_rcnn_mdconv_c3-c5_group4_r50_fpn_1x_coco.py',
-        # 'dcn/faster_rcnn_mdconv_c3-c5_r50_fpn_1x_coco.py',
-        # ---
-        # 'htc/htc_x101_32x4d_fpn_20e_16gpu_coco.py',
-        'htc/htc_without_semantic_r50_fpn_1x_coco.py',
-        # 'htc/htc_dconv_c3-c5_mstrain_400_1400_x101_64x4d_fpn_20e_coco.py',
-        # 'htc/htc_x101_64x4d_fpn_20e_16gpu_coco.py',
-        # 'htc/htc_r50_fpn_1x_coco.py',
-        # 'htc/htc_r101_fpn_20e.py',
-        # 'htc/htc_r50_fpn_20e.py',
-        # ---
-        'cityscapes/mask_rcnn_r50_fpn_1x_cityscapes.py',
-        # 'cityscapes/faster_rcnn_r50_fpn_1x_cityscapes.py',
-        # ---
-        # 'scratch/scratch_faster_rcnn_r50_fpn_gn_6x_coco.py',
-        # 'scratch/scratch_mask_rcnn_r50_fpn_gn_6x_coco.py',
-        # ---
-        # 'grid_rcnn/grid_rcnn_gn_head_x101_32x4d_fpn_2x_coco.py',
-        'grid_rcnn/grid_rcnn_r50_fpn_gn-head_2x_coco.py',
-        # ---
-        'double_heads/dh_faster_rcnn_r50_fpn_1x_coco.py',
-        # ---
-        'empirical_attention/faster_rcnn_r50_fpn_attention_0010_dcn_1x_coco'
-        '.py',
-        # 'empirical_attention/faster_rcnn_r50_fpn_attention_1111_1x_coco.py',
-        # 'empirical_attention/faster_rcnn_r50_fpn_attention_0010_1x_coco.py',
-        # 'empirical_attention/faster_rcnn_r50_fpn_attention_1111_dcn_1x_coco
-        # .py',
-        # ---
-        # 'ms_rcnn/ms_rcnn_r101_caffe_fpn_1x_coco.py',
-        # 'ms_rcnn/ms_rcnn_x101_64x4d_fpn_1x_coco.py',
-        # 'ms_rcnn/ms_rcnn_r50_caffe_fpn_1x_coco.py',
-        # ---
-        # 'guided_anchoring/ga_faster_x101_32x4d_fpn_1x_coco.py',
-        # 'guided_anchoring/ga_rpn_x101_32x4d_fpn_1x_coco.py',
-        # 'guided_anchoring/ga_retinanet_r50_caffe_fpn_1x_coco.py',
-        # 'guided_anchoring/ga_fast_r50_caffe_fpn_1x_coco.py',
-        # 'guided_anchoring/ga_retinanet_x101_32x4d_fpn_1x_coco.py',
-        # 'guided_anchoring/ga_rpn_r101_caffe_rpn_1x_coco.py',
-        # 'guided_anchoring/ga_faster_r50_caffe_fpn_1x_coco.py',
-        'guided_anchoring/ga_rpn_r50_caffe_fpn_1x_coco.py',
-        # ---
-        'foveabox/fovea_r50_fpn_4x4_1x_coco.py',
-        # 'foveabox/fovea_align_gn_ms_r101_fpn_4gpu_2x_coco.py',
-        # 'foveabox/fovea_align_gn_r50_fpn_4gpu_2x_coco.py',
-        # 'foveabox/fovea_align_gn_r101_fpn_4gpu_2x_coco.py',
-        'foveabox/fovea_align_r50_fpn_gn-head_mstrain_640-800_4x4_2x_coco.py',
-        # ---
-        # 'hrnet/cascade_rcnn_hrnetv2p_w32_20e_coco.py',
-        # 'hrnet/mask_rcnn_hrnetv2p_w32_1x_coco.py',
-        # 'hrnet/cascade_mask_rcnn_hrnetv2p_w32_20e_coco.py',
-        # 'hrnet/htc_hrnetv2p_w32_20e_coco.py',
-        # 'hrnet/faster_rcnn_hrnetv2p_w18_1x_coco.py',
-        # 'hrnet/mask_rcnn_hrnetv2p_w18_1x_coco.py',
-        # 'hrnet/faster_rcnn_hrnetv2p_w32_1x_coco.py',
-        # 'hrnet/faster_rcnn_hrnetv2p_w40_1x_coco.py',
-        'hrnet/fcos_hrnetv2p_w32_gn-head_4x4_1x_coco.py',
-        # ---
-        # 'gn+ws/faster_rcnn_r50_fpn_gn_ws_1x_coco.py',
-        # 'gn+ws/mask_rcnn_x101_32x4d_fpn_gn_ws_2x_coco.py',
-        'gn+ws/mask_rcnn_r50_fpn_gn_ws-all_2x_coco.py',
-        # 'gn+ws/mask_rcnn_r50_fpn_gn_ws_20_23_24e_coco.py',
-        # ---
-        # 'wider_face/ssd300_wider_face.py',
-        # ---
-        'pascal_voc/ssd300_voc0712.py',
-        'pascal_voc/faster_rcnn_r50_fpn_1x_voc0712.py',
-        'pascal_voc/ssd512_voc0712.py',
-        # ---
-        # 'gcnet/mask_rcnn_r4_gcb_c3-c5_r50_fpn_syncbn_1x_coco.py',
-        # 'gcnet/mask_rcnn_r16_gcb_c3-c5_r50_fpn_syncbn_1x_coco.py',
-        # 'gcnet/mask_rcnn_r4_gcb_c3-c5_r50_fpn_1x_coco.py',
-        # 'gcnet/mask_rcnn_r16_gcb_c3-c5_r50_fpn_1x_coco.py',
-        'gcnet/mask_rcnn_r50_fpn_syncbn-backbone_1x_coco.py',
-        # ---
-        'gn/mask_rcnn_r50_fpn_gn-all_contrib_2x_coco.py',
-        # 'gn/mask_rcnn_r50_fpn_gn_2x_coco.py',
-        # 'gn/mask_rcnn_r101_fpn_gn_2x_coco.py',
-        # ---
-        # 'reppoints/reppoints_moment_x101_dcn_fpn_2x.py',
-        'reppoints/reppoints_moment_r50_fpn_gn-neck+head_2x_coco.py',
-        # 'reppoints/reppoints_moment_x101_dcn_fpn_2x_mt_coco.py',
-        'reppoints/reppoints_partial_minmax_r50_fpn_gn-neck+head_1x_coco.py',
-        'reppoints/bbox_r50_grid_center_fpn_gn-neck+head_1x_coco.py',
-        # 'reppoints/reppoints_moment_r101_dcn_fpn_2x_coco.py',
-        # 'reppoints/reppoints_moment_r101_fpn_2x_mt_coco.py',
-        # 'reppoints/reppoints_moment_r50_fpn_2x_mt_coco.py',
-        'reppoints/reppoints_minmax_r50_fpn_gn-neck+head_1x_coco.py',
-        # 'reppoints/reppoints_moment_r50_fpn_1x_coco.py',
-        # 'reppoints/reppoints_moment_r101_fpn_2x_coco.py',
-        # 'reppoints/reppoints_moment_r101_dcn_fpn_2x_mt_coco.py',
-        'reppoints/bbox_r50_grid_fpn_gn-neck+head_1x_coco.py',
-        # ---
-        # 'fcos/fcos_mstrain_640_800_x101_64x4d_fpn_gn_2x_coco.py',
-        # 'fcos/fcos_mstrain_640_800_r101_caffe_fpn_gn_2x_4gpu_coco.py',
-        'fcos/fcos_r50_caffe_fpn_gn-head_4x4_1x_coco.py',
-        # ---
-        'albu_example/mask_rcnn_r50_fpn_albu_1x_coco.py',
-        # ---
-        'libra_rcnn/libra_faster_rcnn_r50_fpn_1x_coco.py',
-        # 'libra_rcnn/libra_retinanet_r50_fpn_1x_coco.py',
-        # 'libra_rcnn/libra_faster_rcnn_r101_fpn_1x_coco.py',
-        # 'libra_rcnn/libra_faster_rcnn_x101_64x4d_fpn_1x_coco.py',
-        # 'libra_rcnn/libra_fast_rcnn_r50_fpn_1x_coco.py',
-        # ---
-        # 'ghm/retinanet_ghm_r50_fpn_1x_coco.py',
-        # ---
-        # 'fp16/retinanet_r50_fpn_fp16_1x_coco.py',
-        'fp16/mask_rcnn_r50_fpn_fp16_1x_coco.py',
-        'fp16/faster_rcnn_r50_fpn_fp16_1x_coco.py'
-    ]
+    import glob
+    config_fpaths = list(glob.glob(join(config_dpath, '**', '*.py')))
+    config_fpaths = [p for p in config_fpaths if p.find('_base_') == -1]
+    config_names = [relpath(p, config_dpath) for p in config_fpaths]
 
     print('Using {} config files'.format(len(config_names)))
 
@@ -173,6 +52,21 @@ def test_config_build_detector():
             test_cfg=config_mod.test_cfg)
         assert detector is not None
 
+        if 'roi_head' in config_mod.model.keys():
+            # for two stage detector
+            # detectors must have bbox head
+            assert detector.roi_head.with_bbox and detector.with_bbox
+            assert detector.roi_head.with_mask == detector.with_mask
+
+            head_config = config_mod.model['roi_head']
+            _check_roihead(head_config, detector.roi_head)
+        # else:
+        #     # for single stage detector
+        #     # detectors must have bbox head
+        #     # assert detector.with_bbox
+        #     head_config = config_mod.model['bbox_head']
+        #     _check_bboxhead(head_config, detector.bbox_head)
+
 
 def test_config_data_pipeline():
     """
@@ -279,3 +173,144 @@ def test_config_data_pipeline():
         results['mask_fields'] = ['gt_masks']
         output_results = test_pipeline(results)
         assert output_results is not None
+
+
+def _check_roihead(config, head):
+    # check consistency between head_config and roi_head
+    assert config['type'] == head.__class__.__name__
+
+    # check roi_align
+    bbox_roi_cfg = config.bbox_roi_extractor
+    bbox_roi_extractor = head.bbox_roi_extractor
+    _check_roi_extractor(bbox_roi_cfg, bbox_roi_extractor)
+
+    # check bbox head infos
+    bbox_cfg = config.bbox_head
+    bbox_head = head.bbox_head
+    _check_bboxhead(bbox_cfg, bbox_head)
+
+    if head.with_mask:
+        # check roi_align
+        if config.mask_roi_extractor:
+            mask_roi_cfg = config.mask_roi_extractor
+            mask_roi_extractor = head.mask_roi_extractor
+            _check_roi_extractor(mask_roi_cfg, mask_roi_extractor,
+                                 bbox_roi_extractor)
+
+        # check mask head infos
+        mask_head = head.mask_head
+        mask_cfg = config.mask_head
+        _check_maskhead(mask_cfg, mask_head)
+
+    # check arch specific settings, e.g., cascade/htc
+    if config['type'] in ['CascadeRoIHead', 'HybridTaskCascadeRoIHead']:
+        assert config.num_stages == len(head.bbox_head)
+        assert config.num_stages == len(head.bbox_roi_extractor)
+
+        if head.with_mask:
+            assert config.num_stages == len(head.mask_head)
+            assert config.num_stages == len(head.mask_roi_extractor)
+
+    elif config['type'] in ['MaskScoringRoIHead']:
+        assert (hasattr(head, 'mask_iou_head')
+                and head.mask_iou_head is not None)
+        mask_iou_cfg = config.mask_iou_head
+        mask_iou_head = head.mask_iou_head
+        assert (mask_iou_cfg.fc_out_channels ==
+                mask_iou_head.fc_mask_iou.in_features)
+
+    elif config['type'] in ['GridRoIHead']:
+        grid_roi_cfg = config.grid_roi_extractor
+        grid_roi_extractor = head.grid_roi_extractor
+        _check_roi_extractor(grid_roi_cfg, grid_roi_extractor,
+                             bbox_roi_extractor)
+
+        config.grid_head.grid_points = head.grid_head.grid_points
+
+
+def _check_roi_extractor(config, roi_extractor, prev_roi_extractor=None):
+    import torch.nn as nn
+    if isinstance(roi_extractor, nn.ModuleList):
+        if prev_roi_extractor:
+            prev_roi_extractor = prev_roi_extractor[0]
+        roi_extractor = roi_extractor[0]
+
+    assert (len(config.featmap_strides) == len(roi_extractor.roi_layers))
+    assert (config.out_channels == roi_extractor.out_channels)
+    from torch.nn.modules.utils import _pair
+    assert (_pair(
+        config.roi_layer.out_size) == roi_extractor.roi_layers[0].out_size)
+
+    if 'use_torchvision' in config.roi_layer:
+        assert (config.roi_layer.use_torchvision ==
+                roi_extractor.roi_layers[0].use_torchvision)
+    elif 'aligned' in config.roi_layer:
+        assert (
+            config.roi_layer.aligned == roi_extractor.roi_layers[0].aligned)
+
+    if prev_roi_extractor:
+        assert (roi_extractor.roi_layers[0].aligned ==
+                prev_roi_extractor.roi_layers[0].aligned)
+        assert (roi_extractor.roi_layers[0].use_torchvision ==
+                prev_roi_extractor.roi_layers[0].use_torchvision)
+
+
+def _check_maskhead(mask_cfg, mask_head):
+    import torch.nn as nn
+    if isinstance(mask_cfg, list):
+        for single_mask_cfg, single_mask_head in zip(mask_cfg, mask_head):
+            _check_maskhead(single_mask_cfg, single_mask_head)
+    elif isinstance(mask_head, nn.ModuleList):
+        for single_mask_head in mask_head:
+            _check_maskhead(mask_cfg, single_mask_head)
+    else:
+        assert mask_cfg['type'] == mask_head.__class__.__name__
+        assert mask_cfg.in_channels == mask_head.in_channels
+        assert (
+            mask_cfg.conv_out_channels == mask_head.conv_logits.in_channels)
+        class_agnostic = mask_cfg.get('class_agnostic', False)
+        out_dim = (1 if class_agnostic else mask_cfg.num_classes)
+        assert mask_head.conv_logits.out_channels == out_dim
+
+
+def _check_bboxhead(bbox_cfg, bbox_head):
+    import torch.nn as nn
+    if isinstance(bbox_cfg, list):
+        for single_bbox_cfg, single_bbox_head in zip(bbox_cfg, bbox_head):
+            _check_bboxhead(single_bbox_cfg, single_bbox_head)
+    elif isinstance(bbox_head, nn.ModuleList):
+        for single_bbox_head in bbox_head:
+            _check_bboxhead(bbox_cfg, single_bbox_head)
+    else:
+        assert bbox_cfg['type'] == bbox_head.__class__.__name__
+        assert bbox_cfg.in_channels == bbox_head.in_channels
+        with_cls = bbox_cfg.get('with_cls', True)
+        if with_cls:
+            fc_out_channels = bbox_cfg.get('fc_out_channels', 2048)
+            assert (fc_out_channels == bbox_head.fc_cls.in_features)
+            assert bbox_cfg.num_classes == bbox_head.fc_cls.out_features
+
+        with_reg = bbox_cfg.get('with_reg', True)
+        if with_reg:
+            out_dim = (4 if bbox_cfg.reg_class_agnostic else 4 *
+                       bbox_cfg.num_classes)
+            assert bbox_head.fc_reg.out_features == out_dim
+
+
+def _check_anchorhead(config, head):
+    # check consistency between head_config and roi_head
+    assert config['type'] == head.__class__.__name__
+    assert config.in_channels == head.in_channels
+
+    num_classes = (
+        config.num_classes -
+        1 if config.loss_cls.get('use_sigmoid', False) else config.num_classes)
+    if config['type'] == 'ATSSHead':
+        assert (config.feat_channels == head.atss_cls.in_channels)
+        assert (config.feat_channels == head.atss_reg.in_channels)
+        assert (config.feat_channels == head.atss_centerness.in_channels)
+    else:
+        assert (config.in_channels == head.conv_cls.in_channels)
+        assert (config.in_channels == head.conv_reg.in_channels)
+        assert (head.conv_cls.out_channels == num_classes * head.num_anchors)
+        assert head.fc_reg.out_channels == 4 * head.num_anchors
diff --git a/tests/test_forward.py b/tests/test_forward.py
index 2599fb11..2ace7bbd 100644
--- a/tests/test_forward.py
+++ b/tests/test_forward.py
@@ -180,7 +180,8 @@ def test_cascade_forward():
         'cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py')
     model['pretrained'] = None
     # torchvision roi align supports CPU
-    model['bbox_roi_extractor']['roi_layer']['use_torchvision'] = True
+    model['roi_head']['bbox_roi_extractor']['roi_layer'][
+        'use_torchvision'] = True
 
     from mmdet.models import build_detector
     detector = build_detector(model, train_cfg=train_cfg, test_cfg=test_cfg)
@@ -233,7 +234,8 @@ def test_faster_rcnn_forward():
         'faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py')
     model['pretrained'] = None
     # torchvision roi align supports CPU
-    model['bbox_roi_extractor']['roi_layer']['use_torchvision'] = True
+    model['roi_head']['bbox_roi_extractor']['roi_layer'][
+        'use_torchvision'] = True
 
     from mmdet.models import build_detector
     detector = build_detector(model, train_cfg=train_cfg, test_cfg=test_cfg)
@@ -286,7 +288,8 @@ def test_faster_rcnn_ohem_forward():
         'faster_rcnn/faster_rcnn_r50_fpn_ohem_1x_coco.py')
     model['pretrained'] = None
     # torchvision roi align supports CPU
-    model['bbox_roi_extractor']['roi_layer']['use_torchvision'] = True
+    model['roi_head']['bbox_roi_extractor']['roi_layer'][
+        'use_torchvision'] = True
 
     from mmdet.models import build_detector
     detector = build_detector(model, train_cfg=train_cfg, test_cfg=test_cfg)
diff --git a/tests/test_sampler.py b/tests/test_sampler.py
index cdae8f48..53ef380f 100644
--- a/tests/test_sampler.py
+++ b/tests/test_sampler.py
@@ -103,9 +103,11 @@ def _context_for_ohem():
         'faster_rcnn/faster_rcnn_r50_fpn_ohem_1x_coco.py')
     model['pretrained'] = None
     # torchvision roi align supports CPU
-    model['bbox_roi_extractor']['roi_layer']['use_torchvision'] = True
+    model['roi_head']['bbox_roi_extractor']['roi_layer'][
+        'use_torchvision'] = True
     from mmdet.models import build_detector
-    context = build_detector(model, train_cfg=train_cfg, test_cfg=test_cfg)
+    context = build_detector(
+        model, train_cfg=train_cfg, test_cfg=test_cfg).roi_head
     return context
 
 
-- 
GitLab