From b5431092505f7dcd7de616c8a79eba4d2532fbc8 Mon Sep 17 00:00:00 2001
From: Jiaqi Wang <1155098160@link.cuhk.edu.hk>
Date: Fri, 21 Feb 2020 20:57:33 +0800
Subject: [PATCH] Code Release: CARAFE: Content-Aware ReAssembly of FEatures
 (ICCV 2019) (#1583)

* add carafe ops

* rename carafe benchmark

* grad check fix

* update grad check

* update grad check output

* add fpn carafe & mask head carafe

* add ReadMe

* update readme

* add carafe setup

* update naive carafe

* update readme and setup

* readme typo fix

* fix flake8 error

* fix flake 8 error

* fix flake 8

* fix flake 8 more

* flake 8 fix plus

* flake 8 fix

* fix flake 8

* reformat ops files

* update fpn files and cfgs

* update readme

* update fcn_mask_head

* update fpn_carafe

* update kernel

* update

* update

* add docstring in FPN_CARAFE

* reformat with yapf

* update

* update

* add build upsampler

* fix mask head build error

* reformat build upsample layer

* add doc string for CARAFE and PixelShuffle

* update

* update upsample_cfg_

* update

* update doc string

* rm abbr in build upsample layer

* update readme

* update model_zoo

* add link to other features in ReadMe
---
 README.md                                     |  13 +-
 configs/carafe/README.md                      |  53 ++
 .../carafe/faster_rcnn_r50_fpn_carafe_1x.py   | 188 +++++++
 configs/carafe/mask_rcnn_r50_fpn_carafe_1x.py | 210 ++++++++
 docs/MODEL_ZOO.md                             |   3 +
 mmdet/models/mask_heads/fcn_mask_head.py      |  51 +-
 mmdet/models/necks/__init__.py                |   3 +-
 mmdet/models/necks/fpn_carafe.py              | 254 ++++++++++
 mmdet/models/utils/__init__.py                |   5 +-
 mmdet/models/utils/upsample.py                |  78 +++
 mmdet/ops/carafe/__init__.py                  |   3 +
 mmdet/ops/carafe/carafe.py                    | 237 +++++++++
 mmdet/ops/carafe/grad_check.py                |  61 +++
 mmdet/ops/carafe/setup.py                     |  29 ++
 mmdet/ops/carafe/src/carafe_cuda.cpp          | 113 +++++
 mmdet/ops/carafe/src/carafe_cuda_kernel.cu    | 475 ++++++++++++++++++
 mmdet/ops/carafe/src/carafe_naive_cuda.cpp    |  75 +++
 .../carafe/src/carafe_naive_cuda_kernel.cu    | 176 +++++++
 setup.py                                      |  11 +
 19 files changed, 2012 insertions(+), 26 deletions(-)
 create mode 100644 configs/carafe/README.md
 create mode 100644 configs/carafe/faster_rcnn_r50_fpn_carafe_1x.py
 create mode 100644 configs/carafe/mask_rcnn_r50_fpn_carafe_1x.py
 create mode 100644 mmdet/models/necks/fpn_carafe.py
 create mode 100644 mmdet/models/utils/upsample.py
 create mode 100644 mmdet/ops/carafe/__init__.py
 create mode 100644 mmdet/ops/carafe/carafe.py
 create mode 100644 mmdet/ops/carafe/grad_check.py
 create mode 100644 mmdet/ops/carafe/setup.py
 create mode 100644 mmdet/ops/carafe/src/carafe_cuda.cpp
 create mode 100644 mmdet/ops/carafe/src/carafe_cuda_kernel.cu
 create mode 100644 mmdet/ops/carafe/src/carafe_naive_cuda.cpp
 create mode 100644 mmdet/ops/carafe/src/carafe_naive_cuda_kernel.cu

diff --git a/README.md b/README.md
index 6b97890b..9c5ffe65 100644
--- a/README.md
+++ b/README.md
@@ -73,14 +73,15 @@ Results and models are available in the [Model zoo](docs/MODEL_ZOO.md).
 | ATSS               | 鉁�        | 鉁�        | 鈽�        | 鉁�        | 鉁�     |
 
 Other features
-- [x] DCNv2
-- [x] Group Normalization
-- [x] Weight Standardization
+- [x] [CARAFE](configs/carafe/README.md)
+- [x] [DCNv2](configs/dcn/README.md)
+- [x] [Group Normalization](configs/gn/README.md)
+- [x] [Weight Standardization](configs/gn+ws/README.md)
 - [x] OHEM
 - [x] Soft-NMS
-- [x] Generalized Attention
-- [x] GCNet
-- [x] Mixed Precision (FP16) Training
+- [x] [Generalized Attention](configs/empirical_attention/README.md)
+- [x] [GCNet](configs/gcnet/README.md)
+- [x] [Mixed Precision (FP16) Training](https://github.com/open-mmlab/mmdetection/blob/master/configs/fp16)
 - [x] [InstaBoost](configs/instaboost/README.md)
 
 
diff --git a/configs/carafe/README.md b/configs/carafe/README.md
new file mode 100644
index 00000000..a1be11da
--- /dev/null
+++ b/configs/carafe/README.md
@@ -0,0 +1,53 @@
+# CARAFE: Content-Aware ReAssembly of FEatures
+
+## Introduction
+
+We provide config files to reproduce the object detection & instance segmentation results in the ICCV 2019 Oral paper for [CARAFE: Content-Aware ReAssembly of FEatures](https://arxiv.org/abs/1905.02188).
+
+```
+@inproceedings{Wang_2019_ICCV,
+    title = {CARAFE: Content-Aware ReAssembly of FEatures},
+    author = {Wang, Jiaqi and Chen, Kai and Xu, Rui and Liu, Ziwei and Loy, Chen Change and Lin, Dahua},
+    booktitle = {The IEEE International Conference on Computer Vision (ICCV)},
+    month = {October},
+    year = {2019}
+}
+```
+
+## Results and Models
+
+The results on COCO 2017 val is shown in the below table.
+
+| Method |    Backbone     |  Style  | Lr schd | Test Proposal Num| Box AP |   Mask AP |                                                                 Download                                                                    |
+| :--------------------: | :-------------: | :-----: | :-----: | :--------------: | :----: | :--------: |:----------------------------------------------------------------------------------------------------: |
+| Faster R-CNN w/ CARAFE |    R-50-FPN  |  pytorch  |   1x    | 1000 |  37.8  | -  | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/carafe/faster_rcnn_r50_fpn_carafe_1x-2ca2d094.pth)  |
+| - |    -  |  -  |   -    | 2000 |  37.9  | -  | -  |
+| Mask R-CNN w/ CARAFE |    R-50-FPN  |  pytorch  |   1x   | 1000 |  38.6   | 35.6| [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/carafe/mask_rcnn_r50_fpn_carafe_1x-2cc4b9fe.pth) |
+| - |   -  |  -  |   -   | 2000 |  38.6   | 35.7| - |
+
+## Implementation
+
+The CUDA implementation of CARAFE can be find at `mmdet/ops/carafe` under this repository.
+
+## Setup CARAFE
+
+a. Use CARAFE in mmdetection.
+
+Install mmdetection following the official guide.
+
+b. Use CARAFE in your own project.
+
+Git clone mmdetection.
+```shell
+git clone https://github.com/open-mmlab/mmdetection.git
+cd mmdetection
+```
+Setup CARAFE in our project.
+```shell
+cp -r ./mmdet/ops/carafe $Your_Project_Path$
+cd $Your_Project_Path$/carafe
+python setup.py develop
+# or "pip install -v -e ."
+cd ..
+python ./carafe/grad_check.py
+```
diff --git a/configs/carafe/faster_rcnn_r50_fpn_carafe_1x.py b/configs/carafe/faster_rcnn_r50_fpn_carafe_1x.py
new file mode 100644
index 00000000..94c8a0fc
--- /dev/null
+++ b/configs/carafe/faster_rcnn_r50_fpn_carafe_1x.py
@@ -0,0 +1,188 @@
+# model settings
+model = dict(
+    type='FasterRCNN',
+    pretrained='torchvision://resnet50',
+    backbone=dict(
+        type='ResNet',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        frozen_stages=1,
+        style='pytorch'),
+    neck=dict(
+        type='FPN_CARAFE',
+        in_channels=[256, 512, 1024, 2048],
+        out_channels=256,
+        num_outs=5,
+        start_level=0,
+        end_level=-1,
+        norm_cfg=None,
+        activation=None,
+        order=('conv', 'norm', 'act'),
+        upsample_cfg=dict(
+            type='carafe',
+            up_kernel=5,
+            up_group=1,
+            encoder_kernel=3,
+            encoder_dilation=1,
+            compressed_channels=64)),
+    rpn_head=dict(
+        type='RPNHead',
+        in_channels=256,
+        feat_channels=256,
+        anchor_scales=[8],
+        anchor_ratios=[0.5, 1.0, 2.0],
+        anchor_strides=[4, 8, 16, 32, 64],
+        target_means=[.0, .0, .0, .0],
+        target_stds=[1.0, 1.0, 1.0, 1.0],
+        loss_cls=dict(
+            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
+        loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
+    bbox_roi_extractor=dict(
+        type='SingleRoIExtractor',
+        roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
+        out_channels=256,
+        featmap_strides=[4, 8, 16, 32]),
+    bbox_head=dict(
+        type='SharedFCBBoxHead',
+        num_fcs=2,
+        in_channels=256,
+        fc_out_channels=1024,
+        roi_feat_size=7,
+        num_classes=81,
+        target_means=[0., 0., 0., 0.],
+        target_stds=[0.1, 0.1, 0.2, 0.2],
+        reg_class_agnostic=False,
+        loss_cls=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
+        loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)))
+# model training and testing settings
+train_cfg = dict(
+    rpn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.7,
+            neg_iou_thr=0.3,
+            min_pos_iou=0.3,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=256,
+            pos_fraction=0.5,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=False),
+        allowed_border=0,
+        pos_weight=-1,
+        debug=False),
+    rpn_proposal=dict(
+        nms_across_levels=False,
+        nms_pre=2000,
+        nms_post=2000,
+        max_num=2000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.5,
+            neg_iou_thr=0.5,
+            min_pos_iou=0.5,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=512,
+            pos_fraction=0.25,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=True),
+        pos_weight=-1,
+        debug=False))
+test_cfg = dict(
+    rpn=dict(
+        nms_across_levels=False,
+        nms_pre=1000,
+        nms_post=1000,
+        max_num=1000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100)
+    # soft-nms is also supported for rcnn testing
+    # e.g., nms=dict(type='soft_nms', iou_thr=0.5, min_score=0.05)
+)
+# dataset settings
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+train_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(type='LoadAnnotations', with_bbox=True),
+    dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
+    dict(type='RandomFlip', flip_ratio=0.5),
+    dict(type='Normalize', **img_norm_cfg),
+    dict(type='Pad', size_divisor=64),
+    dict(type='DefaultFormatBundle'),
+    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
+]
+test_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='MultiScaleFlipAug',
+        img_scale=(1333, 800),
+        flip=False,
+        transforms=[
+            dict(type='Resize', keep_ratio=True),
+            dict(type='RandomFlip'),
+            dict(type='Normalize', **img_norm_cfg),
+            dict(type='Pad', size_divisor=64),
+            dict(type='ImageToTensor', keys=['img']),
+            dict(type='Collect', keys=['img']),
+        ])
+]
+data = dict(
+    imgs_per_gpu=2,
+    workers_per_gpu=2,
+    train=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_train2017.json',
+        img_prefix=data_root + 'train2017/',
+        pipeline=train_pipeline),
+    val=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        pipeline=test_pipeline),
+    test=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        pipeline=test_pipeline))
+evaluation = dict(interval=1, metric='bbox')
+# optimizer
+optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+# learning policy
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=500,
+    warmup_ratio=1.0 / 3,
+    step=[8, 11])
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+    interval=50,
+    hooks=[
+        dict(type='TextLoggerHook'),
+        # dict(type='TensorboardLoggerHook')
+    ])
+# yapf:enable
+evaluation = dict(interval=1)
+# runtime settings
+total_epochs = 12
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+work_dir = './work_dirs/faster_rcnn_r50_fpn_carafe_1x'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/configs/carafe/mask_rcnn_r50_fpn_carafe_1x.py b/configs/carafe/mask_rcnn_r50_fpn_carafe_1x.py
new file mode 100644
index 00000000..656bd7c6
--- /dev/null
+++ b/configs/carafe/mask_rcnn_r50_fpn_carafe_1x.py
@@ -0,0 +1,210 @@
+# model settings
+model = dict(
+    type='MaskRCNN',
+    pretrained='torchvision://resnet50',
+    backbone=dict(
+        type='ResNet',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        frozen_stages=1,
+        style='pytorch'),
+    neck=dict(
+        type='FPN_CARAFE',
+        in_channels=[256, 512, 1024, 2048],
+        out_channels=256,
+        num_outs=5,
+        start_level=0,
+        end_level=-1,
+        norm_cfg=None,
+        activation=None,
+        order=('conv', 'norm', 'act'),
+        upsample_cfg=dict(
+            type='carafe',
+            up_kernel=5,
+            up_group=1,
+            encoder_kernel=3,
+            encoder_dilation=1,
+            compressed_channels=64)),
+    rpn_head=dict(
+        type='RPNHead',
+        in_channels=256,
+        feat_channels=256,
+        anchor_scales=[8],
+        anchor_ratios=[0.5, 1.0, 2.0],
+        anchor_strides=[4, 8, 16, 32, 64],
+        target_means=[.0, .0, .0, .0],
+        target_stds=[1.0, 1.0, 1.0, 1.0],
+        loss_cls=dict(
+            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
+        loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
+    bbox_roi_extractor=dict(
+        type='SingleRoIExtractor',
+        roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
+        out_channels=256,
+        featmap_strides=[4, 8, 16, 32]),
+    bbox_head=dict(
+        type='SharedFCBBoxHead',
+        num_fcs=2,
+        in_channels=256,
+        fc_out_channels=1024,
+        roi_feat_size=7,
+        num_classes=81,
+        target_means=[0., 0., 0., 0.],
+        target_stds=[0.1, 0.1, 0.2, 0.2],
+        reg_class_agnostic=False,
+        loss_cls=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
+        loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
+    mask_roi_extractor=dict(
+        type='SingleRoIExtractor',
+        roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
+        out_channels=256,
+        featmap_strides=[4, 8, 16, 32]),
+    mask_head=dict(
+        type='FCNMaskHead',
+        num_convs=4,
+        in_channels=256,
+        conv_out_channels=256,
+        num_classes=81,
+        upsample_cfg=dict(
+            type='carafe',
+            scale_factor=2,
+            up_kernel=5,
+            up_group=1,
+            encoder_kernel=3,
+            encoder_dilation=1,
+            compressed_channels=64),
+        loss_mask=dict(
+            type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)))
+# model training and testing settings
+train_cfg = dict(
+    rpn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.7,
+            neg_iou_thr=0.3,
+            min_pos_iou=0.3,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=256,
+            pos_fraction=0.5,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=False),
+        allowed_border=0,
+        pos_weight=-1,
+        debug=False),
+    rpn_proposal=dict(
+        nms_across_levels=False,
+        nms_pre=2000,
+        nms_post=2000,
+        max_num=2000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.5,
+            neg_iou_thr=0.5,
+            min_pos_iou=0.5,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=512,
+            pos_fraction=0.25,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=True),
+        mask_size=28,
+        pos_weight=-1,
+        debug=False))
+test_cfg = dict(
+    rpn=dict(
+        nms_across_levels=False,
+        nms_pre=1000,
+        nms_post=1000,
+        max_num=1000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        score_thr=0.05,
+        nms=dict(type='nms', iou_thr=0.5),
+        max_per_img=100,
+        mask_thr_binary=0.5))
+# dataset settings
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+train_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
+    dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
+    dict(type='RandomFlip', flip_ratio=0.5),
+    dict(type='Normalize', **img_norm_cfg),
+    dict(type='Pad', size_divisor=64),
+    dict(type='DefaultFormatBundle'),
+    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
+]
+test_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='MultiScaleFlipAug',
+        img_scale=(1333, 800),
+        flip=False,
+        transforms=[
+            dict(type='Resize', keep_ratio=True),
+            dict(type='RandomFlip'),
+            dict(type='Normalize', **img_norm_cfg),
+            dict(type='Pad', size_divisor=64),
+            dict(type='ImageToTensor', keys=['img']),
+            dict(type='Collect', keys=['img']),
+        ])
+]
+data = dict(
+    imgs_per_gpu=2,
+    workers_per_gpu=2,
+    train=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_train2017.json',
+        img_prefix=data_root + 'train2017/',
+        pipeline=train_pipeline),
+    val=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        pipeline=test_pipeline),
+    test=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        pipeline=test_pipeline))
+evaluation = dict(interval=1, metric=['bbox', 'segm'])
+# optimizer
+optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+# learning policy
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=500,
+    warmup_ratio=1.0 / 3,
+    step=[8, 11])
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+    interval=50,
+    hooks=[
+        dict(type='TextLoggerHook'),
+        # dict(type='TensorboardLoggerHook')
+    ])
+# yapf:enable
+evaluation = dict(interval=1)
+# runtime settings
+total_epochs = 12
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+work_dir = './work_dirs/mask_rcnn_r50_fpn_carafe_1x'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/docs/MODEL_ZOO.md b/docs/MODEL_ZOO.md
index 954a4772..c15a00b6 100644
--- a/docs/MODEL_ZOO.md
+++ b/docs/MODEL_ZOO.md
@@ -230,6 +230,9 @@ Please refer to [Weight Standardization](https://github.com/open-mmlab/mmdetecti
 
 Please refer to [Deformable Convolutional Networks](https://github.com/open-mmlab/mmdetection/blob/master/configs/dcn) for details.
 
+### CARAFE: Content-Aware ReAssembly of FEatures
+Please refer to [CARAFE](https://github.com/open-mmlab/mmdetection/blob/master/configs/carafe) for details.
+
 ### Instaboost
 
 Please refer to [Instaboost](https://github.com/open-mmlab/mmdetection/blob/master/configs/instaboost) for details.
diff --git a/mmdet/models/mask_heads/fcn_mask_head.py b/mmdet/models/mask_heads/fcn_mask_head.py
index 6d11cfff..15a9e330 100644
--- a/mmdet/models/mask_heads/fcn_mask_head.py
+++ b/mmdet/models/mask_heads/fcn_mask_head.py
@@ -6,9 +6,10 @@ import torch.nn as nn
 from torch.nn.modules.utils import _pair
 
 from mmdet.core import auto_fp16, force_fp32, mask_target
+from mmdet.ops.carafe import CARAFEPack
 from ..builder import build_loss
 from ..registry import HEADS
-from ..utils import ConvModule
+from ..utils import ConvModule, build_upsample_layer
 
 
 @HEADS.register_module
@@ -20,27 +21,30 @@ class FCNMaskHead(nn.Module):
                  in_channels=256,
                  conv_kernel_size=3,
                  conv_out_channels=256,
-                 upsample_method='deconv',
-                 upsample_ratio=2,
                  num_classes=81,
                  class_agnostic=False,
+                 upsample_cfg=dict(type='deconv', scale_factor=2),
                  conv_cfg=None,
                  norm_cfg=None,
                  loss_mask=dict(
                      type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)):
         super(FCNMaskHead, self).__init__()
-        if upsample_method not in [None, 'deconv', 'nearest', 'bilinear']:
+        self.upsample_cfg = upsample_cfg.copy()
+        if self.upsample_cfg['type'] not in [
+                None, 'deconv', 'nearest', 'bilinear', 'carafe'
+        ]:
             raise ValueError(
                 'Invalid upsample method {}, accepted methods '
-                'are "deconv", "nearest", "bilinear"'.format(upsample_method))
+                'are "deconv", "nearest", "bilinear", "carafe"'.format(
+                    self.upsample_cfg['type']))
         self.num_convs = num_convs
         # WARN: roi_feat_size is reserved and not used
         self.roi_feat_size = _pair(roi_feat_size)
         self.in_channels = in_channels
         self.conv_kernel_size = conv_kernel_size
         self.conv_out_channels = conv_out_channels
-        self.upsample_method = upsample_method
-        self.upsample_ratio = upsample_ratio
+        self.upsample_method = self.upsample_cfg.get('type')
+        self.scale_factor = self.upsample_cfg.pop('scale_factor')
         self.num_classes = num_classes
         self.class_agnostic = class_agnostic
         self.conv_cfg = conv_cfg
@@ -63,17 +67,27 @@ class FCNMaskHead(nn.Module):
                     norm_cfg=norm_cfg))
         upsample_in_channels = (
             self.conv_out_channels if self.num_convs > 0 else in_channels)
+        upsample_cfg_ = self.upsample_cfg.copy()
         if self.upsample_method is None:
             self.upsample = None
         elif self.upsample_method == 'deconv':
-            self.upsample = nn.ConvTranspose2d(
-                upsample_in_channels,
-                self.conv_out_channels,
-                self.upsample_ratio,
-                stride=self.upsample_ratio)
+            upsample_cfg_.update(
+                in_channels=upsample_in_channels,
+                out_channels=self.conv_out_channels,
+                kernel_size=self.scale_factor,
+                stride=self.scale_factor)
+        elif self.upsample_method == 'carafe':
+            upsample_cfg_.update(
+                channels=upsample_in_channels, scale_factor=self.scale_factor)
         else:
-            self.upsample = nn.Upsample(
-                scale_factor=self.upsample_ratio, mode=self.upsample_method)
+            # suppress warnings
+            align_corners = (None
+                             if self.upsample_method == 'nearest' else False)
+            upsample_cfg_.update(
+                scale_factor=self.scale_factor,
+                mode=self.upsample_method,
+                align_corners=align_corners)
+        self.upsample = build_upsample_layer(upsample_cfg_)
 
         out_channels = 1 if self.class_agnostic else self.num_classes
         logits_in_channel = (
@@ -87,9 +101,12 @@ class FCNMaskHead(nn.Module):
         for m in [self.upsample, self.conv_logits]:
             if m is None:
                 continue
-            nn.init.kaiming_normal_(
-                m.weight, mode='fan_out', nonlinearity='relu')
-            nn.init.constant_(m.bias, 0)
+            elif isinstance(m, CARAFEPack):
+                m.init_weights()
+            else:
+                nn.init.kaiming_normal_(
+                    m.weight, mode='fan_out', nonlinearity='relu')
+                nn.init.constant_(m.bias, 0)
 
     @auto_fp16()
     def forward(self, x):
diff --git a/mmdet/models/necks/__init__.py b/mmdet/models/necks/__init__.py
index fa574044..7844ec78 100644
--- a/mmdet/models/necks/__init__.py
+++ b/mmdet/models/necks/__init__.py
@@ -1,6 +1,7 @@
 from .bfp import BFP
 from .fpn import FPN
+from .fpn_carafe import FPN_CARAFE
 from .hrfpn import HRFPN
 from .nas_fpn import NASFPN
 
-__all__ = ['FPN', 'BFP', 'HRFPN', 'NASFPN']
+__all__ = ['FPN', 'BFP', 'HRFPN', 'NASFPN', 'FPN_CARAFE']
diff --git a/mmdet/models/necks/fpn_carafe.py b/mmdet/models/necks/fpn_carafe.py
new file mode 100644
index 00000000..5d01b3c0
--- /dev/null
+++ b/mmdet/models/necks/fpn_carafe.py
@@ -0,0 +1,254 @@
+import torch.nn as nn
+from mmcv.cnn import xavier_init
+
+from mmdet.ops.carafe import CARAFEPack
+from ..registry import NECKS
+from ..utils import ConvModule, build_upsample_layer
+
+
+@NECKS.register_module
+class FPN_CARAFE(nn.Module):
+    """FPN_CARAFE is a more flexible implementation of FPN.
+    It allows more choice for upsample methods during the top-down pathway.
+
+    It can reproduce the preformance of ICCV 2019 paper
+    CARAFE: Content-Aware ReAssembly of FEatures
+    Please refer to https://arxiv.org/abs/1905.02188 for more details.
+
+    Args:
+        in_channels (list[int]): Number of channels for each input feature map.
+        out_channels (int): Output channels of feature pyramids.
+        num_outs (int): Number of output stages.
+        start_level (int): Start level of feature pyramids.
+            (Default: 0)
+        end_level (int): End level of feature pyramids.
+            (Default: -1 indicates the last level).
+        norm_cfg (dict): Dictionary to construct and config norm layer.
+        activate (str): Type of activation function in ConvModule
+            (Default: None indicates w/o activation).
+        order (dict): Order of components in ConvModule.
+        upsample (str): Type of upsample layer.
+        upsample_cfg (dict): Dictionary to construct and config upsample layer.
+    """
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 num_outs,
+                 start_level=0,
+                 end_level=-1,
+                 norm_cfg=None,
+                 activation=None,
+                 order=('conv', 'norm', 'act'),
+                 upsample_cfg=dict(
+                     type='carafe',
+                     up_kernel=5,
+                     up_group=1,
+                     encoder_kernel=3,
+                     encoder_dilation=1)):
+        super(FPN_CARAFE, self).__init__()
+        assert isinstance(in_channels, list)
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.num_ins = len(in_channels)
+        self.num_outs = num_outs
+        self.activation = activation
+        self.norm_cfg = norm_cfg
+        self.with_bias = norm_cfg is None
+        self.upsample_cfg = upsample_cfg.copy()
+        self.upsample = self.upsample_cfg.get('type')
+        self.relu = nn.ReLU(inplace=False)
+
+        self.order = order
+        assert order in [('conv', 'norm', 'act'), ('act', 'conv', 'norm')]
+
+        assert self.upsample in [
+            'nearest', 'bilinear', 'deconv', 'pixel_shuffle', 'carafe', None
+        ]
+        if self.upsample in ['deconv', 'pixel_shuffle']:
+            assert hasattr(
+                self.upsample_cfg,
+                'upsample_kernel') and self.upsample_cfg.upsample_kernel > 0
+            self.upsample_kernel = self.upsample_cfg.pop('upsample_kernel')
+
+        if end_level == -1:
+            self.backbone_end_level = self.num_ins
+            assert num_outs >= self.num_ins - start_level
+        else:
+            # if end_level < inputs, no extra level is allowed
+            self.backbone_end_level = end_level
+            assert end_level <= len(in_channels)
+            assert num_outs == end_level - start_level
+        self.start_level = start_level
+        self.end_level = end_level
+
+        self.lateral_convs = nn.ModuleList()
+        self.fpn_convs = nn.ModuleList()
+        self.upsample_modules = nn.ModuleList()
+
+        for i in range(self.start_level, self.backbone_end_level):
+            l_conv = ConvModule(
+                in_channels[i],
+                out_channels,
+                1,
+                norm_cfg=norm_cfg,
+                bias=self.with_bias,
+                activation=activation,
+                inplace=False,
+                order=self.order)
+            fpn_conv = ConvModule(
+                out_channels,
+                out_channels,
+                3,
+                padding=1,
+                norm_cfg=self.norm_cfg,
+                bias=self.with_bias,
+                activation=activation,
+                inplace=False,
+                order=self.order)
+            if i != self.backbone_end_level - 1:
+                upsample_cfg_ = self.upsample_cfg.copy()
+                if self.upsample == 'deconv':
+                    upsample_cfg_.update(
+                        in_channels=out_channels,
+                        out_channels=out_channels,
+                        kernel_size=self.upsample_kernel,
+                        stride=2,
+                        padding=(self.upsample_kernel - 1) // 2,
+                        output_padding=(self.upsample_kernel - 1) // 2)
+                elif self.upsample == 'pixel_shuffle':
+                    upsample_cfg_.update(
+                        in_channels=out_channels,
+                        out_channels=out_channels,
+                        scale_factor=2,
+                        upsample_kernel=self.upsample_kernel)
+                elif self.upsample == 'carafe':
+                    upsample_cfg_.update(channels=out_channels, scale_factor=2)
+                else:
+                    # suppress warnings
+                    align_corners = (None
+                                     if self.upsample == 'nearest' else False)
+                    upsample_cfg_.update(
+                        scale_factor=2,
+                        mode=self.upsample,
+                        align_corners=align_corners)
+                upsample_module = build_upsample_layer(upsample_cfg_)
+                self.upsample_modules.append(upsample_module)
+            self.lateral_convs.append(l_conv)
+            self.fpn_convs.append(fpn_conv)
+
+        # add extra conv layers (e.g., RetinaNet)
+        extra_out_levels = (
+            num_outs - self.backbone_end_level + self.start_level)
+        if extra_out_levels >= 1:
+            for i in range(extra_out_levels):
+                in_channels = (
+                    self.in_channels[self.backbone_end_level -
+                                     1] if i == 0 else out_channels)
+                extra_l_conv = ConvModule(
+                    in_channels,
+                    out_channels,
+                    3,
+                    stride=2,
+                    padding=1,
+                    norm_cfg=norm_cfg,
+                    bias=self.with_bias,
+                    activation=self.activation,
+                    inplace=False,
+                    order=self.order)
+                if self.upsample == 'deconv':
+                    upsampler_cfg_ = dict(
+                        in_channels=out_channels,
+                        out_channels=out_channels,
+                        kernel_size=self.upsample_kernel,
+                        stride=2,
+                        padding=(self.upsample_kernel - 1) // 2,
+                        output_padding=(self.upsample_kernel - 1) // 2)
+                elif self.upsample == 'pixel_shuffle':
+                    upsampler_cfg_ = dict(
+                        in_channels=out_channels,
+                        out_channels=out_channels,
+                        scale_factor=2,
+                        upsample_kernel=self.upsample_kernel)
+                elif self.upsample == 'carafe':
+                    upsampler_cfg_ = dict(
+                        channels=out_channels,
+                        scale_factor=2,
+                        **self.upsample_cfg)
+                else:
+                    # suppress warnings
+                    align_corners = (None
+                                     if self.upsample == 'nearest' else False)
+                    upsampler_cfg_ = dict(
+                        scale_factor=2,
+                        mode=self.upsample,
+                        align_corners=align_corners)
+                upsampler_cfg_['type'] = self.upsample
+                upsample_module = build_upsample_layer(upsampler_cfg_)
+                extra_fpn_conv = ConvModule(
+                    out_channels,
+                    out_channels,
+                    3,
+                    padding=1,
+                    norm_cfg=self.norm_cfg,
+                    bias=self.with_bias,
+                    activation=activation,
+                    inplace=False,
+                    order=self.order)
+                self.upsample_modules.append(upsample_module)
+                self.fpn_convs.append(extra_fpn_conv)
+                self.lateral_convs.append(extra_l_conv)
+
+    # default init_weights for conv(msra) and norm in ConvModule
+    def init_weights(self):
+        for m in self.modules():
+            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
+                xavier_init(m, distribution='uniform')
+        for m in self.modules():
+            if isinstance(m, CARAFEPack):
+                m.init_weights()
+
+    def slice_as(self, src, dst):
+        # slice src as dst
+        # src should have the same or larger size than dst
+        assert (src.size(2) >= dst.size(2)) and (src.size(3) >= dst.size(3))
+        if src.size(2) == dst.size(2) and src.size(3) == dst.size(3):
+            return src
+        else:
+            return src[:, :, :dst.size(2), :dst.size(3)]
+
+    def tensor_add(self, a, b):
+        if a.size() == b.size():
+            c = a + b
+        else:
+            c = a + self.slice_as(b, a)
+        return c
+
+    def forward(self, inputs):
+        assert len(inputs) == len(self.in_channels)
+
+        # build laterals
+        laterals = []
+        for i, lateral_conv in enumerate(self.lateral_convs):
+            if i <= self.backbone_end_level - self.start_level:
+                input = inputs[min(i + self.start_level, len(inputs) - 1)]
+            else:
+                input = laterals[-1]
+            lateral = lateral_conv(input)
+            laterals.append(lateral)
+
+        # build top-down path
+        for i in range(len(laterals) - 1, 0, -1):
+            if self.upsample is not None:
+                upsample_feat = self.upsample_modules[i - 1](laterals[i])
+            else:
+                upsample_feat = laterals[i]
+            laterals[i - 1] = self.tensor_add(laterals[i - 1], upsample_feat)
+
+        # build outputs
+        num_conv_outs = len(self.fpn_convs)
+        outs = []
+        for i in range(num_conv_outs):
+            out = self.fpn_convs[i](laterals[i])
+            outs.append(out)
+        return tuple(outs)
diff --git a/mmdet/models/utils/__init__.py b/mmdet/models/utils/__init__.py
index 3db40920..c1b9ad11 100644
--- a/mmdet/models/utils/__init__.py
+++ b/mmdet/models/utils/__init__.py
@@ -2,11 +2,12 @@ from .conv_module import ConvModule, build_conv_layer
 from .conv_ws import ConvWS2d, conv_ws_2d
 from .norm import build_norm_layer
 from .scale import Scale
+from .upsample import build_upsample_layer
 from .weight_init import (bias_init_with_prob, kaiming_init, normal_init,
                           uniform_init, xavier_init)
 
 __all__ = [
     'conv_ws_2d', 'ConvWS2d', 'build_conv_layer', 'ConvModule',
-    'build_norm_layer', 'xavier_init', 'normal_init', 'uniform_init',
-    'kaiming_init', 'bias_init_with_prob', 'Scale'
+    'build_norm_layer', 'build_upsample_layer', 'xavier_init', 'normal_init',
+    'uniform_init', 'kaiming_init', 'bias_init_with_prob', 'Scale'
 ]
diff --git a/mmdet/models/utils/upsample.py b/mmdet/models/utils/upsample.py
new file mode 100644
index 00000000..fd880fb0
--- /dev/null
+++ b/mmdet/models/utils/upsample.py
@@ -0,0 +1,78 @@
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import xavier_init
+
+from mmdet.ops.carafe import CARAFEPack
+
+
+class PixelShufflePack(nn.Module):
+    """ Pixel Shuffle upsample layer
+
+    Args:
+        in_channels (int): Number of input channels
+        out_channels (int): Number of output channels
+        scale_factor (int): Upsample ratio
+        upsample_kernel (int): Kernel size of Conv layer to expand the channels
+
+    Returns:
+        upsampled feature map
+    """
+
+    def __init__(self, in_channels, out_channels, scale_factor,
+                 upsample_kernel):
+        super(PixelShufflePack, self).__init__()
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.scale_factor = scale_factor
+        self.upsample_kernel = upsample_kernel
+        self.upsample_conv = nn.Conv2d(
+            self.in_channels,
+            self.out_channels * scale_factor * scale_factor,
+            self.upsample_kernel,
+            padding=(self.upsample_kernel - 1) // 2)
+        self.init_weights()
+
+    def init_weights(self):
+        xavier_init(self.upsample_conv, distribution='uniform')
+
+    def forward(self, x):
+        x = self.upsample_conv(x)
+        x = F.pixel_shuffle(x, self.scale_factor)
+        return x
+
+
+upsample_cfg = {
+    # format: layer_type: (abbreviation, module)
+    'nearest': nn.Upsample,
+    'bilinear': nn.Upsample,
+    'deconv': nn.ConvTranspose2d,
+    'pixel_shuffle': PixelShufflePack,
+    'carafe': CARAFEPack
+}
+
+
+def build_upsample_layer(cfg):
+    """ Build upsample layer
+
+    Args:
+        cfg (dict): cfg should contain:
+            type (str): Identify upsample layer type.
+            upsample ratio (int): Upsample ratio
+            layer args: args needed to instantiate a upsample layer.
+
+    Returns:
+        layer (nn.Module): Created upsample layer
+    """
+    assert isinstance(cfg, dict) and 'type' in cfg
+    cfg_ = cfg.copy()
+
+    layer_type = cfg_.pop('type')
+    if layer_type not in upsample_cfg:
+        raise KeyError('Unrecognized upsample type {}'.format(layer_type))
+    else:
+        upsample = upsample_cfg[layer_type]
+        if upsample is None:
+            raise NotImplementedError
+
+    layer = upsample(**cfg_)
+    return layer
diff --git a/mmdet/ops/carafe/__init__.py b/mmdet/ops/carafe/__init__.py
new file mode 100644
index 00000000..029038f8
--- /dev/null
+++ b/mmdet/ops/carafe/__init__.py
@@ -0,0 +1,3 @@
+from .carafe import CARAFE, CARAFENaive, CARAFEPack, carafe, carafe_naive
+
+__all__ = ['carafe', 'carafe_naive', 'CARAFE', 'CARAFENaive', 'CARAFEPack']
diff --git a/mmdet/ops/carafe/carafe.py b/mmdet/ops/carafe/carafe.py
new file mode 100644
index 00000000..2c81735b
--- /dev/null
+++ b/mmdet/ops/carafe/carafe.py
@@ -0,0 +1,237 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import normal_init, xavier_init
+from torch.autograd import Function
+from torch.nn.modules.module import Module
+
+from . import carafe_cuda, carafe_naive_cuda
+
+
+class CARAFENaiveFunction(Function):
+
+    @staticmethod
+    def forward(ctx, features, masks, kernel_size, group_size, scale_factor):
+        assert scale_factor >= 1
+        assert masks.size(1) == kernel_size * kernel_size * group_size
+        assert masks.size(-1) == features.size(-1) * scale_factor
+        assert masks.size(-2) == features.size(-2) * scale_factor
+        assert features.size(1) % group_size == 0
+        assert (kernel_size - 1) % 2 == 0 and kernel_size >= 1
+        ctx.kernel_size = kernel_size
+        ctx.group_size = group_size
+        ctx.scale_factor = scale_factor
+        ctx.feature_size = features.size()
+        ctx.mask_size = masks.size()
+
+        n, c, h, w = features.size()
+        output = features.new_zeros((n, c, h * scale_factor, w * scale_factor))
+        if features.is_cuda:
+            carafe_naive_cuda.forward(features, masks, kernel_size, group_size,
+                                      scale_factor, output)
+        else:
+            raise NotImplementedError
+
+        if features.requires_grad or masks.requires_grad:
+            ctx.save_for_backward(features, masks)
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        assert grad_output.is_cuda
+
+        features, masks = ctx.saved_tensors
+        kernel_size = ctx.kernel_size
+        group_size = ctx.group_size
+        scale_factor = ctx.scale_factor
+
+        grad_input = torch.zeros_like(features)
+        grad_masks = torch.zeros_like(masks)
+        carafe_naive_cuda.backward(grad_output.contiguous(), features, masks,
+                                   kernel_size, group_size, scale_factor,
+                                   grad_input, grad_masks)
+
+        return grad_input, grad_masks, None, None, None
+
+
+carafe_naive = CARAFENaiveFunction.apply
+
+
+class CARAFENaive(Module):
+
+    def __init__(self, kernel_size, group_size, scale_factor):
+        super(CARAFENaive, self).__init__()
+
+        assert isinstance(kernel_size, int) and isinstance(
+            group_size, int) and isinstance(scale_factor, int)
+        self.kernel_size = kernel_size
+        self.group_size = group_size
+        self.scale_factor = scale_factor
+
+    def forward(self, features, masks):
+        return CARAFENaiveFunction.apply(features, masks, self.kernel_size,
+                                         self.group_size, self.scale_factor)
+
+
+class CARAFEFunction(Function):
+
+    @staticmethod
+    def forward(ctx, features, masks, kernel_size, group_size, scale_factor):
+        assert scale_factor >= 1
+        assert masks.size(1) == kernel_size * kernel_size * group_size
+        assert masks.size(-1) == features.size(-1) * scale_factor
+        assert masks.size(-2) == features.size(-2) * scale_factor
+        assert features.size(1) % group_size == 0
+        assert (kernel_size - 1) % 2 == 0 and kernel_size >= 1
+        ctx.kernel_size = kernel_size
+        ctx.group_size = group_size
+        ctx.scale_factor = scale_factor
+        ctx.feature_size = features.size()
+        ctx.mask_size = masks.size()
+
+        n, c, h, w = features.size()
+        output = features.new_zeros((n, c, h * scale_factor, w * scale_factor))
+        routput = features.new_zeros(output.size(), requires_grad=False)
+        rfeatures = features.new_zeros(features.size(), requires_grad=False)
+        rmasks = masks.new_zeros(masks.size(), requires_grad=False)
+        if features.is_cuda:
+            carafe_cuda.forward(features, rfeatures, masks, rmasks,
+                                kernel_size, group_size, scale_factor, routput,
+                                output)
+        else:
+            raise NotImplementedError
+
+        if features.requires_grad or masks.requires_grad:
+            ctx.save_for_backward(features, masks, rfeatures)
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        assert grad_output.is_cuda
+
+        features, masks, rfeatures = ctx.saved_tensors
+        kernel_size = ctx.kernel_size
+        group_size = ctx.group_size
+        scale_factor = ctx.scale_factor
+
+        rgrad_output = torch.zeros_like(grad_output, requires_grad=False)
+        rgrad_input_hs = torch.zeros_like(grad_output, requires_grad=False)
+        rgrad_input = torch.zeros_like(features, requires_grad=False)
+        rgrad_masks = torch.zeros_like(masks, requires_grad=False)
+        grad_input = torch.zeros_like(features, requires_grad=False)
+        grad_masks = torch.zeros_like(masks, requires_grad=False)
+        carafe_cuda.backward(grad_output.contiguous(), rfeatures, masks,
+                             kernel_size, group_size, scale_factor,
+                             rgrad_output, rgrad_input_hs, rgrad_input,
+                             rgrad_masks, grad_input, grad_masks)
+        return grad_input, grad_masks, None, None, None, None
+
+
+carafe = CARAFEFunction.apply
+
+
+class CARAFE(Module):
+    """ CARAFE: Content-Aware ReAssembly of FEatures
+
+    Please refer to https://arxiv.org/abs/1905.02188 for more details.
+
+    Args:
+        kernel_size (int): reassemble kernel size
+        group_size (int): reassemble group size
+        scale_factor (int): upsample ratio
+
+    Returns:
+        upsampled feature map
+    """
+
+    def __init__(self, kernel_size, group_size, scale_factor):
+        super(CARAFE, self).__init__()
+
+        assert isinstance(kernel_size, int) and isinstance(
+            group_size, int) and isinstance(scale_factor, int)
+        self.kernel_size = kernel_size
+        self.group_size = group_size
+        self.scale_factor = scale_factor
+
+    def forward(self, features, masks):
+        return CARAFEFunction.apply(features, masks, self.kernel_size,
+                                    self.group_size, self.scale_factor)
+
+
+class CARAFEPack(nn.Module):
+    """ A unified package of CARAFE upsampler that contains:
+    1) channel compressor 2) content encoder 3) CARAFE op
+
+    Official implementation of ICCV 2019 paper
+    CARAFE: Content-Aware ReAssembly of FEatures
+    Please refer to https://arxiv.org/abs/1905.02188 for more details.
+
+    Args:
+        channels (int): input feature channels
+        scale_factor (int): upsample ratio
+        up_kernel (int): kernel size of CARAFE op
+        up_group (int): group size of CARAFE op
+        encoder_kernel (int): kernel size of content encoder
+        encoder_dilation (int): dilation of content encoder
+        compressed_channels (int): output channels of channels compressor
+
+    Returns:
+        upsampled feature map
+    """
+
+    def __init__(self,
+                 channels,
+                 scale_factor,
+                 up_kernel=5,
+                 up_group=1,
+                 encoder_kernel=3,
+                 encoder_dilation=1,
+                 compressed_channels=64):
+        super(CARAFEPack, self).__init__()
+        self.channels = channels
+        self.scale_factor = scale_factor
+        self.up_kernel = up_kernel
+        self.up_group = up_group
+        self.encoder_kernel = encoder_kernel
+        self.encoder_dilation = encoder_dilation
+        self.compressed_channels = compressed_channels
+        self.channel_compressor = nn.Conv2d(channels, self.compressed_channels,
+                                            1)
+        self.content_encoder = nn.Conv2d(
+            self.compressed_channels,
+            self.up_kernel * self.up_kernel * self.up_group *
+            self.scale_factor * self.scale_factor,
+            self.encoder_kernel,
+            padding=int((self.encoder_kernel - 1) * self.encoder_dilation / 2),
+            dilation=self.encoder_dilation,
+            groups=1)
+        self.init_weights()
+
+    def init_weights(self):
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                xavier_init(m, distribution='uniform')
+        normal_init(self.content_encoder, std=0.001)
+
+    def kernel_normalizer(self, mask):
+        mask = F.pixel_shuffle(mask, self.scale_factor)
+        n, mask_c, h, w = mask.size()
+        mask_channel = int(mask_c / (self.up_kernel * self.up_kernel))
+        mask = mask.view(n, mask_channel, -1, h, w)
+
+        mask = F.softmax(mask, dim=2)
+        mask = mask.view(n, mask_c, h, w).contiguous()
+
+        return mask
+
+    def feature_reassemble(self, x, mask):
+        x = carafe(x, mask, self.up_kernel, self.up_group, self.scale_factor)
+        return x
+
+    def forward(self, x):
+        compressed_x = self.channel_compressor(x)
+        mask = self.content_encoder(compressed_x)
+        mask = self.kernel_normalizer(mask)
+
+        x = self.feature_reassemble(x, mask)
+        return x
diff --git a/mmdet/ops/carafe/grad_check.py b/mmdet/ops/carafe/grad_check.py
new file mode 100644
index 00000000..06820be2
--- /dev/null
+++ b/mmdet/ops/carafe/grad_check.py
@@ -0,0 +1,61 @@
+import os.path as osp
+import sys
+
+import mmcv
+import torch
+from torch.autograd import gradcheck
+
+sys.path.append(osp.abspath(osp.join(__file__, '../../')))
+from mmdet.ops.carafe import CARAFENAIVE  # noqa: E402, isort:skip
+from mmdet.ops.carafe import carafe_naive  # noqa: E402, isort:skip
+from mmdet.ops.carafe import carafe, CARAFE  # noqa: E402, isort:skip
+
+feat = torch.randn(2, 64, 3, 3, requires_grad=True, device='cuda:0').double()
+mask = torch.randn(
+    2, 100, 6, 6, requires_grad=True, device='cuda:0').sigmoid().double()
+
+print('Gradcheck for carafe...')
+test = gradcheck(CARAFE(5, 4, 2), (feat, mask), atol=1e-4, eps=1e-4)
+print(test)
+
+print('Gradcheck for carafe naive...')
+test = gradcheck(CARAFENAIVE(5, 4, 2), (feat, mask), atol=1e-4, eps=1e-4)
+print(test)
+
+feat = torch.randn(
+    2, 1024, 100, 100, requires_grad=True, device='cuda:0').float()
+mask = torch.randn(
+    2, 25, 200, 200, requires_grad=True, device='cuda:0').sigmoid().float()
+loop_num = 500
+
+time_forward = 0
+time_backward = 0
+bar = mmcv.ProgressBar(loop_num)
+timer = mmcv.Timer()
+for i in range(loop_num):
+    x = carafe(feat.clone(), mask.clone(), 5, 1, 2)
+    torch.cuda.synchronize()
+    time_forward += timer.since_last_check()
+    x.sum().backward(retain_graph=True)
+    torch.cuda.synchronize()
+    time_backward += timer.since_last_check()
+    bar.update()
+print('\nCARAFE time forward: {} ms/iter | time backward: {} ms/iter'.format(
+    (time_forward + 1e-3) * 1e3 / loop_num,
+    (time_backward + 1e-3) * 1e3 / loop_num))
+
+time_naive_forward = 0
+time_naive_backward = 0
+bar = mmcv.ProgressBar(loop_num)
+timer = mmcv.Timer()
+for i in range(loop_num):
+    x = carafe_naive(feat.clone(), mask.clone(), 5, 1, 2)
+    torch.cuda.synchronize()
+    time_naive_forward += timer.since_last_check()
+    x.sum().backward(retain_graph=True)
+    torch.cuda.synchronize()
+    time_naive_backward += timer.since_last_check()
+    bar.update()
+print('\nCARAFE naive time forward: {} ms/iter | time backward: {} ms/iter'.
+      format((time_naive_forward + 1e-3) * 1e3 / loop_num,
+             (time_naive_backward + 1e-3) * 1e3 / loop_num))
diff --git a/mmdet/ops/carafe/setup.py b/mmdet/ops/carafe/setup.py
new file mode 100644
index 00000000..7ef1c66d
--- /dev/null
+++ b/mmdet/ops/carafe/setup.py
@@ -0,0 +1,29 @@
+from setuptools import setup
+
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+
+NVCC_ARGS = [
+    '-D__CUDA_NO_HALF_OPERATORS__',
+    '-D__CUDA_NO_HALF_CONVERSIONS__',
+    '-D__CUDA_NO_HALF2_OPERATORS__',
+]
+
+setup(
+    name='carafe',
+    ext_modules=[
+        CUDAExtension(
+            'carafe_cuda',
+            ['src/carafe_cuda.cpp', 'src/carafe_cuda_kernel.cu'],
+            extra_compile_args={
+                'cxx': [],
+                'nvcc': NVCC_ARGS
+            }),
+        CUDAExtension(
+            'carafe_naive_cuda',
+            ['src/carafe_naive_cuda.cpp', 'src/carafe_naive_cuda_kernel.cu'],
+            extra_compile_args={
+                'cxx': [],
+                'nvcc': NVCC_ARGS
+            })
+    ],
+    cmdclass={'build_ext': BuildExtension})
diff --git a/mmdet/ops/carafe/src/carafe_cuda.cpp b/mmdet/ops/carafe/src/carafe_cuda.cpp
new file mode 100644
index 00000000..9a7c73af
--- /dev/null
+++ b/mmdet/ops/carafe/src/carafe_cuda.cpp
@@ -0,0 +1,113 @@
+#include <ATen/ATen.h>
+#include <torch/extension.h>
+
+#include <cmath>
+#include <vector>
+
+int CARAFEForwardLaucher(const at::Tensor features, const at::Tensor masks,
+                         const int kernel_size, const int group_size,
+                         const int scale_factor, const int batch_size,
+                         const int channels, const int input_height,
+                         const int input_width, const int output_height,
+                         const int output_width, const int mask_channels,
+                         at::Tensor rfeatures, at::Tensor routput,
+                         at::Tensor rmasks, at::Tensor output);
+
+int CARAFEBackwardLaucher(const at::Tensor top_grad, const at::Tensor rfeatures,
+                          const at::Tensor masks, const int kernel_size,
+                          const int group_size, const int scale_factor,
+                          const int batch_size, const int channels,
+                          const int input_height, const int input_width,
+                          const int output_height, const int output_width,
+                          const int mask_channels, at::Tensor rtop_grad,
+                          at::Tensor rbottom_grad_hs, at::Tensor rbottom_grad,
+                          at::Tensor rmask_grad, at::Tensor bottom_grad,
+                          at::Tensor mask_grad);
+
+#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
+#define CHECK_CONTIGUOUS(x) \
+  AT_CHECK(x.is_contiguous(), #x, " must be contiguous ")
+#define CHECK_INPUT(x) \
+  CHECK_CUDA(x);       \
+  CHECK_CONTIGUOUS(x)
+
+int carafe_forward_cuda(at::Tensor features, at::Tensor rfeatures,
+                        at::Tensor masks, at::Tensor rmasks, int kernel_size,
+                        int group_size, int scale_factor, at::Tensor routput,
+                        at::Tensor output) {
+  CHECK_INPUT(features);
+  CHECK_INPUT(rfeatures);
+  CHECK_INPUT(masks);
+  CHECK_INPUT(rmasks);
+  CHECK_INPUT(output);
+  CHECK_INPUT(routput);
+  at::DeviceGuard guard(features.device());
+
+  const int batch_size = output.size(0);
+  const int num_channels = output.size(1);
+  const int output_height = output.size(2);
+  const int output_width = output.size(3);
+
+  const int input_height = features.size(2);
+  const int input_width = features.size(3);
+
+  const int mask_channels = masks.size(1);
+
+  rfeatures.resize_({batch_size, input_height, input_width, num_channels});
+  routput.resize_({batch_size, output_height, output_width, num_channels});
+  rmasks.resize_({batch_size, output_height, output_width, mask_channels});
+
+  CARAFEForwardLaucher(features, masks, kernel_size, group_size, scale_factor,
+                       batch_size, num_channels, input_height, input_width,
+                       output_height, output_width, mask_channels, rfeatures,
+                       routput, rmasks, output);
+
+  return 1;
+}
+
+int carafe_backward_cuda(at::Tensor top_grad, at::Tensor rfeatures,
+                         at::Tensor masks, int kernel_size, int group_size,
+                         int scale_factor, at::Tensor rtop_grad,
+                         at::Tensor rbottom_grad_hs, at::Tensor rbottom_grad,
+                         at::Tensor rmask_grad, at::Tensor bottom_grad,
+                         at::Tensor mask_grad) {
+  CHECK_INPUT(top_grad);
+  CHECK_INPUT(rfeatures);
+  CHECK_INPUT(masks);
+  CHECK_INPUT(rtop_grad);
+  CHECK_INPUT(rbottom_grad_hs);
+  CHECK_INPUT(rbottom_grad);
+  CHECK_INPUT(rmask_grad);
+  CHECK_INPUT(bottom_grad);
+  CHECK_INPUT(mask_grad);
+  at::DeviceGuard guard(top_grad.device());
+
+  const int batch_size = top_grad.size(0);
+  const int num_channels = top_grad.size(1);
+  const int output_height = top_grad.size(2);
+  const int output_width = top_grad.size(3);
+
+  const int input_height = bottom_grad.size(2);
+  const int input_width = bottom_grad.size(3);
+
+  const int mask_channels = masks.size(1);
+
+  rtop_grad.resize_({batch_size, output_height, output_width, num_channels});
+  rbottom_grad.resize_({batch_size, input_height, input_width, num_channels});
+  rbottom_grad_hs.resize_(
+      {batch_size, output_height, output_width, num_channels});
+  rmask_grad.resize_({batch_size, output_height, output_width, mask_channels});
+
+  CARAFEBackwardLaucher(top_grad, rfeatures, masks, kernel_size, group_size,
+                        scale_factor, batch_size, num_channels, input_height,
+                        input_width, output_height, output_width, mask_channels,
+                        rtop_grad, rbottom_grad_hs, rbottom_grad, rmask_grad,
+                        bottom_grad, mask_grad);
+
+  return 1;
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+  m.def("forward", &carafe_forward_cuda, "carafe forward (CUDA)");
+  m.def("backward", &carafe_backward_cuda, "carafe backward (CUDA)");
+}
diff --git a/mmdet/ops/carafe/src/carafe_cuda_kernel.cu b/mmdet/ops/carafe/src/carafe_cuda_kernel.cu
new file mode 100644
index 00000000..da627550
--- /dev/null
+++ b/mmdet/ops/carafe/src/carafe_cuda_kernel.cu
@@ -0,0 +1,475 @@
+#include <ATen/ATen.h>
+#include <ATen/TensorUtils.h>
+#include <ATen/Utils.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <ATen/cuda/CUDAApplyUtils.cuh>
+#include <THC/THCAtomics.cuh>
+#include <cmath>
+
+using namespace at;
+
+#define CUDA_1D_KERNEL_LOOP(i, n)                            \
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
+       i += blockDim.x * gridDim.x)
+
+#define THREADS_PER_BLOCK 1024  // 32 * 32
+#define WARP_SIZE 32
+#define THREADS_PER_PIXEL 32
+#define MAX_SHARED_MEMORY 49152
+#define MAX_SHARED_SCALAR_T 6144  // 49152 / 8 = 6144
+#define MAXIMIZE_KERNEL_SIZE true
+#define kTileDim 32
+#define kBlockRows 8
+#define FULL_MASK 0xffffffff
+
+inline int divideUP(const int x, const int y) { return (((x) + (y)-1) / (y)); }
+
+__device__ inline int Loc2Index(const int n, const int c, const int h,
+                                const int w, const int channel_num,
+                                const int height, const int width) {
+  int index = w + (h + (c + n * channel_num) * height) * width;
+  return index;
+}
+/* TODO: move this to a common place */
+template <typename scalar_t>
+__device__ inline scalar_t min(scalar_t a, scalar_t b) {
+  return a < b ? a : b;
+}
+
+template <typename scalar_t>
+__device__ inline scalar_t max(scalar_t a, scalar_t b) {
+  return a > b ? a : b;
+}
+
+template <typename scalar_t>
+__device__ __forceinline__ scalar_t warpReduceSum(scalar_t val) {
+  for (int offset = 16; offset > 0; offset /= 2)
+    val += __shfl_down_sync(FULL_MASK, val, offset);
+  return val;
+}
+
+// Splits the original matrix into submatrices with size 32 * 32.
+// Each block transposes one submatrix by loading it into shared memory.
+// Reference https://devblogs.nvidia.com/efficient-matrix-transpose-cuda-cc/
+template <typename scalar_t>
+__global__ void BatchTranspose2DCUDAKernel(const int N, const int H,
+                                           const int W, const int dh,
+                                           const int dw,
+                                           const scalar_t *__restrict__ X,
+                                           scalar_t *__restrict__ Y) {
+  __shared__ scalar_t tile[kTileDim][kTileDim + 1];
+  const int n = blockIdx.x / (dh * dw);
+  const int k = blockIdx.x % (dh * dw);
+  const int r = k / dw;
+  const int c = k % dw;
+  const int offset = n * H * W;
+  int x = c * kTileDim + threadIdx.x;
+  int y = r * kTileDim + threadIdx.y;
+  if (x < W) {
+    for (int i = 0; threadIdx.y + i < kTileDim && y + i < H; i += kBlockRows) {
+      tile[threadIdx.y + i][threadIdx.x] = X[offset + (y + i) * W + x];
+    }
+  }
+  __syncthreads();
+  x = r * kTileDim + threadIdx.x;
+  y = c * kTileDim + threadIdx.y;
+  if (x < H) {
+    for (int i = 0; threadIdx.y + i < kTileDim && y + i < W; i += kBlockRows) {
+      Y[offset + (y + i) * H + x] = tile[threadIdx.x][threadIdx.y + i];
+    }
+  }
+}
+template <typename scalar_t>
+__global__ void CARAFEForward(
+    const int num_kernels, const scalar_t *__restrict__ bottom_data,
+    const scalar_t *__restrict__ bottom_masks, const int kernel_size,
+    const int group_size, const int scale_factor, const int channels,
+    const int down_height, const int down_width, const int height,
+    const int width, const int mask_channels, scalar_t *__restrict__ top_data) {
+#if MAXIMIZE_KERNEL_SIZE
+  __shared__ float shared_mask[MAX_SHARED_SCALAR_T * 2];
+#else
+  __shared__ scalar_t shared_mask[MAX_SHARED_SCALAR_T];
+#endif
+
+  int index = threadIdx.x + blockIdx.x * blockDim.x;
+  if (index > num_kernels - 1) {
+    return;
+  }
+  const int pixel_id = threadIdx.x / THREADS_PER_PIXEL;
+  const int split_id = threadIdx.x % THREADS_PER_PIXEL;
+  index = index / THREADS_PER_PIXEL;
+  const int pw = index % width;
+  const int ph = (index / width) % height;
+  const int n = index / width / height;
+
+  const int down_pw = pw / scale_factor;
+  const int down_ph = ph / scale_factor;
+
+  const int start_w = down_pw - (kernel_size - 1) / 2;
+  const int end_w = down_pw + (kernel_size - 1) / 2 + 1;
+  const int start_h = down_ph - (kernel_size - 1) / 2;
+  const int end_h = down_ph + (kernel_size - 1) / 2 + 1;
+  for (int c = split_id; c < mask_channels; c += THREADS_PER_PIXEL) {
+    int mask_index = Loc2Index(n, ph, pw, c, height, width, mask_channels);
+    shared_mask[c * WARP_SIZE + pixel_id] = bottom_masks[mask_index];
+  }
+  __syncthreads();
+
+  const int channels_per_group = ceilf(channels / (float)group_size);
+#pragma unroll
+  for (int c = split_id; c < channels; c += THREADS_PER_PIXEL) {
+    int mask_group = c / channels_per_group;
+    scalar_t output_val = 0;
+#pragma unroll
+    for (int iy = start_h; iy < end_h; iy++) {
+#pragma unroll
+      for (int ix = start_w; ix < end_w; ix++) {
+        if (iy < 0 || iy > down_height - 1 || ix < 0 || ix > down_width - 1) {
+          continue;
+        }
+        int mask_iy = iy - down_ph + (kernel_size - 1) / 2;
+        int mask_ix = ix - down_pw + (kernel_size - 1) / 2;
+        int mask_c =
+            (mask_group * kernel_size + mask_iy) * kernel_size + mask_ix;
+        int feat_index =
+            Loc2Index(n, iy, ix, c, down_height, down_width, channels);
+
+        output_val += bottom_data[feat_index] *
+                      shared_mask[mask_c * WARP_SIZE + pixel_id];
+      }
+    }
+
+    int top_index = Loc2Index(n, ph, pw, c, height, width, channels);
+    top_data[top_index] = output_val;
+  }
+}
+
+int CARAFEForwardLaucher(const at::Tensor features, const at::Tensor masks,
+                         const int kernel_size, const int group_size,
+                         const int scale_factor, const int batch_size,
+                         const int channels, const int input_height,
+                         const int input_width, const int output_height,
+                         const int output_width, const int mask_channels,
+                         at::Tensor rfeatures, at::Tensor routput,
+                         at::Tensor rmasks, at::Tensor output) {
+  // one warp per pixel
+  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      features.type(), "NCHW2NHWC_Feature", ([&] {
+        const scalar_t *bottom_data = features.data<scalar_t>();
+        scalar_t *top_data = rfeatures.data<scalar_t>();
+        const int dh = divideUP(channels, kTileDim);
+        const int dw = divideUP(input_height * input_width, kTileDim);
+        BatchTranspose2DCUDAKernel<scalar_t>
+            <<<batch_size * dh * dw, dim3(kTileDim, kBlockRows), 0, stream>>>(
+                batch_size, channels, input_height * input_width, dh, dw,
+                bottom_data, top_data);
+      }));
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      features.type(), "NCHW2NHWC_Masks", ([&] {
+        const scalar_t *bottom_data = masks.data<scalar_t>();
+        scalar_t *top_data = rmasks.data<scalar_t>();
+        const int dh = divideUP(mask_channels, kTileDim);
+        const int dw = divideUP(output_height * output_width, kTileDim);
+        BatchTranspose2DCUDAKernel<scalar_t>
+            <<<batch_size * dh * dw, dim3(kTileDim, kBlockRows), 0, stream>>>(
+                batch_size, mask_channels, output_height * output_width, dh, dw,
+                bottom_data, top_data);
+      }));
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      features.type(), "CARAFELaucherForward", ([&] {
+        const int num_kernels =
+            batch_size * output_height * output_width * THREADS_PER_PIXEL;
+        const scalar_t *bottom_data = rfeatures.data<scalar_t>();
+        const scalar_t *bottom_masks = rmasks.data<scalar_t>();
+        scalar_t *top_data = routput.data<scalar_t>();
+
+        CARAFEForward<scalar_t>
+            <<<at::cuda::ATenCeilDiv(num_kernels, THREADS_PER_BLOCK),
+               THREADS_PER_BLOCK, 0, stream>>>(
+                num_kernels, bottom_data, bottom_masks, kernel_size, group_size,
+                scale_factor, channels, input_height, input_width,
+                output_height, output_width, mask_channels, top_data);
+      }));
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      features.type(), "NHWC2NCHW", ([&] {
+        const scalar_t *bottom_data = routput.data<scalar_t>();
+        scalar_t *top_data = output.data<scalar_t>();
+        const int dh = divideUP(output_height * output_width, kTileDim);
+        const int dw = divideUP(channels, kTileDim);
+        BatchTranspose2DCUDAKernel<scalar_t>
+            <<<batch_size * dh * dw, dim3(kTileDim, kBlockRows), 0, stream>>>(
+                batch_size, output_height * output_width, channels, dh, dw,
+                bottom_data, top_data);
+      }));
+  cudaError_t err = cudaGetLastError();
+  if (cudaSuccess != err) {
+    fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err));
+    exit(-1);
+  }
+
+  return 1;
+}
+
+template <typename scalar_t>
+__global__ void CARAFEBackward_Feature(
+    const int num_kernels, const scalar_t *__restrict__ top_diff,
+    const scalar_t *__restrict__ bottom_masks, const int kernel_size,
+    const int group_size, const int scale_factor, const int channels,
+    const int down_height, const int down_width, const int height,
+    const int width, const int mask_channels,
+    scalar_t *__restrict__ bottom_diff) {
+#if MAXIMIZE_KERNEL_SIZE
+  __shared__ float shared_mask[MAX_SHARED_SCALAR_T * 2];
+#else
+  __shared__ scalar_t shared_mask[MAX_SHARED_SCALAR_T];
+#endif
+
+  int index = threadIdx.x + blockIdx.x * blockDim.x;
+  if (index > num_kernels - 1) {
+    return;
+  }
+
+  const int pixel_id = threadIdx.x / THREADS_PER_PIXEL;
+  const int split_id = threadIdx.x % THREADS_PER_PIXEL;
+  // (n, c, ph, pw) is an element in the bottom_data
+  index = index / THREADS_PER_PIXEL;
+  const int pw = index % width;
+  const int ph = (index / width) % height;
+  const int n = index / width / height;
+
+  const int start_w = pw - (kernel_size - 1) * scale_factor / 2;
+  const int end_w = pw + (kernel_size - 1) * scale_factor / 2 + 1;
+  const int start_h = ph - (kernel_size - 1) * scale_factor / 2;
+  const int end_h = ph + (kernel_size - 1) * scale_factor / 2 + 1;
+  for (int c = split_id; c < mask_channels; c += THREADS_PER_PIXEL) {
+    const int mask_w = (c % kernel_size) * scale_factor;
+    const int mask_h = (c / kernel_size % kernel_size) * scale_factor;
+    const int mask_x = start_w + mask_w;
+    const int mask_y = start_h + mask_h;
+    if (mask_y < 0 || mask_y > height - 1 || mask_x < 0 || mask_x > width - 1) {
+      shared_mask[c * WARP_SIZE + pixel_id] = 0;
+      continue;
+    }
+    const int mask_group = c / (kernel_size * kernel_size);
+    const int mask_c = (2 * mask_group + 1) * kernel_size * kernel_size - c - 1;
+    int mask_index =
+        Loc2Index(n, mask_c, mask_y, mask_x, mask_channels, height, width);
+    shared_mask[c * WARP_SIZE + pixel_id] = bottom_masks[mask_index];
+  }
+  __syncthreads();
+  const int channels_per_group = ceilf(channels / (float)group_size);
+#pragma unroll
+  for (int c = split_id; c < channels; c += THREADS_PER_PIXEL) {
+    int mask_group = c / channels_per_group;
+    int top_index = Loc2Index(n, ph, pw, c, height, width, channels);
+    scalar_t output_val = 0;
+#pragma unroll
+    for (int iy = start_h; iy < end_h; iy += scale_factor) {
+#pragma unroll
+      for (int ix = start_w; ix < end_w; ix += scale_factor) {
+        if (iy < 0 || iy > height - 1 || ix < 0 || ix > width - 1) {
+          continue;
+        }
+        int mask_iy =
+            (iy - ph + (kernel_size - 1) * scale_factor / 2) / scale_factor;
+        int mask_ix =
+            (ix - pw + (kernel_size - 1) * scale_factor / 2) / scale_factor;
+        int mask_c =
+            (mask_group * kernel_size + mask_iy) * kernel_size + mask_ix;
+        int feat_index = Loc2Index(n, iy, ix, c, height, width, channels);
+        output_val +=
+            shared_mask[mask_c * WARP_SIZE + pixel_id] * top_diff[feat_index];
+      }
+    }
+    bottom_diff[top_index] = output_val;
+  }
+}
+
+template <typename scalar_t>
+__global__ void FeatureSum(const int num_kernels,
+                           const scalar_t *__restrict__ input_data,
+                           const int scale_factor, const int channels,
+                           const int height, const int width,
+                           scalar_t *__restrict__ output_data) {
+  int index = threadIdx.x + blockIdx.x * blockDim.x;
+  if (index > num_kernels - 1) {
+    return;
+  }
+  const int split_id = threadIdx.x % THREADS_PER_PIXEL;
+  index = index / THREADS_PER_PIXEL;
+  const int pw = index % width;
+  const int ph = (index / width) % height;
+  const int n = index / width / height;
+  for (int c = split_id; c < channels; c += THREADS_PER_PIXEL) {
+    scalar_t output_val = 0;
+    for (int iy = ph * scale_factor; iy < (ph + 1) * scale_factor; iy++) {
+      for (int ix = pw * scale_factor; ix < (pw + 1) * scale_factor; ix++) {
+        int input_id = Loc2Index(n, iy, ix, c, height * scale_factor,
+                                 width * scale_factor, channels);
+        output_val += input_data[input_id];
+      }
+    }
+    const int output_id = Loc2Index(n, ph, pw, c, height, width, channels);
+    output_data[output_id] = output_val;
+  }
+}
+
+template <typename scalar_t>
+__global__ void CARAFEBackward_Mask(const int num_kernels,
+                                    const scalar_t *__restrict__ top_diff,
+                                    const scalar_t *__restrict__ bottom_data,
+                                    const int kernel_size, const int group_size,
+                                    const int scale_factor, const int channels,
+                                    const int down_height, const int down_width,
+                                    const int height, const int width,
+                                    const int mask_channels,
+                                    scalar_t *__restrict__ mask_diff) {
+  int index = threadIdx.x + blockIdx.x * blockDim.x;
+  if (index > num_kernels - 1) {
+    return;
+  }
+
+  const int lane_id = index % WARP_SIZE;
+  index = index / WARP_SIZE;
+  const int mask_c = index % mask_channels;
+  // (n, c, ph, pw) is an element in the bottom_data
+  index = index / mask_channels;
+  const int pw = index % width;
+  const int ph = (index / width) % height;
+  const int n = index / width / height;
+
+  const int down_pw = pw / scale_factor;
+  const int down_ph = ph / scale_factor;
+
+  const int mask_group = mask_c / (kernel_size * kernel_size);
+  const int mask_loc = mask_c % (kernel_size * kernel_size);
+
+  const int offset_x = mask_loc % kernel_size - (kernel_size - 1) / 2;
+  const int offset_y =
+      mask_loc / kernel_size % kernel_size - (kernel_size - 1) / 2;
+
+  const int down_x = down_pw + offset_x;
+  const int down_y = down_ph + offset_y;
+
+  scalar_t output_val = 0;
+
+  if (down_y >= 0 && down_y <= down_height - 1 && down_x >= 0 &&
+      down_x <= down_width - 1) {
+    const int channels_per_mask = ceilf(channels / (float)group_size);
+    const int start = channels_per_mask * mask_group;
+    const int end = min(channels_per_mask * (mask_group + 1), channels);
+    for (int c = start + lane_id; c < end; c += WARP_SIZE) {
+      int bottom_id =
+          Loc2Index(n, down_y, down_x, c, down_height, down_width, channels);
+      int top_id = Loc2Index(n, ph, pw, c, height, width, channels);
+      output_val += top_diff[top_id] * bottom_data[bottom_id];
+    }
+  }
+  __syncwarp();
+  output_val = warpReduceSum(output_val);
+  if (lane_id == 0) {
+    const int mask_id =
+        Loc2Index(n, ph, pw, mask_c, height, width, mask_channels);
+    mask_diff[mask_id] = output_val;
+  }
+}
+
+int CARAFEBackwardLaucher(const at::Tensor top_grad, const at::Tensor rfeatures,
+                          const at::Tensor masks, const int kernel_size,
+                          const int group_size, const int scale_factor,
+                          const int batch_size, const int channels,
+                          const int input_height, const int input_width,
+                          const int output_height, const int output_width,
+                          const int mask_channels, at::Tensor rtop_grad,
+                          at::Tensor rbottom_grad_hs, at::Tensor rbottom_grad,
+                          at::Tensor rmask_grad, at::Tensor bottom_grad,
+                          at::Tensor mask_grad) {
+  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      top_grad.type(), "NCHW2NHWC_Top_Grad", ([&] {
+        const scalar_t *bottom_data = top_grad.data<scalar_t>();
+        scalar_t *top_data = rtop_grad.data<scalar_t>();
+        const int dh = divideUP(channels, kTileDim);
+        const int dw = divideUP(output_height * output_width, kTileDim);
+        BatchTranspose2DCUDAKernel<scalar_t>
+            <<<batch_size * dh * dw, dim3(kTileDim, kBlockRows), 0, stream>>>(
+                batch_size, channels, output_height * output_width, dh, dw,
+                bottom_data, top_data);
+      }));
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      top_grad.type(), "CARAFELaucherBackward_Feature", ([&] {
+        const int num_kernels =
+            batch_size * output_height * output_width * THREADS_PER_PIXEL;
+        const scalar_t *top_diff = rtop_grad.data<scalar_t>();
+        const scalar_t *bottom_masks = masks.data<scalar_t>();
+        scalar_t *bottom_diff = rbottom_grad_hs.data<scalar_t>();
+
+        CARAFEBackward_Feature<scalar_t>
+            <<<at::cuda::ATenCeilDiv(num_kernels, THREADS_PER_BLOCK),
+               THREADS_PER_BLOCK, 0, stream>>>(
+                num_kernels, top_diff, bottom_masks, kernel_size, group_size,
+                scale_factor, channels, input_height, input_width,
+                output_height, output_width, mask_channels, bottom_diff);
+      }));
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      top_grad.type(), "FeatureSum", ([&] {
+        const int num_kernels =
+            batch_size * input_height * input_width * THREADS_PER_PIXEL;
+        const scalar_t *bottom_diff_hs = rbottom_grad_hs.data<scalar_t>();
+        scalar_t *bottom_diff = rbottom_grad.data<scalar_t>();
+
+        FeatureSum<scalar_t>
+            <<<at::cuda::ATenCeilDiv(num_kernels, THREADS_PER_BLOCK),
+               THREADS_PER_BLOCK, 0, stream>>>(
+                num_kernels, bottom_diff_hs, scale_factor, channels,
+                input_height, input_width, bottom_diff);
+      }));
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      top_grad.type(), "NHWC2NCHW_Bottom_Grad", ([&] {
+        const scalar_t *bottom_data = rbottom_grad.data<scalar_t>();
+        scalar_t *top_data = bottom_grad.data<scalar_t>();
+        const int dh = divideUP(input_height * input_width, kTileDim);
+        const int dw = divideUP(channels, kTileDim);
+        BatchTranspose2DCUDAKernel<scalar_t>
+            <<<batch_size * dh * dw, dim3(kTileDim, kBlockRows), 0, stream>>>(
+                batch_size, input_height * input_width, channels, dh, dw,
+                bottom_data, top_data);
+      }));
+
+  AT_DISPATCH_FLOATING_TYPES(
+      top_grad.type(), "CARAFELaucherBackward_Mask", ([&] {
+        const int num_kernels = batch_size * output_height * output_width *
+                                mask_channels * WARP_SIZE;
+        const scalar_t *top_diff = rtop_grad.data<scalar_t>();
+        const scalar_t *bottom_data = rfeatures.data<scalar_t>();
+        scalar_t *mask_diff = rmask_grad.data<scalar_t>();
+
+        CARAFEBackward_Mask<scalar_t>
+            <<<at::cuda::ATenCeilDiv(num_kernels, THREADS_PER_BLOCK),
+               THREADS_PER_BLOCK, 0, stream>>>(
+                num_kernels, top_diff, bottom_data, kernel_size, group_size,
+                scale_factor, channels, input_height, input_width,
+                output_height, output_width, mask_channels, mask_diff);
+      }));
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      top_grad.type(), "NHWC2NCHW_Mask_Grad", ([&] {
+        const scalar_t *bottom_data = rmask_grad.data<scalar_t>();
+        scalar_t *top_data = mask_grad.data<scalar_t>();
+        const int dh = divideUP(output_height * output_width, kTileDim);
+        const int dw = divideUP(mask_channels, kTileDim);
+        BatchTranspose2DCUDAKernel<scalar_t>
+            <<<batch_size * dh * dw, dim3(kTileDim, kBlockRows), 0, stream>>>(
+                batch_size, output_height * output_width, mask_channels, dh, dw,
+                bottom_data, top_data);
+      }));
+  cudaError_t err = cudaGetLastError();
+  if (cudaSuccess != err) {
+    fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err));
+    exit(-1);
+  }
+
+  return 1;
+}
diff --git a/mmdet/ops/carafe/src/carafe_naive_cuda.cpp b/mmdet/ops/carafe/src/carafe_naive_cuda.cpp
new file mode 100644
index 00000000..fbcda80e
--- /dev/null
+++ b/mmdet/ops/carafe/src/carafe_naive_cuda.cpp
@@ -0,0 +1,75 @@
+#include <ATen/ATen.h>
+#include <torch/torch.h>
+
+#include <cmath>
+#include <vector>
+
+int CARAFENAIVEForwardLaucher(const at::Tensor features, const at::Tensor masks,
+                              const int kernel_size, const int group_size,
+                              const int scale_factor, const int batch_size,
+                              const int channels, const int height,
+                              const int width, at::Tensor output);
+
+int CARAFENAIVEBackwardLaucher(const at::Tensor top_grad,
+                               const at::Tensor features,
+                               const at::Tensor masks, const int kernel_size,
+                               const int group_size, const int scale_factor,
+                               const int batch_size, const int channels,
+                               const int height, const int width,
+                               at::Tensor bottom_grad, at::Tensor mask_grad);
+
+#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
+#define CHECK_CONTIGUOUS(x) \
+  AT_CHECK(x.is_contiguous(), #x, " must be contiguous ")
+#define CHECK_INPUT(x) \
+  CHECK_CUDA(x);       \
+  CHECK_CONTIGUOUS(x)
+
+int carafe_naive_forward_cuda(at::Tensor features, at::Tensor masks,
+                              int kernel_size, int group_size, int scale_factor,
+                              at::Tensor output) {
+  CHECK_INPUT(features);
+  CHECK_INPUT(masks);
+  CHECK_INPUT(output);
+  at::DeviceGuard guard(features.device());
+
+  int batch_size = output.size(0);
+  int num_channels = output.size(1);
+  int data_height = output.size(2);
+  int data_width = output.size(3);
+
+  CARAFENAIVEForwardLaucher(features, masks, kernel_size, group_size,
+                            scale_factor, batch_size, num_channels, data_height,
+                            data_width, output);
+
+  return 1;
+}
+
+int carafe_naive_backward_cuda(at::Tensor top_grad, at::Tensor features,
+                               at::Tensor masks, int kernel_size,
+                               int group_size, int scale_factor,
+                               at::Tensor bottom_grad, at::Tensor mask_grad) {
+  CHECK_INPUT(top_grad);
+  CHECK_INPUT(features);
+  CHECK_INPUT(masks);
+  CHECK_INPUT(bottom_grad);
+  CHECK_INPUT(mask_grad);
+  at::DeviceGuard guard(top_grad.device());
+
+  int batch_size = top_grad.size(0);
+  int num_channels = top_grad.size(1);
+  int data_height = top_grad.size(2);
+  int data_width = top_grad.size(3);
+
+  CARAFENAIVEBackwardLaucher(top_grad, features, masks, kernel_size, group_size,
+                             scale_factor, batch_size, num_channels,
+                             data_height, data_width, bottom_grad, mask_grad);
+
+  return 1;
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+  m.def("forward", &carafe_naive_forward_cuda, "carafe_naive forward (CUDA)");
+  m.def("backward", &carafe_naive_backward_cuda,
+        "carafe_naive backward (CUDA)");
+}
diff --git a/mmdet/ops/carafe/src/carafe_naive_cuda_kernel.cu b/mmdet/ops/carafe/src/carafe_naive_cuda_kernel.cu
new file mode 100644
index 00000000..3edbae79
--- /dev/null
+++ b/mmdet/ops/carafe/src/carafe_naive_cuda_kernel.cu
@@ -0,0 +1,176 @@
+#include <ATen/ATen.h>
+#include <THC/THCAtomics.cuh>
+
+using namespace at;  // temporal fix for pytorch<=0.4.1 (see #9848)
+
+#define CUDA_1D_KERNEL_LOOP(i, n)                            \
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
+       i += blockDim.x * gridDim.x)
+
+#define THREADS_PER_BLOCK 1024
+
+inline int GET_BLOCKS(const int N) {
+  int optimal_block_num = (N + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
+  int max_block_num = 65536;
+  return min(optimal_block_num, max_block_num);
+}
+
+__device__ inline int Loc2Index(const int n, const int c, const int h,
+                                const int w, const int channel_num,
+                                const int height, const int width) {
+  int index = w + (h + (c + n * channel_num) * height) * width;
+  return index;
+}
+template <typename scalar_t>
+__global__ void CARAFENAIVEForward(const int nthreads,
+                                   const scalar_t *bottom_data,
+                                   const scalar_t *bottom_masks,
+                                   const int kernel_size, const int group_size,
+                                   const int scale_factor, const int channels,
+                                   const int height, const int width,
+                                   scalar_t *top_data) {
+  CUDA_1D_KERNEL_LOOP(index, nthreads) {
+    // (n, c, ph, pw) is an element in the bottom_data
+    int pw = index % width;
+    int ph = (index / width) % height;
+    int c = (index / width / height) % channels;
+    int n = index / width / height / channels;
+
+    int mask_channels = kernel_size * kernel_size * group_size;
+    int mask_group = c / (channels / group_size);
+
+    int down_pw = pw / scale_factor;
+    int down_ph = ph / scale_factor;
+    int down_width = width / scale_factor;
+    int down_height = height / scale_factor;
+    int start_w = down_pw - (kernel_size - 1) / 2;
+    int end_w = down_pw + (kernel_size - 1) / 2 + 1;
+    int start_h = down_ph - (kernel_size - 1) / 2;
+    int end_h = down_ph + (kernel_size - 1) / 2 + 1;
+
+    scalar_t output_val = 0;
+    for (int iy = start_h; iy < end_h; iy++) {
+      for (int ix = start_w; ix < end_w; ix++) {
+        if (iy < 0 || iy > down_height - 1 || ix < 0 || ix > down_width - 1) {
+          continue;
+        }
+        int mask_iy = iy - down_ph + (kernel_size - 1) / 2;
+        int mask_ix = ix - down_pw + (kernel_size - 1) / 2;
+        int mask_c =
+            (mask_group * kernel_size + mask_iy) * kernel_size + mask_ix;
+        int feat_index =
+            Loc2Index(n, c, iy, ix, channels, down_height, down_width);
+        int mask_index =
+            Loc2Index(n, mask_c, ph, pw, mask_channels, height, width);
+        output_val += bottom_data[feat_index] * bottom_masks[mask_index];
+      }
+    }
+    top_data[index] = output_val;
+  }
+}
+
+int CARAFENAIVEForwardLaucher(const at::Tensor features, const at::Tensor masks,
+                              const int kernel_size, const int group_size,
+                              const int scale_factor, const int batch_size,
+                              const int channels, const int height,
+                              const int width, at::Tensor output) {
+  const int output_size = batch_size * channels * height * width;
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      features.type(), "CARAFENAIVELaucherForward", ([&] {
+        const scalar_t *bottom_data = features.data<scalar_t>();
+        const scalar_t *bottom_masks = masks.data<scalar_t>();
+        scalar_t *top_data = output.data<scalar_t>();
+
+        CARAFENAIVEForward<scalar_t>
+            <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
+                output_size, bottom_data, bottom_masks, kernel_size, group_size,
+                scale_factor, channels, height, width, top_data);
+      }));
+  cudaError_t err = cudaGetLastError();
+  if (cudaSuccess != err) {
+    fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err));
+    exit(-1);
+  }
+
+  return 1;
+}
+
+template <typename scalar_t>
+__global__ void CARAFENAIVEBackward(
+    const int nthreads, const scalar_t *top_diff, const scalar_t *bottom_data,
+    const scalar_t *bottom_masks, const int kernel_size, const int group_size,
+    const int scale_factor, const int channels, const int height,
+    const int width, scalar_t *bottom_diff, scalar_t *mask_diff) {
+  CUDA_1D_KERNEL_LOOP(index, nthreads) {
+    // (n, c, ph, pw) is an element in the bottom_data
+    int pw = index % width;
+    int ph = (index / width) % height;
+    int c = (index / width / height) % channels;
+    int n = index / width / height / channels;
+
+    int mask_channels = kernel_size * kernel_size * group_size;
+    int mask_group = c / (channels / group_size);
+
+    int down_pw = pw / scale_factor;
+    int down_ph = ph / scale_factor;
+    int down_width = width / scale_factor;
+    int down_height = height / scale_factor;
+    int start_w = down_pw - (kernel_size - 1) / 2;
+    int end_w = down_pw + (kernel_size - 1) / 2 + 1;
+    int start_h = down_ph - (kernel_size - 1) / 2;
+    int end_h = down_ph + (kernel_size - 1) / 2 + 1;
+
+    for (int iy = start_h; iy < end_h; iy++) {
+      for (int ix = start_w; ix < end_w; ix++) {
+        if (iy < 0 || iy > down_height - 1 || ix < 0 || ix > down_width - 1) {
+          continue;
+        }
+        int mask_iy = iy - down_ph + (kernel_size - 1) / 2;
+        int mask_ix = ix - down_pw + (kernel_size - 1) / 2;
+        int mask_c =
+            (mask_group * kernel_size + mask_iy) * kernel_size + mask_ix;
+        int feat_index =
+            Loc2Index(n, c, iy, ix, channels, down_height, down_width);
+        int mask_index =
+            Loc2Index(n, mask_c, ph, pw, mask_channels, height, width);
+        atomicAdd(bottom_diff + feat_index,
+                  bottom_masks[mask_index] * top_diff[index]);
+        atomicAdd(mask_diff + mask_index,
+                  bottom_data[feat_index] * top_diff[index]);
+      }
+    }
+  }
+}
+
+int CARAFENAIVEBackwardLaucher(const at::Tensor top_grad,
+                               const at::Tensor features,
+                               const at::Tensor masks, const int kernel_size,
+                               const int group_size, const int scale_factor,
+                               const int batch_size, const int channels,
+                               const int height, const int width,
+                               at::Tensor bottom_grad, at::Tensor mask_grad) {
+  const int output_size = batch_size * channels * height * width;
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      top_grad.type(), "CARAFENAIVELaucherBackward", ([&] {
+        const scalar_t *top_diff = top_grad.data<scalar_t>();
+        const scalar_t *bottom_data = features.data<scalar_t>();
+        const scalar_t *bottom_masks = masks.data<scalar_t>();
+        scalar_t *bottom_diff = bottom_grad.data<scalar_t>();
+        scalar_t *mask_diff = mask_grad.data<scalar_t>();
+
+        CARAFENAIVEBackward<scalar_t>
+            <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
+                output_size, top_diff, bottom_data, bottom_masks, kernel_size,
+                group_size, scale_factor, channels, height, width, bottom_diff,
+                mask_diff);
+      }));
+
+  cudaError_t err = cudaGetLastError();
+  if (cudaSuccess != err) {
+    fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err));
+    exit(-1);
+  }
+
+  return 1;
+}
diff --git a/setup.py b/setup.py
index 8353637b..05f4f06f 100755
--- a/setup.py
+++ b/setup.py
@@ -270,6 +270,17 @@ if __name__ == '__main__':
                 sources=[
                     'src/masked_conv2d_cuda.cpp', 'src/masked_conv2d_kernel.cu'
                 ]),
+            make_cuda_ext(
+                name='carafe_cuda',
+                module='mmdet.ops.carafe',
+                sources=['src/carafe_cuda.cpp', 'src/carafe_cuda_kernel.cu']),
+            make_cuda_ext(
+                name='carafe_naive_cuda',
+                module='mmdet.ops.carafe',
+                sources=[
+                    'src/carafe_naive_cuda.cpp',
+                    'src/carafe_naive_cuda_kernel.cu'
+                ])
         ],
         cmdclass={'build_ext': BuildExtension},
         zip_safe=False)
-- 
GitLab