From 9e3dcd0da64f3f3e1e81d329876f0c850ee3619c Mon Sep 17 00:00:00 2001 From: Jerry Jiarui XU <xvjiarui0826@gmail.com> Date: Sat, 21 Mar 2020 18:50:59 +0800 Subject: [PATCH] refactor ops: unify cuda ops api (#2277) --- mmdet/ops/affine_grid/affine_grid.py | 6 +- mmdet/ops/affine_grid/src/affine_grid_ext.cpp | 23 +++ .../affine_grid_cpu.cpp} | 7 - mmdet/ops/carafe/carafe.py | 25 ++- mmdet/ops/carafe/src/carafe_ext.cpp | 57 ++++++ mmdet/ops/carafe/src/carafe_naive_ext.cpp | 51 ++++++ .../ops/carafe/src/{ => cuda}/carafe_cuda.cpp | 5 - .../src/{ => cuda}/carafe_cuda_kernel.cu | 0 .../src/{ => cuda}/carafe_naive_cuda.cpp | 6 - .../{ => cuda}/carafe_naive_cuda_kernel.cu | 0 mmdet/ops/dcn/deform_conv.py | 12 +- mmdet/ops/dcn/deform_pool.py | 6 +- .../dcn/src/{ => cuda}/deform_conv_cuda.cpp | 16 -- .../src/{ => cuda}/deform_conv_cuda_kernel.cu | 0 .../dcn/src/{ => cuda}/deform_pool_cuda.cpp | 8 - .../src/{ => cuda}/deform_pool_cuda_kernel.cu | 0 mmdet/ops/dcn/src/deform_conv_ext.cpp | 163 ++++++++++++++++++ mmdet/ops/dcn/src/deform_pool_ext.cpp | 71 ++++++++ mmdet/ops/grid_sampler/grid_sampler.py | 33 +--- .../src/cudnn/grid_sampler_cudnn.cpp | 148 ---------------- ...{grid_sampler.cpp => grid_sampler_ext.cpp} | 71 ++++++-- mmdet/ops/masked_conv/masked_conv.py | 16 +- .../src/{ => cuda}/masked_conv2d_cuda.cpp | 7 - .../src/{ => cuda}/masked_conv2d_kernel.cu | 0 .../ops/masked_conv/src/masked_conv2d_ext.cpp | 54 ++++++ mmdet/ops/nms/nms_wrapper.py | 8 +- mmdet/ops/nms/src/{ => cpu}/nms_cpu.cpp | 9 +- mmdet/ops/nms/src/{ => cuda}/nms_cuda.cpp | 10 +- mmdet/ops/nms/src/{ => cuda}/nms_kernel.cu | 2 +- mmdet/ops/nms/src/nms_ext.cpp | 38 ++++ mmdet/ops/roi_align/roi_align.py | 26 +-- .../src/{ => cuda}/roi_align_kernel.cu | 2 +- .../src/{ => cuda}/roi_align_kernel_v2.cu | 0 .../{roi_align_cuda.cpp => roi_align_ext.cpp} | 110 +++++++----- mmdet/ops/roi_pool/roi_pool.py | 10 +- .../src/{ => cuda}/roi_pool_kernel.cu | 0 mmdet/ops/roi_pool/src/roi_pool_cuda.cpp | 88 ---------- mmdet/ops/roi_pool/src/roi_pool_ext.cpp | 104 +++++++++++ .../sigmoid_focal_loss/sigmoid_focal_loss.py | 10 +- .../src/{ => cuda}/sigmoid_focal_loss_cuda.cu | 0 ...al_loss.cpp => sigmoid_focal_loss_ext.cpp} | 14 +- setup.py | 119 +++++++------ 42 files changed, 837 insertions(+), 498 deletions(-) create mode 100644 mmdet/ops/affine_grid/src/affine_grid_ext.cpp rename mmdet/ops/affine_grid/src/{affine_grid_cuda.cpp => cpu/affine_grid_cpu.cpp} (94%) create mode 100644 mmdet/ops/carafe/src/carafe_ext.cpp create mode 100644 mmdet/ops/carafe/src/carafe_naive_ext.cpp rename mmdet/ops/carafe/src/{ => cuda}/carafe_cuda.cpp (96%) rename mmdet/ops/carafe/src/{ => cuda}/carafe_cuda_kernel.cu (100%) rename mmdet/ops/carafe/src/{ => cuda}/carafe_naive_cuda.cpp (92%) rename mmdet/ops/carafe/src/{ => cuda}/carafe_naive_cuda_kernel.cu (100%) rename mmdet/ops/dcn/src/{ => cuda}/deform_conv_cuda.cpp (97%) rename mmdet/ops/dcn/src/{ => cuda}/deform_conv_cuda_kernel.cu (100%) rename mmdet/ops/dcn/src/{ => cuda}/deform_pool_cuda.cpp (92%) rename mmdet/ops/dcn/src/{ => cuda}/deform_pool_cuda_kernel.cu (100%) create mode 100644 mmdet/ops/dcn/src/deform_conv_ext.cpp create mode 100644 mmdet/ops/dcn/src/deform_pool_ext.cpp delete mode 100644 mmdet/ops/grid_sampler/src/cudnn/grid_sampler_cudnn.cpp rename mmdet/ops/grid_sampler/src/{grid_sampler.cpp => grid_sampler_ext.cpp} (51%) rename mmdet/ops/masked_conv/src/{ => cuda}/masked_conv2d_cuda.cpp (91%) rename mmdet/ops/masked_conv/src/{ => cuda}/masked_conv2d_kernel.cu (100%) create mode 100644 mmdet/ops/masked_conv/src/masked_conv2d_ext.cpp rename mmdet/ops/nms/src/{ => cpu}/nms_cpu.cpp (95%) rename mmdet/ops/nms/src/{ => cuda}/nms_cuda.cpp (53%) rename mmdet/ops/nms/src/{ => cuda}/nms_kernel.cu (98%) create mode 100644 mmdet/ops/nms/src/nms_ext.cpp rename mmdet/ops/roi_align/src/{ => cuda}/roi_align_kernel.cu (99%) rename mmdet/ops/roi_align/src/{ => cuda}/roi_align_kernel_v2.cu (100%) rename mmdet/ops/roi_align/src/{roi_align_cuda.cpp => roi_align_ext.cpp} (66%) rename mmdet/ops/roi_pool/src/{ => cuda}/roi_pool_kernel.cu (100%) delete mode 100644 mmdet/ops/roi_pool/src/roi_pool_cuda.cpp create mode 100644 mmdet/ops/roi_pool/src/roi_pool_ext.cpp rename mmdet/ops/sigmoid_focal_loss/src/{ => cuda}/sigmoid_focal_loss_cuda.cu (100%) rename mmdet/ops/sigmoid_focal_loss/src/{sigmoid_focal_loss.cpp => sigmoid_focal_loss_ext.cpp} (87%) diff --git a/mmdet/ops/affine_grid/affine_grid.py b/mmdet/ops/affine_grid/affine_grid.py index 94bacb5e..7c24fa79 100644 --- a/mmdet/ops/affine_grid/affine_grid.py +++ b/mmdet/ops/affine_grid/affine_grid.py @@ -3,7 +3,7 @@ import torch.nn.functional as F from torch.autograd import Function from torch.autograd.function import once_differentiable -from . import affine_grid_cuda +from . import affine_grid_ext class _AffineGridGenerator(Function): @@ -15,7 +15,7 @@ class _AffineGridGenerator(Function): ctx.size = size ctx.align_corners = align_corners - func = affine_grid_cuda.affine_grid_generator_forward + func = affine_grid_ext.affine_grid_generator_forward output = func(theta, size, align_corners) @@ -28,7 +28,7 @@ class _AffineGridGenerator(Function): size = ctx.size align_corners = ctx.align_corners - func = affine_grid_cuda.affine_grid_generator_backward + func = affine_grid_ext.affine_grid_generator_backward grad_input = func(grad_output, theta, size, align_corners) diff --git a/mmdet/ops/affine_grid/src/affine_grid_ext.cpp b/mmdet/ops/affine_grid/src/affine_grid_ext.cpp new file mode 100644 index 00000000..cc5c80d7 --- /dev/null +++ b/mmdet/ops/affine_grid/src/affine_grid_ext.cpp @@ -0,0 +1,23 @@ +// Modified from https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AffineGridGenerator.cpp +#include <ATen/ATen.h> +#include <ATen/NativeFunctions.h> +#include <torch/extension.h> + +namespace mmdetection { + +using namespace at; + +Tensor affine_grid_generator_forward(const Tensor &theta, IntArrayRef size, + bool align_corners); + +Tensor affine_grid_generator_backward(const Tensor &grad, IntArrayRef size, + bool align_corners); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { +m.def("affine_grid_generator_forward", &affine_grid_generator_forward, +"affine_grid_generator_forward"); +m.def("affine_grid_generator_backward", &affine_grid_generator_backward, +"affine_grid_generator_backward"); +} + +} // namespace mmdetection diff --git a/mmdet/ops/affine_grid/src/affine_grid_cuda.cpp b/mmdet/ops/affine_grid/src/cpu/affine_grid_cpu.cpp similarity index 94% rename from mmdet/ops/affine_grid/src/affine_grid_cuda.cpp rename to mmdet/ops/affine_grid/src/cpu/affine_grid_cpu.cpp index 3874128c..51434604 100644 --- a/mmdet/ops/affine_grid/src/affine_grid_cuda.cpp +++ b/mmdet/ops/affine_grid/src/cpu/affine_grid_cpu.cpp @@ -105,11 +105,4 @@ Tensor affine_grid_generator_backward(const Tensor& grad, IntArrayRef size, size[3], size[4], align_corners); } } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("affine_grid_generator_forward", &affine_grid_generator_forward, - "affine_grid_generator_forward"); - m.def("affine_grid_generator_backward", &affine_grid_generator_backward, - "affine_grid_generator_backward"); -} - } // namespace mmdetection diff --git a/mmdet/ops/carafe/carafe.py b/mmdet/ops/carafe/carafe.py index 2c81735b..d8a15c3e 100644 --- a/mmdet/ops/carafe/carafe.py +++ b/mmdet/ops/carafe/carafe.py @@ -5,7 +5,7 @@ from mmcv.cnn import normal_init, xavier_init from torch.autograd import Function from torch.nn.modules.module import Module -from . import carafe_cuda, carafe_naive_cuda +from . import carafe_ext, carafe_naive_ext class CARAFENaiveFunction(Function): @@ -27,8 +27,8 @@ class CARAFENaiveFunction(Function): n, c, h, w = features.size() output = features.new_zeros((n, c, h * scale_factor, w * scale_factor)) if features.is_cuda: - carafe_naive_cuda.forward(features, masks, kernel_size, group_size, - scale_factor, output) + carafe_naive_ext.forward(features, masks, kernel_size, group_size, + scale_factor, output) else: raise NotImplementedError @@ -47,9 +47,9 @@ class CARAFENaiveFunction(Function): grad_input = torch.zeros_like(features) grad_masks = torch.zeros_like(masks) - carafe_naive_cuda.backward(grad_output.contiguous(), features, masks, - kernel_size, group_size, scale_factor, - grad_input, grad_masks) + carafe_naive_ext.backward(grad_output.contiguous(), features, masks, + kernel_size, group_size, scale_factor, + grad_input, grad_masks) return grad_input, grad_masks, None, None, None @@ -95,9 +95,8 @@ class CARAFEFunction(Function): rfeatures = features.new_zeros(features.size(), requires_grad=False) rmasks = masks.new_zeros(masks.size(), requires_grad=False) if features.is_cuda: - carafe_cuda.forward(features, rfeatures, masks, rmasks, - kernel_size, group_size, scale_factor, routput, - output) + carafe_ext.forward(features, rfeatures, masks, rmasks, kernel_size, + group_size, scale_factor, routput, output) else: raise NotImplementedError @@ -120,10 +119,10 @@ class CARAFEFunction(Function): rgrad_masks = torch.zeros_like(masks, requires_grad=False) grad_input = torch.zeros_like(features, requires_grad=False) grad_masks = torch.zeros_like(masks, requires_grad=False) - carafe_cuda.backward(grad_output.contiguous(), rfeatures, masks, - kernel_size, group_size, scale_factor, - rgrad_output, rgrad_input_hs, rgrad_input, - rgrad_masks, grad_input, grad_masks) + carafe_ext.backward(grad_output.contiguous(), rfeatures, masks, + kernel_size, group_size, scale_factor, + rgrad_output, rgrad_input_hs, rgrad_input, + rgrad_masks, grad_input, grad_masks) return grad_input, grad_masks, None, None, None, None diff --git a/mmdet/ops/carafe/src/carafe_ext.cpp b/mmdet/ops/carafe/src/carafe_ext.cpp new file mode 100644 index 00000000..5bee3daf --- /dev/null +++ b/mmdet/ops/carafe/src/carafe_ext.cpp @@ -0,0 +1,57 @@ +#include <ATen/ATen.h> +#include <torch/extension.h> + +#include <cmath> +#include <vector> + +#ifdef WITH_CUDA +int carafe_forward_cuda(at::Tensor features, at::Tensor rfeatures, + at::Tensor masks, at::Tensor rmasks, int kernel_size, + int group_size, int scale_factor, at::Tensor routput, + at::Tensor output); + +int carafe_backward_cuda(at::Tensor top_grad, at::Tensor rfeatures, + at::Tensor masks, int kernel_size, int group_size, + int scale_factor, at::Tensor rtop_grad, + at::Tensor rbottom_grad_hs, at::Tensor rbottom_grad, + at::Tensor rmask_grad, at::Tensor bottom_grad, + at::Tensor mask_grad); +#endif + +int carafe_forward(at::Tensor features, at::Tensor rfeatures, + at::Tensor masks, at::Tensor rmasks, int kernel_size, + int group_size, int scale_factor, at::Tensor routput, + at::Tensor output) { + if (features.type().is_cuda()) { +#ifdef WITH_CUDA + return carafe_forward_cuda(features, rfeatures, masks, rmasks, kernel_size, + group_size, scale_factor, routput, output); +#else + AT_ERROR("carafe is not compiled with GPU support"); +#endif + } + AT_ERROR("carafe is not implemented on CPU"); +} + +int carafe_backward(at::Tensor top_grad, at::Tensor rfeatures, + at::Tensor masks, int kernel_size, int group_size, + int scale_factor, at::Tensor rtop_grad, + at::Tensor rbottom_grad_hs, at::Tensor rbottom_grad, + at::Tensor rmask_grad, at::Tensor bottom_grad, + at::Tensor mask_grad) { + if (top_grad.type().is_cuda()) { +#ifdef WITH_CUDA + return carafe_backward_cuda(top_grad, rfeatures, masks, kernel_size, + group_size, scale_factor, rtop_grad, rbottom_grad_hs, rbottom_grad, + rmask_grad, bottom_grad, mask_grad); +#else + AT_ERROR("carafe is not compiled with GPU support"); +#endif + } + AT_ERROR("carafe is not implemented on CPU"); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &carafe_forward, "carafe forward"); + m.def("backward", &carafe_backward, "carafe backward"); +} diff --git a/mmdet/ops/carafe/src/carafe_naive_ext.cpp b/mmdet/ops/carafe/src/carafe_naive_ext.cpp new file mode 100644 index 00000000..06fe912a --- /dev/null +++ b/mmdet/ops/carafe/src/carafe_naive_ext.cpp @@ -0,0 +1,51 @@ +#include <ATen/ATen.h> +#include <torch/torch.h> + +#include <cmath> +#include <vector> + +#ifdef WITH_CUDA +int carafe_naive_forward_cuda(at::Tensor features, at::Tensor masks, + int kernel_size, int group_size, int scale_factor, + at::Tensor output); + +int carafe_naive_backward_cuda(at::Tensor top_grad, at::Tensor features, + at::Tensor masks, int kernel_size, + int group_size, int scale_factor, + at::Tensor bottom_grad, at::Tensor mask_grad); +#endif + +int carafe_naive_forward(at::Tensor features, at::Tensor masks, + int kernel_size, int group_size, int scale_factor, + at::Tensor output) { + if (features.type().is_cuda()) { +#ifdef WITH_CUDA + return carafe_naive_forward_cuda(features, masks, kernel_size, + group_size, scale_factor, output); +#else + AT_ERROR("carafe naive is not compiled with GPU support"); +#endif + } + AT_ERROR("carafe naive is not implemented on CPU"); +} + +int carafe_naive_backward(at::Tensor top_grad, at::Tensor features, + at::Tensor masks, int kernel_size, + int group_size, int scale_factor, + at::Tensor bottom_grad, at::Tensor mask_grad) { + if (top_grad.type().is_cuda()) { +#ifdef WITH_CUDA + return carafe_naive_backward_cuda(top_grad, features, masks, kernel_size, + group_size, scale_factor, bottom_grad, mask_grad); +#else + AT_ERROR("carafe naive is not compiled with GPU support"); +#endif + } + AT_ERROR("carafe naive is not implemented on CPU"); + +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &carafe_naive_forward, "carafe_naive forward"); + m.def("backward", &carafe_naive_backward, "carafe_naive backward"); +} diff --git a/mmdet/ops/carafe/src/carafe_cuda.cpp b/mmdet/ops/carafe/src/cuda/carafe_cuda.cpp similarity index 96% rename from mmdet/ops/carafe/src/carafe_cuda.cpp rename to mmdet/ops/carafe/src/cuda/carafe_cuda.cpp index 9a7c73af..28d890f5 100644 --- a/mmdet/ops/carafe/src/carafe_cuda.cpp +++ b/mmdet/ops/carafe/src/cuda/carafe_cuda.cpp @@ -106,8 +106,3 @@ int carafe_backward_cuda(at::Tensor top_grad, at::Tensor rfeatures, return 1; } - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &carafe_forward_cuda, "carafe forward (CUDA)"); - m.def("backward", &carafe_backward_cuda, "carafe backward (CUDA)"); -} diff --git a/mmdet/ops/carafe/src/carafe_cuda_kernel.cu b/mmdet/ops/carafe/src/cuda/carafe_cuda_kernel.cu similarity index 100% rename from mmdet/ops/carafe/src/carafe_cuda_kernel.cu rename to mmdet/ops/carafe/src/cuda/carafe_cuda_kernel.cu diff --git a/mmdet/ops/carafe/src/carafe_naive_cuda.cpp b/mmdet/ops/carafe/src/cuda/carafe_naive_cuda.cpp similarity index 92% rename from mmdet/ops/carafe/src/carafe_naive_cuda.cpp rename to mmdet/ops/carafe/src/cuda/carafe_naive_cuda.cpp index fbcda80e..611f1d11 100644 --- a/mmdet/ops/carafe/src/carafe_naive_cuda.cpp +++ b/mmdet/ops/carafe/src/cuda/carafe_naive_cuda.cpp @@ -67,9 +67,3 @@ int carafe_naive_backward_cuda(at::Tensor top_grad, at::Tensor features, return 1; } - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &carafe_naive_forward_cuda, "carafe_naive forward (CUDA)"); - m.def("backward", &carafe_naive_backward_cuda, - "carafe_naive backward (CUDA)"); -} diff --git a/mmdet/ops/carafe/src/carafe_naive_cuda_kernel.cu b/mmdet/ops/carafe/src/cuda/carafe_naive_cuda_kernel.cu similarity index 100% rename from mmdet/ops/carafe/src/carafe_naive_cuda_kernel.cu rename to mmdet/ops/carafe/src/cuda/carafe_naive_cuda_kernel.cu diff --git a/mmdet/ops/dcn/deform_conv.py b/mmdet/ops/dcn/deform_conv.py index c4d3f192..36ab443a 100644 --- a/mmdet/ops/dcn/deform_conv.py +++ b/mmdet/ops/dcn/deform_conv.py @@ -8,7 +8,7 @@ from torch.autograd.function import once_differentiable from torch.nn.modules.utils import _pair, _single from mmdet.utils import print_log -from . import deform_conv_cuda +from . import deform_conv_ext class DeformConvFunction(Function): @@ -49,7 +49,7 @@ class DeformConvFunction(Function): cur_im2col_step = min(ctx.im2col_step, input.shape[0]) assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize' - deform_conv_cuda.deform_conv_forward_cuda( + deform_conv_ext.deform_conv_forward( input, weight, offset, output, ctx.bufs_[0], ctx.bufs_[1], weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1], ctx.padding[0], ctx.dilation[1], @@ -74,7 +74,7 @@ class DeformConvFunction(Function): if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: grad_input = torch.zeros_like(input) grad_offset = torch.zeros_like(offset) - deform_conv_cuda.deform_conv_backward_input_cuda( + deform_conv_ext.deform_conv_backward_input( input, offset, grad_output, grad_input, grad_offset, weight, ctx.bufs_[0], weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0], @@ -84,7 +84,7 @@ class DeformConvFunction(Function): if ctx.needs_input_grad[2]: grad_weight = torch.zeros_like(weight) - deform_conv_cuda.deform_conv_backward_parameters_cuda( + deform_conv_ext.deform_conv_backward_parameters( input, offset, grad_output, grad_weight, ctx.bufs_[0], ctx.bufs_[1], weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0], @@ -142,7 +142,7 @@ class ModulatedDeformConvFunction(Function): output = input.new_empty( ModulatedDeformConvFunction._infer_shape(ctx, input, weight)) ctx._bufs = [input.new_empty(0), input.new_empty(0)] - deform_conv_cuda.modulated_deform_conv_cuda_forward( + deform_conv_ext.modulated_deform_conv_forward( input, weight, bias, ctx._bufs[0], offset, mask, output, ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride, ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, @@ -160,7 +160,7 @@ class ModulatedDeformConvFunction(Function): grad_mask = torch.zeros_like(mask) grad_weight = torch.zeros_like(weight) grad_bias = torch.zeros_like(bias) - deform_conv_cuda.modulated_deform_conv_cuda_backward( + deform_conv_ext.modulated_deform_conv_backward( input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1], grad_input, grad_weight, grad_bias, grad_offset, grad_mask, grad_output, weight.shape[2], weight.shape[3], ctx.stride, diff --git a/mmdet/ops/dcn/deform_pool.py b/mmdet/ops/dcn/deform_pool.py index a3eee759..a0ccd607 100644 --- a/mmdet/ops/dcn/deform_pool.py +++ b/mmdet/ops/dcn/deform_pool.py @@ -4,7 +4,7 @@ from torch.autograd import Function from torch.autograd.function import once_differentiable from torch.nn.modules.utils import _pair -from . import deform_pool_cuda +from . import deform_pool_ext class DeformRoIPoolingFunction(Function): @@ -44,7 +44,7 @@ class DeformRoIPoolingFunction(Function): n = rois.shape[0] output = data.new_empty(n, out_channels, out_size, out_size) output_count = data.new_empty(n, out_channels, out_size, out_size) - deform_pool_cuda.deform_psroi_pooling_cuda_forward( + deform_pool_ext.deform_psroi_pooling_forward( data, rois, offset, output, output_count, ctx.no_trans, ctx.spatial_scale, ctx.out_channels, ctx.group_size, ctx.out_size, ctx.part_size, ctx.sample_per_part, ctx.trans_std) @@ -67,7 +67,7 @@ class DeformRoIPoolingFunction(Function): grad_rois = None grad_offset = torch.zeros_like(offset) - deform_pool_cuda.deform_psroi_pooling_cuda_backward( + deform_pool_ext.deform_psroi_pooling_backward( grad_output, data, rois, offset, output_count, grad_input, grad_offset, ctx.no_trans, ctx.spatial_scale, ctx.out_channels, ctx.group_size, ctx.out_size, ctx.part_size, ctx.sample_per_part, diff --git a/mmdet/ops/dcn/src/deform_conv_cuda.cpp b/mmdet/ops/dcn/src/cuda/deform_conv_cuda.cpp similarity index 97% rename from mmdet/ops/dcn/src/deform_conv_cuda.cpp rename to mmdet/ops/dcn/src/cuda/deform_conv_cuda.cpp index 2321e023..8601eb3b 100644 --- a/mmdet/ops/dcn/src/deform_conv_cuda.cpp +++ b/mmdet/ops/dcn/src/cuda/deform_conv_cuda.cpp @@ -683,19 +683,3 @@ void modulated_deform_conv_cuda_backward( grad_output.size(2), grad_output.size(3), grad_output.size(4)}); } - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("deform_conv_forward_cuda", &deform_conv_forward_cuda, - "deform forward (CUDA)"); - m.def("deform_conv_backward_input_cuda", &deform_conv_backward_input_cuda, - "deform_conv_backward_input (CUDA)"); - m.def("deform_conv_backward_parameters_cuda", - &deform_conv_backward_parameters_cuda, - "deform_conv_backward_parameters (CUDA)"); - m.def("modulated_deform_conv_cuda_forward", - &modulated_deform_conv_cuda_forward, - "modulated deform conv forward (CUDA)"); - m.def("modulated_deform_conv_cuda_backward", - &modulated_deform_conv_cuda_backward, - "modulated deform conv backward (CUDA)"); -} diff --git a/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu b/mmdet/ops/dcn/src/cuda/deform_conv_cuda_kernel.cu similarity index 100% rename from mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu rename to mmdet/ops/dcn/src/cuda/deform_conv_cuda_kernel.cu diff --git a/mmdet/ops/dcn/src/deform_pool_cuda.cpp b/mmdet/ops/dcn/src/cuda/deform_pool_cuda.cpp similarity index 92% rename from mmdet/ops/dcn/src/deform_pool_cuda.cpp rename to mmdet/ops/dcn/src/cuda/deform_pool_cuda.cpp index 9e0e3ffc..d7ed3f63 100644 --- a/mmdet/ops/dcn/src/deform_pool_cuda.cpp +++ b/mmdet/ops/dcn/src/cuda/deform_pool_cuda.cpp @@ -80,11 +80,3 @@ void deform_psroi_pooling_cuda_backward( spatial_scale, output_dim, group_size, pooled_size, part_size, sample_per_part, trans_std); } - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("deform_psroi_pooling_cuda_forward", &deform_psroi_pooling_cuda_forward, - "deform psroi pooling forward(CUDA)"); - m.def("deform_psroi_pooling_cuda_backward", - &deform_psroi_pooling_cuda_backward, - "deform psroi pooling backward(CUDA)"); -} diff --git a/mmdet/ops/dcn/src/deform_pool_cuda_kernel.cu b/mmdet/ops/dcn/src/cuda/deform_pool_cuda_kernel.cu similarity index 100% rename from mmdet/ops/dcn/src/deform_pool_cuda_kernel.cu rename to mmdet/ops/dcn/src/cuda/deform_pool_cuda_kernel.cu diff --git a/mmdet/ops/dcn/src/deform_conv_ext.cpp b/mmdet/ops/dcn/src/deform_conv_ext.cpp new file mode 100644 index 00000000..2beaeffc --- /dev/null +++ b/mmdet/ops/dcn/src/deform_conv_ext.cpp @@ -0,0 +1,163 @@ +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c + +#include <torch/extension.h> +#include <ATen/DeviceGuard.h> + +#include <cmath> +#include <vector> + +#ifdef WITH_CUDA +int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, + at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step); + +int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset, + at::Tensor gradOutput, at::Tensor gradInput, + at::Tensor gradOffset, at::Tensor weight, + at::Tensor columns, int kW, int kH, int dW, + int dH, int padW, int padH, int dilationW, + int dilationH, int group, + int deformable_group, int im2col_step); + +int deform_conv_backward_parameters_cuda( + at::Tensor input, at::Tensor offset, at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, + int padW, int padH, int dilationW, int dilationH, int group, + int deformable_group, float scale, int im2col_step); + +void modulated_deform_conv_cuda_forward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, + int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, const int deformable_group, + const bool with_bias); + +void modulated_deform_conv_cuda_backward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor columns, + at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, + at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, + const bool with_bias); +#endif + +int deform_conv_forward(at::Tensor input, at::Tensor weight, + at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step) { + if (input.type().is_cuda()) { +#ifdef WITH_CUDA + return deform_conv_forward_cuda(input, weight, offset, output, columns, + ones, kW, kH, dW, dH, padW, padH, dilationW, dilationH, group, + deformable_group, im2col_step); +#else + AT_ERROR("deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("deform conv is not implemented on CPU"); +} + +int deform_conv_backward_input(at::Tensor input, at::Tensor offset, + at::Tensor gradOutput, at::Tensor gradInput, + at::Tensor gradOffset, at::Tensor weight, + at::Tensor columns, int kW, int kH, int dW, + int dH, int padW, int padH, int dilationW, + int dilationH, int group, + int deformable_group, int im2col_step) { + if (input.type().is_cuda()) { +#ifdef WITH_CUDA + return deform_conv_backward_input_cuda(input, offset, gradOutput, + gradInput, gradOffset, weight, columns, kW, kH, dW, dH, padW, padH, + dilationW, dilationH, group, deformable_group, im2col_step); +#else + AT_ERROR("deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("deform conv is not implemented on CPU"); +} + +int deform_conv_backward_parameters( + at::Tensor input, at::Tensor offset, at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, + int padW, int padH, int dilationW, int dilationH, int group, + int deformable_group, float scale, int im2col_step) { + if (input.type().is_cuda()) { +#ifdef WITH_CUDA + return deform_conv_backward_parameters_cuda(input, offset, gradOutput, + gradWeight, columns, ones, kW, kH, dW, dH, padW, padH, dilationW, + dilationH, group, deformable_group, scale, im2col_step); +#else + AT_ERROR("deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("deform conv is not implemented on CPU"); +} + +void modulated_deform_conv_forward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, + int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, const int deformable_group, + const bool with_bias) { + if (input.type().is_cuda()) { +#ifdef WITH_CUDA + return modulated_deform_conv_cuda_forward(input, weight, bias, ones, + offset, mask, output, columns, kernel_h, kernel_w, stride_h, + stride_w, pad_h, pad_w, dilation_h, dilation_w, group, + deformable_group, with_bias); +#else + AT_ERROR("modulated deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("modulated deform conv is not implemented on CPU"); +} + +void modulated_deform_conv_backward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor columns, + at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, + at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, + const bool with_bias) { + if (input.type().is_cuda()) { +#ifdef WITH_CUDA + return modulated_deform_conv_cuda_backward(input, weight, bias, ones, + offset, mask, columns, grad_input, grad_weight, grad_bias, grad_offset, + grad_mask, grad_output, kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, group, deformable_group, + with_bias); +#else + AT_ERROR("modulated deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("modulated deform conv is not implemented on CPU"); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("deform_conv_forward", &deform_conv_forward, + "deform forward"); + m.def("deform_conv_backward_input", &deform_conv_backward_input, + "deform_conv_backward_input"); + m.def("deform_conv_backward_parameters", + &deform_conv_backward_parameters, + "deform_conv_backward_parameters"); + m.def("modulated_deform_conv_forward", + &modulated_deform_conv_forward, + "modulated deform conv forward"); + m.def("modulated_deform_conv_backward", + &modulated_deform_conv_backward, + "modulated deform conv backward"); +} diff --git a/mmdet/ops/dcn/src/deform_pool_ext.cpp b/mmdet/ops/dcn/src/deform_pool_ext.cpp new file mode 100644 index 00000000..f590fabe --- /dev/null +++ b/mmdet/ops/dcn/src/deform_pool_ext.cpp @@ -0,0 +1,71 @@ +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/modulated_dcn_cuda.c + +// based on +// author: Charles Shang +// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu + +#include <torch/extension.h> +#include <ATen/DeviceGuard.h> + +#include <cmath> +#include <vector> + +#ifdef WITH_CUDA +void deform_psroi_pooling_cuda_forward( + at::Tensor input, at::Tensor bbox, at::Tensor trans, at::Tensor out, + at::Tensor top_count, const int no_trans, const float spatial_scale, + const int output_dim, const int group_size, const int pooled_size, + const int part_size, const int sample_per_part, const float trans_std); + +void deform_psroi_pooling_cuda_backward( + at::Tensor out_grad, at::Tensor input, at::Tensor bbox, at::Tensor trans, + at::Tensor top_count, at::Tensor input_grad, at::Tensor trans_grad, + const int no_trans, const float spatial_scale, const int output_dim, + const int group_size, const int pooled_size, const int part_size, + const int sample_per_part, const float trans_std); +#endif + +void deform_psroi_pooling_forward( + at::Tensor input, at::Tensor bbox, at::Tensor trans, at::Tensor out, + at::Tensor top_count, const int no_trans, const float spatial_scale, + const int output_dim, const int group_size, const int pooled_size, + const int part_size, const int sample_per_part, const float trans_std) { + if (input.type().is_cuda()) { +#ifdef WITH_CUDA + return deform_psroi_pooling_cuda_forward(input, bbox, trans, out, top_count, + no_trans, spatial_scale, output_dim, group_size, pooled_size, + part_size, sample_per_part, trans_std); +#else + AT_ERROR("deform psroi pooling is not compiled with GPU support"); +#endif + } + AT_ERROR("deform psroi pooling is not implemented on CPU"); +} + +void deform_psroi_pooling_backward( + at::Tensor out_grad, at::Tensor input, at::Tensor bbox, at::Tensor trans, + at::Tensor top_count, at::Tensor input_grad, at::Tensor trans_grad, + const int no_trans, const float spatial_scale, const int output_dim, + const int group_size, const int pooled_size, const int part_size, + const int sample_per_part, const float trans_std) { + if (input.type().is_cuda()) { +#ifdef WITH_CUDA + return deform_psroi_pooling_cuda_backward(out_grad, input, bbox, trans, + top_count, input_grad, trans_grad, no_trans, spatial_scale, + output_dim, group_size, pooled_size, part_size, sample_per_part, + trans_std); +#else + AT_ERROR("deform psroi pooling is not compiled with GPU support"); +#endif + } + AT_ERROR("deform psroi pooling is not implemented on CPU"); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("deform_psroi_pooling_forward", &deform_psroi_pooling_forward, + "deform psroi pooling forward"); + m.def("deform_psroi_pooling_backward", &deform_psroi_pooling_backward, + "deform psroi pooling backward"); +} diff --git a/mmdet/ops/grid_sampler/grid_sampler.py b/mmdet/ops/grid_sampler/grid_sampler.py index a112fa82..b5c59aa4 100644 --- a/mmdet/ops/grid_sampler/grid_sampler.py +++ b/mmdet/ops/grid_sampler/grid_sampler.py @@ -3,7 +3,7 @@ import torch.nn.functional as F from torch.autograd import Function from torch.autograd.function import once_differentiable -from . import grid_sampler_cuda +from . import grid_sampler_ext class _GridSampler(Function): @@ -16,18 +16,9 @@ class _GridSampler(Function): ctx.padding_mode_enum = padding_mode_enum ctx.align_corners = align_corners - if input.is_cuda: - if input.dim() == 4: - func = grid_sampler_cuda.grid_sampler_2d_forward_cuda - else: - func = grid_sampler_cuda.grid_sampler_3d_forward_cuda - else: - if input.dim() == 4: - func = grid_sampler_cuda.grid_sampler_2d_forward_cpu - else: - func = grid_sampler_cuda.grid_sampler_3d_forward_cpu - - output = func(input, grid, mode_enum, padding_mode_enum, align_corners) + output = grid_sampler_ext.grid_sampler_forward(input, grid, mode_enum, + padding_mode_enum, + align_corners) return output @@ -39,19 +30,9 @@ class _GridSampler(Function): padding_mode_enum = ctx.padding_mode_enum align_corners = ctx.align_corners - if input.is_cuda: - if input.dim() == 4: - func = grid_sampler_cuda.grid_sampler_2d_backward_cuda - else: - func = grid_sampler_cuda.grid_sampler_3d_backward_cuda - else: - if input.dim() == 4: - func = grid_sampler_cuda.grid_sampler_2d_backward_cpu - else: - func = grid_sampler_cuda.grid_sampler_3d_backward_cpu - - grad_input, grad_grid = func(grad_output, input, grid, mode_enum, - padding_mode_enum, align_corners) + grad_input, grad_grid = grid_sampler_ext.grid_sampler_backward( + grad_output, input, grid, mode_enum, padding_mode_enum, + align_corners) return grad_input, grad_grid, None, None, None diff --git a/mmdet/ops/grid_sampler/src/cudnn/grid_sampler_cudnn.cpp b/mmdet/ops/grid_sampler/src/cudnn/grid_sampler_cudnn.cpp deleted file mode 100644 index f684abf1..00000000 --- a/mmdet/ops/grid_sampler/src/cudnn/grid_sampler_cudnn.cpp +++ /dev/null @@ -1,148 +0,0 @@ -#include <ATen/ATen.h> -#include <ATen/NativeFunctions.h> -#include <ATen/Config.h> -#include <ATen/cuda/CUDAConfig.h> - -#if !AT_CUDNN_ENABLED() - -namespace at { namespace native { - -// See Note [ATen preprocessor philosophy] - -Tensor cudnn_grid_sampler_forward( - const Tensor& input_t, const Tensor& grid_t) { - AT_ERROR("cudnn_grid_sampler_forward: ATen not compiled with cuDNN support"); -} - -std::tuple<Tensor, Tensor> cudnn_grid_sampler_backward( - const Tensor& input_t, const Tensor& grid_t, - const Tensor& grad_output_t) { - AT_ERROR("cudnn_grid_sampler_backward: ATen not compiled with cuDNN support"); -} - -}} - -#else // AT_CUDNN_ENABLED - -#include <ATen/cudnn/Descriptors.h> -#include <ATen/cudnn/Types.h> -#include <ATen/cudnn/Utils.h> -#include <ATen/cuda/Exceptions.h> - -#include <ATen/TensorUtils.h> - -// TODO: descriptor checking - - -namespace mmdetection { - -using namespace at; - -namespace { - -void setSamplerDescriptor(SpatialTransformerDescriptor& desc, cudnnDataType_t dataType, const at::Tensor& tensor) -{ - int inputSize[4] = {0}; - for (int i = 0; i < tensor.dim(); ++i) { - inputSize[i] = (int) tensor.size(i); - } - desc.set(dataType, 4, inputSize); -} - -void checkGridSize(CheckedFrom c, TensorArg grid, TensorArg input) -{ - // assert size of grid is n*h*w*2 - // FYI: grid is between [-1, 1], where -1 left most pixel, - // 1 represents right most pixel (and hence 0 is the center pixel) - // if grid has values >1 or <-1, those values are ignored - checkContiguous(c, grid); - checkDim(c, grid, 4); - // TODO: Maybe more user friendly to report where the expected size - // came from - checkSize(c, grid, 0, input->size(0)); - checkSize(c, grid, 3, 2); -} - -} // namespace - -Tensor cudnn_grid_sampler_forward( - const Tensor& input_t, const Tensor& grid_t) -{ - TensorArg input{ contiguousIfZeroInStrides(input_t), "input", 1 }, - grid{ grid_t.contiguous(), "grid", 2 }; - CheckedFrom c = "cudnn_grid_sampler_forward"; - checkAllSameGPU(c, {input, grid}); - checkAllSameType(c, {input, grid}); - checkGridSize(c, grid, input); - checkDim(c, input, 4); - - auto output_t = at::empty({0}, input->options()); - output_t.resize_({input->size(0), input->size(1), grid->size(1), grid->size(2)}); - - TensorDescriptor idesc{ *input }; // input descriptor - TensorDescriptor odesc{ output_t }; // output descriptor - SpatialTransformerDescriptor desc; // sampler descriptor - - auto handle = getCudnnHandle(); - auto dataType = getCudnnDataType(*input); - setSamplerDescriptor(desc, dataType, output_t); - - Constant one(dataType, 1); - Constant zero(dataType, 0); - AT_CUDNN_CHECK(cudnnSpatialTfSamplerForward( - handle, desc.desc(), - &one, idesc.desc(), input->data_ptr(), - grid->data_ptr(), - &zero, odesc.desc(), output_t.data_ptr() - )); - - return output_t; -} - -// NB: CuDNN does not support output mask; you always get both -// gradients. -std::tuple<Tensor, Tensor> cudnn_grid_sampler_backward( - const Tensor& input_t, const Tensor& grid_t, - const Tensor& grad_output_t) -{ - TensorArg input{ contiguousIfZeroInStrides(input_t), "input", 1 }, - grid{ grid_t.contiguous(), "grid", 2 }, - grad_output{ contiguousIfZeroInStrides(grad_output_t), "grad_output", 3 }; - CheckedFrom c = "cudnn_grid_sampler_backward"; - checkAllSameGPU(c, {input, grad_output, grid}); - checkGridSize(c, grid, input); - checkDim(c, input, 4); - checkDim(c, grad_output, 4); - - auto grad_input_t = at::empty({0}, input->options()); - grad_input_t.resize_(input->sizes()); - auto grad_grid_t = at::empty({0}, grid->options()); - grad_grid_t.resize_(grid->sizes()); - - TensorDescriptor idesc{ *input }; // input descriptor - TensorDescriptor odesc{ *grad_output }; // grad_output descriptor - TensorDescriptor gdesc{ grad_input_t }; // grad_input descriptor - SpatialTransformerDescriptor desc; // sampler descriptor - - auto handle = getCudnnHandle(); - auto dataType = getCudnnDataType(*input); - setSamplerDescriptor(desc, dataType, *grad_output); - - Constant one(dataType, 1); - Constant zero(dataType, 0); - AT_CUDNN_CHECK(cudnnSpatialTfSamplerBackward( - handle, desc.desc(), - &one, idesc.desc(), input->data_ptr(), - &zero, gdesc.desc(), grad_input_t.data_ptr(), - &one, odesc.desc(), grad_output->data_ptr(), - // intruigingly, the outputs don't need descriptors - grid->data_ptr(), - &zero, grad_grid_t.data_ptr() - )); - - return std::tuple<Tensor, Tensor>{ grad_input_t, grad_grid_t }; -} - -} // namespace mmdetection - -#endif diff --git a/mmdet/ops/grid_sampler/src/grid_sampler.cpp b/mmdet/ops/grid_sampler/src/grid_sampler_ext.cpp similarity index 51% rename from mmdet/ops/grid_sampler/src/grid_sampler.cpp rename to mmdet/ops/grid_sampler/src/grid_sampler_ext.cpp index 009675fe..7e76a7aa 100644 --- a/mmdet/ops/grid_sampler/src/grid_sampler.cpp +++ b/mmdet/ops/grid_sampler/src/grid_sampler_ext.cpp @@ -27,6 +27,7 @@ grid_sampler_3d_backward_cpu(const Tensor& grad_output, const Tensor& input, const Tensor& grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners); +#ifdef WITH_CUDA // No shape checking needed here. See # NOTE [ grid_sampler Native Functions ]. Tensor grid_sampler_2d_forward_cuda(const Tensor& input, const Tensor& grid, int64_t interpolation_mode, int64_t padding_mode, @@ -48,19 +49,69 @@ std::tuple<Tensor, Tensor> grid_sampler_3d_backward_cuda(const Tensor& grad_output, const Tensor& input, const Tensor& grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners); +#endif +// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ]. +Tensor grid_sampler_forward(const Tensor& input, const Tensor& grid, + int64_t interpolation_mode, int64_t padding_mode, + bool align_corners) { + if (input.dim() == 4) { + if (input.type().is_cuda()) { +#ifdef WITH_CUDA + return grid_sampler_2d_forward_cuda(input, grid, interpolation_mode, + padding_mode, align_corners); +#else + AT_ERROR("grid_sampler is not compiled with GPU support"); +#endif + } + return grid_sampler_2d_forward_cpu(input, grid, interpolation_mode, + padding_mode, align_corners); + } else { + if (input.type().is_cuda()) { +#ifdef WITH_CUDA + return grid_sampler_3d_forward_cuda(input, grid, interpolation_mode, + padding_mode, align_corners); +#else + AT_ERROR("grid_sampler is not compiled with GPU support"); +#endif + } + return grid_sampler_3d_forward_cpu(input, grid, interpolation_mode, + padding_mode, align_corners); + } +} -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - - m.def("grid_sampler_2d_forward_cpu", &grid_sampler_2d_forward_cpu, "grid_sampler_2d_forward (CPU)"); - m.def("grid_sampler_2d_backward_cpu", &grid_sampler_2d_backward_cpu, "grid_sampler_2d_backward (CPU)"); - m.def("grid_sampler_3d_forward_cpu", &grid_sampler_3d_forward_cpu, "grid_sampler_3d_forward (CPU)"); - m.def("grid_sampler_3d_backward_cpu", &grid_sampler_3d_backward_cpu, "grid_sampler_3d_backward (CPU)"); +std::tuple<Tensor, Tensor> +grid_sampler_backward(const Tensor& grad_output, const Tensor& input, + const Tensor& grid, int64_t interpolation_mode, + int64_t padding_mode, bool align_corners) { + if (input.dim() == 4) { + if (input.type().is_cuda()) { +#ifdef WITH_CUDA + return grid_sampler_2d_backward_cuda(grad_output, input, grid, + interpolation_mode, padding_mode, align_corners); +#else + AT_ERROR("grid_sampler is not compiled with GPU support"); +#endif + } + return grid_sampler_2d_backward_cpu(grad_output, input, grid, + interpolation_mode, padding_mode, align_corners); + } else { + if (input.type().is_cuda()) { +#ifdef WITH_CUDA + return grid_sampler_3d_backward_cuda(grad_output, input, grid, + interpolation_mode, padding_mode, align_corners); +#else + AT_ERROR("grid_sampler is not compiled with GPU support"); +#endif + } + return grid_sampler_3d_backward_cpu(grad_output, input, grid, + interpolation_mode, padding_mode, align_corners); + } +} - m.def("grid_sampler_2d_forward_cuda", &grid_sampler_2d_forward_cuda, "grid_sampler_2d_forward (CUDA)"); - m.def("grid_sampler_2d_backward_cuda", &grid_sampler_2d_backward_cuda, "grid_sampler_2d_backward (CUDA)"); - m.def("grid_sampler_3d_forward_cuda", &grid_sampler_3d_forward_cuda, "grid_sampler_3d_forward (CUDA)"); - m.def("grid_sampler_3d_backward_cuda", &grid_sampler_3d_backward_cuda, "grid_sampler_3d_backward (CUDA)"); +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("grid_sampler_forward_cuda", &grid_sampler_forward, "grid_sampler_forward"); + m.def("grid_sampler_backward_cuda", &grid_sampler_backward, "grid_sampler_backward"); } } // namespace mmdetection diff --git a/mmdet/ops/masked_conv/masked_conv.py b/mmdet/ops/masked_conv/masked_conv.py index 7d84f503..06c25a5c 100644 --- a/mmdet/ops/masked_conv/masked_conv.py +++ b/mmdet/ops/masked_conv/masked_conv.py @@ -6,7 +6,7 @@ from torch.autograd import Function from torch.autograd.function import once_differentiable from torch.nn.modules.utils import _pair -from . import masked_conv2d_cuda +from . import masked_conv2d_ext class MaskedConv2dFunction(Function): @@ -40,16 +40,16 @@ class MaskedConv2dFunction(Function): mask_w_idx = mask_inds[:, 1].contiguous() data_col = features.new_zeros(in_channel * kernel_h * kernel_w, mask_inds.size(0)) - masked_conv2d_cuda.masked_im2col_forward(features, mask_h_idx, - mask_w_idx, kernel_h, - kernel_w, pad_h, pad_w, - data_col) + masked_conv2d_ext.masked_im2col_forward(features, mask_h_idx, + mask_w_idx, kernel_h, + kernel_w, pad_h, pad_w, + data_col) masked_output = torch.addmm(1, bias[:, None], 1, weight.view(out_channel, -1), data_col) - masked_conv2d_cuda.masked_col2im_forward(masked_output, mask_h_idx, - mask_w_idx, out_h, out_w, - out_channel, output) + masked_conv2d_ext.masked_col2im_forward(masked_output, mask_h_idx, + mask_w_idx, out_h, out_w, + out_channel, output) return output @staticmethod diff --git a/mmdet/ops/masked_conv/src/masked_conv2d_cuda.cpp b/mmdet/ops/masked_conv/src/cuda/masked_conv2d_cuda.cpp similarity index 91% rename from mmdet/ops/masked_conv/src/masked_conv2d_cuda.cpp rename to mmdet/ops/masked_conv/src/cuda/masked_conv2d_cuda.cpp index 6c2a8f6a..b2850d91 100644 --- a/mmdet/ops/masked_conv/src/masked_conv2d_cuda.cpp +++ b/mmdet/ops/masked_conv/src/cuda/masked_conv2d_cuda.cpp @@ -67,10 +67,3 @@ int masked_col2im_forward_cuda(const at::Tensor col, return 1; } - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("masked_im2col_forward", &masked_im2col_forward_cuda, - "masked_im2col forward (CUDA)"); - m.def("masked_col2im_forward", &masked_col2im_forward_cuda, - "masked_col2im forward (CUDA)"); -} diff --git a/mmdet/ops/masked_conv/src/masked_conv2d_kernel.cu b/mmdet/ops/masked_conv/src/cuda/masked_conv2d_kernel.cu similarity index 100% rename from mmdet/ops/masked_conv/src/masked_conv2d_kernel.cu rename to mmdet/ops/masked_conv/src/cuda/masked_conv2d_kernel.cu diff --git a/mmdet/ops/masked_conv/src/masked_conv2d_ext.cpp b/mmdet/ops/masked_conv/src/masked_conv2d_ext.cpp new file mode 100644 index 00000000..972589fd --- /dev/null +++ b/mmdet/ops/masked_conv/src/masked_conv2d_ext.cpp @@ -0,0 +1,54 @@ +#include <torch/extension.h> + +#include <cmath> +#include <vector> + +#ifdef WITH_CUDA +int masked_im2col_forward_cuda(const at::Tensor im, const at::Tensor mask_h_idx, + const at::Tensor mask_w_idx, const int kernel_h, + const int kernel_w, const int pad_h, + const int pad_w, at::Tensor col); + +int masked_col2im_forward_cuda(const at::Tensor col, + const at::Tensor mask_h_idx, + const at::Tensor mask_w_idx, int height, + int width, int channels, at::Tensor im); +#endif + +int masked_im2col_forward(const at::Tensor im, const at::Tensor mask_h_idx, + const at::Tensor mask_w_idx, const int kernel_h, + const int kernel_w, const int pad_h, + const int pad_w, at::Tensor col) { + if (im.type().is_cuda()) { +#ifdef WITH_CUDA + return masked_im2col_forward_cuda(im, mask_h_idx, mask_w_idx, kernel_h, + kernel_w, pad_h, pad_w, col); +#else + AT_ERROR("masked_im2col is not compiled with GPU support"); +#endif + } + AT_ERROR("masked_im2col is not implemented on CPU"); +} + +int masked_col2im_forward(const at::Tensor col, + const at::Tensor mask_h_idx, + const at::Tensor mask_w_idx, int height, + int width, int channels, at::Tensor im) { + if (col.type().is_cuda()) { +#ifdef WITH_CUDA + return masked_col2im_forward_cuda(col, mask_h_idx, mask_w_idx, height, + width, channels, im); +#else + AT_ERROR("masked_col2im is not compiled with GPU support"); +#endif + } + AT_ERROR("masked_col2im is not implemented on CPU"); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("masked_im2col_forward", &masked_im2col_forward_cuda, + "masked_im2col forward (CUDA)"); + m.def("masked_col2im_forward", &masked_col2im_forward_cuda, + "masked_col2im forward (CUDA)"); +} diff --git a/mmdet/ops/nms/nms_wrapper.py b/mmdet/ops/nms/nms_wrapper.py index 145a218e..a9ebac22 100644 --- a/mmdet/ops/nms/nms_wrapper.py +++ b/mmdet/ops/nms/nms_wrapper.py @@ -1,7 +1,7 @@ import numpy as np import torch -from . import nms_cpu, nms_cuda +from . import nms_ext def nms(dets, iou_thr, device_id=None): @@ -51,9 +51,9 @@ def nms(dets, iou_thr, device_id=None): inds = dets_th.new_zeros(0, dtype=torch.long) else: if dets_th.is_cuda: - inds = nms_cuda.nms(dets_th, iou_thr) + inds = nms_ext.nms(dets_th, iou_thr) else: - inds = nms_cpu.nms(dets_th, iou_thr) + inds = nms_ext.nms(dets_th, iou_thr) if is_numpy: inds = inds.cpu().numpy() @@ -103,7 +103,7 @@ def soft_nms(dets, iou_thr, method='linear', sigma=0.5, min_score=1e-3): method_codes = {'linear': 1, 'gaussian': 2} if method not in method_codes: raise ValueError('Invalid method for SoftNMS: {}'.format(method)) - results = nms_cpu.soft_nms(dets_t, iou_thr, method_codes[method], sigma, + results = nms_ext.soft_nms(dets_t, iou_thr, method_codes[method], sigma, min_score) new_dets = results[:, :5] diff --git a/mmdet/ops/nms/src/nms_cpu.cpp b/mmdet/ops/nms/src/cpu/nms_cpu.cpp similarity index 95% rename from mmdet/ops/nms/src/nms_cpu.cpp rename to mmdet/ops/nms/src/cpu/nms_cpu.cpp index 7a59b32d..1fa589dc 100644 --- a/mmdet/ops/nms/src/nms_cpu.cpp +++ b/mmdet/ops/nms/src/cpu/nms_cpu.cpp @@ -59,7 +59,7 @@ at::Tensor nms_cpu_kernel(const at::Tensor& dets, const float threshold) { return at::nonzero(suppressed_t == 0).squeeze(1); } -at::Tensor nms(const at::Tensor& dets, const float threshold) { +at::Tensor nms_cpu(const at::Tensor& dets, const float threshold) { at::Tensor result; AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms", [&] { result = nms_cpu_kernel<scalar_t>(dets, threshold); @@ -200,7 +200,7 @@ at::Tensor soft_nms_cpu_kernel(const at::Tensor& dets, const float threshold, return result; } -at::Tensor soft_nms(const at::Tensor& dets, const float threshold, +at::Tensor soft_nms_cpu(const at::Tensor& dets, const float threshold, const unsigned char method, const float sigma, const float min_score) { at::Tensor result; AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "soft_nms", [&] { @@ -208,8 +208,3 @@ at::Tensor soft_nms(const at::Tensor& dets, const float threshold, }); return result; } - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("nms", &nms, "non-maximum suppression"); - m.def("soft_nms", &soft_nms, "soft non-maximum suppression"); -} diff --git a/mmdet/ops/nms/src/nms_cuda.cpp b/mmdet/ops/nms/src/cuda/nms_cuda.cpp similarity index 53% rename from mmdet/ops/nms/src/nms_cuda.cpp rename to mmdet/ops/nms/src/cuda/nms_cuda.cpp index 274c7248..61ca93a2 100644 --- a/mmdet/ops/nms/src/nms_cuda.cpp +++ b/mmdet/ops/nms/src/cuda/nms_cuda.cpp @@ -3,15 +3,11 @@ #define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ") -at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh); +at::Tensor nms_cuda_forward(const at::Tensor boxes, float nms_overlap_thresh); -at::Tensor nms(const at::Tensor& dets, const float threshold) { +at::Tensor nms_cuda(const at::Tensor& dets, const float threshold) { CHECK_CUDA(dets); if (dets.numel() == 0) return at::empty({0}, dets.options().dtype(at::kLong).device(at::kCPU)); - return nms_cuda(dets, threshold); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("nms", &nms, "non-maximum suppression"); + return nms_cuda_forward(dets, threshold); } diff --git a/mmdet/ops/nms/src/nms_kernel.cu b/mmdet/ops/nms/src/cuda/nms_kernel.cu similarity index 98% rename from mmdet/ops/nms/src/nms_kernel.cu rename to mmdet/ops/nms/src/cuda/nms_kernel.cu index ada9bea2..8dc98be1 100644 --- a/mmdet/ops/nms/src/nms_kernel.cu +++ b/mmdet/ops/nms/src/cuda/nms_kernel.cu @@ -68,7 +68,7 @@ __global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh, } // boxes is a N x 5 tensor -at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) { +at::Tensor nms_cuda_forward(const at::Tensor boxes, float nms_overlap_thresh) { // Ensure CUDA uses the input tensor device. at::DeviceGuard guard(boxes.device()); diff --git a/mmdet/ops/nms/src/nms_ext.cpp b/mmdet/ops/nms/src/nms_ext.cpp new file mode 100644 index 00000000..6d95303a --- /dev/null +++ b/mmdet/ops/nms/src/nms_ext.cpp @@ -0,0 +1,38 @@ +// Modified from https://github.com/bharatsingh430/soft-nms/blob/master/lib/nms/cpu_nms.pyx, Soft-NMS is added +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#include <torch/extension.h> + +at::Tensor nms_cpu(const at::Tensor& dets, const float threshold); + +at::Tensor soft_nms_cpu(const at::Tensor& dets, const float threshold, + const unsigned char method, const float sigma, const + float min_score); + +#ifdef WITH_CUDA +at::Tensor nms_cuda(const at::Tensor& dets, const float threshold); +#endif + +at::Tensor nms(const at::Tensor& dets, const float threshold){ + if (dets.type().is_cuda()) { +#ifdef WITH_CUDA + return nms_cuda(dets, threshold); +#else + AT_ERROR("nms is not compiled with GPU support"); +#endif + } + return nms_cpu(dets, threshold); +} + +at::Tensor soft_nms(const at::Tensor& dets, const float threshold, + const unsigned char method, const float sigma, const + float min_score) { + if (dets.type().is_cuda()) { + AT_ERROR("soft_nms is not implemented on GPU"); + } + return soft_nms_cpu(dets, threshold, method, sigma, min_score); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("nms", &nms, "non-maximum suppression"); + m.def("soft_nms", &soft_nms, "soft non-maximum suppression"); +} diff --git a/mmdet/ops/roi_align/roi_align.py b/mmdet/ops/roi_align/roi_align.py index e28cb5f9..203c1152 100644 --- a/mmdet/ops/roi_align/roi_align.py +++ b/mmdet/ops/roi_align/roi_align.py @@ -3,7 +3,7 @@ from torch.autograd import Function from torch.autograd.function import once_differentiable from torch.nn.modules.utils import _pair -from . import roi_align_cuda +from . import roi_align_ext class RoIAlignFunction(Function): @@ -32,12 +32,12 @@ class RoIAlignFunction(Function): output = features.new_zeros(num_rois, num_channels, out_h, out_w) - roi_align_cuda.forward_v1(features, rois, out_h, out_w, - spatial_scale, sample_num, output) + roi_align_ext.forward_v1(features, rois, out_h, out_w, + spatial_scale, sample_num, output) else: - output = roi_align_cuda.forward_v2(features, rois, - spatial_scale, out_h, out_w, - sample_num, aligned) + output = roi_align_ext.forward_v2(features, rois, + spatial_scale, out_h, out_w, + sample_num, aligned) else: raise NotImplementedError @@ -62,13 +62,15 @@ class RoIAlignFunction(Function): if ctx.needs_input_grad[0]: grad_input = rois.new_zeros(batch_size, num_channels, data_height, data_width) - roi_align_cuda.backward_v1(grad_output.contiguous(), rois, - out_h, out_w, spatial_scale, - sample_num, grad_input) + roi_align_ext.backward_v1(grad_output.contiguous(), rois, + out_h, out_w, spatial_scale, + sample_num, grad_input) else: - grad_input = roi_align_cuda.backward_v2( - grad_output, rois, spatial_scale, out_h, out_w, batch_size, - num_channels, data_height, data_width, sample_num, aligned) + grad_input = roi_align_ext.backward_v2(grad_output, rois, + spatial_scale, out_h, out_w, + batch_size, num_channels, + data_height, data_width, + sample_num, aligned) return grad_input, grad_rois, None, None, None, None diff --git a/mmdet/ops/roi_align/src/roi_align_kernel.cu b/mmdet/ops/roi_align/src/cuda/roi_align_kernel.cu similarity index 99% rename from mmdet/ops/roi_align/src/roi_align_kernel.cu rename to mmdet/ops/roi_align/src/cuda/roi_align_kernel.cu index b2ac72e3..113fc110 100644 --- a/mmdet/ops/roi_align/src/roi_align_kernel.cu +++ b/mmdet/ops/roi_align/src/cuda/roi_align_kernel.cu @@ -280,4 +280,4 @@ int ROIAlignBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois, })); THCudaCheck(cudaGetLastError()); return 1; -} \ No newline at end of file +} diff --git a/mmdet/ops/roi_align/src/roi_align_kernel_v2.cu b/mmdet/ops/roi_align/src/cuda/roi_align_kernel_v2.cu similarity index 100% rename from mmdet/ops/roi_align/src/roi_align_kernel_v2.cu rename to mmdet/ops/roi_align/src/cuda/roi_align_kernel_v2.cu diff --git a/mmdet/ops/roi_align/src/roi_align_cuda.cpp b/mmdet/ops/roi_align/src/roi_align_ext.cpp similarity index 66% rename from mmdet/ops/roi_align/src/roi_align_cuda.cpp rename to mmdet/ops/roi_align/src/roi_align_ext.cpp index 268f6907..50454d25 100644 --- a/mmdet/ops/roi_align/src/roi_align_cuda.cpp +++ b/mmdet/ops/roi_align/src/roi_align_ext.cpp @@ -44,56 +44,70 @@ at::Tensor ROIAlignBackwardV2Laucher( int ROIAlign_forwardV1(at::Tensor features, at::Tensor rois, int pooled_height, int pooled_width, float spatial_scale, int sample_num, at::Tensor output) { - CHECK_INPUT(features); - CHECK_INPUT(rois); - CHECK_INPUT(output); - at::DeviceGuard guard(features.device()); - - // Number of ROIs - int num_rois = rois.size(0); - int size_rois = rois.size(1); - - if (size_rois != 5) { - printf("wrong roi size\n"); - return 0; - } + if (features.type().is_cuda()) { +#ifdef WITH_CUDA + CHECK_INPUT(features); + CHECK_INPUT(rois); + CHECK_INPUT(output); + at::DeviceGuard guard(features.device()); + + // Number of ROIs + int num_rois = rois.size(0); + int size_rois = rois.size(1); + + if (size_rois != 5) { + printf("wrong roi size\n"); + return 0; + } - int num_channels = features.size(1); - int data_height = features.size(2); - int data_width = features.size(3); + int num_channels = features.size(1); + int data_height = features.size(2); + int data_width = features.size(3); - ROIAlignForwardLaucher(features, rois, spatial_scale, sample_num, - num_channels, data_height, data_width, num_rois, - pooled_height, pooled_width, output); + ROIAlignForwardLaucher(features, rois, spatial_scale, sample_num, + num_channels, data_height, data_width, num_rois, + pooled_height, pooled_width, output); - return 1; + return 1; +#else + AT_ERROR("ROIAlign is not compiled with GPU support"); +#endif + } + AT_ERROR("ROIAlign is not implemented on CPU"); } int ROIAlign_backwardV1(at::Tensor top_grad, at::Tensor rois, int pooled_height, int pooled_width, float spatial_scale, int sample_num, at::Tensor bottom_grad) { - CHECK_INPUT(top_grad); - CHECK_INPUT(rois); - CHECK_INPUT(bottom_grad); - at::DeviceGuard guard(top_grad.device()); - - // Number of ROIs - int num_rois = rois.size(0); - int size_rois = rois.size(1); - if (size_rois != 5) { - printf("wrong roi size\n"); - return 0; + if (top_grad.type().is_cuda()) { +#ifdef WITH_CUDA + CHECK_INPUT(top_grad); + CHECK_INPUT(rois); + CHECK_INPUT(bottom_grad); + at::DeviceGuard guard(top_grad.device()); + + // Number of ROIs + int num_rois = rois.size(0); + int size_rois = rois.size(1); + if (size_rois != 5) { + printf("wrong roi size\n"); + return 0; + } + + int num_channels = bottom_grad.size(1); + int data_height = bottom_grad.size(2); + int data_width = bottom_grad.size(3); + + ROIAlignBackwardLaucher(top_grad, rois, spatial_scale, sample_num, + num_channels, data_height, data_width, num_rois, + pooled_height, pooled_width, bottom_grad); + + return 1; +#else + AT_ERROR("ROIAlign is not compiled with GPU support"); +#endif } - - int num_channels = bottom_grad.size(1); - int data_height = bottom_grad.size(2); - int data_width = bottom_grad.size(3); - - ROIAlignBackwardLaucher(top_grad, rois, spatial_scale, sample_num, - num_channels, data_height, data_width, num_rois, - pooled_height, pooled_width, bottom_grad); - - return 1; + AT_ERROR("ROIAlign is not implemented on CPU"); } // Interface for Python @@ -108,9 +122,10 @@ inline at::Tensor ROIAlign_forwardV2(const at::Tensor& input, return ROIAlignForwardV2Laucher(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned); #else - AT_ERROR("Not compiled with GPU support"); + AT_ERROR("ROIAlignV2 is not compiled with GPU support"); #endif } + AT_ERROR("ROIAlignV2 is not implemented on CPU"); } inline at::Tensor ROIAlign_backwardV2( @@ -124,14 +139,15 @@ inline at::Tensor ROIAlign_backwardV2( pooled_width, batch_size, channels, height, width, sampling_ratio, aligned); #else - AT_ERROR("Not compiled with GPU support"); + AT_ERROR("ROIAlignV2 is not compiled with GPU support"); #endif } + AT_ERROR("ROIAlignV2 is not implemented on CPU"); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward_v1", &ROIAlign_forwardV1, "Roi_Align V1 forward (CUDA)"); - m.def("backward_v1", &ROIAlign_backwardV1, "Roi_Align V1 backward (CUDA)"); - m.def("forward_v2", &ROIAlign_forwardV2, "Roi_Align V2 forward (CUDA)"); - m.def("backward_v2", &ROIAlign_backwardV2, "Roi_Align V2 backward (CUDA)"); + m.def("forward_v1", &ROIAlign_forwardV1, "Roi_Align V1 forward"); + m.def("backward_v1", &ROIAlign_backwardV1, "Roi_Align V1 backward"); + m.def("forward_v2", &ROIAlign_forwardV2, "Roi_Align V2 forward"); + m.def("backward_v2", &ROIAlign_backwardV2, "Roi_Align V2 backward"); } diff --git a/mmdet/ops/roi_pool/roi_pool.py b/mmdet/ops/roi_pool/roi_pool.py index 26d900f7..5f52805a 100644 --- a/mmdet/ops/roi_pool/roi_pool.py +++ b/mmdet/ops/roi_pool/roi_pool.py @@ -4,7 +4,7 @@ from torch.autograd import Function from torch.autograd.function import once_differentiable from torch.nn.modules.utils import _pair -from . import roi_pool_cuda +from . import roi_pool_ext class RoIPoolFunction(Function): @@ -20,8 +20,8 @@ class RoIPoolFunction(Function): out_size = (num_rois, num_channels, out_h, out_w) output = features.new_zeros(out_size) argmax = features.new_zeros(out_size, dtype=torch.int) - roi_pool_cuda.forward(features, rois, out_h, out_w, spatial_scale, - output, argmax) + roi_pool_ext.forward(features, rois, out_h, out_w, spatial_scale, + output, argmax) ctx.spatial_scale = spatial_scale ctx.feature_size = features.size() ctx.argmax = argmax @@ -41,8 +41,8 @@ class RoIPoolFunction(Function): grad_input = grad_rois = None if ctx.needs_input_grad[0]: grad_input = grad_output.new_zeros(feature_size) - roi_pool_cuda.backward(grad_output.contiguous(), rois, argmax, - spatial_scale, grad_input) + roi_pool_ext.backward(grad_output.contiguous(), rois, argmax, + spatial_scale, grad_input) return grad_input, grad_rois, None, None diff --git a/mmdet/ops/roi_pool/src/roi_pool_kernel.cu b/mmdet/ops/roi_pool/src/cuda/roi_pool_kernel.cu similarity index 100% rename from mmdet/ops/roi_pool/src/roi_pool_kernel.cu rename to mmdet/ops/roi_pool/src/cuda/roi_pool_kernel.cu diff --git a/mmdet/ops/roi_pool/src/roi_pool_cuda.cpp b/mmdet/ops/roi_pool/src/roi_pool_cuda.cpp deleted file mode 100644 index 87e39be8..00000000 --- a/mmdet/ops/roi_pool/src/roi_pool_cuda.cpp +++ /dev/null @@ -1,88 +0,0 @@ -#include <torch/extension.h> - -#include <cmath> -#include <vector> - -int ROIPoolForwardLaucher(const at::Tensor features, const at::Tensor rois, - const float spatial_scale, const int channels, - const int height, const int width, const int num_rois, - const int pooled_h, const int pooled_w, - at::Tensor output, at::Tensor argmax); - -int ROIPoolBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois, - const at::Tensor argmax, const float spatial_scale, - const int batch_size, const int channels, - const int height, const int width, - const int num_rois, const int pooled_h, - const int pooled_w, at::Tensor bottom_grad); - -#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ") -#define CHECK_CONTIGUOUS(x) \ - AT_CHECK(x.is_contiguous(), #x, " must be contiguous ") -#define CHECK_INPUT(x) \ - CHECK_CUDA(x); \ - CHECK_CONTIGUOUS(x) - -int roi_pooling_forward_cuda(at::Tensor features, at::Tensor rois, - int pooled_height, int pooled_width, - float spatial_scale, at::Tensor output, - at::Tensor argmax) { - CHECK_INPUT(features); - CHECK_INPUT(rois); - CHECK_INPUT(output); - CHECK_INPUT(argmax); - at::DeviceGuard guard(features.device()); - - // Number of ROIs - int num_rois = rois.size(0); - int size_rois = rois.size(1); - - if (size_rois != 5) { - printf("wrong roi size\n"); - return 0; - } - - int channels = features.size(1); - int height = features.size(2); - int width = features.size(3); - - ROIPoolForwardLaucher(features, rois, spatial_scale, channels, height, width, - num_rois, pooled_height, pooled_width, output, argmax); - - return 1; -} - -int roi_pooling_backward_cuda(at::Tensor top_grad, at::Tensor rois, - at::Tensor argmax, float spatial_scale, - at::Tensor bottom_grad) { - CHECK_INPUT(top_grad); - CHECK_INPUT(rois); - CHECK_INPUT(argmax); - CHECK_INPUT(bottom_grad); - at::DeviceGuard guard(top_grad.device()); - - int pooled_height = top_grad.size(2); - int pooled_width = top_grad.size(3); - int num_rois = rois.size(0); - int size_rois = rois.size(1); - - if (size_rois != 5) { - printf("wrong roi size\n"); - return 0; - } - int batch_size = bottom_grad.size(0); - int channels = bottom_grad.size(1); - int height = bottom_grad.size(2); - int width = bottom_grad.size(3); - - ROIPoolBackwardLaucher(top_grad, rois, argmax, spatial_scale, batch_size, - channels, height, width, num_rois, pooled_height, - pooled_width, bottom_grad); - - return 1; -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &roi_pooling_forward_cuda, "Roi_Pooling forward (CUDA)"); - m.def("backward", &roi_pooling_backward_cuda, "Roi_Pooling backward (CUDA)"); -} diff --git a/mmdet/ops/roi_pool/src/roi_pool_ext.cpp b/mmdet/ops/roi_pool/src/roi_pool_ext.cpp new file mode 100644 index 00000000..af7bd855 --- /dev/null +++ b/mmdet/ops/roi_pool/src/roi_pool_ext.cpp @@ -0,0 +1,104 @@ +#include <torch/extension.h> + +#include <cmath> +#include <vector> + +#ifdef WITH_CUDA +int ROIPoolForwardLaucher(const at::Tensor features, const at::Tensor rois, + const float spatial_scale, const int channels, + const int height, const int width, const int num_rois, + const int pooled_h, const int pooled_w, + at::Tensor output, at::Tensor argmax); + +int ROIPoolBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois, + const at::Tensor argmax, const float spatial_scale, + const int batch_size, const int channels, + const int height, const int width, + const int num_rois, const int pooled_h, + const int pooled_w, at::Tensor bottom_grad); +#endif + +#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ") +#define CHECK_CONTIGUOUS(x) \ + AT_CHECK(x.is_contiguous(), #x, " must be contiguous ") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +int roi_pooling_forward(at::Tensor features, at::Tensor rois, + int pooled_height, int pooled_width, + float spatial_scale, at::Tensor output, + at::Tensor argmax) { + if (features.type().is_cuda()) { +#ifdef WITH_CUDA + CHECK_INPUT(features); + CHECK_INPUT(rois); + CHECK_INPUT(output); + CHECK_INPUT(argmax); + at::DeviceGuard guard(features.device()); + + // Number of ROIs + int num_rois = rois.size(0); + int size_rois = rois.size(1); + + if (size_rois != 5) { + printf("wrong roi size\n"); + return 0; + } + + int channels = features.size(1); + int height = features.size(2); + int width = features.size(3); + + ROIPoolForwardLaucher(features, rois, spatial_scale, channels, height, width, + num_rois, pooled_height, pooled_width, output, argmax); + + return 1; +#else + AT_ERROR("roi_pool is not compiled with GPU support"); +#endif + } + AT_ERROR("roi_pool is not implemented on CPU"); +} + +int roi_pooling_backward(at::Tensor top_grad, at::Tensor rois, + at::Tensor argmax, float spatial_scale, + at::Tensor bottom_grad) { + if (top_grad.type().is_cuda()) { +#ifdef WITH_CUDA + CHECK_INPUT(top_grad); + CHECK_INPUT(rois); + CHECK_INPUT(argmax); + CHECK_INPUT(bottom_grad); + at::DeviceGuard guard(top_grad.device()); + + int pooled_height = top_grad.size(2); + int pooled_width = top_grad.size(3); + int num_rois = rois.size(0); + int size_rois = rois.size(1); + + if (size_rois != 5) { + printf("wrong roi size\n"); + return 0; + } + int batch_size = bottom_grad.size(0); + int channels = bottom_grad.size(1); + int height = bottom_grad.size(2); + int width = bottom_grad.size(3); + + ROIPoolBackwardLaucher(top_grad, rois, argmax, spatial_scale, batch_size, + channels, height, width, num_rois, pooled_height, + pooled_width, bottom_grad); + + return 1; +#else + AT_ERROR("roi_pool is not compiled with GPU support"); +#endif + } + AT_ERROR("roi_pool is not implemented on CPU"); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &roi_pooling_forward, "Roi_Pooling forward"); + m.def("backward", &roi_pooling_backward, "Roi_Pooling backward"); +} diff --git a/mmdet/ops/sigmoid_focal_loss/sigmoid_focal_loss.py b/mmdet/ops/sigmoid_focal_loss/sigmoid_focal_loss.py index 8298f433..62e584e6 100644 --- a/mmdet/ops/sigmoid_focal_loss/sigmoid_focal_loss.py +++ b/mmdet/ops/sigmoid_focal_loss/sigmoid_focal_loss.py @@ -2,7 +2,7 @@ import torch.nn as nn from torch.autograd import Function from torch.autograd.function import once_differentiable -from . import sigmoid_focal_loss_cuda +from . import sigmoid_focal_loss_ext class SigmoidFocalLossFunction(Function): @@ -15,8 +15,8 @@ class SigmoidFocalLossFunction(Function): ctx.gamma = gamma ctx.alpha = alpha - loss = sigmoid_focal_loss_cuda.forward(input, target, num_classes, - gamma, alpha) + loss = sigmoid_focal_loss_ext.forward(input, target, num_classes, + gamma, alpha) return loss @staticmethod @@ -27,8 +27,8 @@ class SigmoidFocalLossFunction(Function): gamma = ctx.gamma alpha = ctx.alpha d_loss = d_loss.contiguous() - d_input = sigmoid_focal_loss_cuda.backward(input, target, d_loss, - num_classes, gamma, alpha) + d_input = sigmoid_focal_loss_ext.backward(input, target, d_loss, + num_classes, gamma, alpha) return d_input, None, None, None, None diff --git a/mmdet/ops/sigmoid_focal_loss/src/sigmoid_focal_loss_cuda.cu b/mmdet/ops/sigmoid_focal_loss/src/cuda/sigmoid_focal_loss_cuda.cu similarity index 100% rename from mmdet/ops/sigmoid_focal_loss/src/sigmoid_focal_loss_cuda.cu rename to mmdet/ops/sigmoid_focal_loss/src/cuda/sigmoid_focal_loss_cuda.cu diff --git a/mmdet/ops/sigmoid_focal_loss/src/sigmoid_focal_loss.cpp b/mmdet/ops/sigmoid_focal_loss/src/sigmoid_focal_loss_ext.cpp similarity index 87% rename from mmdet/ops/sigmoid_focal_loss/src/sigmoid_focal_loss.cpp rename to mmdet/ops/sigmoid_focal_loss/src/sigmoid_focal_loss_ext.cpp index 46d04eae..faf2e787 100644 --- a/mmdet/ops/sigmoid_focal_loss/src/sigmoid_focal_loss.cpp +++ b/mmdet/ops/sigmoid_focal_loss/src/sigmoid_focal_loss_ext.cpp @@ -2,6 +2,7 @@ // https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/csrc/SigmoidFocalLoss.h #include <torch/extension.h> +#ifdef WITH_CUDA at::Tensor SigmoidFocalLoss_forward_cuda(const at::Tensor &logits, const at::Tensor &targets, const int num_classes, @@ -12,6 +13,7 @@ at::Tensor SigmoidFocalLoss_backward_cuda(const at::Tensor &logits, const at::Tensor &d_losses, const int num_classes, const float gamma, const float alpha); +#endif // Interface for Python at::Tensor SigmoidFocalLoss_forward(const at::Tensor &logits, @@ -19,9 +21,13 @@ at::Tensor SigmoidFocalLoss_forward(const at::Tensor &logits, const int num_classes, const float gamma, const float alpha) { if (logits.type().is_cuda()) { +#ifdef WITH_CUDA at::DeviceGuard guard(logits.device()); return SigmoidFocalLoss_forward_cuda(logits, targets, num_classes, gamma, alpha); +#else + AT_ERROR("SigmoidFocalLoss is not compiled with GPU support"); +#endif } AT_ERROR("SigmoidFocalLoss is not implemented on the CPU"); } @@ -32,16 +38,20 @@ at::Tensor SigmoidFocalLoss_backward(const at::Tensor &logits, const int num_classes, const float gamma, const float alpha) { if (logits.type().is_cuda()) { +#ifdef WITH_CUDA at::DeviceGuard guard(logits.device()); return SigmoidFocalLoss_backward_cuda(logits, targets, d_losses, num_classes, gamma, alpha); +#else + AT_ERROR("SigmoidFocalLoss is not compiled with GPU support"); +#endif } AT_ERROR("SigmoidFocalLoss is not implemented on the CPU"); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward", &SigmoidFocalLoss_forward, - "SigmoidFocalLoss forward (CUDA)"); + "SigmoidFocalLoss forward"); m.def("backward", &SigmoidFocalLoss_backward, - "SigmoidFocalLoss backward (CUDA)"); + "SigmoidFocalLoss backward"); } diff --git a/setup.py b/setup.py index 5a5ddbe0..7c01e6be 100755 --- a/setup.py +++ b/setup.py @@ -5,7 +5,8 @@ import time from setuptools import find_packages, setup import torch -from torch.utils.cpp_extension import BuildExtension, CUDAExtension +from torch.utils.cpp_extension import (BuildExtension, CppExtension, + CUDAExtension) def readme(): @@ -87,27 +88,30 @@ def get_version(): return locals()['__version__'] -def make_cuda_ext(name, module, sources): +def make_cuda_ext(name, module, sources, sources_cuda=[]): define_macros = [] + extra_compile_args = {'cxx': []} if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1': define_macros += [('WITH_CUDA', None)] + extension = CUDAExtension + extra_compile_args['nvcc'] = [ + '-D__CUDA_NO_HALF_OPERATORS__', + '-D__CUDA_NO_HALF_CONVERSIONS__', + '-D__CUDA_NO_HALF2_OPERATORS__', + ] + sources += sources_cuda else: - raise EnvironmentError('CUDA is required to compile MMDetection!') + print('Compiling {} without CUDA'.format(name)) + extension = CppExtension + # raise EnvironmentError('CUDA is required to compile MMDetection!') - return CUDAExtension( + return extension( name='{}.{}'.format(module, name), sources=[os.path.join(*module.split('.'), p) for p in sources], define_macros=define_macros, - extra_compile_args={ - 'cxx': [], - 'nvcc': [ - '-D__CUDA_NO_HALF_OPERATORS__', - '-D__CUDA_NO_HALF_CONVERSIONS__', - '-D__CUDA_NO_HALF2_OPERATORS__', - ] - }) + extra_compile_args=extra_compile_args) def parse_requirements(fname='requirements.txt', with_version=True): @@ -227,73 +231,82 @@ if __name__ == '__main__': module='mmdet.ops.utils', sources=['src/compiling_info.cpp']), make_cuda_ext( - name='nms_cpu', + name='nms_ext', module='mmdet.ops.nms', - sources=['src/nms_cpu.cpp']), - make_cuda_ext( - name='nms_cuda', - module='mmdet.ops.nms', - sources=['src/nms_cuda.cpp', 'src/nms_kernel.cu']), + sources=['src/nms_ext.cpp', 'src/cpu/nms_cpu.cpp'], + sources_cuda=[ + 'src/cuda/nms_cuda.cpp', 'src/cuda/nms_kernel.cu' + ]), make_cuda_ext( - name='roi_align_cuda', + name='roi_align_ext', module='mmdet.ops.roi_align', - sources=[ - 'src/roi_align_cuda.cpp', - 'src/roi_align_kernel.cu', - 'src/roi_align_kernel_v2.cu', + sources=['src/roi_align_ext.cpp'], + sources_cuda=[ + 'src/cuda/roi_align_kernel.cu', + 'src/cuda/roi_align_kernel_v2.cu' ]), make_cuda_ext( - name='roi_pool_cuda', + name='roi_pool_ext', module='mmdet.ops.roi_pool', - sources=['src/roi_pool_cuda.cpp', 'src/roi_pool_kernel.cu']), + sources=['src/roi_pool_ext.cpp'], + sources_cuda=['src/cuda/roi_pool_kernel.cu']), make_cuda_ext( - name='deform_conv_cuda', + name='deform_conv_ext', module='mmdet.ops.dcn', - sources=[ - 'src/deform_conv_cuda.cpp', - 'src/deform_conv_cuda_kernel.cu' + sources=['src/deform_conv_ext.cpp'], + sources_cuda=[ + 'src/cuda/deform_conv_cuda.cpp', + 'src/cuda/deform_conv_cuda_kernel.cu' ]), make_cuda_ext( - name='deform_pool_cuda', + name='deform_pool_ext', module='mmdet.ops.dcn', - sources=[ - 'src/deform_pool_cuda.cpp', - 'src/deform_pool_cuda_kernel.cu' + sources=['src/deform_pool_ext.cpp'], + sources_cuda=[ + 'src/cuda/deform_pool_cuda.cpp', + 'src/cuda/deform_pool_cuda_kernel.cu' ]), make_cuda_ext( - name='sigmoid_focal_loss_cuda', + name='sigmoid_focal_loss_ext', module='mmdet.ops.sigmoid_focal_loss', - sources=[ - 'src/sigmoid_focal_loss.cpp', - 'src/sigmoid_focal_loss_cuda.cu' - ]), + sources=['src/sigmoid_focal_loss_ext.cpp'], + sources_cuda=['src/cuda/sigmoid_focal_loss_cuda.cu']), make_cuda_ext( - name='masked_conv2d_cuda', + name='masked_conv2d_ext', module='mmdet.ops.masked_conv', - sources=[ - 'src/masked_conv2d_cuda.cpp', 'src/masked_conv2d_kernel.cu' + sources=['src/masked_conv2d_ext.cpp'], + sources_cuda=[ + 'src/cuda/masked_conv2d_cuda.cpp', + 'src/cuda/masked_conv2d_kernel.cu' ]), make_cuda_ext( - name='affine_grid_cuda', + name='affine_grid_ext', module='mmdet.ops.affine_grid', - sources=['src/affine_grid_cuda.cpp']), + sources=[ + 'src/affine_grid_ext.cpp', 'src/cpu/affine_grid_cpu.cpp' + ]), make_cuda_ext( - name='grid_sampler_cuda', + name='grid_sampler_ext', module='mmdet.ops.grid_sampler', sources=[ - 'src/cpu/grid_sampler_cpu.cpp', - 'src/cuda/grid_sampler_cuda.cu', 'src/grid_sampler.cpp' - ]), + 'src/grid_sampler_ext.cpp', 'src/cpu/grid_sampler_cpu.cpp' + ], + sources_cuda=['src/cuda/grid_sampler_cuda.cu']), make_cuda_ext( - name='carafe_cuda', + name='carafe_ext', module='mmdet.ops.carafe', - sources=['src/carafe_cuda.cpp', 'src/carafe_cuda_kernel.cu']), + sources=['src/carafe_ext.cpp'], + sources_cuda=[ + 'src/cuda/carafe_cuda.cpp', + 'src/cuda/carafe_cuda_kernel.cu' + ]), make_cuda_ext( - name='carafe_naive_cuda', + name='carafe_naive_ext', module='mmdet.ops.carafe', - sources=[ - 'src/carafe_naive_cuda.cpp', - 'src/carafe_naive_cuda_kernel.cu' + sources=['src/carafe_naive_ext.cpp'], + sources_cuda=[ + 'src/cuda/carafe_naive_cuda.cpp', + 'src/cuda/carafe_naive_cuda_kernel.cu' ]) ], cmdclass={'build_ext': BuildExtension}, -- GitLab