Skip to content
Snippets Groups Projects
Commit 03c38bdc authored by yhcao6's avatar yhcao6
Browse files

separate dcn and dpool cpp, restruct some code

parent 66489d6b
No related branches found
No related tags found
No related merge requests found
......@@ -2,11 +2,12 @@ from .functions.deform_conv import deform_conv, modulated_deform_conv
from .functions.deform_pool import deform_roi_pooling
from .modules.deform_conv import (DeformConv, ModulatedDeformConv,
ModulatedDeformConvPack)
from .modules.deform_pool import (DeformRoIPooling,
from .modules.deform_pool import (DeformRoIPooling, DeformRoIPoolingPack,
ModulatedDeformRoIPoolingPack)
__all__ = [
'DeformConv', 'DeformRoIPooling', 'ModulatedDeformRoIPoolingPack',
'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv',
'DeformConv', 'DeformRoIPooling', 'DeformRoIPoolingPack',
'ModulatedDeformRoIPoolingPack', 'ModulatedDeformConv',
'ModulatedDeformConvPack', 'deform_conv',
'modulated_deform_conv', 'deform_roi_pooling'
]
......@@ -3,7 +3,6 @@ from torch.autograd import Function
from torch.nn.modules.utils import _pair
from .. import deform_conv_cuda
from .. import modulated_dcn_cuda as _backend
class DeformConvFunction(Function):
......@@ -124,7 +123,7 @@ class ModulatedDeformConvFunction(Function):
output = input.new(
*ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
ctx._bufs = [input.new(), input.new()]
_backend.modulated_deform_conv_cuda_forward(
deform_conv_cuda.modulated_deform_conv_cuda_forward(
input, weight, bias, ctx._bufs[0], offset, mask, output,
ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
......@@ -141,7 +140,7 @@ class ModulatedDeformConvFunction(Function):
grad_mask = torch.zeros_like(mask)
grad_weight = torch.zeros_like(weight)
grad_bias = torch.zeros_like(bias)
_backend.modulated_deform_conv_cuda_backward(
deform_conv_cuda.modulated_deform_conv_cuda_backward(
input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1],
grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
grad_output, weight.shape[2], weight.shape[3], ctx.stride,
......
import torch
from torch.autograd import Function
from .. import modulated_dcn_cuda as _backend
from .. import deform_pool_cuda
class DeformRoIPoolingFunction(Function):
......@@ -36,7 +36,7 @@ class DeformRoIPoolingFunction(Function):
*DeformRoIPoolingFunction._infer_shape(ctx, data, rois))
output_count = data.new(
*DeformRoIPoolingFunction._infer_shape(ctx, data, rois))
_backend.deform_psroi_pooling_cuda_forward(
deform_pool_cuda.deform_psroi_pooling_cuda_forward(
data, rois, offset, output, output_count, ctx.no_trans,
ctx.spatial_scale, ctx.output_dim, ctx.group_size, ctx.pooled_size,
ctx.part_size, ctx.sample_per_part, ctx.trans_std)
......@@ -63,7 +63,7 @@ class DeformRoIPoolingFunction(Function):
grad_input = torch.zeros_like(data)
grad_offset = torch.zeros_like(offset)
_backend.deform_psroi_pooling_cuda_backward(
deform_pool_cuda.deform_psroi_pooling_cuda_backward(
grad_output, data, rois, offset, output_count, grad_input,
grad_offset, ctx.no_trans, ctx.spatial_scale, ctx.output_dim,
ctx.group_size, ctx.pooled_size, ctx.part_size,
......
......@@ -16,7 +16,7 @@ class DeformConv(nn.Module):
stride=1,
padding=0,
dilation=1,
num_deformable_groups=1):
deformable_groups=1):
super(DeformConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
......@@ -24,7 +24,7 @@ class DeformConv(nn.Module):
self.stride = _pair(stride)
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.num_deformable_groups = num_deformable_groups
self.deformable_groups = deformable_groups
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels, *self.kernel_size))
......@@ -41,7 +41,7 @@ class DeformConv(nn.Module):
def forward(self, input, offset):
return deform_conv(input, offset, self.weight, self.stride,
self.padding, self.dilation,
self.num_deformable_groups)
self.deformable_groups)
class ModulatedDeformConv(nn.Module):
......@@ -54,7 +54,7 @@ class ModulatedDeformConv(nn.Module):
padding,
dilation=1,
deformable_groups=1,
no_bias=True):
bias=False):
super(ModulatedDeformConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
......@@ -63,13 +63,12 @@ class ModulatedDeformConv(nn.Module):
self.padding = padding
self.dilation = dilation
self.deformable_groups = deformable_groups
self.no_bias = no_bias
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels, *self.kernel_size))
self.bias = nn.Parameter(torch.zeros(out_channels))
self.reset_parameters()
if self.no_bias:
if not bias:
self.bias.requires_grad = False
def reset_parameters(self):
......@@ -96,10 +95,10 @@ class ModulatedDeformConvPack(ModulatedDeformConv):
padding,
dilation=1,
deformable_groups=1,
no_bias=False):
bias=True):
super(ModulatedDeformConvPack,
self).__init__(in_channels, out_channels, kernel_size, stride,
padding, dilation, deformable_groups, no_bias)
padding, dilation, deformable_groups, bias)
self.conv_offset_mask = nn.Conv2d(
self.in_channels,
......
......@@ -7,7 +7,7 @@ class DeformRoIPooling(nn.Module):
def __init__(self,
spatial_scale,
pooled_size,
out_size,
output_dim,
no_trans,
group_size=1,
......@@ -16,12 +16,11 @@ class DeformRoIPooling(nn.Module):
trans_std=.0):
super(DeformRoIPooling, self).__init__()
self.spatial_scale = spatial_scale
self.pooled_size = pooled_size
self.out_size = pooled_size
self.out_size = out_size
self.output_dim = output_dim
self.no_trans = no_trans
self.group_size = group_size
self.part_size = pooled_size if part_size is None else part_size
self.part_size = out_size if part_size is None else part_size
self.sample_per_part = sample_per_part
self.trans_std = trans_std
......@@ -29,7 +28,7 @@ class DeformRoIPooling(nn.Module):
if self.no_trans:
offset = data.new()
return deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.pooled_size,
data, rois, offset, self.spatial_scale, self.out_size,
self.output_dim, self.no_trans, self.group_size, self.part_size,
self.sample_per_part, self.trans_std)
......@@ -38,7 +37,7 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
def __init__(self,
spatial_scale,
pooled_size,
out_size,
output_dim,
no_trans,
group_size=1,
......@@ -47,7 +46,7 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
trans_std=.0,
deform_fc_dim=1024):
super(ModulatedDeformRoIPoolingPack, self).__init__(
spatial_scale, pooled_size, output_dim, no_trans, group_size,
spatial_scale, out_size, output_dim, no_trans, group_size,
part_size, sample_per_part, trans_std)
self.deform_fc_dim = deform_fc_dim
......@@ -55,20 +54,20 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
if not no_trans:
self.offset_fc = nn.Sequential(
nn.Linear(
self.pooled_size * self.pooled_size * self.output_dim,
self.out_size * self.out_size * self.output_dim,
self.deform_fc_dim), nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_dim, self.deform_fc_dim),
nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_dim,
self.pooled_size * self.pooled_size * 2))
self.out_size * self.out_size * 2))
self.offset_fc[4].weight.data.zero_()
self.offset_fc[4].bias.data.zero_()
self.mask_fc = nn.Sequential(
nn.Linear(
self.pooled_size * self.pooled_size * self.output_dim,
self.out_size * self.out_size * self.output_dim,
self.deform_fc_dim), nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_dim,
self.pooled_size * self.pooled_size * 1),
self.out_size * self.out_size * 1),
nn.Sigmoid())
self.mask_fc[2].weight.data.zero_()
self.mask_fc[2].bias.data.zero_()
......@@ -80,19 +79,72 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
n = rois.shape[0]
offset = data.new()
x = deform_roi_pooling(data, rois, offset, self.spatial_scale,
self.pooled_size, self.output_dim, True,
self.out_size, self.output_dim, True,
self.group_size, self.part_size,
self.sample_per_part, self.trans_std)
offset = self.offset_fc(x.view(n, -1))
offset = offset.view(n, 2, self.pooled_size, self.pooled_size)
offset = offset.view(n, 2, self.out_size, self.out_size)
mask = self.mask_fc(x.view(n, -1))
mask = mask.view(n, 1, self.pooled_size, self.pooled_size)
mask = mask.view(n, 1, self.out_size, self.out_size)
feat = deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.pooled_size,
data, rois, offset, self.spatial_scale, self.out_size,
self.output_dim, self.no_trans, self.group_size,
self.part_size, self.sample_per_part, self.trans_std) * mask
return feat
return deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.pooled_size,
data, rois, offset, self.spatial_scale, self.out_size,
self.output_dim, self.no_trans, self.group_size, self.part_size,
self.sample_per_part, self.trans_std)
class DeformRoIPoolingPack(DeformRoIPooling):
def __init__(self,
spatial_scale,
out_size,
output_dim,
no_trans,
group_size=1,
part_size=None,
sample_per_part=4,
trans_std=.0,
deform_fc_dim=1024):
super(DeformRoIPoolingPack, self).__init__(
spatial_scale, out_size, output_dim, no_trans, group_size,
part_size, sample_per_part, trans_std)
self.deform_fc_dim = deform_fc_dim
if not no_trans:
self.offset_fc = nn.Sequential(
nn.Linear(
self.out_size * self.out_size * self.output_dim,
self.deform_fc_dim), nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_dim, self.deform_fc_dim),
nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_dim,
self.out_size * self.out_size * 2))
self.offset_fc[4].weight.data.zero_()
self.offset_fc[4].bias.data.zero_()
def forward(self, data, rois):
if self.no_trans:
offset = data.new()
else:
n = rois.shape[0]
offset = data.new()
x = deform_roi_pooling(data, rois, offset, self.spatial_scale,
self.out_size, self.output_dim, True,
self.group_size, self.part_size,
self.sample_per_part, self.trans_std)
offset = self.offset_fc(x.view(n, -1))
offset = offset.view(n, 2, self.out_size, self.out_size)
feat = deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.out_size,
self.output_dim, self.no_trans, self.group_size,
self.part_size, self.sample_per_part, self.trans_std)
return feat
return deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.out_size,
self.output_dim, self.no_trans, self.group_size, self.part_size,
self.sample_per_part, self.trans_std)
......@@ -2,11 +2,14 @@ from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='deform_conv_cuda',
name='deform_conv',
ext_modules=[
CUDAExtension('deform_conv_cuda', [
'src/deform_conv_cuda.cpp',
'src/deform_conv_cuda_kernel.cu',
]),
CUDAExtension('deform_pool_cuda', [
'src/deform_pool_cuda.cpp', 'src/deform_pool_cuda_kernel.cu'
]),
],
cmdclass={'build_ext': BuildExtension})
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='modulated_dcn_cuda',
ext_modules=[
CUDAExtension('modulated_dcn_cuda', [
'src/modulated_dcn_cuda.cpp',
'src/modulated_deform_im2col_cuda.cu',
'src/deform_psroi_pooling_cuda.cu'
]),
],
cmdclass={'build_ext': BuildExtension})
......@@ -33,6 +33,32 @@ void deformable_col2im_coord(const at::Tensor data_col,
const int dilation_w, const int parallel_imgs,
const int deformable_group, at::Tensor grad_offset);
void modulated_deformable_im2col_cuda(const at::Tensor data_im, const at::Tensor data_offset,
const at::Tensor data_mask, const int batch_size, const int channels,
const int height_im, const int width_im, const int height_col,
const int width_col, const int kernel_h, const int kenerl_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group, at::Tensor data_col);
void modulated_deformable_col2im_cuda(const at::Tensor data_col, const at::Tensor data_offset,
const at::Tensor data_mask, const int batch_size, const int channels,
const int height_im, const int width_im, const int height_col,
const int width_col, const int kernel_h, const int kenerl_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group, at::Tensor grad_im);
void modulated_deformable_col2im_coord_cuda(const at::Tensor data_col, const at::Tensor data_im,
const at::Tensor data_offset, const at::Tensor data_mask,
const int batch_size, const int channels, const int height_im,
const int width_im, const int height_col, const int width_col,
const int kernel_h, const int kenerl_w, const int pad_h,
const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group, at::Tensor grad_offset,
at::Tensor grad_mask);
void shape_check(at::Tensor input, at::Tensor offset,
at::Tensor *gradOutput, at::Tensor weight, int kH, int kW,
int dH, int dW, int padH, int padW, int dilationH,
......@@ -256,16 +282,6 @@ int deform_conv_backward_input_cuda(
{batchSize / im2col_step, im2col_step, nOutputPlane, outputHeight, outputWidth});
gradOutput.transpose_(1, 2);
at::Tensor gradOutputBuffer = at::zeros_like(gradOutput);
gradOutputBuffer = gradOutputBuffer.view(
{batchSize / im2col_step, nOutputPlane, im2col_step, outputHeight, outputWidth});
gradOutputBuffer.copy_(gradOutput);
gradOutputBuffer = gradOutputBuffer.view(
{batchSize / im2col_step, nOutputPlane, im2col_step * outputHeight, outputWidth});
gradOutput.transpose_(1, 2);
gradOutput = gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
gradInput = gradInput.view(
{batchSize / im2col_step, im2col_step, nInputPlane, inputHeight, inputWidth});
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, inputHeight, inputWidth});
......@@ -276,7 +292,7 @@ int deform_conv_backward_input_cuda(
for (int elt = 0; elt < batchSize / im2col_step; elt++)
{
columns = columns.addmm_(weight.flatten(1).transpose(0, 1), gradOutputBuffer[elt].flatten(1), 0.0f, 1.0f);
columns = columns.addmm_(weight.flatten(1).transpose(0, 1), gradOutput[elt].flatten(1), 0.0f, 1.0f);
deformable_col2im_coord(
columns, input[elt], offset[elt],
......@@ -289,6 +305,9 @@ int deform_conv_backward_input_cuda(
deformable_group, gradInput[elt]);
}
gradOutput.transpose_(1, 2);
gradOutput = gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
gradOffset = gradOffset.view({batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
......@@ -394,6 +413,148 @@ int deform_conv_backward_parameters_cuda(
return 1;
}
void modulated_deform_conv_cuda_forward(at::Tensor input, at::Tensor weight,
at::Tensor bias, at::Tensor ones,
at::Tensor offset, at::Tensor mask,
at::Tensor output, at::Tensor columns,
int kernel_h, int kernel_w,
const int stride_h, const int stride_w,
const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w,
const int deformable_group)
{
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_out = weight.size(0);
const int channels_kernel = weight.size(1);
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
kernel_h_, kernel_w, kernel_h_, kernel_w_);
if (channels != channels_kernel)
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
channels, channels_kernel);
const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
if (ones.ndimension() != 2 ||
ones.size(0) * ones.size(1) < height_out * width_out)
{
// Resize plane and fill with ones...
ones = at::ones({height_out, width_out}, input.type());
}
// resize output
output = output.view({batch, channels_out, height_out, width_out});
// resize temporary columns
columns = at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.type());
for (int b = 0; b < batch; b++)
{
// Do Bias first:
// M,N,K are dims of matrix A and B
// (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
// (N x 1) (1 x M)
output[b] = output[b].flatten(1).addmm_(bias.view({-1, 1}), ones.view({1, -1}), 0.0f, 1.0f).view_as(output[b]);
modulated_deformable_im2col_cuda(input[b], offset[b], mask[b],
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
deformable_group, columns);
//(k * m) x (m * n)
// Y = WC
output[b] = output[b].flatten(1).addmm_(weight.flatten(1), columns).view_as(output[b]);
}
}
void modulated_deform_conv_cuda_backward(at::Tensor input, at::Tensor weight,
at::Tensor bias, at::Tensor ones,
at::Tensor offset, at::Tensor mask,
at::Tensor columns,
at::Tensor grad_input, at::Tensor grad_weight,
at::Tensor grad_bias, at::Tensor grad_offset,
at::Tensor grad_mask, at::Tensor grad_output,
int kernel_h, int kernel_w,
int stride_h, int stride_w,
int pad_h, int pad_w,
int dilation_h, int dilation_w,
int deformable_group)
{
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_kernel = weight.size(1);
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
kernel_h_, kernel_w, kernel_h_, kernel_w_);
if (channels != channels_kernel)
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
channels, channels_kernel);
const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
if (ones.ndimension() != 2 ||
ones.size(0) * ones.size(1) < height_out * width_out)
{
// Resize plane and fill with ones...
ones = at::ones({height_out, width_out}, input.type());
}
grad_input = grad_input.view({batch, channels, height, width});
columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out}, input.type());
for (int b = 0; b < batch; b++)
{
columns.addmm_(weight.flatten(1).transpose(0, 1), grad_output[b].flatten(1), 0.0f, 1.0f);
// gradient w.r.t. input coordinate data
modulated_deformable_col2im_coord_cuda(columns, input[b], offset[b], mask[b],
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
grad_offset[b], grad_mask[b]);
// gradient w.r.t. input data
modulated_deformable_col2im_cuda(columns, offset[b], mask[b],
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
grad_input[b]);
// gradient w.r.t. weight, dWeight should accumulate across the batch and group
modulated_deformable_im2col_cuda(input[b], offset[b], mask[b],
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
columns);
grad_weight = grad_weight.flatten(1).addmm_(grad_output[b].flatten(1), columns.transpose(0, 1)).view_as(grad_weight);
grad_bias = grad_bias.view({-1, 1}).addmm_(grad_output[b].flatten(1), ones.view({-1, 1})).view(-1);
}
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("deform_conv_forward_cuda", &deform_conv_forward_cuda, "deform forward (CUDA)");
......@@ -401,4 +562,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
"deform_conv_backward_input (CUDA)");
m.def("deform_conv_backward_parameters_cuda", &deform_conv_backward_parameters_cuda,
"deform_conv_backward_parameters (CUDA)");
m.def("modulated_deform_conv_cuda_forward", &modulated_deform_conv_cuda_forward,
"modulated deform conv forward (CUDA)");
m.def("modulated_deform_conv_cuda_backward", &modulated_deform_conv_cuda_backward,
"modulated deform conv backward (CUDA)");
}
This diff is collapsed.
......@@ -8,32 +8,6 @@
#include <cmath>
#include <vector>
void modulated_deformable_im2col_cuda(const at::Tensor data_im, const at::Tensor data_offset,
const at::Tensor data_mask, const int batch_size, const int channels,
const int height_im, const int width_im, const int height_col,
const int width_col, const int kernel_h, const int kenerl_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group, at::Tensor data_col);
void modulated_deformable_col2im_cuda(const at::Tensor data_col, const at::Tensor data_offset,
const at::Tensor data_mask, const int batch_size, const int channels,
const int height_im, const int width_im, const int height_col,
const int width_col, const int kernel_h, const int kenerl_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group, at::Tensor grad_im);
void modulated_deformable_col2im_coord_cuda(const at::Tensor data_col, const at::Tensor data_im,
const at::Tensor data_offset, const at::Tensor data_mask,
const int batch_size, const int channels, const int height_im,
const int width_im, const int height_col, const int width_col,
const int kernel_h, const int kenerl_w, const int pad_h,
const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group, at::Tensor grad_offset,
at::Tensor grad_mask);
void DeformablePSROIPoolForward(const at::Tensor data,
const at::Tensor bbox,
const at::Tensor trans,
......@@ -76,148 +50,6 @@ void DeformablePSROIPoolBackwardAcc(const at::Tensor out_grad,
const int sample_per_part,
const float trans_std);
void modulated_deform_conv_cuda_forward(at::Tensor input, at::Tensor weight,
at::Tensor bias, at::Tensor ones,
at::Tensor offset, at::Tensor mask,
at::Tensor output, at::Tensor columns,
int kernel_h, int kernel_w,
const int stride_h, const int stride_w,
const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w,
const int deformable_group)
{
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_out = weight.size(0);
const int channels_kernel = weight.size(1);
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
kernel_h_, kernel_w, kernel_h_, kernel_w_);
if (channels != channels_kernel)
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
channels, channels_kernel);
const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
if (ones.ndimension() != 2 ||
ones.size(0) * ones.size(1) < height_out * width_out)
{
// Resize plane and fill with ones...
ones = at::ones({height_out, width_out}, input.type());
}
// resize output
output = output.view({batch, channels_out, height_out, width_out});
// resize temporary columns
columns = at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.type());
for (int b = 0; b < batch; b++)
{
// Do Bias first:
// M,N,K are dims of matrix A and B
// (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
// (N x 1) (1 x M)
output[b] = output[b].flatten(1).addmm_(bias.view({-1, 1}), ones.view({1, -1}), 0.0f, 1.0f).view_as(output[b]);
modulated_deformable_im2col_cuda(input[b], offset[b], mask[b],
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
deformable_group, columns);
//(k * m) x (m * n)
// Y = WC
output[b] = output[b].flatten(1).addmm_(weight.flatten(1), columns).view_as(output[b]);
}
}
void modulated_deform_conv_cuda_backward(at::Tensor input, at::Tensor weight,
at::Tensor bias, at::Tensor ones,
at::Tensor offset, at::Tensor mask,
at::Tensor columns,
at::Tensor grad_input, at::Tensor grad_weight,
at::Tensor grad_bias, at::Tensor grad_offset,
at::Tensor grad_mask, at::Tensor grad_output,
int kernel_h, int kernel_w,
int stride_h, int stride_w,
int pad_h, int pad_w,
int dilation_h, int dilation_w,
int deformable_group)
{
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_kernel = weight.size(1);
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
kernel_h_, kernel_w, kernel_h_, kernel_w_);
if (channels != channels_kernel)
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
channels, channels_kernel);
const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
if (ones.ndimension() != 2 ||
ones.size(0) * ones.size(1) < height_out * width_out)
{
// Resize plane and fill with ones...
ones = at::ones({height_out, width_out}, input.type());
}
grad_input = grad_input.view({batch, channels, height, width});
columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out}, input.type());
for (int b = 0; b < batch; b++)
{
columns.addmm_(weight.flatten(1).transpose(0, 1), grad_output[b].flatten(1), 0.0f, 1.0f);
// gradient w.r.t. input coordinate data
modulated_deformable_col2im_coord_cuda(columns, input[b], offset[b], mask[b],
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
grad_offset[b], grad_mask[b]);
// gradient w.r.t. input data
modulated_deformable_col2im_cuda(columns, offset[b], mask[b],
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
grad_input[b]);
// gradient w.r.t. weight, dWeight should accumulate across the batch and group
modulated_deformable_im2col_cuda(input[b], offset[b], mask[b],
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
columns);
grad_weight = grad_weight.flatten(1).addmm_(grad_output[b].flatten(1), columns.transpose(0, 1)).view_as(grad_weight);
grad_bias = grad_bias.view({-1, 1}).addmm_(grad_output[b].flatten(1), ones.view({-1, 1})).view(-1);
}
}
void deform_psroi_pooling_cuda_forward(at::Tensor input, at::Tensor bbox,
at::Tensor trans,
at::Tensor out, at::Tensor top_count,
......@@ -305,10 +137,6 @@ void deform_psroi_pooling_cuda_backward(at::Tensor out_grad,
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("modulated_deform_conv_cuda_forward", &modulated_deform_conv_cuda_forward,
"modulated deform conv forward (CUDA)");
m.def("modulated_deform_conv_cuda_backward", &modulated_deform_conv_cuda_backward,
"modulated deform conv backward (CUDA)");
m.def("deform_psroi_pooling_cuda_forward", &deform_psroi_pooling_cuda_forward,
"deform psroi pooling forward(CUDA)");
m.def("deform_psroi_pooling_cuda_backward", &deform_psroi_pooling_cuda_backward,
......
......@@ -289,18 +289,18 @@ void DeformablePSROIPoolForward(const at::Tensor data,
const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data.type(), "deformable_psroi_pool_forward", ([&] {
const scalar_t *bottom_data = data.data<scalar_t>();
const scalar_t *bottom_rois = bbox.data<scalar_t>();
const scalar_t *bottom_trans = no_trans ? NULL : trans.data<scalar_t>();
scalar_t *top_data = out.data<scalar_t>();
scalar_t *top_count_data = top_count.data<scalar_t>();
DeformablePSROIPoolForwardKernel<<<GET_BLOCKS(count), CUDA_NUM_THREADS>>>(
count, bottom_data, (scalar_t)spatial_scale, channels, height, width, pooled_height, pooled_width,
bottom_rois, bottom_trans, no_trans, (scalar_t)trans_std, sample_per_part, output_dim,
group_size, part_size, num_classes, channels_each_class, top_data, top_count_data);
}));
data.type(), "deformable_psroi_pool_forward", ([&] {
const scalar_t *bottom_data = data.data<scalar_t>();
const scalar_t *bottom_rois = bbox.data<scalar_t>();
const scalar_t *bottom_trans = no_trans ? NULL : trans.data<scalar_t>();
scalar_t *top_data = out.data<scalar_t>();
scalar_t *top_count_data = top_count.data<scalar_t>();
DeformablePSROIPoolForwardKernel<<<GET_BLOCKS(count), CUDA_NUM_THREADS>>>(
count, bottom_data, (scalar_t)spatial_scale, channels, height, width, pooled_height, pooled_width,
bottom_rois, bottom_trans, no_trans, (scalar_t)trans_std, sample_per_part, output_dim,
group_size, part_size, num_classes, channels_each_class, top_data, top_count_data);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
......@@ -356,7 +356,6 @@ void DeformablePSROIPoolBackwardAcc(const at::Tensor out_grad,
group_size, part_size, num_classes, channels_each_class);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment