Skip to content
Snippets Groups Projects
Commit ad75a173 authored by yhcao6's avatar yhcao6
Browse files

add dcn group support

parent 7640a04b
No related branches found
No related tags found
No related merge requests found
# model settings
model = dict(
type='FasterRCNN',
pretrained='open-mmlab://resnext101_32x4d',
backbone=dict(
type='ResNeXt',
depth=101,
groups=32,
base_width=4,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
style='pytorch',
dcn=dict(
modulated=False,
groups=32,
deformable_groups=1,
fallback_on_stride=False),
stage_with_dcn=(False, True, True, True)),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_scales=[8],
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[4, 8, 16, 32, 64],
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0],
use_sigmoid_cls=True),
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='SharedFCBBoxHead',
num_fcs=2,
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=81,
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2],
reg_class_agnostic=False))
# model training and testing settings
train_cfg = dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=0,
pos_weight=-1,
smoothl1_beta=1 / 9.0,
debug=False),
rcnn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
pos_weight=-1,
debug=False))
test_cfg = dict(
rpn=dict(
nms_across_levels=False,
nms_pre=2000,
nms_post=2000,
max_num=2000,
nms_thr=0.7,
min_bbox_size=0),
rcnn=dict(
score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100)
# soft-nms is also supported for rcnn testing
# e.g., nms=dict(type='soft_nms', iou_thr=0.5, min_score=0.05)
)
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
data = dict(
imgs_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0.5,
with_mask=False,
with_crowd=True,
with_label=True),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0,
with_mask=False,
with_crowd=True,
with_label=True),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0,
with_mask=False,
with_label=False,
test_mode=True))
# optimizer
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[8, 11])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 12
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/faster_rcnn_dconv_c3-c5_x101_32x4d_fpn_1x'
load_from = None
resume_from = None
workflow = [('train', 1)]
......@@ -2,8 +2,9 @@ import math
import torch.nn as nn
from .resnet import ResNet
from mmdet.ops import DeformConv, ModulatedDeformConv
from .resnet import Bottleneck as _Bottleneck
from .resnet import ResNet
from ..registry import BACKBONES
from ..utils import build_norm_layer
......@@ -22,15 +23,12 @@ class Bottleneck(_Bottleneck):
else:
width = math.floor(self.planes * (base_width / 64)) * groups
self.norm1_name, norm1 = build_norm_layer(self.normalize,
width,
postfix=1)
self.norm2_name, norm2 = build_norm_layer(self.normalize,
width,
postfix=2)
self.norm3_name, norm3 = build_norm_layer(self.normalize,
self.planes * self.expansion,
postfix=3)
self.norm1_name, norm1 = build_norm_layer(
self.normalize, width, postfix=1)
self.norm2_name, norm2 = build_norm_layer(
self.normalize, width, postfix=2)
self.norm3_name, norm3 = build_norm_layer(
self.normalize, self.planes * self.expansion, postfix=3)
self.conv1 = nn.Conv2d(
self.inplanes,
......@@ -39,15 +37,46 @@ class Bottleneck(_Bottleneck):
stride=self.conv1_stride,
bias=False)
self.add_module(self.norm1_name, norm1)
self.conv2 = nn.Conv2d(
width,
width,
kernel_size=3,
stride=self.conv2_stride,
padding=self.dilation,
dilation=self.dilation,
groups=groups,
bias=False)
fallback_on_stride = False
self.with_modulated_dcn = False
if self.with_dcn:
fallback_on_stride = self.dcn.get('fallback_on_stride', False)
self.with_modulated_dcn = self.dcn.get('modulated', False)
if not self.with_dcn or fallback_on_stride:
self.conv2 = nn.Conv2d(
width,
width,
kernel_size=3,
stride=self.conv2_stride,
padding=self.dilation,
dilation=self.dilation,
groups=groups,
bias=False)
else:
deformable_groups = self.dcn.get('deformable_groups', 1)
if not self.with_modulated_dcn:
conv_op = DeformConv
offset_channels = 18
else:
conv_op = ModulatedDeformConv
offset_channels = 27
self.conv2_offset = nn.Conv2d(
width,
deformable_groups * offset_channels,
kernel_size=3,
stride=self.conv2_stride,
padding=self.dilation,
dilation=self.dilation)
self.conv2 = conv_op(
width,
width,
kernel_size=3,
stride=self.conv2_stride,
padding=self.dilation,
dilation=self.dilation,
groups=groups,
deformable_groups=deformable_groups,
bias=False)
self.add_module(self.norm2_name, norm2)
self.conv3 = nn.Conv2d(
width, self.planes * self.expansion, kernel_size=1, bias=False)
......@@ -64,7 +93,8 @@ def make_res_layer(block,
base_width=4,
style='pytorch',
with_cp=False,
normalize=dict(type='BN')):
normalize=dict(type='BN'),
dcn=None):
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
downsample = nn.Sequential(
......@@ -89,7 +119,8 @@ def make_res_layer(block,
base_width=base_width,
style=style,
with_cp=with_cp,
normalize=normalize))
normalize=normalize,
dcn=dcn))
inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(
......@@ -102,7 +133,8 @@ def make_res_layer(block,
base_width=base_width,
style=style,
with_cp=with_cp,
normalize=normalize))
normalize=normalize,
dcn=dcn))
return nn.Sequential(*layers)
......
......@@ -15,6 +15,7 @@ class DeformConvFunction(Function):
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
im2col_step=64):
if input is not None and input.dim() != 4:
......@@ -24,6 +25,7 @@ class DeformConvFunction(Function):
ctx.stride = _pair(stride)
ctx.padding = _pair(padding)
ctx.dilation = _pair(dilation)
ctx.groups = groups
ctx.deformable_groups = deformable_groups
ctx.im2col_step = im2col_step
......@@ -45,7 +47,8 @@ class DeformConvFunction(Function):
input, weight, offset, output, ctx.bufs_[0], ctx.bufs_[1],
weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0],
ctx.padding[1], ctx.padding[0], ctx.dilation[1],
ctx.dilation[0], ctx.deformable_groups, cur_im2col_step)
ctx.dilation[0], ctx.groups, ctx.deformable_groups,
cur_im2col_step)
return output
@staticmethod
......@@ -69,7 +72,8 @@ class DeformConvFunction(Function):
grad_offset, weight, ctx.bufs_[0], weight.size(3),
weight.size(2), ctx.stride[1], ctx.stride[0],
ctx.padding[1], ctx.padding[0], ctx.dilation[1],
ctx.dilation[0], ctx.deformable_groups, cur_im2col_step)
ctx.dilation[0], ctx.groups, ctx.deformable_groups,
cur_im2col_step)
if ctx.needs_input_grad[2]:
grad_weight = torch.zeros_like(weight)
......@@ -78,9 +82,11 @@ class DeformConvFunction(Function):
grad_weight, ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
weight.size(2), ctx.stride[1], ctx.stride[0],
ctx.padding[1], ctx.padding[0], ctx.dilation[1],
ctx.dilation[0], ctx.deformable_groups, 1, cur_im2col_step)
ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1,
cur_im2col_step)
return grad_input, grad_offset, grad_weight, None, None, None, None
return (grad_input, grad_offset, grad_weight, None, None, None, None,
None)
@staticmethod
def _output_size(input, weight, padding, dilation, stride):
......@@ -111,10 +117,12 @@ class ModulatedDeformConvFunction(Function):
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1):
ctx.stride = stride
ctx.padding = padding
ctx.dilation = dilation
ctx.groups = groups
ctx.deformable_groups = deformable_groups
ctx.with_bias = bias is not None
if not ctx.with_bias:
......@@ -131,7 +139,7 @@ class ModulatedDeformConvFunction(Function):
input, weight, bias, ctx._bufs[0], offset, mask, output,
ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
ctx.deformable_groups, ctx.with_bias)
ctx.groups, ctx.deformable_groups, ctx.with_bias)
return output
@staticmethod
......@@ -149,12 +157,12 @@ class ModulatedDeformConvFunction(Function):
grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
grad_output, weight.shape[2], weight.shape[3], ctx.stride,
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
ctx.deformable_groups, ctx.with_bias)
ctx.groups, ctx.deformable_groups, ctx.with_bias)
if not ctx.with_bias:
grad_bias = None
return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias,
None, None, None, None)
None, None, None, None, None)
@staticmethod
def _infer_shape(ctx, input, weight):
......
......@@ -16,6 +16,7 @@ class DeformConv(nn.Module):
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
bias=False):
assert not bias
......@@ -26,10 +27,12 @@ class DeformConv(nn.Module):
self.stride = _pair(stride)
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.groups = groups
self.deformable_groups = deformable_groups
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels, *self.kernel_size))
torch.Tensor(out_channels, in_channels // self.groups,
*self.kernel_size))
self.reset_parameters()
......@@ -42,7 +45,8 @@ class DeformConv(nn.Module):
def forward(self, input, offset):
return deform_conv(input, offset, self.weight, self.stride,
self.padding, self.dilation, self.deformable_groups)
self.padding, self.dilation, self.groups,
self.deformable_groups)
class ModulatedDeformConv(nn.Module):
......@@ -54,6 +58,7 @@ class ModulatedDeformConv(nn.Module):
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
bias=True):
super(ModulatedDeformConv, self).__init__()
......@@ -63,11 +68,13 @@ class ModulatedDeformConv(nn.Module):
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.deformable_groups = deformable_groups
self.with_bias = bias
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels, *self.kernel_size))
torch.Tensor(out_channels, in_channels // groups,
*self.kernel_size))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
......@@ -84,9 +91,9 @@ class ModulatedDeformConv(nn.Module):
self.bias.data.zero_()
def forward(self, input, offset, mask):
return modulated_deform_conv(input, offset, mask, self.weight,
self.bias, self.stride, self.padding,
self.dilation, self.deformable_groups)
return modulated_deform_conv(
input, offset, mask, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups, self.deformable_groups)
class ModulatedDeformConvPack(ModulatedDeformConv):
......@@ -98,14 +105,15 @@ class ModulatedDeformConvPack(ModulatedDeformConv):
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
bias=True):
super(ModulatedDeformConvPack,
self).__init__(in_channels, out_channels, kernel_size, stride,
padding, dilation, deformable_groups, bias)
super(ModulatedDeformConvPack, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
groups, deformable_groups, bias)
self.conv_offset_mask = nn.Conv2d(
self.in_channels,
self.in_channels // self.groups,
self.deformable_groups * 3 * self.kernel_size[0] *
self.kernel_size[1],
kernel_size=self.kernel_size,
......@@ -123,6 +131,6 @@ class ModulatedDeformConvPack(ModulatedDeformConv):
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
return modulated_deform_conv(input, offset, mask, self.weight,
self.bias, self.stride, self.padding,
self.dilation, self.deformable_groups)
return modulated_deform_conv(
input, offset, mask, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups, self.deformable_groups)
......@@ -62,7 +62,7 @@ void modulated_deformable_col2im_coord_cuda(const at::Tensor data_col, const at:
void shape_check(at::Tensor input, at::Tensor offset,
at::Tensor *gradOutput, at::Tensor weight, int kH, int kW,
int dH, int dW, int padH, int padW, int dilationH,
int dilationW, int deformable_group)
int dilationW, int group, int deformable_group)
{
AT_CHECK(weight.ndimension() == 4,
......@@ -105,7 +105,7 @@ void shape_check(at::Tensor input, at::Tensor offset,
AT_CHECK(ndim == 3 || ndim == 4,
"3D or 4D input tensor expected but got: %s", ndim);
long nInputPlane = weight.size(1);
long nInputPlane = weight.size(1) * group;
long inputHeight = input.size(dimh);
long inputWidth = input.size(dimw);
long nOutputPlane = weight.size(0);
......@@ -154,7 +154,7 @@ int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
at::Tensor offset, at::Tensor output,
at::Tensor columns, at::Tensor ones, int kW,
int kH, int dW, int dH, int padW, int padH,
int dilationW, int dilationH,
int dilationW, int dilationH, int group,
int deformable_group, int im2col_step)
{
......@@ -164,7 +164,7 @@ int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
// todo: possibly change data indexing because of parallel_imgs
shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW,
dilationH, dilationW, deformable_group);
dilationH, dilationW, group, deformable_group);
input = input.contiguous();
offset = offset.contiguous();
......@@ -207,6 +207,8 @@ int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
at::Tensor output_buffer = at::zeros({batchSize / im2col_step, nOutputPlane, im2col_step * outputHeight, outputWidth}, output.type());
output_buffer = output_buffer.view({output_buffer.size(0), group, output_buffer.size(1) / group, output_buffer.size(2), output_buffer.size(3)});
for (int elt = 0; elt < batchSize / im2col_step; elt++)
{
deformable_im2col(
......@@ -214,10 +216,17 @@ int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
inputWidth, kH, kW, padH, padW, dH, dW, dilationH, dilationW,
im2col_step, deformable_group, columns);
output_buffer[elt] =
output_buffer[elt].flatten(1).addmm_(weight.flatten(1), columns).view_as(output_buffer[elt]);
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
weight = weight.view({group, weight.size(0) / group, weight.size(1), weight.size(2), weight.size(3)});
for (int g = 0; g < group; g++){
output_buffer[elt][g] =
output_buffer[elt][g].flatten(1).addmm_(weight[g].flatten(1), columns[g]).view_as(output_buffer[elt][g]);
}
}
output_buffer = output_buffer.view({output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2), output_buffer.size(3), output_buffer.size(4)});
output_buffer = output_buffer.view(
{batchSize / im2col_step, nOutputPlane, im2col_step, outputHeight, outputWidth});
output_buffer.transpose_(1, 2);
......@@ -241,11 +250,11 @@ int deform_conv_backward_input_cuda(
at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
at::Tensor gradInput, at::Tensor gradOffset, at::Tensor weight,
at::Tensor columns, int kW, int kH, int dW, int dH, int padW, int padH,
int dilationW, int dilationH, int deformable_group, int im2col_step)
int dilationW, int dilationH, int group, int deformable_group, int im2col_step)
{
shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH,
padW, dilationH, dilationW, deformable_group);
padW, dilationH, dilationW, group, deformable_group);
input = input.contiguous();
offset = offset.contiguous();
......@@ -292,7 +301,17 @@ int deform_conv_backward_input_cuda(
for (int elt = 0; elt < batchSize / im2col_step; elt++)
{
columns = columns.addmm_(weight.flatten(1).transpose(0, 1), gradOutput[elt].flatten(1), 0.0f, 1.0f);
// divide into groups
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
weight = weight.view({group, weight.size(0) / group, weight.size(1), weight.size(2), weight.size(3)});
gradOutput = gradOutput.view({gradOutput.size(0), group, gradOutput.size(1) / group, gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});
for (int g = 0; g < group; g++){
columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), gradOutput[elt][g].flatten(1), 0.0f, 1.0f);
}
columns = columns.view({columns.size(0) * columns.size(1), columns.size(2)});
gradOutput = gradOutput.view({gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2), gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});
deformable_col2im_coord(
columns, input[elt], offset[elt],
......@@ -329,7 +348,7 @@ int deform_conv_backward_parameters_cuda(
at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
at::Tensor gradWeight, // at::Tensor gradBias,
at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
int padW, int padH, int dilationW, int dilationH, int deformable_group,
int padW, int padH, int dilationW, int dilationH, int group, int deformable_group,
float scale, int im2col_step)
{
......@@ -338,7 +357,7 @@ int deform_conv_backward_parameters_cuda(
// todo: add im2col_step as input
shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW,
padH, padW, dilationH, dilationW, deformable_group);
padH, padW, dilationH, dilationW, group, deformable_group);
input = input.contiguous();
offset = offset.contiguous();
......@@ -395,9 +414,19 @@ int deform_conv_backward_parameters_cuda(
inputWidth, kH, kW, padH, padW, dH, dW, dilationH, dilationW,
im2col_step, deformable_group, columns);
gradWeight = gradWeight.flatten(1).addmm_(
gradOutputBuffer[elt].flatten(1), columns.transpose(1, 0), 1.0, scale)
.view_as(gradWeight);
// divide into group
gradOutputBuffer = gradOutputBuffer.view({gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group, gradOutputBuffer.size(2), gradOutputBuffer.size(3)});
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
gradWeight = gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1), gradWeight.size(2), gradWeight.size(3)});
for (int g = 0; g < group; g++){
gradWeight[g] = gradWeight[g].flatten(1).addmm_(
gradOutputBuffer[elt][g].flatten(1), columns[g].transpose(1, 0), 1.0, scale)
.view_as(gradWeight[g]);
}
gradOutputBuffer = gradOutputBuffer.view({gradOutputBuffer.size(0), gradOutputBuffer.size(1) * gradOutputBuffer.size(2), gradOutputBuffer.size(3), gradOutputBuffer.size(4)});
columns = columns.view({columns.size(0) * columns.size(1), columns.size(2)});
gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1), gradWeight.size(2), gradWeight.size(3), gradWeight.size(4)});
}
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
......@@ -413,6 +442,7 @@ int deform_conv_backward_parameters_cuda(
return 1;
}
void modulated_deform_conv_cuda_forward(at::Tensor input, at::Tensor weight,
at::Tensor bias, at::Tensor ones,
at::Tensor offset, at::Tensor mask,
......@@ -420,7 +450,7 @@ void modulated_deform_conv_cuda_forward(at::Tensor input, at::Tensor weight,
int kernel_h, int kernel_w,
const int stride_h, const int stride_w,
const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w,
const int dilation_h, const int dilation_w, const int group,
const int deformable_group, const bool with_bias)
{
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
......@@ -439,9 +469,9 @@ void modulated_deform_conv_cuda_forward(at::Tensor input, at::Tensor weight,
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
kernel_h_, kernel_w, kernel_h_, kernel_w_);
if (channels != channels_kernel)
if (channels != channels_kernel * group)
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
channels, channels_kernel);
channels, channels_kernel * group);
const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
......@@ -458,6 +488,8 @@ void modulated_deform_conv_cuda_forward(at::Tensor input, at::Tensor weight,
// resize temporary columns
columns = at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.type());
output = output.view({output.size(0), group, output.size(1) / group, output.size(2), output.size(3)});
for (int b = 0; b < batch; b++)
{
modulated_deformable_im2col_cuda(input[b], offset[b], mask[b],
......@@ -466,9 +498,20 @@ void modulated_deform_conv_cuda_forward(at::Tensor input, at::Tensor weight,
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
deformable_group, columns);
output[b] = output[b].flatten(1).addmm_(weight.flatten(1), columns).view_as(output[b]);
// divide into group
weight = weight.view({group, weight.size(0) / group, weight.size(1), weight.size(2), weight.size(3)});
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
for (int g = 0; g < group; g++){
output[b][g] = output[b][g].flatten(1).addmm_(weight[g].flatten(1), columns[g]).view_as(output[b][g]);
}
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), weight.size(3), weight.size(4)});
columns = columns.view({columns.size(0) * columns.size(1), columns.size(2)});
}
output = output.view({output.size(0), output.size(1) * output.size(2), output.size(3), output.size(4)});
if (with_bias){
output += bias.view({1, bias.size(0), 1, 1});
}
......@@ -484,7 +527,7 @@ void modulated_deform_conv_cuda_backward(at::Tensor input, at::Tensor weight,
int kernel_h, int kernel_w,
int stride_h, int stride_w,
int pad_h, int pad_w,
int dilation_h, int dilation_w,
int dilation_h, int dilation_w, int group,
int deformable_group, const bool with_bias)
{
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
......@@ -501,9 +544,9 @@ void modulated_deform_conv_cuda_backward(at::Tensor input, at::Tensor weight,
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
kernel_h_, kernel_w, kernel_h_, kernel_w_);
if (channels != channels_kernel)
if (channels != channels_kernel * group)
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
channels, channels_kernel);
channels, channels_kernel * group);
const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
......@@ -518,9 +561,20 @@ void modulated_deform_conv_cuda_backward(at::Tensor input, at::Tensor weight,
grad_input = grad_input.view({batch, channels, height, width});
columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out}, input.type());
grad_output = grad_output.view({grad_output.size(0), group, grad_output.size(1) / group, grad_output.size(2), grad_output.size(3)});
for (int b = 0; b < batch; b++)
{
columns.addmm_(weight.flatten(1).transpose(0, 1), grad_output[b].flatten(1), 0.0f, 1.0f);
// divide int group
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
weight = weight.view({group, weight.size(0) / group, weight.size(1), weight.size(2), weight.size(3)});
for (int g = 0; g < group; g++){
columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), grad_output[b][g].flatten(1), 0.0f, 1.0f);
}
columns = columns.view({columns.size(0) * columns.size(1), columns.size(2)});
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), weight.size(3), weight.size(4)});
// gradient w.r.t. input coordinate data
modulated_deformable_col2im_coord_cuda(columns, input[b], offset[b], mask[b],
......@@ -545,14 +599,25 @@ void modulated_deform_conv_cuda_backward(at::Tensor input, at::Tensor weight,
dilation_h, dilation_w, deformable_group,
columns);
grad_weight = grad_weight.flatten(1).addmm_(grad_output[b].flatten(1), columns.transpose(0, 1)).view_as(grad_weight);
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
grad_weight = grad_weight.view({group, grad_weight.size(0) / group, grad_weight.size(1), grad_weight.size(2), grad_weight.size(3)});
grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
if (with_bias){
grad_bias = grad_bias.view({-1, 1}).addmm_(grad_output[b].flatten(1), ones.view({-1, 1})).view(-1);
for (int g = 0; g < group; g++){
grad_weight[g] = grad_weight[g].flatten(1).addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1)).view_as(grad_weight[g]);
if (with_bias){
grad_bias[g] = grad_bias[g].view({-1, 1}).addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1})).view(-1);
}
}
columns = columns.view({columns.size(0) * columns.size(1), columns.size(2)});
grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1), grad_weight.size(2), grad_weight.size(3), grad_weight.size(4)});
grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
}
grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1), grad_output.size(2), grad_output.size(3), grad_output.size(4)});
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("deform_conv_forward_cuda", &deform_conv_forward_cuda, "deform forward (CUDA)");
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment