Skip to content
Snippets Groups Projects
Unverified Commit 3b6ae96d authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #299 from yhcao6/dcn_cpp_extension

Add dcn group param support
parents e2227ddb 2b6104d3
No related branches found
No related tags found
No related merge requests found
......@@ -193,6 +193,7 @@ We released RPN, Faster R-CNN and Mask R-CNN models in the first version. More m
| R-50-FPN | Faster | pytorch | - | dpool | 1x | 4.6 | 0.714 | 8.7 | 37.9 | - | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/dcn/faster_rcnn_dpool_r50_fpn_1x_20190125-f4fc1d70.pth) |
| R-50-FPN | Faster | pytorch | - | mdpool | 1x | 5.2 | 0.769 | 8.2 | 38.1 | - | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/dcn/faster_rcnn_mdpool_r50_fpn_1x_20190125-473d0f3d.pth) |
| R-101-FPN | Faster | pytorch | dconv(c3-c5) | - | 1x | 5.8 | 0.811 | 8.0 | 42.1 | - | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/dcn/faster_rcnn_dconv_c3-c5_r101_fpn_1x_20190125-a7e31b65.pth) |
| X-101-32x4d-FPN | Faster | pytorch | dconv(c3-c5) | - | 1x | 7.1 | 1.126 | 6.6 | 43.5 | - | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/dcn/faster_rcnn_dconv_c3-c5_x101_32x4d_fpn_1x_20190201-6d46376f.pth) |
| R-50-FPN | Mask | pytorch | dconv(c3-c5) | - | 1x | 4.5 | 0.712 | 7.7 | 41.1 | 37.2 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/dcn/mask_rcnn_dconv_c3-c5_r50_fpn_1x_20190125-4f94ff79.pth) |
| R-50-FPN | Mask | pytorch | mdconv(c3-c5) | - | 1x | 4.5 | 0.712 | 7.7 | 41.4 | 37.4 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/dcn/mask_rcnn_mdconv_c3-c5_r50_fpn_1x_20190125-c5601dc3.pth) |
| R-101-FPN | Mask | pytorch | dconv(c3-c5) | - | 1x | 6.4 | 0.939 | 6.5 | 43.2 | 38.7 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/dcn/mask_rcnn_dconv_c3-c5_r101_fpn_1x_20190125-decb6db5.pth) |
......
# model settings
model = dict(
type='FasterRCNN',
pretrained='open-mmlab://resnext101_32x4d',
backbone=dict(
type='ResNeXt',
depth=101,
groups=32,
base_width=4,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
style='pytorch',
dcn=dict(
modulated=False,
groups=32,
deformable_groups=1,
fallback_on_stride=False),
stage_with_dcn=(False, True, True, True)),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_scales=[8],
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[4, 8, 16, 32, 64],
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0],
use_sigmoid_cls=True),
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='SharedFCBBoxHead',
num_fcs=2,
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=81,
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2],
reg_class_agnostic=False))
# model training and testing settings
train_cfg = dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=0,
pos_weight=-1,
smoothl1_beta=1 / 9.0,
debug=False),
rcnn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
pos_weight=-1,
debug=False))
test_cfg = dict(
rpn=dict(
nms_across_levels=False,
nms_pre=2000,
nms_post=2000,
max_num=2000,
nms_thr=0.7,
min_bbox_size=0),
rcnn=dict(
score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100)
# soft-nms is also supported for rcnn testing
# e.g., nms=dict(type='soft_nms', iou_thr=0.5, min_score=0.05)
)
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
data = dict(
imgs_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0.5,
with_mask=False,
with_crowd=True,
with_label=True),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0,
with_mask=False,
with_crowd=True,
with_label=True),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0,
with_mask=False,
with_label=False,
test_mode=True))
# optimizer
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[8, 11])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 12
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/faster_rcnn_dconv_c3-c5_x101_32x4d_fpn_1x'
load_from = None
resume_from = None
workflow = [('train', 1)]
......@@ -2,8 +2,9 @@ import math
import torch.nn as nn
from .resnet import ResNet
from mmdet.ops import DeformConv, ModulatedDeformConv
from .resnet import Bottleneck as _Bottleneck
from .resnet import ResNet
from ..registry import BACKBONES
from ..utils import build_norm_layer
......@@ -22,15 +23,12 @@ class Bottleneck(_Bottleneck):
else:
width = math.floor(self.planes * (base_width / 64)) * groups
self.norm1_name, norm1 = build_norm_layer(self.normalize,
width,
postfix=1)
self.norm2_name, norm2 = build_norm_layer(self.normalize,
width,
postfix=2)
self.norm3_name, norm3 = build_norm_layer(self.normalize,
self.planes * self.expansion,
postfix=3)
self.norm1_name, norm1 = build_norm_layer(
self.normalize, width, postfix=1)
self.norm2_name, norm2 = build_norm_layer(
self.normalize, width, postfix=2)
self.norm3_name, norm3 = build_norm_layer(
self.normalize, self.planes * self.expansion, postfix=3)
self.conv1 = nn.Conv2d(
self.inplanes,
......@@ -39,15 +37,47 @@ class Bottleneck(_Bottleneck):
stride=self.conv1_stride,
bias=False)
self.add_module(self.norm1_name, norm1)
self.conv2 = nn.Conv2d(
width,
width,
kernel_size=3,
stride=self.conv2_stride,
padding=self.dilation,
dilation=self.dilation,
groups=groups,
bias=False)
fallback_on_stride = False
self.with_modulated_dcn = False
if self.with_dcn:
fallback_on_stride = self.dcn.get('fallback_on_stride', False)
self.with_modulated_dcn = self.dcn.get('modulated', False)
if not self.with_dcn or fallback_on_stride:
self.conv2 = nn.Conv2d(
width,
width,
kernel_size=3,
stride=self.conv2_stride,
padding=self.dilation,
dilation=self.dilation,
groups=groups,
bias=False)
else:
groups = self.dcn.get('groups', 1)
deformable_groups = self.dcn.get('deformable_groups', 1)
if not self.with_modulated_dcn:
conv_op = DeformConv
offset_channels = 18
else:
conv_op = ModulatedDeformConv
offset_channels = 27
self.conv2_offset = nn.Conv2d(
width,
deformable_groups * offset_channels,
kernel_size=3,
stride=self.conv2_stride,
padding=self.dilation,
dilation=self.dilation)
self.conv2 = conv_op(
width,
width,
kernel_size=3,
stride=self.conv2_stride,
padding=self.dilation,
dilation=self.dilation,
groups=groups,
deformable_groups=deformable_groups,
bias=False)
self.add_module(self.norm2_name, norm2)
self.conv3 = nn.Conv2d(
width, self.planes * self.expansion, kernel_size=1, bias=False)
......@@ -64,7 +94,8 @@ def make_res_layer(block,
base_width=4,
style='pytorch',
with_cp=False,
normalize=dict(type='BN')):
normalize=dict(type='BN'),
dcn=None):
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
downsample = nn.Sequential(
......@@ -89,7 +120,8 @@ def make_res_layer(block,
base_width=base_width,
style=style,
with_cp=with_cp,
normalize=normalize))
normalize=normalize,
dcn=dcn))
inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(
......@@ -102,7 +134,8 @@ def make_res_layer(block,
base_width=base_width,
style=style,
with_cp=with_cp,
normalize=normalize))
normalize=normalize,
dcn=dcn))
return nn.Sequential(*layers)
......@@ -150,6 +183,7 @@ class ResNeXt(ResNet):
for i, num_blocks in enumerate(self.stage_blocks):
stride = self.strides[i]
dilation = self.dilations[i]
dcn = self.dcn if self.stage_with_dcn[i] else None
planes = 64 * 2**i
res_layer = make_res_layer(
self.block,
......@@ -162,7 +196,8 @@ class ResNeXt(ResNet):
base_width=self.base_width,
style=self.style,
with_cp=self.with_cp,
normalize=self.normalize)
normalize=self.normalize,
dcn=dcn)
self.inplanes = planes * self.block.expansion
layer_name = 'layer{}'.format(i + 1)
self.add_module(layer_name, res_layer)
......
......@@ -15,6 +15,7 @@ class DeformConvFunction(Function):
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
im2col_step=64):
if input is not None and input.dim() != 4:
......@@ -24,6 +25,7 @@ class DeformConvFunction(Function):
ctx.stride = _pair(stride)
ctx.padding = _pair(padding)
ctx.dilation = _pair(dilation)
ctx.groups = groups
ctx.deformable_groups = deformable_groups
ctx.im2col_step = im2col_step
......@@ -45,7 +47,8 @@ class DeformConvFunction(Function):
input, weight, offset, output, ctx.bufs_[0], ctx.bufs_[1],
weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0],
ctx.padding[1], ctx.padding[0], ctx.dilation[1],
ctx.dilation[0], ctx.deformable_groups, cur_im2col_step)
ctx.dilation[0], ctx.groups, ctx.deformable_groups,
cur_im2col_step)
return output
@staticmethod
......@@ -69,7 +72,8 @@ class DeformConvFunction(Function):
grad_offset, weight, ctx.bufs_[0], weight.size(3),
weight.size(2), ctx.stride[1], ctx.stride[0],
ctx.padding[1], ctx.padding[0], ctx.dilation[1],
ctx.dilation[0], ctx.deformable_groups, cur_im2col_step)
ctx.dilation[0], ctx.groups, ctx.deformable_groups,
cur_im2col_step)
if ctx.needs_input_grad[2]:
grad_weight = torch.zeros_like(weight)
......@@ -78,9 +82,11 @@ class DeformConvFunction(Function):
grad_weight, ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
weight.size(2), ctx.stride[1], ctx.stride[0],
ctx.padding[1], ctx.padding[0], ctx.dilation[1],
ctx.dilation[0], ctx.deformable_groups, 1, cur_im2col_step)
ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1,
cur_im2col_step)
return grad_input, grad_offset, grad_weight, None, None, None, None
return (grad_input, grad_offset, grad_weight, None, None, None, None,
None)
@staticmethod
def _output_size(input, weight, padding, dilation, stride):
......@@ -111,10 +117,12 @@ class ModulatedDeformConvFunction(Function):
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1):
ctx.stride = stride
ctx.padding = padding
ctx.dilation = dilation
ctx.groups = groups
ctx.deformable_groups = deformable_groups
ctx.with_bias = bias is not None
if not ctx.with_bias:
......@@ -131,7 +139,7 @@ class ModulatedDeformConvFunction(Function):
input, weight, bias, ctx._bufs[0], offset, mask, output,
ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
ctx.deformable_groups, ctx.with_bias)
ctx.groups, ctx.deformable_groups, ctx.with_bias)
return output
@staticmethod
......@@ -149,12 +157,12 @@ class ModulatedDeformConvFunction(Function):
grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
grad_output, weight.shape[2], weight.shape[3], ctx.stride,
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
ctx.deformable_groups, ctx.with_bias)
ctx.groups, ctx.deformable_groups, ctx.with_bias)
if not ctx.with_bias:
grad_bias = None
return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias,
None, None, None, None)
None, None, None, None, None)
@staticmethod
def _infer_shape(ctx, input, weight):
......
......@@ -16,20 +16,30 @@ class DeformConv(nn.Module):
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
bias=False):
assert not bias
super(DeformConv, self).__init__()
assert in_channels % groups == 0, \
'in_channels {} cannot be divisible by groups {}'.format(
in_channels, groups)
assert out_channels % groups == 0, \
'out_channels {} cannot be divisible by groups {}'.format(
out_channels, groups)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.groups = groups
self.deformable_groups = deformable_groups
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels, *self.kernel_size))
torch.Tensor(out_channels, in_channels // self.groups,
*self.kernel_size))
self.reset_parameters()
......@@ -42,7 +52,8 @@ class DeformConv(nn.Module):
def forward(self, input, offset):
return deform_conv(input, offset, self.weight, self.stride,
self.padding, self.dilation, self.deformable_groups)
self.padding, self.dilation, self.groups,
self.deformable_groups)
class ModulatedDeformConv(nn.Module):
......@@ -54,6 +65,7 @@ class ModulatedDeformConv(nn.Module):
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
bias=True):
super(ModulatedDeformConv, self).__init__()
......@@ -63,11 +75,13 @@ class ModulatedDeformConv(nn.Module):
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.deformable_groups = deformable_groups
self.with_bias = bias
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels, *self.kernel_size))
torch.Tensor(out_channels, in_channels // groups,
*self.kernel_size))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
......@@ -84,9 +98,9 @@ class ModulatedDeformConv(nn.Module):
self.bias.data.zero_()
def forward(self, input, offset, mask):
return modulated_deform_conv(input, offset, mask, self.weight,
self.bias, self.stride, self.padding,
self.dilation, self.deformable_groups)
return modulated_deform_conv(
input, offset, mask, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups, self.deformable_groups)
class ModulatedDeformConvPack(ModulatedDeformConv):
......@@ -98,14 +112,15 @@ class ModulatedDeformConvPack(ModulatedDeformConv):
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
bias=True):
super(ModulatedDeformConvPack,
self).__init__(in_channels, out_channels, kernel_size, stride,
padding, dilation, deformable_groups, bias)
super(ModulatedDeformConvPack, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
groups, deformable_groups, bias)
self.conv_offset_mask = nn.Conv2d(
self.in_channels,
self.in_channels // self.groups,
self.deformable_groups * 3 * self.kernel_size[0] *
self.kernel_size[1],
kernel_size=self.kernel_size,
......@@ -123,6 +138,6 @@ class ModulatedDeformConvPack(ModulatedDeformConv):
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
return modulated_deform_conv(input, offset, mask, self.weight,
self.bias, self.stride, self.padding,
self.dilation, self.deformable_groups)
return modulated_deform_conv(
input, offset, mask, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups, self.deformable_groups)
......@@ -62,7 +62,7 @@ void modulated_deformable_col2im_coord_cuda(const at::Tensor data_col, const at:
void shape_check(at::Tensor input, at::Tensor offset,
at::Tensor *gradOutput, at::Tensor weight, int kH, int kW,
int dH, int dW, int padH, int padW, int dilationH,
int dilationW, int deformable_group)
int dilationW, int group, int deformable_group)
{
AT_CHECK(weight.ndimension() == 4,
......@@ -105,7 +105,7 @@ void shape_check(at::Tensor input, at::Tensor offset,
AT_CHECK(ndim == 3 || ndim == 4,
"3D or 4D input tensor expected but got: %s", ndim);
long nInputPlane = weight.size(1);
long nInputPlane = weight.size(1) * group;
long inputHeight = input.size(dimh);
long inputWidth = input.size(dimw);
long nOutputPlane = weight.size(0);
......@@ -154,7 +154,7 @@ int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
at::Tensor offset, at::Tensor output,
at::Tensor columns, at::Tensor ones, int kW,
int kH, int dW, int dH, int padW, int padH,
int dilationW, int dilationH,
int dilationW, int dilationH, int group,
int deformable_group, int im2col_step)
{
......@@ -164,7 +164,7 @@ int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
// todo: possibly change data indexing because of parallel_imgs
shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW,
dilationH, dilationW, deformable_group);
dilationH, dilationW, group, deformable_group);
input = input.contiguous();
offset = offset.contiguous();
......@@ -207,6 +207,8 @@ int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
at::Tensor output_buffer = at::zeros({batchSize / im2col_step, nOutputPlane, im2col_step * outputHeight, outputWidth}, output.type());
output_buffer = output_buffer.view({output_buffer.size(0), group, output_buffer.size(1) / group, output_buffer.size(2), output_buffer.size(3)});
for (int elt = 0; elt < batchSize / im2col_step; elt++)
{
deformable_im2col(
......@@ -214,10 +216,17 @@ int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
inputWidth, kH, kW, padH, padW, dH, dW, dilationH, dilationW,
im2col_step, deformable_group, columns);
output_buffer[elt] =
output_buffer[elt].flatten(1).addmm_(weight.flatten(1), columns).view_as(output_buffer[elt]);
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
weight = weight.view({group, weight.size(0) / group, weight.size(1), weight.size(2), weight.size(3)});
for (int g = 0; g < group; g++){
output_buffer[elt][g] =
output_buffer[elt][g].flatten(1).addmm_(weight[g].flatten(1), columns[g]).view_as(output_buffer[elt][g]);
}
}
output_buffer = output_buffer.view({output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2), output_buffer.size(3), output_buffer.size(4)});
output_buffer = output_buffer.view(
{batchSize / im2col_step, nOutputPlane, im2col_step, outputHeight, outputWidth});
output_buffer.transpose_(1, 2);
......@@ -241,11 +250,11 @@ int deform_conv_backward_input_cuda(
at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
at::Tensor gradInput, at::Tensor gradOffset, at::Tensor weight,
at::Tensor columns, int kW, int kH, int dW, int dH, int padW, int padH,
int dilationW, int dilationH, int deformable_group, int im2col_step)
int dilationW, int dilationH, int group, int deformable_group, int im2col_step)
{
shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH,
padW, dilationH, dilationW, deformable_group);
padW, dilationH, dilationW, group, deformable_group);
input = input.contiguous();
offset = offset.contiguous();
......@@ -292,7 +301,17 @@ int deform_conv_backward_input_cuda(
for (int elt = 0; elt < batchSize / im2col_step; elt++)
{
columns = columns.addmm_(weight.flatten(1).transpose(0, 1), gradOutput[elt].flatten(1), 0.0f, 1.0f);
// divide into groups
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
weight = weight.view({group, weight.size(0) / group, weight.size(1), weight.size(2), weight.size(3)});
gradOutput = gradOutput.view({gradOutput.size(0), group, gradOutput.size(1) / group, gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});
for (int g = 0; g < group; g++){
columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), gradOutput[elt][g].flatten(1), 0.0f, 1.0f);
}
columns = columns.view({columns.size(0) * columns.size(1), columns.size(2)});
gradOutput = gradOutput.view({gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2), gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});
deformable_col2im_coord(
columns, input[elt], offset[elt],
......@@ -329,7 +348,7 @@ int deform_conv_backward_parameters_cuda(
at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
at::Tensor gradWeight, // at::Tensor gradBias,
at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
int padW, int padH, int dilationW, int dilationH, int deformable_group,
int padW, int padH, int dilationW, int dilationH, int group, int deformable_group,
float scale, int im2col_step)
{
......@@ -338,7 +357,7 @@ int deform_conv_backward_parameters_cuda(
// todo: add im2col_step as input
shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW,
padH, padW, dilationH, dilationW, deformable_group);
padH, padW, dilationH, dilationW, group, deformable_group);
input = input.contiguous();
offset = offset.contiguous();
......@@ -395,9 +414,19 @@ int deform_conv_backward_parameters_cuda(
inputWidth, kH, kW, padH, padW, dH, dW, dilationH, dilationW,
im2col_step, deformable_group, columns);
gradWeight = gradWeight.flatten(1).addmm_(
gradOutputBuffer[elt].flatten(1), columns.transpose(1, 0), 1.0, scale)
.view_as(gradWeight);
// divide into group
gradOutputBuffer = gradOutputBuffer.view({gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group, gradOutputBuffer.size(2), gradOutputBuffer.size(3)});
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
gradWeight = gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1), gradWeight.size(2), gradWeight.size(3)});
for (int g = 0; g < group; g++){
gradWeight[g] = gradWeight[g].flatten(1).addmm_(
gradOutputBuffer[elt][g].flatten(1), columns[g].transpose(1, 0), 1.0, scale)
.view_as(gradWeight[g]);
}
gradOutputBuffer = gradOutputBuffer.view({gradOutputBuffer.size(0), gradOutputBuffer.size(1) * gradOutputBuffer.size(2), gradOutputBuffer.size(3), gradOutputBuffer.size(4)});
columns = columns.view({columns.size(0) * columns.size(1), columns.size(2)});
gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1), gradWeight.size(2), gradWeight.size(3), gradWeight.size(4)});
}
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
......@@ -413,6 +442,7 @@ int deform_conv_backward_parameters_cuda(
return 1;
}
void modulated_deform_conv_cuda_forward(at::Tensor input, at::Tensor weight,
at::Tensor bias, at::Tensor ones,
at::Tensor offset, at::Tensor mask,
......@@ -420,7 +450,7 @@ void modulated_deform_conv_cuda_forward(at::Tensor input, at::Tensor weight,
int kernel_h, int kernel_w,
const int stride_h, const int stride_w,
const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w,
const int dilation_h, const int dilation_w, const int group,
const int deformable_group, const bool with_bias)
{
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
......@@ -439,9 +469,9 @@ void modulated_deform_conv_cuda_forward(at::Tensor input, at::Tensor weight,
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
kernel_h_, kernel_w, kernel_h_, kernel_w_);
if (channels != channels_kernel)
if (channels != channels_kernel * group)
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
channels, channels_kernel);
channels, channels_kernel * group);
const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
......@@ -458,6 +488,8 @@ void modulated_deform_conv_cuda_forward(at::Tensor input, at::Tensor weight,
// resize temporary columns
columns = at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.type());
output = output.view({output.size(0), group, output.size(1) / group, output.size(2), output.size(3)});
for (int b = 0; b < batch; b++)
{
modulated_deformable_im2col_cuda(input[b], offset[b], mask[b],
......@@ -466,9 +498,20 @@ void modulated_deform_conv_cuda_forward(at::Tensor input, at::Tensor weight,
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
deformable_group, columns);
output[b] = output[b].flatten(1).addmm_(weight.flatten(1), columns).view_as(output[b]);
// divide into group
weight = weight.view({group, weight.size(0) / group, weight.size(1), weight.size(2), weight.size(3)});
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
for (int g = 0; g < group; g++){
output[b][g] = output[b][g].flatten(1).addmm_(weight[g].flatten(1), columns[g]).view_as(output[b][g]);
}
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), weight.size(3), weight.size(4)});
columns = columns.view({columns.size(0) * columns.size(1), columns.size(2)});
}
output = output.view({output.size(0), output.size(1) * output.size(2), output.size(3), output.size(4)});
if (with_bias){
output += bias.view({1, bias.size(0), 1, 1});
}
......@@ -484,7 +527,7 @@ void modulated_deform_conv_cuda_backward(at::Tensor input, at::Tensor weight,
int kernel_h, int kernel_w,
int stride_h, int stride_w,
int pad_h, int pad_w,
int dilation_h, int dilation_w,
int dilation_h, int dilation_w, int group,
int deformable_group, const bool with_bias)
{
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
......@@ -501,9 +544,9 @@ void modulated_deform_conv_cuda_backward(at::Tensor input, at::Tensor weight,
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
kernel_h_, kernel_w, kernel_h_, kernel_w_);
if (channels != channels_kernel)
if (channels != channels_kernel * group)
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
channels, channels_kernel);
channels, channels_kernel * group);
const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
......@@ -518,9 +561,20 @@ void modulated_deform_conv_cuda_backward(at::Tensor input, at::Tensor weight,
grad_input = grad_input.view({batch, channels, height, width});
columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out}, input.type());
grad_output = grad_output.view({grad_output.size(0), group, grad_output.size(1) / group, grad_output.size(2), grad_output.size(3)});
for (int b = 0; b < batch; b++)
{
columns.addmm_(weight.flatten(1).transpose(0, 1), grad_output[b].flatten(1), 0.0f, 1.0f);
// divide int group
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
weight = weight.view({group, weight.size(0) / group, weight.size(1), weight.size(2), weight.size(3)});
for (int g = 0; g < group; g++){
columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), grad_output[b][g].flatten(1), 0.0f, 1.0f);
}
columns = columns.view({columns.size(0) * columns.size(1), columns.size(2)});
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), weight.size(3), weight.size(4)});
// gradient w.r.t. input coordinate data
modulated_deformable_col2im_coord_cuda(columns, input[b], offset[b], mask[b],
......@@ -545,14 +599,27 @@ void modulated_deform_conv_cuda_backward(at::Tensor input, at::Tensor weight,
dilation_h, dilation_w, deformable_group,
columns);
grad_weight = grad_weight.flatten(1).addmm_(grad_output[b].flatten(1), columns.transpose(0, 1)).view_as(grad_weight);
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
grad_weight = grad_weight.view({group, grad_weight.size(0) / group, grad_weight.size(1), grad_weight.size(2), grad_weight.size(3)});
if (with_bias)
grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
if (with_bias){
grad_bias = grad_bias.view({-1, 1}).addmm_(grad_output[b].flatten(1), ones.view({-1, 1})).view(-1);
for (int g = 0; g < group; g++){
grad_weight[g] = grad_weight[g].flatten(1).addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1)).view_as(grad_weight[g]);
if (with_bias){
grad_bias[g] = grad_bias[g].view({-1, 1}).addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1})).view(-1);
}
}
columns = columns.view({columns.size(0) * columns.size(1), columns.size(2)});
grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1), grad_weight.size(2), grad_weight.size(3), grad_weight.size(4)});
if (with_bias)
grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
}
grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1), grad_output.size(2), grad_output.size(3), grad_output.size(4)});
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("deform_conv_forward_cuda", &deform_conv_forward_cuda, "deform forward (CUDA)");
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment