From b5431092505f7dcd7de616c8a79eba4d2532fbc8 Mon Sep 17 00:00:00 2001 From: Jiaqi Wang <1155098160@link.cuhk.edu.hk> Date: Fri, 21 Feb 2020 20:57:33 +0800 Subject: [PATCH] Code Release: CARAFE: Content-Aware ReAssembly of FEatures (ICCV 2019) (#1583) * add carafe ops * rename carafe benchmark * grad check fix * update grad check * update grad check output * add fpn carafe & mask head carafe * add ReadMe * update readme * add carafe setup * update naive carafe * update readme and setup * readme typo fix * fix flake8 error * fix flake 8 error * fix flake 8 * fix flake 8 more * flake 8 fix plus * flake 8 fix * fix flake 8 * reformat ops files * update fpn files and cfgs * update readme * update fcn_mask_head * update fpn_carafe * update kernel * update * update * add docstring in FPN_CARAFE * reformat with yapf * update * update * add build upsampler * fix mask head build error * reformat build upsample layer * add doc string for CARAFE and PixelShuffle * update * update upsample_cfg_ * update * update doc string * rm abbr in build upsample layer * update readme * update model_zoo * add link to other features in ReadMe --- README.md | 13 +- configs/carafe/README.md | 53 ++ .../carafe/faster_rcnn_r50_fpn_carafe_1x.py | 188 +++++++ configs/carafe/mask_rcnn_r50_fpn_carafe_1x.py | 210 ++++++++ docs/MODEL_ZOO.md | 3 + mmdet/models/mask_heads/fcn_mask_head.py | 51 +- mmdet/models/necks/__init__.py | 3 +- mmdet/models/necks/fpn_carafe.py | 254 ++++++++++ mmdet/models/utils/__init__.py | 5 +- mmdet/models/utils/upsample.py | 78 +++ mmdet/ops/carafe/__init__.py | 3 + mmdet/ops/carafe/carafe.py | 237 +++++++++ mmdet/ops/carafe/grad_check.py | 61 +++ mmdet/ops/carafe/setup.py | 29 ++ mmdet/ops/carafe/src/carafe_cuda.cpp | 113 +++++ mmdet/ops/carafe/src/carafe_cuda_kernel.cu | 475 ++++++++++++++++++ mmdet/ops/carafe/src/carafe_naive_cuda.cpp | 75 +++ .../carafe/src/carafe_naive_cuda_kernel.cu | 176 +++++++ setup.py | 11 + 19 files changed, 2012 insertions(+), 26 deletions(-) create mode 100644 configs/carafe/README.md create mode 100644 configs/carafe/faster_rcnn_r50_fpn_carafe_1x.py create mode 100644 configs/carafe/mask_rcnn_r50_fpn_carafe_1x.py create mode 100644 mmdet/models/necks/fpn_carafe.py create mode 100644 mmdet/models/utils/upsample.py create mode 100644 mmdet/ops/carafe/__init__.py create mode 100644 mmdet/ops/carafe/carafe.py create mode 100644 mmdet/ops/carafe/grad_check.py create mode 100644 mmdet/ops/carafe/setup.py create mode 100644 mmdet/ops/carafe/src/carafe_cuda.cpp create mode 100644 mmdet/ops/carafe/src/carafe_cuda_kernel.cu create mode 100644 mmdet/ops/carafe/src/carafe_naive_cuda.cpp create mode 100644 mmdet/ops/carafe/src/carafe_naive_cuda_kernel.cu diff --git a/README.md b/README.md index 6b97890b..9c5ffe65 100644 --- a/README.md +++ b/README.md @@ -73,14 +73,15 @@ Results and models are available in the [Model zoo](docs/MODEL_ZOO.md). | ATSS | 鉁� | 鉁� | 鈽� | 鉁� | 鉁� | Other features -- [x] DCNv2 -- [x] Group Normalization -- [x] Weight Standardization +- [x] [CARAFE](configs/carafe/README.md) +- [x] [DCNv2](configs/dcn/README.md) +- [x] [Group Normalization](configs/gn/README.md) +- [x] [Weight Standardization](configs/gn+ws/README.md) - [x] OHEM - [x] Soft-NMS -- [x] Generalized Attention -- [x] GCNet -- [x] Mixed Precision (FP16) Training +- [x] [Generalized Attention](configs/empirical_attention/README.md) +- [x] [GCNet](configs/gcnet/README.md) +- [x] [Mixed Precision (FP16) Training](https://github.com/open-mmlab/mmdetection/blob/master/configs/fp16) - [x] [InstaBoost](configs/instaboost/README.md) diff --git a/configs/carafe/README.md b/configs/carafe/README.md new file mode 100644 index 00000000..a1be11da --- /dev/null +++ b/configs/carafe/README.md @@ -0,0 +1,53 @@ +# CARAFE: Content-Aware ReAssembly of FEatures + +## Introduction + +We provide config files to reproduce the object detection & instance segmentation results in the ICCV 2019 Oral paper for [CARAFE: Content-Aware ReAssembly of FEatures](https://arxiv.org/abs/1905.02188). + +``` +@inproceedings{Wang_2019_ICCV, + title = {CARAFE: Content-Aware ReAssembly of FEatures}, + author = {Wang, Jiaqi and Chen, Kai and Xu, Rui and Liu, Ziwei and Loy, Chen Change and Lin, Dahua}, + booktitle = {The IEEE International Conference on Computer Vision (ICCV)}, + month = {October}, + year = {2019} +} +``` + +## Results and Models + +The results on COCO 2017 val is shown in the below table. + +| Method | Backbone | Style | Lr schd | Test Proposal Num| Box AP | Mask AP | Download | +| :--------------------: | :-------------: | :-----: | :-----: | :--------------: | :----: | :--------: |:----------------------------------------------------------------------------------------------------: | +| Faster R-CNN w/ CARAFE | R-50-FPN | pytorch | 1x | 1000 | 37.8 | - | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/carafe/faster_rcnn_r50_fpn_carafe_1x-2ca2d094.pth) | +| - | - | - | - | 2000 | 37.9 | - | - | +| Mask R-CNN w/ CARAFE | R-50-FPN | pytorch | 1x | 1000 | 38.6 | 35.6| [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/carafe/mask_rcnn_r50_fpn_carafe_1x-2cc4b9fe.pth) | +| - | - | - | - | 2000 | 38.6 | 35.7| - | + +## Implementation + +The CUDA implementation of CARAFE can be find at `mmdet/ops/carafe` under this repository. + +## Setup CARAFE + +a. Use CARAFE in mmdetection. + +Install mmdetection following the official guide. + +b. Use CARAFE in your own project. + +Git clone mmdetection. +```shell +git clone https://github.com/open-mmlab/mmdetection.git +cd mmdetection +``` +Setup CARAFE in our project. +```shell +cp -r ./mmdet/ops/carafe $Your_Project_Path$ +cd $Your_Project_Path$/carafe +python setup.py develop +# or "pip install -v -e ." +cd .. +python ./carafe/grad_check.py +``` diff --git a/configs/carafe/faster_rcnn_r50_fpn_carafe_1x.py b/configs/carafe/faster_rcnn_r50_fpn_carafe_1x.py new file mode 100644 index 00000000..94c8a0fc --- /dev/null +++ b/configs/carafe/faster_rcnn_r50_fpn_carafe_1x.py @@ -0,0 +1,188 @@ +# model settings +model = dict( + type='FasterRCNN', + pretrained='torchvision://resnet50', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + style='pytorch'), + neck=dict( + type='FPN_CARAFE', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=5, + start_level=0, + end_level=-1, + norm_cfg=None, + activation=None, + order=('conv', 'norm', 'act'), + upsample_cfg=dict( + type='carafe', + up_kernel=5, + up_group=1, + encoder_kernel=3, + encoder_dilation=1, + compressed_channels=64)), + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_scales=[8], + anchor_ratios=[0.5, 1.0, 2.0], + anchor_strides=[4, 8, 16, 32, 64], + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0], + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)), + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='SharedFCBBoxHead', + num_fcs=2, + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=81, + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2], + reg_class_agnostic=False, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))) +# model training and testing settings +train_cfg = dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=0, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_across_levels=False, + nms_pre=2000, + nms_post=2000, + max_num=2000, + nms_thr=0.7, + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)) +test_cfg = dict( + rpn=dict( + nms_across_levels=False, + nms_pre=1000, + nms_post=1000, + max_num=1000, + nms_thr=0.7, + min_bbox_size=0), + rcnn=dict( + score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100) + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_thr=0.5, min_score=0.05) +) +# dataset settings +dataset_type = 'CocoDataset' +data_root = 'data/coco/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=64), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(1333, 800), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=64), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + imgs_per_gpu=2, + workers_per_gpu=2, + train=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_train2017.json', + img_prefix=data_root + 'train2017/', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline)) +evaluation = dict(interval=1, metric='bbox') +# optimizer +optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) +optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=1.0 / 3, + step=[8, 11]) +checkpoint_config = dict(interval=1) +# yapf:disable +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) +# yapf:enable +evaluation = dict(interval=1) +# runtime settings +total_epochs = 12 +dist_params = dict(backend='nccl') +log_level = 'INFO' +work_dir = './work_dirs/faster_rcnn_r50_fpn_carafe_1x' +load_from = None +resume_from = None +workflow = [('train', 1)] diff --git a/configs/carafe/mask_rcnn_r50_fpn_carafe_1x.py b/configs/carafe/mask_rcnn_r50_fpn_carafe_1x.py new file mode 100644 index 00000000..656bd7c6 --- /dev/null +++ b/configs/carafe/mask_rcnn_r50_fpn_carafe_1x.py @@ -0,0 +1,210 @@ +# model settings +model = dict( + type='MaskRCNN', + pretrained='torchvision://resnet50', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + style='pytorch'), + neck=dict( + type='FPN_CARAFE', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=5, + start_level=0, + end_level=-1, + norm_cfg=None, + activation=None, + order=('conv', 'norm', 'act'), + upsample_cfg=dict( + type='carafe', + up_kernel=5, + up_group=1, + encoder_kernel=3, + encoder_dilation=1, + compressed_channels=64)), + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_scales=[8], + anchor_ratios=[0.5, 1.0, 2.0], + anchor_strides=[4, 8, 16, 32, 64], + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0], + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)), + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='SharedFCBBoxHead', + num_fcs=2, + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=81, + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2], + reg_class_agnostic=False, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)), + mask_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + mask_head=dict( + type='FCNMaskHead', + num_convs=4, + in_channels=256, + conv_out_channels=256, + num_classes=81, + upsample_cfg=dict( + type='carafe', + scale_factor=2, + up_kernel=5, + up_group=1, + encoder_kernel=3, + encoder_dilation=1, + compressed_channels=64), + loss_mask=dict( + type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))) +# model training and testing settings +train_cfg = dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=0, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_across_levels=False, + nms_pre=2000, + nms_post=2000, + max_num=2000, + nms_thr=0.7, + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + mask_size=28, + pos_weight=-1, + debug=False)) +test_cfg = dict( + rpn=dict( + nms_across_levels=False, + nms_pre=1000, + nms_post=1000, + max_num=1000, + nms_thr=0.7, + min_bbox_size=0), + rcnn=dict( + score_thr=0.05, + nms=dict(type='nms', iou_thr=0.5), + max_per_img=100, + mask_thr_binary=0.5)) +# dataset settings +dataset_type = 'CocoDataset' +data_root = 'data/coco/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True, with_mask=True), + dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=64), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(1333, 800), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=64), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + imgs_per_gpu=2, + workers_per_gpu=2, + train=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_train2017.json', + img_prefix=data_root + 'train2017/', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline)) +evaluation = dict(interval=1, metric=['bbox', 'segm']) +# optimizer +optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) +optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=1.0 / 3, + step=[8, 11]) +checkpoint_config = dict(interval=1) +# yapf:disable +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) +# yapf:enable +evaluation = dict(interval=1) +# runtime settings +total_epochs = 12 +dist_params = dict(backend='nccl') +log_level = 'INFO' +work_dir = './work_dirs/mask_rcnn_r50_fpn_carafe_1x' +load_from = None +resume_from = None +workflow = [('train', 1)] diff --git a/docs/MODEL_ZOO.md b/docs/MODEL_ZOO.md index 954a4772..c15a00b6 100644 --- a/docs/MODEL_ZOO.md +++ b/docs/MODEL_ZOO.md @@ -230,6 +230,9 @@ Please refer to [Weight Standardization](https://github.com/open-mmlab/mmdetecti Please refer to [Deformable Convolutional Networks](https://github.com/open-mmlab/mmdetection/blob/master/configs/dcn) for details. +### CARAFE: Content-Aware ReAssembly of FEatures +Please refer to [CARAFE](https://github.com/open-mmlab/mmdetection/blob/master/configs/carafe) for details. + ### Instaboost Please refer to [Instaboost](https://github.com/open-mmlab/mmdetection/blob/master/configs/instaboost) for details. diff --git a/mmdet/models/mask_heads/fcn_mask_head.py b/mmdet/models/mask_heads/fcn_mask_head.py index 6d11cfff..15a9e330 100644 --- a/mmdet/models/mask_heads/fcn_mask_head.py +++ b/mmdet/models/mask_heads/fcn_mask_head.py @@ -6,9 +6,10 @@ import torch.nn as nn from torch.nn.modules.utils import _pair from mmdet.core import auto_fp16, force_fp32, mask_target +from mmdet.ops.carafe import CARAFEPack from ..builder import build_loss from ..registry import HEADS -from ..utils import ConvModule +from ..utils import ConvModule, build_upsample_layer @HEADS.register_module @@ -20,27 +21,30 @@ class FCNMaskHead(nn.Module): in_channels=256, conv_kernel_size=3, conv_out_channels=256, - upsample_method='deconv', - upsample_ratio=2, num_classes=81, class_agnostic=False, + upsample_cfg=dict(type='deconv', scale_factor=2), conv_cfg=None, norm_cfg=None, loss_mask=dict( type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)): super(FCNMaskHead, self).__init__() - if upsample_method not in [None, 'deconv', 'nearest', 'bilinear']: + self.upsample_cfg = upsample_cfg.copy() + if self.upsample_cfg['type'] not in [ + None, 'deconv', 'nearest', 'bilinear', 'carafe' + ]: raise ValueError( 'Invalid upsample method {}, accepted methods ' - 'are "deconv", "nearest", "bilinear"'.format(upsample_method)) + 'are "deconv", "nearest", "bilinear", "carafe"'.format( + self.upsample_cfg['type'])) self.num_convs = num_convs # WARN: roi_feat_size is reserved and not used self.roi_feat_size = _pair(roi_feat_size) self.in_channels = in_channels self.conv_kernel_size = conv_kernel_size self.conv_out_channels = conv_out_channels - self.upsample_method = upsample_method - self.upsample_ratio = upsample_ratio + self.upsample_method = self.upsample_cfg.get('type') + self.scale_factor = self.upsample_cfg.pop('scale_factor') self.num_classes = num_classes self.class_agnostic = class_agnostic self.conv_cfg = conv_cfg @@ -63,17 +67,27 @@ class FCNMaskHead(nn.Module): norm_cfg=norm_cfg)) upsample_in_channels = ( self.conv_out_channels if self.num_convs > 0 else in_channels) + upsample_cfg_ = self.upsample_cfg.copy() if self.upsample_method is None: self.upsample = None elif self.upsample_method == 'deconv': - self.upsample = nn.ConvTranspose2d( - upsample_in_channels, - self.conv_out_channels, - self.upsample_ratio, - stride=self.upsample_ratio) + upsample_cfg_.update( + in_channels=upsample_in_channels, + out_channels=self.conv_out_channels, + kernel_size=self.scale_factor, + stride=self.scale_factor) + elif self.upsample_method == 'carafe': + upsample_cfg_.update( + channels=upsample_in_channels, scale_factor=self.scale_factor) else: - self.upsample = nn.Upsample( - scale_factor=self.upsample_ratio, mode=self.upsample_method) + # suppress warnings + align_corners = (None + if self.upsample_method == 'nearest' else False) + upsample_cfg_.update( + scale_factor=self.scale_factor, + mode=self.upsample_method, + align_corners=align_corners) + self.upsample = build_upsample_layer(upsample_cfg_) out_channels = 1 if self.class_agnostic else self.num_classes logits_in_channel = ( @@ -87,9 +101,12 @@ class FCNMaskHead(nn.Module): for m in [self.upsample, self.conv_logits]: if m is None: continue - nn.init.kaiming_normal_( - m.weight, mode='fan_out', nonlinearity='relu') - nn.init.constant_(m.bias, 0) + elif isinstance(m, CARAFEPack): + m.init_weights() + else: + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + nn.init.constant_(m.bias, 0) @auto_fp16() def forward(self, x): diff --git a/mmdet/models/necks/__init__.py b/mmdet/models/necks/__init__.py index fa574044..7844ec78 100644 --- a/mmdet/models/necks/__init__.py +++ b/mmdet/models/necks/__init__.py @@ -1,6 +1,7 @@ from .bfp import BFP from .fpn import FPN +from .fpn_carafe import FPN_CARAFE from .hrfpn import HRFPN from .nas_fpn import NASFPN -__all__ = ['FPN', 'BFP', 'HRFPN', 'NASFPN'] +__all__ = ['FPN', 'BFP', 'HRFPN', 'NASFPN', 'FPN_CARAFE'] diff --git a/mmdet/models/necks/fpn_carafe.py b/mmdet/models/necks/fpn_carafe.py new file mode 100644 index 00000000..5d01b3c0 --- /dev/null +++ b/mmdet/models/necks/fpn_carafe.py @@ -0,0 +1,254 @@ +import torch.nn as nn +from mmcv.cnn import xavier_init + +from mmdet.ops.carafe import CARAFEPack +from ..registry import NECKS +from ..utils import ConvModule, build_upsample_layer + + +@NECKS.register_module +class FPN_CARAFE(nn.Module): + """FPN_CARAFE is a more flexible implementation of FPN. + It allows more choice for upsample methods during the top-down pathway. + + It can reproduce the preformance of ICCV 2019 paper + CARAFE: Content-Aware ReAssembly of FEatures + Please refer to https://arxiv.org/abs/1905.02188 for more details. + + Args: + in_channels (list[int]): Number of channels for each input feature map. + out_channels (int): Output channels of feature pyramids. + num_outs (int): Number of output stages. + start_level (int): Start level of feature pyramids. + (Default: 0) + end_level (int): End level of feature pyramids. + (Default: -1 indicates the last level). + norm_cfg (dict): Dictionary to construct and config norm layer. + activate (str): Type of activation function in ConvModule + (Default: None indicates w/o activation). + order (dict): Order of components in ConvModule. + upsample (str): Type of upsample layer. + upsample_cfg (dict): Dictionary to construct and config upsample layer. + """ + + def __init__(self, + in_channels, + out_channels, + num_outs, + start_level=0, + end_level=-1, + norm_cfg=None, + activation=None, + order=('conv', 'norm', 'act'), + upsample_cfg=dict( + type='carafe', + up_kernel=5, + up_group=1, + encoder_kernel=3, + encoder_dilation=1)): + super(FPN_CARAFE, self).__init__() + assert isinstance(in_channels, list) + self.in_channels = in_channels + self.out_channels = out_channels + self.num_ins = len(in_channels) + self.num_outs = num_outs + self.activation = activation + self.norm_cfg = norm_cfg + self.with_bias = norm_cfg is None + self.upsample_cfg = upsample_cfg.copy() + self.upsample = self.upsample_cfg.get('type') + self.relu = nn.ReLU(inplace=False) + + self.order = order + assert order in [('conv', 'norm', 'act'), ('act', 'conv', 'norm')] + + assert self.upsample in [ + 'nearest', 'bilinear', 'deconv', 'pixel_shuffle', 'carafe', None + ] + if self.upsample in ['deconv', 'pixel_shuffle']: + assert hasattr( + self.upsample_cfg, + 'upsample_kernel') and self.upsample_cfg.upsample_kernel > 0 + self.upsample_kernel = self.upsample_cfg.pop('upsample_kernel') + + if end_level == -1: + self.backbone_end_level = self.num_ins + assert num_outs >= self.num_ins - start_level + else: + # if end_level < inputs, no extra level is allowed + self.backbone_end_level = end_level + assert end_level <= len(in_channels) + assert num_outs == end_level - start_level + self.start_level = start_level + self.end_level = end_level + + self.lateral_convs = nn.ModuleList() + self.fpn_convs = nn.ModuleList() + self.upsample_modules = nn.ModuleList() + + for i in range(self.start_level, self.backbone_end_level): + l_conv = ConvModule( + in_channels[i], + out_channels, + 1, + norm_cfg=norm_cfg, + bias=self.with_bias, + activation=activation, + inplace=False, + order=self.order) + fpn_conv = ConvModule( + out_channels, + out_channels, + 3, + padding=1, + norm_cfg=self.norm_cfg, + bias=self.with_bias, + activation=activation, + inplace=False, + order=self.order) + if i != self.backbone_end_level - 1: + upsample_cfg_ = self.upsample_cfg.copy() + if self.upsample == 'deconv': + upsample_cfg_.update( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=self.upsample_kernel, + stride=2, + padding=(self.upsample_kernel - 1) // 2, + output_padding=(self.upsample_kernel - 1) // 2) + elif self.upsample == 'pixel_shuffle': + upsample_cfg_.update( + in_channels=out_channels, + out_channels=out_channels, + scale_factor=2, + upsample_kernel=self.upsample_kernel) + elif self.upsample == 'carafe': + upsample_cfg_.update(channels=out_channels, scale_factor=2) + else: + # suppress warnings + align_corners = (None + if self.upsample == 'nearest' else False) + upsample_cfg_.update( + scale_factor=2, + mode=self.upsample, + align_corners=align_corners) + upsample_module = build_upsample_layer(upsample_cfg_) + self.upsample_modules.append(upsample_module) + self.lateral_convs.append(l_conv) + self.fpn_convs.append(fpn_conv) + + # add extra conv layers (e.g., RetinaNet) + extra_out_levels = ( + num_outs - self.backbone_end_level + self.start_level) + if extra_out_levels >= 1: + for i in range(extra_out_levels): + in_channels = ( + self.in_channels[self.backbone_end_level - + 1] if i == 0 else out_channels) + extra_l_conv = ConvModule( + in_channels, + out_channels, + 3, + stride=2, + padding=1, + norm_cfg=norm_cfg, + bias=self.with_bias, + activation=self.activation, + inplace=False, + order=self.order) + if self.upsample == 'deconv': + upsampler_cfg_ = dict( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=self.upsample_kernel, + stride=2, + padding=(self.upsample_kernel - 1) // 2, + output_padding=(self.upsample_kernel - 1) // 2) + elif self.upsample == 'pixel_shuffle': + upsampler_cfg_ = dict( + in_channels=out_channels, + out_channels=out_channels, + scale_factor=2, + upsample_kernel=self.upsample_kernel) + elif self.upsample == 'carafe': + upsampler_cfg_ = dict( + channels=out_channels, + scale_factor=2, + **self.upsample_cfg) + else: + # suppress warnings + align_corners = (None + if self.upsample == 'nearest' else False) + upsampler_cfg_ = dict( + scale_factor=2, + mode=self.upsample, + align_corners=align_corners) + upsampler_cfg_['type'] = self.upsample + upsample_module = build_upsample_layer(upsampler_cfg_) + extra_fpn_conv = ConvModule( + out_channels, + out_channels, + 3, + padding=1, + norm_cfg=self.norm_cfg, + bias=self.with_bias, + activation=activation, + inplace=False, + order=self.order) + self.upsample_modules.append(upsample_module) + self.fpn_convs.append(extra_fpn_conv) + self.lateral_convs.append(extra_l_conv) + + # default init_weights for conv(msra) and norm in ConvModule + def init_weights(self): + for m in self.modules(): + if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): + xavier_init(m, distribution='uniform') + for m in self.modules(): + if isinstance(m, CARAFEPack): + m.init_weights() + + def slice_as(self, src, dst): + # slice src as dst + # src should have the same or larger size than dst + assert (src.size(2) >= dst.size(2)) and (src.size(3) >= dst.size(3)) + if src.size(2) == dst.size(2) and src.size(3) == dst.size(3): + return src + else: + return src[:, :, :dst.size(2), :dst.size(3)] + + def tensor_add(self, a, b): + if a.size() == b.size(): + c = a + b + else: + c = a + self.slice_as(b, a) + return c + + def forward(self, inputs): + assert len(inputs) == len(self.in_channels) + + # build laterals + laterals = [] + for i, lateral_conv in enumerate(self.lateral_convs): + if i <= self.backbone_end_level - self.start_level: + input = inputs[min(i + self.start_level, len(inputs) - 1)] + else: + input = laterals[-1] + lateral = lateral_conv(input) + laterals.append(lateral) + + # build top-down path + for i in range(len(laterals) - 1, 0, -1): + if self.upsample is not None: + upsample_feat = self.upsample_modules[i - 1](laterals[i]) + else: + upsample_feat = laterals[i] + laterals[i - 1] = self.tensor_add(laterals[i - 1], upsample_feat) + + # build outputs + num_conv_outs = len(self.fpn_convs) + outs = [] + for i in range(num_conv_outs): + out = self.fpn_convs[i](laterals[i]) + outs.append(out) + return tuple(outs) diff --git a/mmdet/models/utils/__init__.py b/mmdet/models/utils/__init__.py index 3db40920..c1b9ad11 100644 --- a/mmdet/models/utils/__init__.py +++ b/mmdet/models/utils/__init__.py @@ -2,11 +2,12 @@ from .conv_module import ConvModule, build_conv_layer from .conv_ws import ConvWS2d, conv_ws_2d from .norm import build_norm_layer from .scale import Scale +from .upsample import build_upsample_layer from .weight_init import (bias_init_with_prob, kaiming_init, normal_init, uniform_init, xavier_init) __all__ = [ 'conv_ws_2d', 'ConvWS2d', 'build_conv_layer', 'ConvModule', - 'build_norm_layer', 'xavier_init', 'normal_init', 'uniform_init', - 'kaiming_init', 'bias_init_with_prob', 'Scale' + 'build_norm_layer', 'build_upsample_layer', 'xavier_init', 'normal_init', + 'uniform_init', 'kaiming_init', 'bias_init_with_prob', 'Scale' ] diff --git a/mmdet/models/utils/upsample.py b/mmdet/models/utils/upsample.py new file mode 100644 index 00000000..fd880fb0 --- /dev/null +++ b/mmdet/models/utils/upsample.py @@ -0,0 +1,78 @@ +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import xavier_init + +from mmdet.ops.carafe import CARAFEPack + + +class PixelShufflePack(nn.Module): + """ Pixel Shuffle upsample layer + + Args: + in_channels (int): Number of input channels + out_channels (int): Number of output channels + scale_factor (int): Upsample ratio + upsample_kernel (int): Kernel size of Conv layer to expand the channels + + Returns: + upsampled feature map + """ + + def __init__(self, in_channels, out_channels, scale_factor, + upsample_kernel): + super(PixelShufflePack, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.scale_factor = scale_factor + self.upsample_kernel = upsample_kernel + self.upsample_conv = nn.Conv2d( + self.in_channels, + self.out_channels * scale_factor * scale_factor, + self.upsample_kernel, + padding=(self.upsample_kernel - 1) // 2) + self.init_weights() + + def init_weights(self): + xavier_init(self.upsample_conv, distribution='uniform') + + def forward(self, x): + x = self.upsample_conv(x) + x = F.pixel_shuffle(x, self.scale_factor) + return x + + +upsample_cfg = { + # format: layer_type: (abbreviation, module) + 'nearest': nn.Upsample, + 'bilinear': nn.Upsample, + 'deconv': nn.ConvTranspose2d, + 'pixel_shuffle': PixelShufflePack, + 'carafe': CARAFEPack +} + + +def build_upsample_layer(cfg): + """ Build upsample layer + + Args: + cfg (dict): cfg should contain: + type (str): Identify upsample layer type. + upsample ratio (int): Upsample ratio + layer args: args needed to instantiate a upsample layer. + + Returns: + layer (nn.Module): Created upsample layer + """ + assert isinstance(cfg, dict) and 'type' in cfg + cfg_ = cfg.copy() + + layer_type = cfg_.pop('type') + if layer_type not in upsample_cfg: + raise KeyError('Unrecognized upsample type {}'.format(layer_type)) + else: + upsample = upsample_cfg[layer_type] + if upsample is None: + raise NotImplementedError + + layer = upsample(**cfg_) + return layer diff --git a/mmdet/ops/carafe/__init__.py b/mmdet/ops/carafe/__init__.py new file mode 100644 index 00000000..029038f8 --- /dev/null +++ b/mmdet/ops/carafe/__init__.py @@ -0,0 +1,3 @@ +from .carafe import CARAFE, CARAFENaive, CARAFEPack, carafe, carafe_naive + +__all__ = ['carafe', 'carafe_naive', 'CARAFE', 'CARAFENaive', 'CARAFEPack'] diff --git a/mmdet/ops/carafe/carafe.py b/mmdet/ops/carafe/carafe.py new file mode 100644 index 00000000..2c81735b --- /dev/null +++ b/mmdet/ops/carafe/carafe.py @@ -0,0 +1,237 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import normal_init, xavier_init +from torch.autograd import Function +from torch.nn.modules.module import Module + +from . import carafe_cuda, carafe_naive_cuda + + +class CARAFENaiveFunction(Function): + + @staticmethod + def forward(ctx, features, masks, kernel_size, group_size, scale_factor): + assert scale_factor >= 1 + assert masks.size(1) == kernel_size * kernel_size * group_size + assert masks.size(-1) == features.size(-1) * scale_factor + assert masks.size(-2) == features.size(-2) * scale_factor + assert features.size(1) % group_size == 0 + assert (kernel_size - 1) % 2 == 0 and kernel_size >= 1 + ctx.kernel_size = kernel_size + ctx.group_size = group_size + ctx.scale_factor = scale_factor + ctx.feature_size = features.size() + ctx.mask_size = masks.size() + + n, c, h, w = features.size() + output = features.new_zeros((n, c, h * scale_factor, w * scale_factor)) + if features.is_cuda: + carafe_naive_cuda.forward(features, masks, kernel_size, group_size, + scale_factor, output) + else: + raise NotImplementedError + + if features.requires_grad or masks.requires_grad: + ctx.save_for_backward(features, masks) + return output + + @staticmethod + def backward(ctx, grad_output): + assert grad_output.is_cuda + + features, masks = ctx.saved_tensors + kernel_size = ctx.kernel_size + group_size = ctx.group_size + scale_factor = ctx.scale_factor + + grad_input = torch.zeros_like(features) + grad_masks = torch.zeros_like(masks) + carafe_naive_cuda.backward(grad_output.contiguous(), features, masks, + kernel_size, group_size, scale_factor, + grad_input, grad_masks) + + return grad_input, grad_masks, None, None, None + + +carafe_naive = CARAFENaiveFunction.apply + + +class CARAFENaive(Module): + + def __init__(self, kernel_size, group_size, scale_factor): + super(CARAFENaive, self).__init__() + + assert isinstance(kernel_size, int) and isinstance( + group_size, int) and isinstance(scale_factor, int) + self.kernel_size = kernel_size + self.group_size = group_size + self.scale_factor = scale_factor + + def forward(self, features, masks): + return CARAFENaiveFunction.apply(features, masks, self.kernel_size, + self.group_size, self.scale_factor) + + +class CARAFEFunction(Function): + + @staticmethod + def forward(ctx, features, masks, kernel_size, group_size, scale_factor): + assert scale_factor >= 1 + assert masks.size(1) == kernel_size * kernel_size * group_size + assert masks.size(-1) == features.size(-1) * scale_factor + assert masks.size(-2) == features.size(-2) * scale_factor + assert features.size(1) % group_size == 0 + assert (kernel_size - 1) % 2 == 0 and kernel_size >= 1 + ctx.kernel_size = kernel_size + ctx.group_size = group_size + ctx.scale_factor = scale_factor + ctx.feature_size = features.size() + ctx.mask_size = masks.size() + + n, c, h, w = features.size() + output = features.new_zeros((n, c, h * scale_factor, w * scale_factor)) + routput = features.new_zeros(output.size(), requires_grad=False) + rfeatures = features.new_zeros(features.size(), requires_grad=False) + rmasks = masks.new_zeros(masks.size(), requires_grad=False) + if features.is_cuda: + carafe_cuda.forward(features, rfeatures, masks, rmasks, + kernel_size, group_size, scale_factor, routput, + output) + else: + raise NotImplementedError + + if features.requires_grad or masks.requires_grad: + ctx.save_for_backward(features, masks, rfeatures) + return output + + @staticmethod + def backward(ctx, grad_output): + assert grad_output.is_cuda + + features, masks, rfeatures = ctx.saved_tensors + kernel_size = ctx.kernel_size + group_size = ctx.group_size + scale_factor = ctx.scale_factor + + rgrad_output = torch.zeros_like(grad_output, requires_grad=False) + rgrad_input_hs = torch.zeros_like(grad_output, requires_grad=False) + rgrad_input = torch.zeros_like(features, requires_grad=False) + rgrad_masks = torch.zeros_like(masks, requires_grad=False) + grad_input = torch.zeros_like(features, requires_grad=False) + grad_masks = torch.zeros_like(masks, requires_grad=False) + carafe_cuda.backward(grad_output.contiguous(), rfeatures, masks, + kernel_size, group_size, scale_factor, + rgrad_output, rgrad_input_hs, rgrad_input, + rgrad_masks, grad_input, grad_masks) + return grad_input, grad_masks, None, None, None, None + + +carafe = CARAFEFunction.apply + + +class CARAFE(Module): + """ CARAFE: Content-Aware ReAssembly of FEatures + + Please refer to https://arxiv.org/abs/1905.02188 for more details. + + Args: + kernel_size (int): reassemble kernel size + group_size (int): reassemble group size + scale_factor (int): upsample ratio + + Returns: + upsampled feature map + """ + + def __init__(self, kernel_size, group_size, scale_factor): + super(CARAFE, self).__init__() + + assert isinstance(kernel_size, int) and isinstance( + group_size, int) and isinstance(scale_factor, int) + self.kernel_size = kernel_size + self.group_size = group_size + self.scale_factor = scale_factor + + def forward(self, features, masks): + return CARAFEFunction.apply(features, masks, self.kernel_size, + self.group_size, self.scale_factor) + + +class CARAFEPack(nn.Module): + """ A unified package of CARAFE upsampler that contains: + 1) channel compressor 2) content encoder 3) CARAFE op + + Official implementation of ICCV 2019 paper + CARAFE: Content-Aware ReAssembly of FEatures + Please refer to https://arxiv.org/abs/1905.02188 for more details. + + Args: + channels (int): input feature channels + scale_factor (int): upsample ratio + up_kernel (int): kernel size of CARAFE op + up_group (int): group size of CARAFE op + encoder_kernel (int): kernel size of content encoder + encoder_dilation (int): dilation of content encoder + compressed_channels (int): output channels of channels compressor + + Returns: + upsampled feature map + """ + + def __init__(self, + channels, + scale_factor, + up_kernel=5, + up_group=1, + encoder_kernel=3, + encoder_dilation=1, + compressed_channels=64): + super(CARAFEPack, self).__init__() + self.channels = channels + self.scale_factor = scale_factor + self.up_kernel = up_kernel + self.up_group = up_group + self.encoder_kernel = encoder_kernel + self.encoder_dilation = encoder_dilation + self.compressed_channels = compressed_channels + self.channel_compressor = nn.Conv2d(channels, self.compressed_channels, + 1) + self.content_encoder = nn.Conv2d( + self.compressed_channels, + self.up_kernel * self.up_kernel * self.up_group * + self.scale_factor * self.scale_factor, + self.encoder_kernel, + padding=int((self.encoder_kernel - 1) * self.encoder_dilation / 2), + dilation=self.encoder_dilation, + groups=1) + self.init_weights() + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + xavier_init(m, distribution='uniform') + normal_init(self.content_encoder, std=0.001) + + def kernel_normalizer(self, mask): + mask = F.pixel_shuffle(mask, self.scale_factor) + n, mask_c, h, w = mask.size() + mask_channel = int(mask_c / (self.up_kernel * self.up_kernel)) + mask = mask.view(n, mask_channel, -1, h, w) + + mask = F.softmax(mask, dim=2) + mask = mask.view(n, mask_c, h, w).contiguous() + + return mask + + def feature_reassemble(self, x, mask): + x = carafe(x, mask, self.up_kernel, self.up_group, self.scale_factor) + return x + + def forward(self, x): + compressed_x = self.channel_compressor(x) + mask = self.content_encoder(compressed_x) + mask = self.kernel_normalizer(mask) + + x = self.feature_reassemble(x, mask) + return x diff --git a/mmdet/ops/carafe/grad_check.py b/mmdet/ops/carafe/grad_check.py new file mode 100644 index 00000000..06820be2 --- /dev/null +++ b/mmdet/ops/carafe/grad_check.py @@ -0,0 +1,61 @@ +import os.path as osp +import sys + +import mmcv +import torch +from torch.autograd import gradcheck + +sys.path.append(osp.abspath(osp.join(__file__, '../../'))) +from mmdet.ops.carafe import CARAFENAIVE # noqa: E402, isort:skip +from mmdet.ops.carafe import carafe_naive # noqa: E402, isort:skip +from mmdet.ops.carafe import carafe, CARAFE # noqa: E402, isort:skip + +feat = torch.randn(2, 64, 3, 3, requires_grad=True, device='cuda:0').double() +mask = torch.randn( + 2, 100, 6, 6, requires_grad=True, device='cuda:0').sigmoid().double() + +print('Gradcheck for carafe...') +test = gradcheck(CARAFE(5, 4, 2), (feat, mask), atol=1e-4, eps=1e-4) +print(test) + +print('Gradcheck for carafe naive...') +test = gradcheck(CARAFENAIVE(5, 4, 2), (feat, mask), atol=1e-4, eps=1e-4) +print(test) + +feat = torch.randn( + 2, 1024, 100, 100, requires_grad=True, device='cuda:0').float() +mask = torch.randn( + 2, 25, 200, 200, requires_grad=True, device='cuda:0').sigmoid().float() +loop_num = 500 + +time_forward = 0 +time_backward = 0 +bar = mmcv.ProgressBar(loop_num) +timer = mmcv.Timer() +for i in range(loop_num): + x = carafe(feat.clone(), mask.clone(), 5, 1, 2) + torch.cuda.synchronize() + time_forward += timer.since_last_check() + x.sum().backward(retain_graph=True) + torch.cuda.synchronize() + time_backward += timer.since_last_check() + bar.update() +print('\nCARAFE time forward: {} ms/iter | time backward: {} ms/iter'.format( + (time_forward + 1e-3) * 1e3 / loop_num, + (time_backward + 1e-3) * 1e3 / loop_num)) + +time_naive_forward = 0 +time_naive_backward = 0 +bar = mmcv.ProgressBar(loop_num) +timer = mmcv.Timer() +for i in range(loop_num): + x = carafe_naive(feat.clone(), mask.clone(), 5, 1, 2) + torch.cuda.synchronize() + time_naive_forward += timer.since_last_check() + x.sum().backward(retain_graph=True) + torch.cuda.synchronize() + time_naive_backward += timer.since_last_check() + bar.update() +print('\nCARAFE naive time forward: {} ms/iter | time backward: {} ms/iter'. + format((time_naive_forward + 1e-3) * 1e3 / loop_num, + (time_naive_backward + 1e-3) * 1e3 / loop_num)) diff --git a/mmdet/ops/carafe/setup.py b/mmdet/ops/carafe/setup.py new file mode 100644 index 00000000..7ef1c66d --- /dev/null +++ b/mmdet/ops/carafe/setup.py @@ -0,0 +1,29 @@ +from setuptools import setup + +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +NVCC_ARGS = [ + '-D__CUDA_NO_HALF_OPERATORS__', + '-D__CUDA_NO_HALF_CONVERSIONS__', + '-D__CUDA_NO_HALF2_OPERATORS__', +] + +setup( + name='carafe', + ext_modules=[ + CUDAExtension( + 'carafe_cuda', + ['src/carafe_cuda.cpp', 'src/carafe_cuda_kernel.cu'], + extra_compile_args={ + 'cxx': [], + 'nvcc': NVCC_ARGS + }), + CUDAExtension( + 'carafe_naive_cuda', + ['src/carafe_naive_cuda.cpp', 'src/carafe_naive_cuda_kernel.cu'], + extra_compile_args={ + 'cxx': [], + 'nvcc': NVCC_ARGS + }) + ], + cmdclass={'build_ext': BuildExtension}) diff --git a/mmdet/ops/carafe/src/carafe_cuda.cpp b/mmdet/ops/carafe/src/carafe_cuda.cpp new file mode 100644 index 00000000..9a7c73af --- /dev/null +++ b/mmdet/ops/carafe/src/carafe_cuda.cpp @@ -0,0 +1,113 @@ +#include <ATen/ATen.h> +#include <torch/extension.h> + +#include <cmath> +#include <vector> + +int CARAFEForwardLaucher(const at::Tensor features, const at::Tensor masks, + const int kernel_size, const int group_size, + const int scale_factor, const int batch_size, + const int channels, const int input_height, + const int input_width, const int output_height, + const int output_width, const int mask_channels, + at::Tensor rfeatures, at::Tensor routput, + at::Tensor rmasks, at::Tensor output); + +int CARAFEBackwardLaucher(const at::Tensor top_grad, const at::Tensor rfeatures, + const at::Tensor masks, const int kernel_size, + const int group_size, const int scale_factor, + const int batch_size, const int channels, + const int input_height, const int input_width, + const int output_height, const int output_width, + const int mask_channels, at::Tensor rtop_grad, + at::Tensor rbottom_grad_hs, at::Tensor rbottom_grad, + at::Tensor rmask_grad, at::Tensor bottom_grad, + at::Tensor mask_grad); + +#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ") +#define CHECK_CONTIGUOUS(x) \ + AT_CHECK(x.is_contiguous(), #x, " must be contiguous ") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +int carafe_forward_cuda(at::Tensor features, at::Tensor rfeatures, + at::Tensor masks, at::Tensor rmasks, int kernel_size, + int group_size, int scale_factor, at::Tensor routput, + at::Tensor output) { + CHECK_INPUT(features); + CHECK_INPUT(rfeatures); + CHECK_INPUT(masks); + CHECK_INPUT(rmasks); + CHECK_INPUT(output); + CHECK_INPUT(routput); + at::DeviceGuard guard(features.device()); + + const int batch_size = output.size(0); + const int num_channels = output.size(1); + const int output_height = output.size(2); + const int output_width = output.size(3); + + const int input_height = features.size(2); + const int input_width = features.size(3); + + const int mask_channels = masks.size(1); + + rfeatures.resize_({batch_size, input_height, input_width, num_channels}); + routput.resize_({batch_size, output_height, output_width, num_channels}); + rmasks.resize_({batch_size, output_height, output_width, mask_channels}); + + CARAFEForwardLaucher(features, masks, kernel_size, group_size, scale_factor, + batch_size, num_channels, input_height, input_width, + output_height, output_width, mask_channels, rfeatures, + routput, rmasks, output); + + return 1; +} + +int carafe_backward_cuda(at::Tensor top_grad, at::Tensor rfeatures, + at::Tensor masks, int kernel_size, int group_size, + int scale_factor, at::Tensor rtop_grad, + at::Tensor rbottom_grad_hs, at::Tensor rbottom_grad, + at::Tensor rmask_grad, at::Tensor bottom_grad, + at::Tensor mask_grad) { + CHECK_INPUT(top_grad); + CHECK_INPUT(rfeatures); + CHECK_INPUT(masks); + CHECK_INPUT(rtop_grad); + CHECK_INPUT(rbottom_grad_hs); + CHECK_INPUT(rbottom_grad); + CHECK_INPUT(rmask_grad); + CHECK_INPUT(bottom_grad); + CHECK_INPUT(mask_grad); + at::DeviceGuard guard(top_grad.device()); + + const int batch_size = top_grad.size(0); + const int num_channels = top_grad.size(1); + const int output_height = top_grad.size(2); + const int output_width = top_grad.size(3); + + const int input_height = bottom_grad.size(2); + const int input_width = bottom_grad.size(3); + + const int mask_channels = masks.size(1); + + rtop_grad.resize_({batch_size, output_height, output_width, num_channels}); + rbottom_grad.resize_({batch_size, input_height, input_width, num_channels}); + rbottom_grad_hs.resize_( + {batch_size, output_height, output_width, num_channels}); + rmask_grad.resize_({batch_size, output_height, output_width, mask_channels}); + + CARAFEBackwardLaucher(top_grad, rfeatures, masks, kernel_size, group_size, + scale_factor, batch_size, num_channels, input_height, + input_width, output_height, output_width, mask_channels, + rtop_grad, rbottom_grad_hs, rbottom_grad, rmask_grad, + bottom_grad, mask_grad); + + return 1; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &carafe_forward_cuda, "carafe forward (CUDA)"); + m.def("backward", &carafe_backward_cuda, "carafe backward (CUDA)"); +} diff --git a/mmdet/ops/carafe/src/carafe_cuda_kernel.cu b/mmdet/ops/carafe/src/carafe_cuda_kernel.cu new file mode 100644 index 00000000..da627550 --- /dev/null +++ b/mmdet/ops/carafe/src/carafe_cuda_kernel.cu @@ -0,0 +1,475 @@ +#include <ATen/ATen.h> +#include <ATen/TensorUtils.h> +#include <ATen/Utils.h> +#include <ATen/cuda/CUDAContext.h> +#include <ATen/cuda/CUDAApplyUtils.cuh> +#include <THC/THCAtomics.cuh> +#include <cmath> + +using namespace at; + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ + i += blockDim.x * gridDim.x) + +#define THREADS_PER_BLOCK 1024 // 32 * 32 +#define WARP_SIZE 32 +#define THREADS_PER_PIXEL 32 +#define MAX_SHARED_MEMORY 49152 +#define MAX_SHARED_SCALAR_T 6144 // 49152 / 8 = 6144 +#define MAXIMIZE_KERNEL_SIZE true +#define kTileDim 32 +#define kBlockRows 8 +#define FULL_MASK 0xffffffff + +inline int divideUP(const int x, const int y) { return (((x) + (y)-1) / (y)); } + +__device__ inline int Loc2Index(const int n, const int c, const int h, + const int w, const int channel_num, + const int height, const int width) { + int index = w + (h + (c + n * channel_num) * height) * width; + return index; +} +/* TODO: move this to a common place */ +template <typename scalar_t> +__device__ inline scalar_t min(scalar_t a, scalar_t b) { + return a < b ? a : b; +} + +template <typename scalar_t> +__device__ inline scalar_t max(scalar_t a, scalar_t b) { + return a > b ? a : b; +} + +template <typename scalar_t> +__device__ __forceinline__ scalar_t warpReduceSum(scalar_t val) { + for (int offset = 16; offset > 0; offset /= 2) + val += __shfl_down_sync(FULL_MASK, val, offset); + return val; +} + +// Splits the original matrix into submatrices with size 32 * 32. +// Each block transposes one submatrix by loading it into shared memory. +// Reference https://devblogs.nvidia.com/efficient-matrix-transpose-cuda-cc/ +template <typename scalar_t> +__global__ void BatchTranspose2DCUDAKernel(const int N, const int H, + const int W, const int dh, + const int dw, + const scalar_t *__restrict__ X, + scalar_t *__restrict__ Y) { + __shared__ scalar_t tile[kTileDim][kTileDim + 1]; + const int n = blockIdx.x / (dh * dw); + const int k = blockIdx.x % (dh * dw); + const int r = k / dw; + const int c = k % dw; + const int offset = n * H * W; + int x = c * kTileDim + threadIdx.x; + int y = r * kTileDim + threadIdx.y; + if (x < W) { + for (int i = 0; threadIdx.y + i < kTileDim && y + i < H; i += kBlockRows) { + tile[threadIdx.y + i][threadIdx.x] = X[offset + (y + i) * W + x]; + } + } + __syncthreads(); + x = r * kTileDim + threadIdx.x; + y = c * kTileDim + threadIdx.y; + if (x < H) { + for (int i = 0; threadIdx.y + i < kTileDim && y + i < W; i += kBlockRows) { + Y[offset + (y + i) * H + x] = tile[threadIdx.x][threadIdx.y + i]; + } + } +} +template <typename scalar_t> +__global__ void CARAFEForward( + const int num_kernels, const scalar_t *__restrict__ bottom_data, + const scalar_t *__restrict__ bottom_masks, const int kernel_size, + const int group_size, const int scale_factor, const int channels, + const int down_height, const int down_width, const int height, + const int width, const int mask_channels, scalar_t *__restrict__ top_data) { +#if MAXIMIZE_KERNEL_SIZE + __shared__ float shared_mask[MAX_SHARED_SCALAR_T * 2]; +#else + __shared__ scalar_t shared_mask[MAX_SHARED_SCALAR_T]; +#endif + + int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index > num_kernels - 1) { + return; + } + const int pixel_id = threadIdx.x / THREADS_PER_PIXEL; + const int split_id = threadIdx.x % THREADS_PER_PIXEL; + index = index / THREADS_PER_PIXEL; + const int pw = index % width; + const int ph = (index / width) % height; + const int n = index / width / height; + + const int down_pw = pw / scale_factor; + const int down_ph = ph / scale_factor; + + const int start_w = down_pw - (kernel_size - 1) / 2; + const int end_w = down_pw + (kernel_size - 1) / 2 + 1; + const int start_h = down_ph - (kernel_size - 1) / 2; + const int end_h = down_ph + (kernel_size - 1) / 2 + 1; + for (int c = split_id; c < mask_channels; c += THREADS_PER_PIXEL) { + int mask_index = Loc2Index(n, ph, pw, c, height, width, mask_channels); + shared_mask[c * WARP_SIZE + pixel_id] = bottom_masks[mask_index]; + } + __syncthreads(); + + const int channels_per_group = ceilf(channels / (float)group_size); +#pragma unroll + for (int c = split_id; c < channels; c += THREADS_PER_PIXEL) { + int mask_group = c / channels_per_group; + scalar_t output_val = 0; +#pragma unroll + for (int iy = start_h; iy < end_h; iy++) { +#pragma unroll + for (int ix = start_w; ix < end_w; ix++) { + if (iy < 0 || iy > down_height - 1 || ix < 0 || ix > down_width - 1) { + continue; + } + int mask_iy = iy - down_ph + (kernel_size - 1) / 2; + int mask_ix = ix - down_pw + (kernel_size - 1) / 2; + int mask_c = + (mask_group * kernel_size + mask_iy) * kernel_size + mask_ix; + int feat_index = + Loc2Index(n, iy, ix, c, down_height, down_width, channels); + + output_val += bottom_data[feat_index] * + shared_mask[mask_c * WARP_SIZE + pixel_id]; + } + } + + int top_index = Loc2Index(n, ph, pw, c, height, width, channels); + top_data[top_index] = output_val; + } +} + +int CARAFEForwardLaucher(const at::Tensor features, const at::Tensor masks, + const int kernel_size, const int group_size, + const int scale_factor, const int batch_size, + const int channels, const int input_height, + const int input_width, const int output_height, + const int output_width, const int mask_channels, + at::Tensor rfeatures, at::Tensor routput, + at::Tensor rmasks, at::Tensor output) { + // one warp per pixel + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + features.type(), "NCHW2NHWC_Feature", ([&] { + const scalar_t *bottom_data = features.data<scalar_t>(); + scalar_t *top_data = rfeatures.data<scalar_t>(); + const int dh = divideUP(channels, kTileDim); + const int dw = divideUP(input_height * input_width, kTileDim); + BatchTranspose2DCUDAKernel<scalar_t> + <<<batch_size * dh * dw, dim3(kTileDim, kBlockRows), 0, stream>>>( + batch_size, channels, input_height * input_width, dh, dw, + bottom_data, top_data); + })); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + features.type(), "NCHW2NHWC_Masks", ([&] { + const scalar_t *bottom_data = masks.data<scalar_t>(); + scalar_t *top_data = rmasks.data<scalar_t>(); + const int dh = divideUP(mask_channels, kTileDim); + const int dw = divideUP(output_height * output_width, kTileDim); + BatchTranspose2DCUDAKernel<scalar_t> + <<<batch_size * dh * dw, dim3(kTileDim, kBlockRows), 0, stream>>>( + batch_size, mask_channels, output_height * output_width, dh, dw, + bottom_data, top_data); + })); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + features.type(), "CARAFELaucherForward", ([&] { + const int num_kernels = + batch_size * output_height * output_width * THREADS_PER_PIXEL; + const scalar_t *bottom_data = rfeatures.data<scalar_t>(); + const scalar_t *bottom_masks = rmasks.data<scalar_t>(); + scalar_t *top_data = routput.data<scalar_t>(); + + CARAFEForward<scalar_t> + <<<at::cuda::ATenCeilDiv(num_kernels, THREADS_PER_BLOCK), + THREADS_PER_BLOCK, 0, stream>>>( + num_kernels, bottom_data, bottom_masks, kernel_size, group_size, + scale_factor, channels, input_height, input_width, + output_height, output_width, mask_channels, top_data); + })); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + features.type(), "NHWC2NCHW", ([&] { + const scalar_t *bottom_data = routput.data<scalar_t>(); + scalar_t *top_data = output.data<scalar_t>(); + const int dh = divideUP(output_height * output_width, kTileDim); + const int dw = divideUP(channels, kTileDim); + BatchTranspose2DCUDAKernel<scalar_t> + <<<batch_size * dh * dw, dim3(kTileDim, kBlockRows), 0, stream>>>( + batch_size, output_height * output_width, channels, dh, dw, + bottom_data, top_data); + })); + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } + + return 1; +} + +template <typename scalar_t> +__global__ void CARAFEBackward_Feature( + const int num_kernels, const scalar_t *__restrict__ top_diff, + const scalar_t *__restrict__ bottom_masks, const int kernel_size, + const int group_size, const int scale_factor, const int channels, + const int down_height, const int down_width, const int height, + const int width, const int mask_channels, + scalar_t *__restrict__ bottom_diff) { +#if MAXIMIZE_KERNEL_SIZE + __shared__ float shared_mask[MAX_SHARED_SCALAR_T * 2]; +#else + __shared__ scalar_t shared_mask[MAX_SHARED_SCALAR_T]; +#endif + + int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index > num_kernels - 1) { + return; + } + + const int pixel_id = threadIdx.x / THREADS_PER_PIXEL; + const int split_id = threadIdx.x % THREADS_PER_PIXEL; + // (n, c, ph, pw) is an element in the bottom_data + index = index / THREADS_PER_PIXEL; + const int pw = index % width; + const int ph = (index / width) % height; + const int n = index / width / height; + + const int start_w = pw - (kernel_size - 1) * scale_factor / 2; + const int end_w = pw + (kernel_size - 1) * scale_factor / 2 + 1; + const int start_h = ph - (kernel_size - 1) * scale_factor / 2; + const int end_h = ph + (kernel_size - 1) * scale_factor / 2 + 1; + for (int c = split_id; c < mask_channels; c += THREADS_PER_PIXEL) { + const int mask_w = (c % kernel_size) * scale_factor; + const int mask_h = (c / kernel_size % kernel_size) * scale_factor; + const int mask_x = start_w + mask_w; + const int mask_y = start_h + mask_h; + if (mask_y < 0 || mask_y > height - 1 || mask_x < 0 || mask_x > width - 1) { + shared_mask[c * WARP_SIZE + pixel_id] = 0; + continue; + } + const int mask_group = c / (kernel_size * kernel_size); + const int mask_c = (2 * mask_group + 1) * kernel_size * kernel_size - c - 1; + int mask_index = + Loc2Index(n, mask_c, mask_y, mask_x, mask_channels, height, width); + shared_mask[c * WARP_SIZE + pixel_id] = bottom_masks[mask_index]; + } + __syncthreads(); + const int channels_per_group = ceilf(channels / (float)group_size); +#pragma unroll + for (int c = split_id; c < channels; c += THREADS_PER_PIXEL) { + int mask_group = c / channels_per_group; + int top_index = Loc2Index(n, ph, pw, c, height, width, channels); + scalar_t output_val = 0; +#pragma unroll + for (int iy = start_h; iy < end_h; iy += scale_factor) { +#pragma unroll + for (int ix = start_w; ix < end_w; ix += scale_factor) { + if (iy < 0 || iy > height - 1 || ix < 0 || ix > width - 1) { + continue; + } + int mask_iy = + (iy - ph + (kernel_size - 1) * scale_factor / 2) / scale_factor; + int mask_ix = + (ix - pw + (kernel_size - 1) * scale_factor / 2) / scale_factor; + int mask_c = + (mask_group * kernel_size + mask_iy) * kernel_size + mask_ix; + int feat_index = Loc2Index(n, iy, ix, c, height, width, channels); + output_val += + shared_mask[mask_c * WARP_SIZE + pixel_id] * top_diff[feat_index]; + } + } + bottom_diff[top_index] = output_val; + } +} + +template <typename scalar_t> +__global__ void FeatureSum(const int num_kernels, + const scalar_t *__restrict__ input_data, + const int scale_factor, const int channels, + const int height, const int width, + scalar_t *__restrict__ output_data) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index > num_kernels - 1) { + return; + } + const int split_id = threadIdx.x % THREADS_PER_PIXEL; + index = index / THREADS_PER_PIXEL; + const int pw = index % width; + const int ph = (index / width) % height; + const int n = index / width / height; + for (int c = split_id; c < channels; c += THREADS_PER_PIXEL) { + scalar_t output_val = 0; + for (int iy = ph * scale_factor; iy < (ph + 1) * scale_factor; iy++) { + for (int ix = pw * scale_factor; ix < (pw + 1) * scale_factor; ix++) { + int input_id = Loc2Index(n, iy, ix, c, height * scale_factor, + width * scale_factor, channels); + output_val += input_data[input_id]; + } + } + const int output_id = Loc2Index(n, ph, pw, c, height, width, channels); + output_data[output_id] = output_val; + } +} + +template <typename scalar_t> +__global__ void CARAFEBackward_Mask(const int num_kernels, + const scalar_t *__restrict__ top_diff, + const scalar_t *__restrict__ bottom_data, + const int kernel_size, const int group_size, + const int scale_factor, const int channels, + const int down_height, const int down_width, + const int height, const int width, + const int mask_channels, + scalar_t *__restrict__ mask_diff) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index > num_kernels - 1) { + return; + } + + const int lane_id = index % WARP_SIZE; + index = index / WARP_SIZE; + const int mask_c = index % mask_channels; + // (n, c, ph, pw) is an element in the bottom_data + index = index / mask_channels; + const int pw = index % width; + const int ph = (index / width) % height; + const int n = index / width / height; + + const int down_pw = pw / scale_factor; + const int down_ph = ph / scale_factor; + + const int mask_group = mask_c / (kernel_size * kernel_size); + const int mask_loc = mask_c % (kernel_size * kernel_size); + + const int offset_x = mask_loc % kernel_size - (kernel_size - 1) / 2; + const int offset_y = + mask_loc / kernel_size % kernel_size - (kernel_size - 1) / 2; + + const int down_x = down_pw + offset_x; + const int down_y = down_ph + offset_y; + + scalar_t output_val = 0; + + if (down_y >= 0 && down_y <= down_height - 1 && down_x >= 0 && + down_x <= down_width - 1) { + const int channels_per_mask = ceilf(channels / (float)group_size); + const int start = channels_per_mask * mask_group; + const int end = min(channels_per_mask * (mask_group + 1), channels); + for (int c = start + lane_id; c < end; c += WARP_SIZE) { + int bottom_id = + Loc2Index(n, down_y, down_x, c, down_height, down_width, channels); + int top_id = Loc2Index(n, ph, pw, c, height, width, channels); + output_val += top_diff[top_id] * bottom_data[bottom_id]; + } + } + __syncwarp(); + output_val = warpReduceSum(output_val); + if (lane_id == 0) { + const int mask_id = + Loc2Index(n, ph, pw, mask_c, height, width, mask_channels); + mask_diff[mask_id] = output_val; + } +} + +int CARAFEBackwardLaucher(const at::Tensor top_grad, const at::Tensor rfeatures, + const at::Tensor masks, const int kernel_size, + const int group_size, const int scale_factor, + const int batch_size, const int channels, + const int input_height, const int input_width, + const int output_height, const int output_width, + const int mask_channels, at::Tensor rtop_grad, + at::Tensor rbottom_grad_hs, at::Tensor rbottom_grad, + at::Tensor rmask_grad, at::Tensor bottom_grad, + at::Tensor mask_grad) { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + top_grad.type(), "NCHW2NHWC_Top_Grad", ([&] { + const scalar_t *bottom_data = top_grad.data<scalar_t>(); + scalar_t *top_data = rtop_grad.data<scalar_t>(); + const int dh = divideUP(channels, kTileDim); + const int dw = divideUP(output_height * output_width, kTileDim); + BatchTranspose2DCUDAKernel<scalar_t> + <<<batch_size * dh * dw, dim3(kTileDim, kBlockRows), 0, stream>>>( + batch_size, channels, output_height * output_width, dh, dw, + bottom_data, top_data); + })); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + top_grad.type(), "CARAFELaucherBackward_Feature", ([&] { + const int num_kernels = + batch_size * output_height * output_width * THREADS_PER_PIXEL; + const scalar_t *top_diff = rtop_grad.data<scalar_t>(); + const scalar_t *bottom_masks = masks.data<scalar_t>(); + scalar_t *bottom_diff = rbottom_grad_hs.data<scalar_t>(); + + CARAFEBackward_Feature<scalar_t> + <<<at::cuda::ATenCeilDiv(num_kernels, THREADS_PER_BLOCK), + THREADS_PER_BLOCK, 0, stream>>>( + num_kernels, top_diff, bottom_masks, kernel_size, group_size, + scale_factor, channels, input_height, input_width, + output_height, output_width, mask_channels, bottom_diff); + })); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + top_grad.type(), "FeatureSum", ([&] { + const int num_kernels = + batch_size * input_height * input_width * THREADS_PER_PIXEL; + const scalar_t *bottom_diff_hs = rbottom_grad_hs.data<scalar_t>(); + scalar_t *bottom_diff = rbottom_grad.data<scalar_t>(); + + FeatureSum<scalar_t> + <<<at::cuda::ATenCeilDiv(num_kernels, THREADS_PER_BLOCK), + THREADS_PER_BLOCK, 0, stream>>>( + num_kernels, bottom_diff_hs, scale_factor, channels, + input_height, input_width, bottom_diff); + })); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + top_grad.type(), "NHWC2NCHW_Bottom_Grad", ([&] { + const scalar_t *bottom_data = rbottom_grad.data<scalar_t>(); + scalar_t *top_data = bottom_grad.data<scalar_t>(); + const int dh = divideUP(input_height * input_width, kTileDim); + const int dw = divideUP(channels, kTileDim); + BatchTranspose2DCUDAKernel<scalar_t> + <<<batch_size * dh * dw, dim3(kTileDim, kBlockRows), 0, stream>>>( + batch_size, input_height * input_width, channels, dh, dw, + bottom_data, top_data); + })); + + AT_DISPATCH_FLOATING_TYPES( + top_grad.type(), "CARAFELaucherBackward_Mask", ([&] { + const int num_kernels = batch_size * output_height * output_width * + mask_channels * WARP_SIZE; + const scalar_t *top_diff = rtop_grad.data<scalar_t>(); + const scalar_t *bottom_data = rfeatures.data<scalar_t>(); + scalar_t *mask_diff = rmask_grad.data<scalar_t>(); + + CARAFEBackward_Mask<scalar_t> + <<<at::cuda::ATenCeilDiv(num_kernels, THREADS_PER_BLOCK), + THREADS_PER_BLOCK, 0, stream>>>( + num_kernels, top_diff, bottom_data, kernel_size, group_size, + scale_factor, channels, input_height, input_width, + output_height, output_width, mask_channels, mask_diff); + })); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + top_grad.type(), "NHWC2NCHW_Mask_Grad", ([&] { + const scalar_t *bottom_data = rmask_grad.data<scalar_t>(); + scalar_t *top_data = mask_grad.data<scalar_t>(); + const int dh = divideUP(output_height * output_width, kTileDim); + const int dw = divideUP(mask_channels, kTileDim); + BatchTranspose2DCUDAKernel<scalar_t> + <<<batch_size * dh * dw, dim3(kTileDim, kBlockRows), 0, stream>>>( + batch_size, output_height * output_width, mask_channels, dh, dw, + bottom_data, top_data); + })); + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } + + return 1; +} diff --git a/mmdet/ops/carafe/src/carafe_naive_cuda.cpp b/mmdet/ops/carafe/src/carafe_naive_cuda.cpp new file mode 100644 index 00000000..fbcda80e --- /dev/null +++ b/mmdet/ops/carafe/src/carafe_naive_cuda.cpp @@ -0,0 +1,75 @@ +#include <ATen/ATen.h> +#include <torch/torch.h> + +#include <cmath> +#include <vector> + +int CARAFENAIVEForwardLaucher(const at::Tensor features, const at::Tensor masks, + const int kernel_size, const int group_size, + const int scale_factor, const int batch_size, + const int channels, const int height, + const int width, at::Tensor output); + +int CARAFENAIVEBackwardLaucher(const at::Tensor top_grad, + const at::Tensor features, + const at::Tensor masks, const int kernel_size, + const int group_size, const int scale_factor, + const int batch_size, const int channels, + const int height, const int width, + at::Tensor bottom_grad, at::Tensor mask_grad); + +#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ") +#define CHECK_CONTIGUOUS(x) \ + AT_CHECK(x.is_contiguous(), #x, " must be contiguous ") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +int carafe_naive_forward_cuda(at::Tensor features, at::Tensor masks, + int kernel_size, int group_size, int scale_factor, + at::Tensor output) { + CHECK_INPUT(features); + CHECK_INPUT(masks); + CHECK_INPUT(output); + at::DeviceGuard guard(features.device()); + + int batch_size = output.size(0); + int num_channels = output.size(1); + int data_height = output.size(2); + int data_width = output.size(3); + + CARAFENAIVEForwardLaucher(features, masks, kernel_size, group_size, + scale_factor, batch_size, num_channels, data_height, + data_width, output); + + return 1; +} + +int carafe_naive_backward_cuda(at::Tensor top_grad, at::Tensor features, + at::Tensor masks, int kernel_size, + int group_size, int scale_factor, + at::Tensor bottom_grad, at::Tensor mask_grad) { + CHECK_INPUT(top_grad); + CHECK_INPUT(features); + CHECK_INPUT(masks); + CHECK_INPUT(bottom_grad); + CHECK_INPUT(mask_grad); + at::DeviceGuard guard(top_grad.device()); + + int batch_size = top_grad.size(0); + int num_channels = top_grad.size(1); + int data_height = top_grad.size(2); + int data_width = top_grad.size(3); + + CARAFENAIVEBackwardLaucher(top_grad, features, masks, kernel_size, group_size, + scale_factor, batch_size, num_channels, + data_height, data_width, bottom_grad, mask_grad); + + return 1; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &carafe_naive_forward_cuda, "carafe_naive forward (CUDA)"); + m.def("backward", &carafe_naive_backward_cuda, + "carafe_naive backward (CUDA)"); +} diff --git a/mmdet/ops/carafe/src/carafe_naive_cuda_kernel.cu b/mmdet/ops/carafe/src/carafe_naive_cuda_kernel.cu new file mode 100644 index 00000000..3edbae79 --- /dev/null +++ b/mmdet/ops/carafe/src/carafe_naive_cuda_kernel.cu @@ -0,0 +1,176 @@ +#include <ATen/ATen.h> +#include <THC/THCAtomics.cuh> + +using namespace at; // temporal fix for pytorch<=0.4.1 (see #9848) + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ + i += blockDim.x * gridDim.x) + +#define THREADS_PER_BLOCK 1024 + +inline int GET_BLOCKS(const int N) { + int optimal_block_num = (N + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; + int max_block_num = 65536; + return min(optimal_block_num, max_block_num); +} + +__device__ inline int Loc2Index(const int n, const int c, const int h, + const int w, const int channel_num, + const int height, const int width) { + int index = w + (h + (c + n * channel_num) * height) * width; + return index; +} +template <typename scalar_t> +__global__ void CARAFENAIVEForward(const int nthreads, + const scalar_t *bottom_data, + const scalar_t *bottom_masks, + const int kernel_size, const int group_size, + const int scale_factor, const int channels, + const int height, const int width, + scalar_t *top_data) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the bottom_data + int pw = index % width; + int ph = (index / width) % height; + int c = (index / width / height) % channels; + int n = index / width / height / channels; + + int mask_channels = kernel_size * kernel_size * group_size; + int mask_group = c / (channels / group_size); + + int down_pw = pw / scale_factor; + int down_ph = ph / scale_factor; + int down_width = width / scale_factor; + int down_height = height / scale_factor; + int start_w = down_pw - (kernel_size - 1) / 2; + int end_w = down_pw + (kernel_size - 1) / 2 + 1; + int start_h = down_ph - (kernel_size - 1) / 2; + int end_h = down_ph + (kernel_size - 1) / 2 + 1; + + scalar_t output_val = 0; + for (int iy = start_h; iy < end_h; iy++) { + for (int ix = start_w; ix < end_w; ix++) { + if (iy < 0 || iy > down_height - 1 || ix < 0 || ix > down_width - 1) { + continue; + } + int mask_iy = iy - down_ph + (kernel_size - 1) / 2; + int mask_ix = ix - down_pw + (kernel_size - 1) / 2; + int mask_c = + (mask_group * kernel_size + mask_iy) * kernel_size + mask_ix; + int feat_index = + Loc2Index(n, c, iy, ix, channels, down_height, down_width); + int mask_index = + Loc2Index(n, mask_c, ph, pw, mask_channels, height, width); + output_val += bottom_data[feat_index] * bottom_masks[mask_index]; + } + } + top_data[index] = output_val; + } +} + +int CARAFENAIVEForwardLaucher(const at::Tensor features, const at::Tensor masks, + const int kernel_size, const int group_size, + const int scale_factor, const int batch_size, + const int channels, const int height, + const int width, at::Tensor output) { + const int output_size = batch_size * channels * height * width; + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + features.type(), "CARAFENAIVELaucherForward", ([&] { + const scalar_t *bottom_data = features.data<scalar_t>(); + const scalar_t *bottom_masks = masks.data<scalar_t>(); + scalar_t *top_data = output.data<scalar_t>(); + + CARAFENAIVEForward<scalar_t> + <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>( + output_size, bottom_data, bottom_masks, kernel_size, group_size, + scale_factor, channels, height, width, top_data); + })); + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } + + return 1; +} + +template <typename scalar_t> +__global__ void CARAFENAIVEBackward( + const int nthreads, const scalar_t *top_diff, const scalar_t *bottom_data, + const scalar_t *bottom_masks, const int kernel_size, const int group_size, + const int scale_factor, const int channels, const int height, + const int width, scalar_t *bottom_diff, scalar_t *mask_diff) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the bottom_data + int pw = index % width; + int ph = (index / width) % height; + int c = (index / width / height) % channels; + int n = index / width / height / channels; + + int mask_channels = kernel_size * kernel_size * group_size; + int mask_group = c / (channels / group_size); + + int down_pw = pw / scale_factor; + int down_ph = ph / scale_factor; + int down_width = width / scale_factor; + int down_height = height / scale_factor; + int start_w = down_pw - (kernel_size - 1) / 2; + int end_w = down_pw + (kernel_size - 1) / 2 + 1; + int start_h = down_ph - (kernel_size - 1) / 2; + int end_h = down_ph + (kernel_size - 1) / 2 + 1; + + for (int iy = start_h; iy < end_h; iy++) { + for (int ix = start_w; ix < end_w; ix++) { + if (iy < 0 || iy > down_height - 1 || ix < 0 || ix > down_width - 1) { + continue; + } + int mask_iy = iy - down_ph + (kernel_size - 1) / 2; + int mask_ix = ix - down_pw + (kernel_size - 1) / 2; + int mask_c = + (mask_group * kernel_size + mask_iy) * kernel_size + mask_ix; + int feat_index = + Loc2Index(n, c, iy, ix, channels, down_height, down_width); + int mask_index = + Loc2Index(n, mask_c, ph, pw, mask_channels, height, width); + atomicAdd(bottom_diff + feat_index, + bottom_masks[mask_index] * top_diff[index]); + atomicAdd(mask_diff + mask_index, + bottom_data[feat_index] * top_diff[index]); + } + } + } +} + +int CARAFENAIVEBackwardLaucher(const at::Tensor top_grad, + const at::Tensor features, + const at::Tensor masks, const int kernel_size, + const int group_size, const int scale_factor, + const int batch_size, const int channels, + const int height, const int width, + at::Tensor bottom_grad, at::Tensor mask_grad) { + const int output_size = batch_size * channels * height * width; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + top_grad.type(), "CARAFENAIVELaucherBackward", ([&] { + const scalar_t *top_diff = top_grad.data<scalar_t>(); + const scalar_t *bottom_data = features.data<scalar_t>(); + const scalar_t *bottom_masks = masks.data<scalar_t>(); + scalar_t *bottom_diff = bottom_grad.data<scalar_t>(); + scalar_t *mask_diff = mask_grad.data<scalar_t>(); + + CARAFENAIVEBackward<scalar_t> + <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>( + output_size, top_diff, bottom_data, bottom_masks, kernel_size, + group_size, scale_factor, channels, height, width, bottom_diff, + mask_diff); + })); + + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } + + return 1; +} diff --git a/setup.py b/setup.py index 8353637b..05f4f06f 100755 --- a/setup.py +++ b/setup.py @@ -270,6 +270,17 @@ if __name__ == '__main__': sources=[ 'src/masked_conv2d_cuda.cpp', 'src/masked_conv2d_kernel.cu' ]), + make_cuda_ext( + name='carafe_cuda', + module='mmdet.ops.carafe', + sources=['src/carafe_cuda.cpp', 'src/carafe_cuda_kernel.cu']), + make_cuda_ext( + name='carafe_naive_cuda', + module='mmdet.ops.carafe', + sources=[ + 'src/carafe_naive_cuda.cpp', + 'src/carafe_naive_cuda_kernel.cu' + ]) ], cmdclass={'build_ext': BuildExtension}, zip_safe=False) -- GitLab