Skip to content
Snippets Groups Projects
Unverified Commit b5431092 authored by Jiaqi Wang's avatar Jiaqi Wang Committed by GitHub
Browse files

Code Release: CARAFE: Content-Aware ReAssembly of FEatures (ICCV 2019) (#1583)

* add carafe ops

* rename carafe benchmark

* grad check fix

* update grad check

* update grad check output

* add fpn carafe & mask head carafe

* add ReadMe

* update readme

* add carafe setup

* update naive carafe

* update readme and setup

* readme typo fix

* fix flake8 error

* fix flake 8 error

* fix flake 8

* fix flake 8 more

* flake 8 fix plus

* flake 8 fix

* fix flake 8

* reformat ops files

* update fpn files and cfgs

* update readme

* update fcn_mask_head

* update fpn_carafe

* update kernel

* update

* update

* add docstring in FPN_CARAFE

* reformat with yapf

* update

* update

* add build upsampler

* fix mask head build error

* reformat build upsample layer

* add doc string for CARAFE and PixelShuffle

* update

* update upsample_cfg_

* update

* update doc string

* rm abbr in build upsample layer

* update readme

* update model_zoo

* add link to other features in ReadMe
parent dc8500b4
No related branches found
No related tags found
No related merge requests found
Showing
with 2012 additions and 26 deletions
...@@ -73,14 +73,15 @@ Results and models are available in the [Model zoo](docs/MODEL_ZOO.md). ...@@ -73,14 +73,15 @@ Results and models are available in the [Model zoo](docs/MODEL_ZOO.md).
| ATSS | ✓ | ✓ | ☐ | ✗ | ✓ | | ATSS | ✓ | ✓ | ☐ | ✗ | ✓ |
Other features Other features
- [x] DCNv2 - [x] [CARAFE](configs/carafe/README.md)
- [x] Group Normalization - [x] [DCNv2](configs/dcn/README.md)
- [x] Weight Standardization - [x] [Group Normalization](configs/gn/README.md)
- [x] [Weight Standardization](configs/gn+ws/README.md)
- [x] OHEM - [x] OHEM
- [x] Soft-NMS - [x] Soft-NMS
- [x] Generalized Attention - [x] [Generalized Attention](configs/empirical_attention/README.md)
- [x] GCNet - [x] [GCNet](configs/gcnet/README.md)
- [x] Mixed Precision (FP16) Training - [x] [Mixed Precision (FP16) Training](https://github.com/open-mmlab/mmdetection/blob/master/configs/fp16)
- [x] [InstaBoost](configs/instaboost/README.md) - [x] [InstaBoost](configs/instaboost/README.md)
......
# CARAFE: Content-Aware ReAssembly of FEatures
## Introduction
We provide config files to reproduce the object detection & instance segmentation results in the ICCV 2019 Oral paper for [CARAFE: Content-Aware ReAssembly of FEatures](https://arxiv.org/abs/1905.02188).
```
@inproceedings{Wang_2019_ICCV,
title = {CARAFE: Content-Aware ReAssembly of FEatures},
author = {Wang, Jiaqi and Chen, Kai and Xu, Rui and Liu, Ziwei and Loy, Chen Change and Lin, Dahua},
booktitle = {The IEEE International Conference on Computer Vision (ICCV)},
month = {October},
year = {2019}
}
```
## Results and Models
The results on COCO 2017 val is shown in the below table.
| Method | Backbone | Style | Lr schd | Test Proposal Num| Box AP | Mask AP | Download |
| :--------------------: | :-------------: | :-----: | :-----: | :--------------: | :----: | :--------: |:----------------------------------------------------------------------------------------------------: |
| Faster R-CNN w/ CARAFE | R-50-FPN | pytorch | 1x | 1000 | 37.8 | - | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/carafe/faster_rcnn_r50_fpn_carafe_1x-2ca2d094.pth) |
| - | - | - | - | 2000 | 37.9 | - | - |
| Mask R-CNN w/ CARAFE | R-50-FPN | pytorch | 1x | 1000 | 38.6 | 35.6| [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/carafe/mask_rcnn_r50_fpn_carafe_1x-2cc4b9fe.pth) |
| - | - | - | - | 2000 | 38.6 | 35.7| - |
## Implementation
The CUDA implementation of CARAFE can be find at `mmdet/ops/carafe` under this repository.
## Setup CARAFE
a. Use CARAFE in mmdetection.
Install mmdetection following the official guide.
b. Use CARAFE in your own project.
Git clone mmdetection.
```shell
git clone https://github.com/open-mmlab/mmdetection.git
cd mmdetection
```
Setup CARAFE in our project.
```shell
cp -r ./mmdet/ops/carafe $Your_Project_Path$
cd $Your_Project_Path$/carafe
python setup.py develop
# or "pip install -v -e ."
cd ..
python ./carafe/grad_check.py
```
# model settings
model = dict(
type='FasterRCNN',
pretrained='torchvision://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
style='pytorch'),
neck=dict(
type='FPN_CARAFE',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5,
start_level=0,
end_level=-1,
norm_cfg=None,
activation=None,
order=('conv', 'norm', 'act'),
upsample_cfg=dict(
type='carafe',
up_kernel=5,
up_group=1,
encoder_kernel=3,
encoder_dilation=1,
compressed_channels=64)),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_scales=[8],
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[4, 8, 16, 32, 64],
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0],
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='SharedFCBBoxHead',
num_fcs=2,
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=81,
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2],
reg_class_agnostic=False,
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)))
# model training and testing settings
train_cfg = dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=0,
pos_weight=-1,
debug=False),
rpn_proposal=dict(
nms_across_levels=False,
nms_pre=2000,
nms_post=2000,
max_num=2000,
nms_thr=0.7,
min_bbox_size=0),
rcnn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
pos_weight=-1,
debug=False))
test_cfg = dict(
rpn=dict(
nms_across_levels=False,
nms_pre=1000,
nms_post=1000,
max_num=1000,
nms_thr=0.7,
min_bbox_size=0),
rcnn=dict(
score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100)
# soft-nms is also supported for rcnn testing
# e.g., nms=dict(type='soft_nms', iou_thr=0.5, min_score=0.05)
)
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=64),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=64),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
imgs_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline))
evaluation = dict(interval=1, metric='bbox')
# optimizer
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[8, 11])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
evaluation = dict(interval=1)
# runtime settings
total_epochs = 12
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/faster_rcnn_r50_fpn_carafe_1x'
load_from = None
resume_from = None
workflow = [('train', 1)]
# model settings
model = dict(
type='MaskRCNN',
pretrained='torchvision://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
style='pytorch'),
neck=dict(
type='FPN_CARAFE',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5,
start_level=0,
end_level=-1,
norm_cfg=None,
activation=None,
order=('conv', 'norm', 'act'),
upsample_cfg=dict(
type='carafe',
up_kernel=5,
up_group=1,
encoder_kernel=3,
encoder_dilation=1,
compressed_channels=64)),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_scales=[8],
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[4, 8, 16, 32, 64],
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0],
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='SharedFCBBoxHead',
num_fcs=2,
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=81,
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2],
reg_class_agnostic=False,
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
mask_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
mask_head=dict(
type='FCNMaskHead',
num_convs=4,
in_channels=256,
conv_out_channels=256,
num_classes=81,
upsample_cfg=dict(
type='carafe',
scale_factor=2,
up_kernel=5,
up_group=1,
encoder_kernel=3,
encoder_dilation=1,
compressed_channels=64),
loss_mask=dict(
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)))
# model training and testing settings
train_cfg = dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=0,
pos_weight=-1,
debug=False),
rpn_proposal=dict(
nms_across_levels=False,
nms_pre=2000,
nms_post=2000,
max_num=2000,
nms_thr=0.7,
min_bbox_size=0),
rcnn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=28,
pos_weight=-1,
debug=False))
test_cfg = dict(
rpn=dict(
nms_across_levels=False,
nms_pre=1000,
nms_post=1000,
max_num=1000,
nms_thr=0.7,
min_bbox_size=0),
rcnn=dict(
score_thr=0.05,
nms=dict(type='nms', iou_thr=0.5),
max_per_img=100,
mask_thr_binary=0.5))
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=64),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=64),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
imgs_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline))
evaluation = dict(interval=1, metric=['bbox', 'segm'])
# optimizer
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[8, 11])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
evaluation = dict(interval=1)
# runtime settings
total_epochs = 12
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/mask_rcnn_r50_fpn_carafe_1x'
load_from = None
resume_from = None
workflow = [('train', 1)]
...@@ -230,6 +230,9 @@ Please refer to [Weight Standardization](https://github.com/open-mmlab/mmdetecti ...@@ -230,6 +230,9 @@ Please refer to [Weight Standardization](https://github.com/open-mmlab/mmdetecti
Please refer to [Deformable Convolutional Networks](https://github.com/open-mmlab/mmdetection/blob/master/configs/dcn) for details. Please refer to [Deformable Convolutional Networks](https://github.com/open-mmlab/mmdetection/blob/master/configs/dcn) for details.
### CARAFE: Content-Aware ReAssembly of FEatures
Please refer to [CARAFE](https://github.com/open-mmlab/mmdetection/blob/master/configs/carafe) for details.
### Instaboost ### Instaboost
Please refer to [Instaboost](https://github.com/open-mmlab/mmdetection/blob/master/configs/instaboost) for details. Please refer to [Instaboost](https://github.com/open-mmlab/mmdetection/blob/master/configs/instaboost) for details.
......
...@@ -6,9 +6,10 @@ import torch.nn as nn ...@@ -6,9 +6,10 @@ import torch.nn as nn
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from mmdet.core import auto_fp16, force_fp32, mask_target from mmdet.core import auto_fp16, force_fp32, mask_target
from mmdet.ops.carafe import CARAFEPack
from ..builder import build_loss from ..builder import build_loss
from ..registry import HEADS from ..registry import HEADS
from ..utils import ConvModule from ..utils import ConvModule, build_upsample_layer
@HEADS.register_module @HEADS.register_module
...@@ -20,27 +21,30 @@ class FCNMaskHead(nn.Module): ...@@ -20,27 +21,30 @@ class FCNMaskHead(nn.Module):
in_channels=256, in_channels=256,
conv_kernel_size=3, conv_kernel_size=3,
conv_out_channels=256, conv_out_channels=256,
upsample_method='deconv',
upsample_ratio=2,
num_classes=81, num_classes=81,
class_agnostic=False, class_agnostic=False,
upsample_cfg=dict(type='deconv', scale_factor=2),
conv_cfg=None, conv_cfg=None,
norm_cfg=None, norm_cfg=None,
loss_mask=dict( loss_mask=dict(
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)): type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)):
super(FCNMaskHead, self).__init__() super(FCNMaskHead, self).__init__()
if upsample_method not in [None, 'deconv', 'nearest', 'bilinear']: self.upsample_cfg = upsample_cfg.copy()
if self.upsample_cfg['type'] not in [
None, 'deconv', 'nearest', 'bilinear', 'carafe'
]:
raise ValueError( raise ValueError(
'Invalid upsample method {}, accepted methods ' 'Invalid upsample method {}, accepted methods '
'are "deconv", "nearest", "bilinear"'.format(upsample_method)) 'are "deconv", "nearest", "bilinear", "carafe"'.format(
self.upsample_cfg['type']))
self.num_convs = num_convs self.num_convs = num_convs
# WARN: roi_feat_size is reserved and not used # WARN: roi_feat_size is reserved and not used
self.roi_feat_size = _pair(roi_feat_size) self.roi_feat_size = _pair(roi_feat_size)
self.in_channels = in_channels self.in_channels = in_channels
self.conv_kernel_size = conv_kernel_size self.conv_kernel_size = conv_kernel_size
self.conv_out_channels = conv_out_channels self.conv_out_channels = conv_out_channels
self.upsample_method = upsample_method self.upsample_method = self.upsample_cfg.get('type')
self.upsample_ratio = upsample_ratio self.scale_factor = self.upsample_cfg.pop('scale_factor')
self.num_classes = num_classes self.num_classes = num_classes
self.class_agnostic = class_agnostic self.class_agnostic = class_agnostic
self.conv_cfg = conv_cfg self.conv_cfg = conv_cfg
...@@ -63,17 +67,27 @@ class FCNMaskHead(nn.Module): ...@@ -63,17 +67,27 @@ class FCNMaskHead(nn.Module):
norm_cfg=norm_cfg)) norm_cfg=norm_cfg))
upsample_in_channels = ( upsample_in_channels = (
self.conv_out_channels if self.num_convs > 0 else in_channels) self.conv_out_channels if self.num_convs > 0 else in_channels)
upsample_cfg_ = self.upsample_cfg.copy()
if self.upsample_method is None: if self.upsample_method is None:
self.upsample = None self.upsample = None
elif self.upsample_method == 'deconv': elif self.upsample_method == 'deconv':
self.upsample = nn.ConvTranspose2d( upsample_cfg_.update(
upsample_in_channels, in_channels=upsample_in_channels,
self.conv_out_channels, out_channels=self.conv_out_channels,
self.upsample_ratio, kernel_size=self.scale_factor,
stride=self.upsample_ratio) stride=self.scale_factor)
elif self.upsample_method == 'carafe':
upsample_cfg_.update(
channels=upsample_in_channels, scale_factor=self.scale_factor)
else: else:
self.upsample = nn.Upsample( # suppress warnings
scale_factor=self.upsample_ratio, mode=self.upsample_method) align_corners = (None
if self.upsample_method == 'nearest' else False)
upsample_cfg_.update(
scale_factor=self.scale_factor,
mode=self.upsample_method,
align_corners=align_corners)
self.upsample = build_upsample_layer(upsample_cfg_)
out_channels = 1 if self.class_agnostic else self.num_classes out_channels = 1 if self.class_agnostic else self.num_classes
logits_in_channel = ( logits_in_channel = (
...@@ -87,9 +101,12 @@ class FCNMaskHead(nn.Module): ...@@ -87,9 +101,12 @@ class FCNMaskHead(nn.Module):
for m in [self.upsample, self.conv_logits]: for m in [self.upsample, self.conv_logits]:
if m is None: if m is None:
continue continue
nn.init.kaiming_normal_( elif isinstance(m, CARAFEPack):
m.weight, mode='fan_out', nonlinearity='relu') m.init_weights()
nn.init.constant_(m.bias, 0) else:
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
nn.init.constant_(m.bias, 0)
@auto_fp16() @auto_fp16()
def forward(self, x): def forward(self, x):
......
from .bfp import BFP from .bfp import BFP
from .fpn import FPN from .fpn import FPN
from .fpn_carafe import FPN_CARAFE
from .hrfpn import HRFPN from .hrfpn import HRFPN
from .nas_fpn import NASFPN from .nas_fpn import NASFPN
__all__ = ['FPN', 'BFP', 'HRFPN', 'NASFPN'] __all__ = ['FPN', 'BFP', 'HRFPN', 'NASFPN', 'FPN_CARAFE']
import torch.nn as nn
from mmcv.cnn import xavier_init
from mmdet.ops.carafe import CARAFEPack
from ..registry import NECKS
from ..utils import ConvModule, build_upsample_layer
@NECKS.register_module
class FPN_CARAFE(nn.Module):
"""FPN_CARAFE is a more flexible implementation of FPN.
It allows more choice for upsample methods during the top-down pathway.
It can reproduce the preformance of ICCV 2019 paper
CARAFE: Content-Aware ReAssembly of FEatures
Please refer to https://arxiv.org/abs/1905.02188 for more details.
Args:
in_channels (list[int]): Number of channels for each input feature map.
out_channels (int): Output channels of feature pyramids.
num_outs (int): Number of output stages.
start_level (int): Start level of feature pyramids.
(Default: 0)
end_level (int): End level of feature pyramids.
(Default: -1 indicates the last level).
norm_cfg (dict): Dictionary to construct and config norm layer.
activate (str): Type of activation function in ConvModule
(Default: None indicates w/o activation).
order (dict): Order of components in ConvModule.
upsample (str): Type of upsample layer.
upsample_cfg (dict): Dictionary to construct and config upsample layer.
"""
def __init__(self,
in_channels,
out_channels,
num_outs,
start_level=0,
end_level=-1,
norm_cfg=None,
activation=None,
order=('conv', 'norm', 'act'),
upsample_cfg=dict(
type='carafe',
up_kernel=5,
up_group=1,
encoder_kernel=3,
encoder_dilation=1)):
super(FPN_CARAFE, self).__init__()
assert isinstance(in_channels, list)
self.in_channels = in_channels
self.out_channels = out_channels
self.num_ins = len(in_channels)
self.num_outs = num_outs
self.activation = activation
self.norm_cfg = norm_cfg
self.with_bias = norm_cfg is None
self.upsample_cfg = upsample_cfg.copy()
self.upsample = self.upsample_cfg.get('type')
self.relu = nn.ReLU(inplace=False)
self.order = order
assert order in [('conv', 'norm', 'act'), ('act', 'conv', 'norm')]
assert self.upsample in [
'nearest', 'bilinear', 'deconv', 'pixel_shuffle', 'carafe', None
]
if self.upsample in ['deconv', 'pixel_shuffle']:
assert hasattr(
self.upsample_cfg,
'upsample_kernel') and self.upsample_cfg.upsample_kernel > 0
self.upsample_kernel = self.upsample_cfg.pop('upsample_kernel')
if end_level == -1:
self.backbone_end_level = self.num_ins
assert num_outs >= self.num_ins - start_level
else:
# if end_level < inputs, no extra level is allowed
self.backbone_end_level = end_level
assert end_level <= len(in_channels)
assert num_outs == end_level - start_level
self.start_level = start_level
self.end_level = end_level
self.lateral_convs = nn.ModuleList()
self.fpn_convs = nn.ModuleList()
self.upsample_modules = nn.ModuleList()
for i in range(self.start_level, self.backbone_end_level):
l_conv = ConvModule(
in_channels[i],
out_channels,
1,
norm_cfg=norm_cfg,
bias=self.with_bias,
activation=activation,
inplace=False,
order=self.order)
fpn_conv = ConvModule(
out_channels,
out_channels,
3,
padding=1,
norm_cfg=self.norm_cfg,
bias=self.with_bias,
activation=activation,
inplace=False,
order=self.order)
if i != self.backbone_end_level - 1:
upsample_cfg_ = self.upsample_cfg.copy()
if self.upsample == 'deconv':
upsample_cfg_.update(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=self.upsample_kernel,
stride=2,
padding=(self.upsample_kernel - 1) // 2,
output_padding=(self.upsample_kernel - 1) // 2)
elif self.upsample == 'pixel_shuffle':
upsample_cfg_.update(
in_channels=out_channels,
out_channels=out_channels,
scale_factor=2,
upsample_kernel=self.upsample_kernel)
elif self.upsample == 'carafe':
upsample_cfg_.update(channels=out_channels, scale_factor=2)
else:
# suppress warnings
align_corners = (None
if self.upsample == 'nearest' else False)
upsample_cfg_.update(
scale_factor=2,
mode=self.upsample,
align_corners=align_corners)
upsample_module = build_upsample_layer(upsample_cfg_)
self.upsample_modules.append(upsample_module)
self.lateral_convs.append(l_conv)
self.fpn_convs.append(fpn_conv)
# add extra conv layers (e.g., RetinaNet)
extra_out_levels = (
num_outs - self.backbone_end_level + self.start_level)
if extra_out_levels >= 1:
for i in range(extra_out_levels):
in_channels = (
self.in_channels[self.backbone_end_level -
1] if i == 0 else out_channels)
extra_l_conv = ConvModule(
in_channels,
out_channels,
3,
stride=2,
padding=1,
norm_cfg=norm_cfg,
bias=self.with_bias,
activation=self.activation,
inplace=False,
order=self.order)
if self.upsample == 'deconv':
upsampler_cfg_ = dict(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=self.upsample_kernel,
stride=2,
padding=(self.upsample_kernel - 1) // 2,
output_padding=(self.upsample_kernel - 1) // 2)
elif self.upsample == 'pixel_shuffle':
upsampler_cfg_ = dict(
in_channels=out_channels,
out_channels=out_channels,
scale_factor=2,
upsample_kernel=self.upsample_kernel)
elif self.upsample == 'carafe':
upsampler_cfg_ = dict(
channels=out_channels,
scale_factor=2,
**self.upsample_cfg)
else:
# suppress warnings
align_corners = (None
if self.upsample == 'nearest' else False)
upsampler_cfg_ = dict(
scale_factor=2,
mode=self.upsample,
align_corners=align_corners)
upsampler_cfg_['type'] = self.upsample
upsample_module = build_upsample_layer(upsampler_cfg_)
extra_fpn_conv = ConvModule(
out_channels,
out_channels,
3,
padding=1,
norm_cfg=self.norm_cfg,
bias=self.with_bias,
activation=activation,
inplace=False,
order=self.order)
self.upsample_modules.append(upsample_module)
self.fpn_convs.append(extra_fpn_conv)
self.lateral_convs.append(extra_l_conv)
# default init_weights for conv(msra) and norm in ConvModule
def init_weights(self):
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
xavier_init(m, distribution='uniform')
for m in self.modules():
if isinstance(m, CARAFEPack):
m.init_weights()
def slice_as(self, src, dst):
# slice src as dst
# src should have the same or larger size than dst
assert (src.size(2) >= dst.size(2)) and (src.size(3) >= dst.size(3))
if src.size(2) == dst.size(2) and src.size(3) == dst.size(3):
return src
else:
return src[:, :, :dst.size(2), :dst.size(3)]
def tensor_add(self, a, b):
if a.size() == b.size():
c = a + b
else:
c = a + self.slice_as(b, a)
return c
def forward(self, inputs):
assert len(inputs) == len(self.in_channels)
# build laterals
laterals = []
for i, lateral_conv in enumerate(self.lateral_convs):
if i <= self.backbone_end_level - self.start_level:
input = inputs[min(i + self.start_level, len(inputs) - 1)]
else:
input = laterals[-1]
lateral = lateral_conv(input)
laterals.append(lateral)
# build top-down path
for i in range(len(laterals) - 1, 0, -1):
if self.upsample is not None:
upsample_feat = self.upsample_modules[i - 1](laterals[i])
else:
upsample_feat = laterals[i]
laterals[i - 1] = self.tensor_add(laterals[i - 1], upsample_feat)
# build outputs
num_conv_outs = len(self.fpn_convs)
outs = []
for i in range(num_conv_outs):
out = self.fpn_convs[i](laterals[i])
outs.append(out)
return tuple(outs)
...@@ -2,11 +2,12 @@ from .conv_module import ConvModule, build_conv_layer ...@@ -2,11 +2,12 @@ from .conv_module import ConvModule, build_conv_layer
from .conv_ws import ConvWS2d, conv_ws_2d from .conv_ws import ConvWS2d, conv_ws_2d
from .norm import build_norm_layer from .norm import build_norm_layer
from .scale import Scale from .scale import Scale
from .upsample import build_upsample_layer
from .weight_init import (bias_init_with_prob, kaiming_init, normal_init, from .weight_init import (bias_init_with_prob, kaiming_init, normal_init,
uniform_init, xavier_init) uniform_init, xavier_init)
__all__ = [ __all__ = [
'conv_ws_2d', 'ConvWS2d', 'build_conv_layer', 'ConvModule', 'conv_ws_2d', 'ConvWS2d', 'build_conv_layer', 'ConvModule',
'build_norm_layer', 'xavier_init', 'normal_init', 'uniform_init', 'build_norm_layer', 'build_upsample_layer', 'xavier_init', 'normal_init',
'kaiming_init', 'bias_init_with_prob', 'Scale' 'uniform_init', 'kaiming_init', 'bias_init_with_prob', 'Scale'
] ]
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import xavier_init
from mmdet.ops.carafe import CARAFEPack
class PixelShufflePack(nn.Module):
""" Pixel Shuffle upsample layer
Args:
in_channels (int): Number of input channels
out_channels (int): Number of output channels
scale_factor (int): Upsample ratio
upsample_kernel (int): Kernel size of Conv layer to expand the channels
Returns:
upsampled feature map
"""
def __init__(self, in_channels, out_channels, scale_factor,
upsample_kernel):
super(PixelShufflePack, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.scale_factor = scale_factor
self.upsample_kernel = upsample_kernel
self.upsample_conv = nn.Conv2d(
self.in_channels,
self.out_channels * scale_factor * scale_factor,
self.upsample_kernel,
padding=(self.upsample_kernel - 1) // 2)
self.init_weights()
def init_weights(self):
xavier_init(self.upsample_conv, distribution='uniform')
def forward(self, x):
x = self.upsample_conv(x)
x = F.pixel_shuffle(x, self.scale_factor)
return x
upsample_cfg = {
# format: layer_type: (abbreviation, module)
'nearest': nn.Upsample,
'bilinear': nn.Upsample,
'deconv': nn.ConvTranspose2d,
'pixel_shuffle': PixelShufflePack,
'carafe': CARAFEPack
}
def build_upsample_layer(cfg):
""" Build upsample layer
Args:
cfg (dict): cfg should contain:
type (str): Identify upsample layer type.
upsample ratio (int): Upsample ratio
layer args: args needed to instantiate a upsample layer.
Returns:
layer (nn.Module): Created upsample layer
"""
assert isinstance(cfg, dict) and 'type' in cfg
cfg_ = cfg.copy()
layer_type = cfg_.pop('type')
if layer_type not in upsample_cfg:
raise KeyError('Unrecognized upsample type {}'.format(layer_type))
else:
upsample = upsample_cfg[layer_type]
if upsample is None:
raise NotImplementedError
layer = upsample(**cfg_)
return layer
from .carafe import CARAFE, CARAFENaive, CARAFEPack, carafe, carafe_naive
__all__ = ['carafe', 'carafe_naive', 'CARAFE', 'CARAFENaive', 'CARAFEPack']
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import normal_init, xavier_init
from torch.autograd import Function
from torch.nn.modules.module import Module
from . import carafe_cuda, carafe_naive_cuda
class CARAFENaiveFunction(Function):
@staticmethod
def forward(ctx, features, masks, kernel_size, group_size, scale_factor):
assert scale_factor >= 1
assert masks.size(1) == kernel_size * kernel_size * group_size
assert masks.size(-1) == features.size(-1) * scale_factor
assert masks.size(-2) == features.size(-2) * scale_factor
assert features.size(1) % group_size == 0
assert (kernel_size - 1) % 2 == 0 and kernel_size >= 1
ctx.kernel_size = kernel_size
ctx.group_size = group_size
ctx.scale_factor = scale_factor
ctx.feature_size = features.size()
ctx.mask_size = masks.size()
n, c, h, w = features.size()
output = features.new_zeros((n, c, h * scale_factor, w * scale_factor))
if features.is_cuda:
carafe_naive_cuda.forward(features, masks, kernel_size, group_size,
scale_factor, output)
else:
raise NotImplementedError
if features.requires_grad or masks.requires_grad:
ctx.save_for_backward(features, masks)
return output
@staticmethod
def backward(ctx, grad_output):
assert grad_output.is_cuda
features, masks = ctx.saved_tensors
kernel_size = ctx.kernel_size
group_size = ctx.group_size
scale_factor = ctx.scale_factor
grad_input = torch.zeros_like(features)
grad_masks = torch.zeros_like(masks)
carafe_naive_cuda.backward(grad_output.contiguous(), features, masks,
kernel_size, group_size, scale_factor,
grad_input, grad_masks)
return grad_input, grad_masks, None, None, None
carafe_naive = CARAFENaiveFunction.apply
class CARAFENaive(Module):
def __init__(self, kernel_size, group_size, scale_factor):
super(CARAFENaive, self).__init__()
assert isinstance(kernel_size, int) and isinstance(
group_size, int) and isinstance(scale_factor, int)
self.kernel_size = kernel_size
self.group_size = group_size
self.scale_factor = scale_factor
def forward(self, features, masks):
return CARAFENaiveFunction.apply(features, masks, self.kernel_size,
self.group_size, self.scale_factor)
class CARAFEFunction(Function):
@staticmethod
def forward(ctx, features, masks, kernel_size, group_size, scale_factor):
assert scale_factor >= 1
assert masks.size(1) == kernel_size * kernel_size * group_size
assert masks.size(-1) == features.size(-1) * scale_factor
assert masks.size(-2) == features.size(-2) * scale_factor
assert features.size(1) % group_size == 0
assert (kernel_size - 1) % 2 == 0 and kernel_size >= 1
ctx.kernel_size = kernel_size
ctx.group_size = group_size
ctx.scale_factor = scale_factor
ctx.feature_size = features.size()
ctx.mask_size = masks.size()
n, c, h, w = features.size()
output = features.new_zeros((n, c, h * scale_factor, w * scale_factor))
routput = features.new_zeros(output.size(), requires_grad=False)
rfeatures = features.new_zeros(features.size(), requires_grad=False)
rmasks = masks.new_zeros(masks.size(), requires_grad=False)
if features.is_cuda:
carafe_cuda.forward(features, rfeatures, masks, rmasks,
kernel_size, group_size, scale_factor, routput,
output)
else:
raise NotImplementedError
if features.requires_grad or masks.requires_grad:
ctx.save_for_backward(features, masks, rfeatures)
return output
@staticmethod
def backward(ctx, grad_output):
assert grad_output.is_cuda
features, masks, rfeatures = ctx.saved_tensors
kernel_size = ctx.kernel_size
group_size = ctx.group_size
scale_factor = ctx.scale_factor
rgrad_output = torch.zeros_like(grad_output, requires_grad=False)
rgrad_input_hs = torch.zeros_like(grad_output, requires_grad=False)
rgrad_input = torch.zeros_like(features, requires_grad=False)
rgrad_masks = torch.zeros_like(masks, requires_grad=False)
grad_input = torch.zeros_like(features, requires_grad=False)
grad_masks = torch.zeros_like(masks, requires_grad=False)
carafe_cuda.backward(grad_output.contiguous(), rfeatures, masks,
kernel_size, group_size, scale_factor,
rgrad_output, rgrad_input_hs, rgrad_input,
rgrad_masks, grad_input, grad_masks)
return grad_input, grad_masks, None, None, None, None
carafe = CARAFEFunction.apply
class CARAFE(Module):
""" CARAFE: Content-Aware ReAssembly of FEatures
Please refer to https://arxiv.org/abs/1905.02188 for more details.
Args:
kernel_size (int): reassemble kernel size
group_size (int): reassemble group size
scale_factor (int): upsample ratio
Returns:
upsampled feature map
"""
def __init__(self, kernel_size, group_size, scale_factor):
super(CARAFE, self).__init__()
assert isinstance(kernel_size, int) and isinstance(
group_size, int) and isinstance(scale_factor, int)
self.kernel_size = kernel_size
self.group_size = group_size
self.scale_factor = scale_factor
def forward(self, features, masks):
return CARAFEFunction.apply(features, masks, self.kernel_size,
self.group_size, self.scale_factor)
class CARAFEPack(nn.Module):
""" A unified package of CARAFE upsampler that contains:
1) channel compressor 2) content encoder 3) CARAFE op
Official implementation of ICCV 2019 paper
CARAFE: Content-Aware ReAssembly of FEatures
Please refer to https://arxiv.org/abs/1905.02188 for more details.
Args:
channels (int): input feature channels
scale_factor (int): upsample ratio
up_kernel (int): kernel size of CARAFE op
up_group (int): group size of CARAFE op
encoder_kernel (int): kernel size of content encoder
encoder_dilation (int): dilation of content encoder
compressed_channels (int): output channels of channels compressor
Returns:
upsampled feature map
"""
def __init__(self,
channels,
scale_factor,
up_kernel=5,
up_group=1,
encoder_kernel=3,
encoder_dilation=1,
compressed_channels=64):
super(CARAFEPack, self).__init__()
self.channels = channels
self.scale_factor = scale_factor
self.up_kernel = up_kernel
self.up_group = up_group
self.encoder_kernel = encoder_kernel
self.encoder_dilation = encoder_dilation
self.compressed_channels = compressed_channels
self.channel_compressor = nn.Conv2d(channels, self.compressed_channels,
1)
self.content_encoder = nn.Conv2d(
self.compressed_channels,
self.up_kernel * self.up_kernel * self.up_group *
self.scale_factor * self.scale_factor,
self.encoder_kernel,
padding=int((self.encoder_kernel - 1) * self.encoder_dilation / 2),
dilation=self.encoder_dilation,
groups=1)
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
xavier_init(m, distribution='uniform')
normal_init(self.content_encoder, std=0.001)
def kernel_normalizer(self, mask):
mask = F.pixel_shuffle(mask, self.scale_factor)
n, mask_c, h, w = mask.size()
mask_channel = int(mask_c / (self.up_kernel * self.up_kernel))
mask = mask.view(n, mask_channel, -1, h, w)
mask = F.softmax(mask, dim=2)
mask = mask.view(n, mask_c, h, w).contiguous()
return mask
def feature_reassemble(self, x, mask):
x = carafe(x, mask, self.up_kernel, self.up_group, self.scale_factor)
return x
def forward(self, x):
compressed_x = self.channel_compressor(x)
mask = self.content_encoder(compressed_x)
mask = self.kernel_normalizer(mask)
x = self.feature_reassemble(x, mask)
return x
import os.path as osp
import sys
import mmcv
import torch
from torch.autograd import gradcheck
sys.path.append(osp.abspath(osp.join(__file__, '../../')))
from mmdet.ops.carafe import CARAFENAIVE # noqa: E402, isort:skip
from mmdet.ops.carafe import carafe_naive # noqa: E402, isort:skip
from mmdet.ops.carafe import carafe, CARAFE # noqa: E402, isort:skip
feat = torch.randn(2, 64, 3, 3, requires_grad=True, device='cuda:0').double()
mask = torch.randn(
2, 100, 6, 6, requires_grad=True, device='cuda:0').sigmoid().double()
print('Gradcheck for carafe...')
test = gradcheck(CARAFE(5, 4, 2), (feat, mask), atol=1e-4, eps=1e-4)
print(test)
print('Gradcheck for carafe naive...')
test = gradcheck(CARAFENAIVE(5, 4, 2), (feat, mask), atol=1e-4, eps=1e-4)
print(test)
feat = torch.randn(
2, 1024, 100, 100, requires_grad=True, device='cuda:0').float()
mask = torch.randn(
2, 25, 200, 200, requires_grad=True, device='cuda:0').sigmoid().float()
loop_num = 500
time_forward = 0
time_backward = 0
bar = mmcv.ProgressBar(loop_num)
timer = mmcv.Timer()
for i in range(loop_num):
x = carafe(feat.clone(), mask.clone(), 5, 1, 2)
torch.cuda.synchronize()
time_forward += timer.since_last_check()
x.sum().backward(retain_graph=True)
torch.cuda.synchronize()
time_backward += timer.since_last_check()
bar.update()
print('\nCARAFE time forward: {} ms/iter | time backward: {} ms/iter'.format(
(time_forward + 1e-3) * 1e3 / loop_num,
(time_backward + 1e-3) * 1e3 / loop_num))
time_naive_forward = 0
time_naive_backward = 0
bar = mmcv.ProgressBar(loop_num)
timer = mmcv.Timer()
for i in range(loop_num):
x = carafe_naive(feat.clone(), mask.clone(), 5, 1, 2)
torch.cuda.synchronize()
time_naive_forward += timer.since_last_check()
x.sum().backward(retain_graph=True)
torch.cuda.synchronize()
time_naive_backward += timer.since_last_check()
bar.update()
print('\nCARAFE naive time forward: {} ms/iter | time backward: {} ms/iter'.
format((time_naive_forward + 1e-3) * 1e3 / loop_num,
(time_naive_backward + 1e-3) * 1e3 / loop_num))
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
NVCC_ARGS = [
'-D__CUDA_NO_HALF_OPERATORS__',
'-D__CUDA_NO_HALF_CONVERSIONS__',
'-D__CUDA_NO_HALF2_OPERATORS__',
]
setup(
name='carafe',
ext_modules=[
CUDAExtension(
'carafe_cuda',
['src/carafe_cuda.cpp', 'src/carafe_cuda_kernel.cu'],
extra_compile_args={
'cxx': [],
'nvcc': NVCC_ARGS
}),
CUDAExtension(
'carafe_naive_cuda',
['src/carafe_naive_cuda.cpp', 'src/carafe_naive_cuda_kernel.cu'],
extra_compile_args={
'cxx': [],
'nvcc': NVCC_ARGS
})
],
cmdclass={'build_ext': BuildExtension})
#include <ATen/ATen.h>
#include <torch/extension.h>
#include <cmath>
#include <vector>
int CARAFEForwardLaucher(const at::Tensor features, const at::Tensor masks,
const int kernel_size, const int group_size,
const int scale_factor, const int batch_size,
const int channels, const int input_height,
const int input_width, const int output_height,
const int output_width, const int mask_channels,
at::Tensor rfeatures, at::Tensor routput,
at::Tensor rmasks, at::Tensor output);
int CARAFEBackwardLaucher(const at::Tensor top_grad, const at::Tensor rfeatures,
const at::Tensor masks, const int kernel_size,
const int group_size, const int scale_factor,
const int batch_size, const int channels,
const int input_height, const int input_width,
const int output_height, const int output_width,
const int mask_channels, at::Tensor rtop_grad,
at::Tensor rbottom_grad_hs, at::Tensor rbottom_grad,
at::Tensor rmask_grad, at::Tensor bottom_grad,
at::Tensor mask_grad);
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) \
AT_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
int carafe_forward_cuda(at::Tensor features, at::Tensor rfeatures,
at::Tensor masks, at::Tensor rmasks, int kernel_size,
int group_size, int scale_factor, at::Tensor routput,
at::Tensor output) {
CHECK_INPUT(features);
CHECK_INPUT(rfeatures);
CHECK_INPUT(masks);
CHECK_INPUT(rmasks);
CHECK_INPUT(output);
CHECK_INPUT(routput);
at::DeviceGuard guard(features.device());
const int batch_size = output.size(0);
const int num_channels = output.size(1);
const int output_height = output.size(2);
const int output_width = output.size(3);
const int input_height = features.size(2);
const int input_width = features.size(3);
const int mask_channels = masks.size(1);
rfeatures.resize_({batch_size, input_height, input_width, num_channels});
routput.resize_({batch_size, output_height, output_width, num_channels});
rmasks.resize_({batch_size, output_height, output_width, mask_channels});
CARAFEForwardLaucher(features, masks, kernel_size, group_size, scale_factor,
batch_size, num_channels, input_height, input_width,
output_height, output_width, mask_channels, rfeatures,
routput, rmasks, output);
return 1;
}
int carafe_backward_cuda(at::Tensor top_grad, at::Tensor rfeatures,
at::Tensor masks, int kernel_size, int group_size,
int scale_factor, at::Tensor rtop_grad,
at::Tensor rbottom_grad_hs, at::Tensor rbottom_grad,
at::Tensor rmask_grad, at::Tensor bottom_grad,
at::Tensor mask_grad) {
CHECK_INPUT(top_grad);
CHECK_INPUT(rfeatures);
CHECK_INPUT(masks);
CHECK_INPUT(rtop_grad);
CHECK_INPUT(rbottom_grad_hs);
CHECK_INPUT(rbottom_grad);
CHECK_INPUT(rmask_grad);
CHECK_INPUT(bottom_grad);
CHECK_INPUT(mask_grad);
at::DeviceGuard guard(top_grad.device());
const int batch_size = top_grad.size(0);
const int num_channels = top_grad.size(1);
const int output_height = top_grad.size(2);
const int output_width = top_grad.size(3);
const int input_height = bottom_grad.size(2);
const int input_width = bottom_grad.size(3);
const int mask_channels = masks.size(1);
rtop_grad.resize_({batch_size, output_height, output_width, num_channels});
rbottom_grad.resize_({batch_size, input_height, input_width, num_channels});
rbottom_grad_hs.resize_(
{batch_size, output_height, output_width, num_channels});
rmask_grad.resize_({batch_size, output_height, output_width, mask_channels});
CARAFEBackwardLaucher(top_grad, rfeatures, masks, kernel_size, group_size,
scale_factor, batch_size, num_channels, input_height,
input_width, output_height, output_width, mask_channels,
rtop_grad, rbottom_grad_hs, rbottom_grad, rmask_grad,
bottom_grad, mask_grad);
return 1;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &carafe_forward_cuda, "carafe forward (CUDA)");
m.def("backward", &carafe_backward_cuda, "carafe backward (CUDA)");
}
This diff is collapsed.
#include <ATen/ATen.h>
#include <torch/torch.h>
#include <cmath>
#include <vector>
int CARAFENAIVEForwardLaucher(const at::Tensor features, const at::Tensor masks,
const int kernel_size, const int group_size,
const int scale_factor, const int batch_size,
const int channels, const int height,
const int width, at::Tensor output);
int CARAFENAIVEBackwardLaucher(const at::Tensor top_grad,
const at::Tensor features,
const at::Tensor masks, const int kernel_size,
const int group_size, const int scale_factor,
const int batch_size, const int channels,
const int height, const int width,
at::Tensor bottom_grad, at::Tensor mask_grad);
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) \
AT_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
int carafe_naive_forward_cuda(at::Tensor features, at::Tensor masks,
int kernel_size, int group_size, int scale_factor,
at::Tensor output) {
CHECK_INPUT(features);
CHECK_INPUT(masks);
CHECK_INPUT(output);
at::DeviceGuard guard(features.device());
int batch_size = output.size(0);
int num_channels = output.size(1);
int data_height = output.size(2);
int data_width = output.size(3);
CARAFENAIVEForwardLaucher(features, masks, kernel_size, group_size,
scale_factor, batch_size, num_channels, data_height,
data_width, output);
return 1;
}
int carafe_naive_backward_cuda(at::Tensor top_grad, at::Tensor features,
at::Tensor masks, int kernel_size,
int group_size, int scale_factor,
at::Tensor bottom_grad, at::Tensor mask_grad) {
CHECK_INPUT(top_grad);
CHECK_INPUT(features);
CHECK_INPUT(masks);
CHECK_INPUT(bottom_grad);
CHECK_INPUT(mask_grad);
at::DeviceGuard guard(top_grad.device());
int batch_size = top_grad.size(0);
int num_channels = top_grad.size(1);
int data_height = top_grad.size(2);
int data_width = top_grad.size(3);
CARAFENAIVEBackwardLaucher(top_grad, features, masks, kernel_size, group_size,
scale_factor, batch_size, num_channels,
data_height, data_width, bottom_grad, mask_grad);
return 1;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &carafe_naive_forward_cuda, "carafe_naive forward (CUDA)");
m.def("backward", &carafe_naive_backward_cuda,
"carafe_naive backward (CUDA)");
}
#include <ATen/ATen.h>
#include <THC/THCAtomics.cuh>
using namespace at; // temporal fix for pytorch<=0.4.1 (see #9848)
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
i += blockDim.x * gridDim.x)
#define THREADS_PER_BLOCK 1024
inline int GET_BLOCKS(const int N) {
int optimal_block_num = (N + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
int max_block_num = 65536;
return min(optimal_block_num, max_block_num);
}
__device__ inline int Loc2Index(const int n, const int c, const int h,
const int w, const int channel_num,
const int height, const int width) {
int index = w + (h + (c + n * channel_num) * height) * width;
return index;
}
template <typename scalar_t>
__global__ void CARAFENAIVEForward(const int nthreads,
const scalar_t *bottom_data,
const scalar_t *bottom_masks,
const int kernel_size, const int group_size,
const int scale_factor, const int channels,
const int height, const int width,
scalar_t *top_data) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the bottom_data
int pw = index % width;
int ph = (index / width) % height;
int c = (index / width / height) % channels;
int n = index / width / height / channels;
int mask_channels = kernel_size * kernel_size * group_size;
int mask_group = c / (channels / group_size);
int down_pw = pw / scale_factor;
int down_ph = ph / scale_factor;
int down_width = width / scale_factor;
int down_height = height / scale_factor;
int start_w = down_pw - (kernel_size - 1) / 2;
int end_w = down_pw + (kernel_size - 1) / 2 + 1;
int start_h = down_ph - (kernel_size - 1) / 2;
int end_h = down_ph + (kernel_size - 1) / 2 + 1;
scalar_t output_val = 0;
for (int iy = start_h; iy < end_h; iy++) {
for (int ix = start_w; ix < end_w; ix++) {
if (iy < 0 || iy > down_height - 1 || ix < 0 || ix > down_width - 1) {
continue;
}
int mask_iy = iy - down_ph + (kernel_size - 1) / 2;
int mask_ix = ix - down_pw + (kernel_size - 1) / 2;
int mask_c =
(mask_group * kernel_size + mask_iy) * kernel_size + mask_ix;
int feat_index =
Loc2Index(n, c, iy, ix, channels, down_height, down_width);
int mask_index =
Loc2Index(n, mask_c, ph, pw, mask_channels, height, width);
output_val += bottom_data[feat_index] * bottom_masks[mask_index];
}
}
top_data[index] = output_val;
}
}
int CARAFENAIVEForwardLaucher(const at::Tensor features, const at::Tensor masks,
const int kernel_size, const int group_size,
const int scale_factor, const int batch_size,
const int channels, const int height,
const int width, at::Tensor output) {
const int output_size = batch_size * channels * height * width;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
features.type(), "CARAFENAIVELaucherForward", ([&] {
const scalar_t *bottom_data = features.data<scalar_t>();
const scalar_t *bottom_masks = masks.data<scalar_t>();
scalar_t *top_data = output.data<scalar_t>();
CARAFENAIVEForward<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
output_size, bottom_data, bottom_masks, kernel_size, group_size,
scale_factor, channels, height, width, top_data);
}));
cudaError_t err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
return 1;
}
template <typename scalar_t>
__global__ void CARAFENAIVEBackward(
const int nthreads, const scalar_t *top_diff, const scalar_t *bottom_data,
const scalar_t *bottom_masks, const int kernel_size, const int group_size,
const int scale_factor, const int channels, const int height,
const int width, scalar_t *bottom_diff, scalar_t *mask_diff) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the bottom_data
int pw = index % width;
int ph = (index / width) % height;
int c = (index / width / height) % channels;
int n = index / width / height / channels;
int mask_channels = kernel_size * kernel_size * group_size;
int mask_group = c / (channels / group_size);
int down_pw = pw / scale_factor;
int down_ph = ph / scale_factor;
int down_width = width / scale_factor;
int down_height = height / scale_factor;
int start_w = down_pw - (kernel_size - 1) / 2;
int end_w = down_pw + (kernel_size - 1) / 2 + 1;
int start_h = down_ph - (kernel_size - 1) / 2;
int end_h = down_ph + (kernel_size - 1) / 2 + 1;
for (int iy = start_h; iy < end_h; iy++) {
for (int ix = start_w; ix < end_w; ix++) {
if (iy < 0 || iy > down_height - 1 || ix < 0 || ix > down_width - 1) {
continue;
}
int mask_iy = iy - down_ph + (kernel_size - 1) / 2;
int mask_ix = ix - down_pw + (kernel_size - 1) / 2;
int mask_c =
(mask_group * kernel_size + mask_iy) * kernel_size + mask_ix;
int feat_index =
Loc2Index(n, c, iy, ix, channels, down_height, down_width);
int mask_index =
Loc2Index(n, mask_c, ph, pw, mask_channels, height, width);
atomicAdd(bottom_diff + feat_index,
bottom_masks[mask_index] * top_diff[index]);
atomicAdd(mask_diff + mask_index,
bottom_data[feat_index] * top_diff[index]);
}
}
}
}
int CARAFENAIVEBackwardLaucher(const at::Tensor top_grad,
const at::Tensor features,
const at::Tensor masks, const int kernel_size,
const int group_size, const int scale_factor,
const int batch_size, const int channels,
const int height, const int width,
at::Tensor bottom_grad, at::Tensor mask_grad) {
const int output_size = batch_size * channels * height * width;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
top_grad.type(), "CARAFENAIVELaucherBackward", ([&] {
const scalar_t *top_diff = top_grad.data<scalar_t>();
const scalar_t *bottom_data = features.data<scalar_t>();
const scalar_t *bottom_masks = masks.data<scalar_t>();
scalar_t *bottom_diff = bottom_grad.data<scalar_t>();
scalar_t *mask_diff = mask_grad.data<scalar_t>();
CARAFENAIVEBackward<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
output_size, top_diff, bottom_data, bottom_masks, kernel_size,
group_size, scale_factor, channels, height, width, bottom_diff,
mask_diff);
}));
cudaError_t err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
return 1;
}
...@@ -270,6 +270,17 @@ if __name__ == '__main__': ...@@ -270,6 +270,17 @@ if __name__ == '__main__':
sources=[ sources=[
'src/masked_conv2d_cuda.cpp', 'src/masked_conv2d_kernel.cu' 'src/masked_conv2d_cuda.cpp', 'src/masked_conv2d_kernel.cu'
]), ]),
make_cuda_ext(
name='carafe_cuda',
module='mmdet.ops.carafe',
sources=['src/carafe_cuda.cpp', 'src/carafe_cuda_kernel.cu']),
make_cuda_ext(
name='carafe_naive_cuda',
module='mmdet.ops.carafe',
sources=[
'src/carafe_naive_cuda.cpp',
'src/carafe_naive_cuda_kernel.cu'
])
], ],
cmdclass={'build_ext': BuildExtension}, cmdclass={'build_ext': BuildExtension},
zip_safe=False) zip_safe=False)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment