diff --git a/mmdet/ops/grid_sampler/__init__.py b/mmdet/ops/grid_sampler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..868617a6b3f42049d8b78253bf639f07e47ec981 --- /dev/null +++ b/mmdet/ops/grid_sampler/__init__.py @@ -0,0 +1,3 @@ +from .grid_sampler import grid_sample + +__all__ = ['grid_sample'] diff --git a/mmdet/ops/grid_sampler/grid_sampler.py b/mmdet/ops/grid_sampler/grid_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..d359d8dfe953f5bd14cf6f5bb90cde856a11b30a --- /dev/null +++ b/mmdet/ops/grid_sampler/grid_sampler.py @@ -0,0 +1,117 @@ +import torch +import torch.nn.functional as F +from torch.autograd import Function +from torch.autograd.function import once_differentiable + +from . import grid_sampler_cuda + + +class _GridSampler(Function): + + @staticmethod + def forward(ctx, input, grid, mode_enum, padding_mode_enum, align_corners): + + ctx.save_for_backward(input, grid) + ctx.mode_enum = mode_enum + ctx.padding_mode_enum = padding_mode_enum + ctx.align_corners = align_corners + + if input.is_cuda: + if input.dim() == 4: + func = grid_sampler_cuda.grid_sampler_2d_forward_cuda + else: + func = grid_sampler_cuda.grid_sampler_3d_forward_cuda + else: + if input.dim() == 4: + func = grid_sampler_cuda.grid_sampler_2d_forward_cpu + else: + func = grid_sampler_cuda.grid_sampler_3d_forward_cpu + + output = func(input, grid, mode_enum, padding_mode_enum, align_corners) + + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + input, grid = ctx.saved_tensors + mode_enum = ctx.mode_enum + padding_mode_enum = ctx.padding_mode_enum + align_corners = ctx.align_corners + + if input.is_cuda: + if input.dim() == 4: + func = grid_sampler_cuda.grid_sampler_2d_backward_cuda + else: + func = grid_sampler_cuda.grid_sampler_3d_backward_cuda + else: + if input.dim() == 4: + func = grid_sampler_cuda.grid_sampler_2d_backward_cpu + else: + func = grid_sampler_cuda.grid_sampler_3d_backward_cpu + + grad_input, grad_grid = func(grad_output, input, grid, mode_enum, + padding_mode_enum, align_corners) + + return grad_input, grad_grid, None, None, None + + +def grid_sample(input, + grid, + mode='bilinear', + padding_mode='zeros', + align_corners=False): + if torch.__version__ >= '1.3' or align_corners: + return F.grid_sample(input, grid, mode, padding_mode, align_corners) + else: + + # use self-compiled grid_sampler to support align_corners=False + + assert mode in ['bilinear', 'nearest'], \ + 'expected mode to be bilinear or nearest, but got: {}'.format(mode) + + assert padding_mode in ['zeros', 'border', 'reflection'], \ + 'expected padding_mode to be zeros, border, or reflection, ' \ + 'but got: {}'.format(padding_mode) + + if mode == 'bilinear': + mode_enum = 0 + else: + mode_enum = 1 + + if padding_mode == 'zeros': + padding_mode_enum = 0 + elif padding_mode == 'border': + padding_mode_enum = 1 + else: + padding_mode_enum = 2 + + # shape check + assert input.device == grid.device, \ + 'expected input and grid to be on same device, ' \ + 'but input is on {} and grid is on {}'.format( + input.device, grid.device) + assert input.dtype == grid.dtype, \ + 'expected input and grid to have the same dtype, ' \ + 'but input has {} and grid has {}'.format( + input.dtype, grid.dtype) + assert input.dim() == 4 or input.dim() == 5, \ + 'expected 4D or 5D input and grid with same number of dimensions' \ + 'but got input with sizes {} and grid with sizes {}'.format( + input.size(), grid.size()) + assert input.size(0) == grid.size(0), \ + 'expected input and grid to have the same batch size, ' \ + 'but got input with sizes {} and grid with sizes {}'.format( + input.size(), grid.size()) + assert grid.size(-1) == input.dim() - 2, \ + 'expected grid to have size {} in last {} dimension, ' \ + 'but got grid with sizes '.format( + input.dim() - 2, grid.size()) + for i in range(2, input.dim()): + assert input.size(i) > 0, \ + 'expected input to have non-empty spatial dimensions, ' \ + 'but input has sizes {} with dimension {} being empty'.format( + input.sizes(), i) + + return _GridSampler.apply(input, grid, mode_enum, padding_mode_enum, + align_corners) diff --git a/mmdet/ops/grid_sampler/src/cpu/grid_sampler_cpu.cpp b/mmdet/ops/grid_sampler/src/cpu/grid_sampler_cpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cf1776ed1d7de573c97dd3254ef08d2352ca4377 --- /dev/null +++ b/mmdet/ops/grid_sampler/src/cpu/grid_sampler_cpu.cpp @@ -0,0 +1,692 @@ +// Modified from https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/GridSampler.cpp + +#include <torch/extension.h> +#include "grid_sampler_cpu.h" +#include <ATen/ATen.h> +#include <ATen/Device.h> +#include <ATen/NativeFunctions.h> +#include <c10/core/Layout.h> +#include <c10/util/Exception.h> + +#ifdef _OPENMP +#include <omp.h> +#endif + +namespace mmdetection { + +using namespace at; +using mmdetection::detail::GridSamplerInterpolation; +using mmdetection::detail::GridSamplerPadding; + +namespace { + + template<typename scalar_t> + Tensor grid_sampler_2d_forward_cpu_impl(const Tensor& input, const Tensor& grid, + GridSamplerInterpolation interpolation_mode, + GridSamplerPadding padding_mode, + bool align_corners) { + int64_t N = input.size(0); + int64_t C = input.size(1); + int64_t inp_H = input.size(2); + int64_t inp_W = input.size(3); + int64_t out_H = grid.size(1); + int64_t out_W = grid.size(2); + auto output = at::empty({N, C, out_H, out_W}, input.options()); + int64_t inp_sN = input.stride(0); + int64_t inp_sC = input.stride(1); + int64_t inp_sH = input.stride(2); + int64_t inp_sW = input.stride(3); + int64_t grid_sN = grid.stride(0); + int64_t grid_sH = grid.stride(1); + int64_t grid_sW = grid.stride(2); + int64_t grid_sCoor = grid.stride(3); + int64_t out_sN = output.stride(0); + int64_t out_sC = output.stride(1); + int64_t out_sH = output.stride(2); + int64_t out_sW = output.stride(3); + scalar_t *inp_ptr = input.data<scalar_t>(); + scalar_t *out_ptr = output.data<scalar_t>(); + scalar_t *grid_ptr = grid.data<scalar_t>(); + // loop over each output pixel + #ifdef _OPENMP + #pragma omp parallel for + #endif + for (int64_t n = 0; n < N; ++n) { + scalar_t *grid_ptr_N = grid_ptr + n * grid_sN; + scalar_t *inp_ptr_N = inp_ptr + n * inp_sN; + for (int64_t h = 0; h < out_H; ++h) { + for (int64_t w = 0; w < out_W; ++w) { + // get the corresponding input x, y, z co-ordinates from grid + scalar_t *grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW; + scalar_t ix = *grid_ptr_NHW; + scalar_t iy = grid_ptr_NHW[grid_sCoor]; + + ix = grid_sampler_compute_source_index(ix, inp_W, padding_mode, align_corners); + iy = grid_sampler_compute_source_index(iy, inp_H, padding_mode, align_corners); + + if (interpolation_mode == GridSamplerInterpolation::Bilinear) { + // get corner pixel values from (x, y, z) + // for 4d, we used north-east-south-west + // for 5d, we add top-bottom + int64_t ix_nw = static_cast<int64_t>(std::floor(ix)); + int64_t iy_nw = static_cast<int64_t>(std::floor(iy)); + + int64_t ix_ne = ix_nw + 1; + int64_t iy_ne = iy_nw; + + int64_t ix_sw = ix_nw; + int64_t iy_sw = iy_nw + 1; + + int64_t ix_se = ix_nw + 1; + int64_t iy_se = iy_nw + 1; + + // get surfaces to each neighbor: + scalar_t nw = (ix_se - ix) * (iy_se - iy) ; + scalar_t ne = (ix - ix_sw) * (iy_sw - iy) ; + scalar_t sw = (ix_ne - ix) * (iy - iy_ne); + scalar_t se = (ix - ix_nw) * (iy - iy_nw); + + // calculate bilinear weighted pixel value and set output pixel + scalar_t *out_ptr_NCHW = out_ptr + n * out_sN + h * out_sH + w * out_sW; + scalar_t *inp_ptr_NC = inp_ptr_N; + for (int c = 0; c < C; ++c, out_ptr_NCHW += out_sC, inp_ptr_NC += inp_sC) { + // (c, iz_tnw, iy_tnw, ix_tnw) * tnw + (c, iz_tne, iy_tne, ix_tne) * tne + // + (c, iz_tsw, iy_tsw, ix_tsw) * tsw + (c, iz_tse, iy_tse, ix_tse) * tse + // + (c, iz_bnw, iy_bnw, ix_bnw) * bnw + (c, iz_bne, iy_bne, ix_bne) * bne + // + (c, iz_bsw, iy_bsw, ix_bsw) * bsw + (c, iz_bse, iy_bse, ix_bse) * bse + *out_ptr_NCHW = static_cast<scalar_t>(0); + if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) { + *out_ptr_NCHW += inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW] * nw; + } + if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) { + *out_ptr_NCHW += inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW] * ne; + } + if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) { + *out_ptr_NCHW += inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW] * sw; + } + if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) { + *out_ptr_NCHW += inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW] * se; + } + } + } else if (interpolation_mode == GridSamplerInterpolation::Nearest) { + int64_t ix_nearest = static_cast<int64_t>(std::round(ix)); + int64_t iy_nearest = static_cast<int64_t>(std::round(iy)); + + // assign nearest neighor pixel value to output pixel + scalar_t *out_ptr_NCHW = out_ptr + n * out_sN + h * out_sH + w * out_sW; + scalar_t *inp_ptr_NC = inp_ptr_N; + for (int c = 0; c < C; ++c, out_ptr_NCHW += out_sC, inp_ptr_NC += inp_sC) { + if (within_bounds_2d(iy_nearest, ix_nearest, inp_H, inp_W)) { + *out_ptr_NCHW = inp_ptr_NC[iy_nearest * inp_sH + ix_nearest * inp_sW]; + } else { + *out_ptr_NCHW = static_cast<scalar_t>(0); + } + } + } + } + } + } + + return output; + } + + template<typename scalar_t> + Tensor grid_sampler_3d_forward_cpu_impl(const Tensor& input, const Tensor& grid, + GridSamplerInterpolation interpolation_mode, + GridSamplerPadding padding_mode, + bool align_corners) { + int64_t N = input.size(0); + int64_t C = input.size(1); + int64_t inp_D = input.size(2); + int64_t inp_H = input.size(3); + int64_t inp_W = input.size(4); + int64_t out_D = grid.size(1); + int64_t out_H = grid.size(2); + int64_t out_W = grid.size(3); + auto output = at::empty({N, C, out_D, out_H, out_W}, input.options()); + int64_t inp_sN = input.stride(0); + int64_t inp_sC = input.stride(1); + int64_t inp_sD = input.stride(2); + int64_t inp_sH = input.stride(3); + int64_t inp_sW = input.stride(4); + int64_t grid_sN = grid.stride(0); + int64_t grid_sD = grid.stride(1); + int64_t grid_sH = grid.stride(2); + int64_t grid_sW = grid.stride(3); + int64_t grid_sCoor = grid.stride(4); + int64_t out_sN = output.stride(0); + int64_t out_sC = output.stride(1); + int64_t out_sD = output.stride(2); + int64_t out_sH = output.stride(3); + int64_t out_sW = output.stride(4); + scalar_t *inp_ptr = input.data<scalar_t>(); + scalar_t *out_ptr = output.data<scalar_t>(); + scalar_t *grid_ptr = grid.data<scalar_t>(); + // loop over each output pixel + #ifdef _OPENMP + #pragma omp parallel for + #endif + for (int64_t n = 0; n < N; ++n) { + scalar_t *grid_ptr_N = grid_ptr + n * grid_sN; + scalar_t *inp_ptr_N = inp_ptr + n * inp_sN; + for (int64_t d = 0; d < out_D; ++d) { + for (int64_t h = 0; h < out_H; ++h) { + for (int64_t w = 0; w < out_W; ++w) { + // get the corresponding input x, y, z co-ordinates from grid + scalar_t *grid_ptr_NDHW = grid_ptr_N + d * grid_sD + h * grid_sH + w * grid_sW; + scalar_t ix = *grid_ptr_NDHW; + scalar_t iy = grid_ptr_NDHW[grid_sCoor]; + scalar_t iz = grid_ptr_NDHW[2 * grid_sCoor]; + + ix = grid_sampler_compute_source_index(ix, inp_W, padding_mode, align_corners); + iy = grid_sampler_compute_source_index(iy, inp_H, padding_mode, align_corners); + iz = grid_sampler_compute_source_index(iz, inp_D, padding_mode, align_corners); + + if (interpolation_mode == GridSamplerInterpolation::Bilinear) { + // get corner pixel values from (x, y, z) + // for 4d, we used north-east-south-west + // for 5d, we add top-bottom + int64_t ix_tnw = static_cast<int64_t>(std::floor(ix)); + int64_t iy_tnw = static_cast<int64_t>(std::floor(iy)); + int64_t iz_tnw = static_cast<int64_t>(std::floor(iz)); + + int64_t ix_tne = ix_tnw + 1; + int64_t iy_tne = iy_tnw; + int64_t iz_tne = iz_tnw; + + int64_t ix_tsw = ix_tnw; + int64_t iy_tsw = iy_tnw + 1; + int64_t iz_tsw = iz_tnw; + + int64_t ix_tse = ix_tnw + 1; + int64_t iy_tse = iy_tnw + 1; + int64_t iz_tse = iz_tnw; + + int64_t ix_bnw = ix_tnw; + int64_t iy_bnw = iy_tnw; + int64_t iz_bnw = iz_tnw + 1; + + int64_t ix_bne = ix_tnw + 1; + int64_t iy_bne = iy_tnw; + int64_t iz_bne = iz_tnw + 1; + + int64_t ix_bsw = ix_tnw; + int64_t iy_bsw = iy_tnw + 1; + int64_t iz_bsw = iz_tnw + 1; + + int64_t ix_bse = ix_tnw + 1; + int64_t iy_bse = iy_tnw + 1; + int64_t iz_bse = iz_tnw + 1; + + // get surfaces to each neighbor: + scalar_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); + scalar_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); + scalar_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); + scalar_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); + scalar_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); + scalar_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); + scalar_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); + scalar_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); + + // calculate bilinear weighted pixel value and set output pixel + scalar_t *out_ptr_NCDHW = out_ptr + n * out_sN + d * out_sD + h * out_sH + w * out_sW; + scalar_t *inp_ptr_NC = inp_ptr_N; + for (int c = 0; c < C; ++c, out_ptr_NCDHW += out_sC, inp_ptr_NC += inp_sC) { + // (c, iz_tnw, iy_tnw, ix_tnw) * tnw + (c, iz_tne, iy_tne, ix_tne) * tne + // + (c, iz_tsw, iy_tsw, ix_tsw) * tsw + (c, iz_tse, iy_tse, ix_tse) * tse + // + (c, iz_bnw, iy_bnw, ix_bnw) * bnw + (c, iz_bne, iy_bne, ix_bne) * bne + // + (c, iz_bsw, iy_bsw, ix_bsw) * bsw + (c, iz_bse, iy_bse, ix_bse) * bse + *out_ptr_NCDHW = static_cast<scalar_t>(0); + if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw; + } + if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne; + } + if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw; + } + if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse; + } + if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw; + } + if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne; + } + if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw; + } + if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse; + } + } + } else if (interpolation_mode == GridSamplerInterpolation::Nearest) { + int64_t ix_nearest = static_cast<int64_t>(std::round(ix)); + int64_t iy_nearest = static_cast<int64_t>(std::round(iy)); + int64_t iz_nearest = static_cast<int64_t>(std::round(iz)); + + // assign nearest neighor pixel value to output pixel + scalar_t *out_ptr_NCDHW = out_ptr + n * out_sN + d * out_sD + h * out_sH + w * out_sW; + scalar_t *inp_ptr_NC = inp_ptr_N; + for (int c = 0; c < C; ++c, out_ptr_NCDHW += out_sC, inp_ptr_NC += inp_sC) { + if (within_bounds_3d(iz_nearest, iy_nearest, ix_nearest, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW = inp_ptr_NC[iz_nearest * inp_sD + iy_nearest * inp_sH + ix_nearest * inp_sW]; + } else { + *out_ptr_NCDHW = static_cast<scalar_t>(0); + } + } + } + } + } + } + } + return output; + } + + template<typename scalar_t> + std::tuple<Tensor, Tensor> + grid_sampler_2d_backward_cpu_impl(const Tensor& grad_output, + const Tensor& input, const Tensor& grid, + GridSamplerInterpolation interpolation_mode, + GridSamplerPadding padding_mode, + bool align_corners) { + auto grad_input = at::zeros_like(input); + auto grad_grid = at::empty_like(grid); + // If interpolation mode is Nearest, then grad_grid is not filled in the + // loop below. + if (interpolation_mode == GridSamplerInterpolation::Nearest) { + grad_grid.zero_(); + } + int64_t N = input.size(0); + int64_t C = input.size(1); + int64_t inp_H = input.size(2); + int64_t inp_W = input.size(3); + int64_t out_H = grid.size(1); + int64_t out_W = grid.size(2); + int64_t inp_sN = input.stride(0); + int64_t inp_sC = input.stride(1); + int64_t inp_sH = input.stride(2); + int64_t inp_sW = input.stride(3); + int64_t grid_sN = grid.stride(0); + int64_t grid_sH = grid.stride(1); + int64_t grid_sW = grid.stride(2); + int64_t grid_sCoor = grid.stride(3); + int64_t gOut_sN = grad_output.stride(0); + int64_t gOut_sC = grad_output.stride(1); + int64_t gOut_sH = grad_output.stride(2); + int64_t gOut_sW = grad_output.stride(3); + int64_t gInp_sN = grad_input.stride(0); + int64_t gInp_sC = grad_input.stride(1); + int64_t gInp_sH = grad_input.stride(2); + int64_t gInp_sW = grad_input.stride(3); + int64_t gGrid_sN = grad_grid.stride(0); + int64_t gGrid_sW = grad_grid.stride(2); + scalar_t *inp_ptr = input.data<scalar_t>(); + scalar_t *grid_ptr = grid.data<scalar_t>(); + scalar_t *gOut_ptr = grad_output.data<scalar_t>(); + scalar_t *gInp_ptr = grad_input.data<scalar_t>(); + scalar_t *gGrid_ptr = grad_grid.data<scalar_t>(); + // loop over each output pixel + #ifdef _OPENMP + #pragma omp parallel for + #endif + for (int64_t n = 0; n < N; ++n) { + scalar_t *grid_ptr_N = grid_ptr + n * grid_sN; + scalar_t *inp_ptr_N = inp_ptr + n * inp_sN; + scalar_t *gGrid_ptr_NHW = gGrid_ptr + n * gGrid_sN; + for (int64_t h = 0; h < out_H; ++h) { + for (int64_t w = 0; w < out_W; ++w, gGrid_ptr_NHW += gGrid_sW /* grad_grid is contiguous */ ) { + // get the corresponding input x, y, z co-ordinates from grid + scalar_t *grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW; + scalar_t ix = *grid_ptr_NHW; + scalar_t iy = grid_ptr_NHW[grid_sCoor]; + + // multipliers for gradients on ix, iy, and iz + scalar_t gix_mult, giy_mult; + ix = grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, align_corners, &gix_mult); + iy = grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, align_corners, &giy_mult); + + if (interpolation_mode == GridSamplerInterpolation::Bilinear) { + // get corner pixel values from (x, y, z) + // for 4d, we used north-east-south-west + // for 5d, we add top-bottom + int64_t ix_nw = static_cast<int64_t>(std::floor(ix)); + int64_t iy_nw = static_cast<int64_t>(std::floor(iy)); + + int64_t ix_ne = ix_nw + 1; + int64_t iy_ne = iy_nw; + + int64_t ix_sw = ix_nw; + int64_t iy_sw = iy_nw + 1; + + int64_t ix_se = ix_nw + 1; + int64_t iy_se = iy_nw + 1; + + // get surfaces to each neighbor: + scalar_t nw = (ix_se - ix) * (iy_se - iy) ; + scalar_t ne = (ix - ix_sw) * (iy_sw - iy) ; + scalar_t sw = (ix_ne - ix) * (iy - iy_ne); + scalar_t se = (ix - ix_nw) * (iy - iy_nw); + + scalar_t gix = static_cast<scalar_t>(0), giy = static_cast<scalar_t>(0); + scalar_t *gOut_ptr_NCHW = gOut_ptr + n * gOut_sN + h * gOut_sH + w * gOut_sW; + scalar_t *gInp_ptr_NC = gInp_ptr + n * gInp_sN; + scalar_t *inp_ptr_NC = inp_ptr_N; + // calculate bilinear weighted pixel value and set output pixel + for (int c = 0; c < C; ++c, gOut_ptr_NCHW += gOut_sC, gInp_ptr_NC += gInp_sC, inp_ptr_NC += inp_sC) { + scalar_t gOut = *gOut_ptr_NCHW; + + // calculate and set grad_input + safe_add_2d(gInp_ptr_NC, iy_nw, ix_nw, gInp_sH, gInp_sW, inp_H, inp_W, nw * gOut); + safe_add_2d(gInp_ptr_NC, iy_ne, ix_ne, gInp_sH, gInp_sW, inp_H, inp_W, ne * gOut); + safe_add_2d(gInp_ptr_NC, iy_sw, ix_sw, gInp_sH, gInp_sW, inp_H, inp_W, sw * gOut); + safe_add_2d(gInp_ptr_NC, iy_se, ix_se, gInp_sH, gInp_sW, inp_H, inp_W, se * gOut); + + // calculate grad_grid + if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) { + scalar_t nw_val = inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW]; + gix -= nw_val * (iy_se - iy) * gOut; + giy -= nw_val * (ix_se - ix) * gOut; + } + if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) { + scalar_t ne_val = inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW]; + gix += ne_val * (iy_sw - iy) * gOut; + giy -= ne_val * (ix - ix_sw) * gOut; + } + if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) { + scalar_t sw_val = inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW]; + gix -= sw_val * (iy - iy_ne) * gOut; + giy += sw_val * (ix_ne - ix) * gOut; + } + if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) { + scalar_t se_val = inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW]; + gix += se_val * (iy - iy_nw) * gOut; + giy += se_val * (ix - ix_nw) * gOut; + } + } + + // assuming grad_grid is contiguous + gGrid_ptr_NHW[0] = gix_mult * gix; + gGrid_ptr_NHW[1] = giy_mult * giy; + } else if (interpolation_mode == GridSamplerInterpolation::Nearest) { + int64_t ix_nearest = static_cast<int64_t>(std::round(ix)); + int64_t iy_nearest = static_cast<int64_t>(std::round(iy)); + + // assign nearest neighor pixel value to output pixel + scalar_t *gOut_ptr_NCHW = gOut_ptr + n * gOut_sN + h * gOut_sH + w * gOut_sW; + scalar_t *gInp_ptr_NC = gInp_ptr + n * gInp_sN; + for (int c = 0; c < C; ++c, gOut_ptr_NCHW += gOut_sC, gInp_ptr_NC += gInp_sC) { + // calculate and set grad_input + safe_add_2d(gInp_ptr_NC, iy_nearest, ix_nearest, + gInp_sH, gInp_sW, inp_H, inp_W, *gOut_ptr_NCHW); + } + } + } + } + } + return std::make_tuple(grad_input, grad_grid); + } + + template<typename scalar_t> + std::tuple<Tensor, Tensor> + grid_sampler_3d_backward_cpu_impl(const Tensor& grad_output, + const Tensor& input, const Tensor& grid, + GridSamplerInterpolation interpolation_mode, + GridSamplerPadding padding_mode, + bool align_corners) { + auto grad_input = at::zeros_like(input); + auto grad_grid = at::empty_like(grid); + // If interpolation mode is Nearest, then grad_grid is not filled in the + // loop below. + if (interpolation_mode == GridSamplerInterpolation::Nearest) { + grad_grid.zero_(); + } + int64_t N = input.size(0); + int64_t C = input.size(1); + int64_t inp_D = input.size(2); + int64_t inp_H = input.size(3); + int64_t inp_W = input.size(4); + int64_t out_D = grid.size(1); + int64_t out_H = grid.size(2); + int64_t out_W = grid.size(3); + int64_t inp_sN = input.stride(0); + int64_t inp_sC = input.stride(1); + int64_t inp_sD = input.stride(2); + int64_t inp_sH = input.stride(3); + int64_t inp_sW = input.stride(4); + int64_t grid_sN = grid.stride(0); + int64_t grid_sD = grid.stride(1); + int64_t grid_sH = grid.stride(2); + int64_t grid_sW = grid.stride(3); + int64_t grid_sCoor = grid.stride(4); + int64_t gOut_sN = grad_output.stride(0); + int64_t gOut_sC = grad_output.stride(1); + int64_t gOut_sD = grad_output.stride(2); + int64_t gOut_sH = grad_output.stride(3); + int64_t gOut_sW = grad_output.stride(4); + int64_t gInp_sN = grad_input.stride(0); + int64_t gInp_sC = grad_input.stride(1); + int64_t gInp_sD = grad_input.stride(2); + int64_t gInp_sH = grad_input.stride(3); + int64_t gInp_sW = grad_input.stride(4); + int64_t gGrid_sN = grad_grid.stride(0); + int64_t gGrid_sW = grad_grid.stride(3); + scalar_t *inp_ptr = input.data<scalar_t>(); + scalar_t *grid_ptr = grid.data<scalar_t>(); + scalar_t *gOut_ptr = grad_output.data<scalar_t>(); + scalar_t *gInp_ptr = grad_input.data<scalar_t>(); + scalar_t *gGrid_ptr = grad_grid.data<scalar_t>(); + // loop over each output pixel + #ifdef _OPENMP + #pragma omp parallel for + #endif + for (int64_t n = 0; n < N; ++n) { + scalar_t *grid_ptr_N = grid_ptr + n * grid_sN; + scalar_t *inp_ptr_N = inp_ptr + n * inp_sN; + scalar_t *gGrid_ptr_NDHW = gGrid_ptr + n * gGrid_sN; + for (int64_t d = 0; d < out_D; ++d) { + for (int64_t h = 0; h < out_H; ++h) { + for (int64_t w = 0; w < out_W; ++w, gGrid_ptr_NDHW += gGrid_sW /* grad_grid is contiguous */ ) { + // get the corresponding input x, y, z co-ordinates from grid + scalar_t *grid_ptr_NDHW = grid_ptr_N + d * grid_sD + h * grid_sH + w * grid_sW; + scalar_t ix = *grid_ptr_NDHW; + scalar_t iy = grid_ptr_NDHW[grid_sCoor]; + scalar_t iz = grid_ptr_NDHW[2 * grid_sCoor]; + + // multipliers for gradients on ix, iy, and iz + scalar_t gix_mult, giy_mult, giz_mult; + ix = grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, align_corners, &gix_mult); + iy = grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, align_corners, &giy_mult); + iz = grid_sampler_compute_source_index_set_grad(iz, inp_D, padding_mode, align_corners, &giz_mult); + + if (interpolation_mode == GridSamplerInterpolation::Bilinear) { + // get corner pixel values from (x, y, z) + // for 4d, we used north-east-south-west + // for 5d, we add top-bottom + int64_t ix_tnw = static_cast<int64_t>(std::floor(ix)); + int64_t iy_tnw = static_cast<int64_t>(std::floor(iy)); + int64_t iz_tnw = static_cast<int64_t>(std::floor(iz)); + + int64_t ix_tne = ix_tnw + 1; + int64_t iy_tne = iy_tnw; + int64_t iz_tne = iz_tnw; + + int64_t ix_tsw = ix_tnw; + int64_t iy_tsw = iy_tnw + 1; + int64_t iz_tsw = iz_tnw; + + int64_t ix_tse = ix_tnw + 1; + int64_t iy_tse = iy_tnw + 1; + int64_t iz_tse = iz_tnw; + + int64_t ix_bnw = ix_tnw; + int64_t iy_bnw = iy_tnw; + int64_t iz_bnw = iz_tnw + 1; + + int64_t ix_bne = ix_tnw + 1; + int64_t iy_bne = iy_tnw; + int64_t iz_bne = iz_tnw + 1; + + int64_t ix_bsw = ix_tnw; + int64_t iy_bsw = iy_tnw + 1; + int64_t iz_bsw = iz_tnw + 1; + + int64_t ix_bse = ix_tnw + 1; + int64_t iy_bse = iy_tnw + 1; + int64_t iz_bse = iz_tnw + 1; + + // get surfaces to each neighbor: + scalar_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); + scalar_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); + scalar_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); + scalar_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); + scalar_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); + scalar_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); + scalar_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); + scalar_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); + + scalar_t gix = static_cast<scalar_t>(0), giy = static_cast<scalar_t>(0), giz = static_cast<scalar_t>(0); + scalar_t *gOut_ptr_NCDHW = gOut_ptr + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW; + scalar_t *gInp_ptr_NC = gInp_ptr + n * gInp_sN; + scalar_t *inp_ptr_NC = inp_ptr_N; + // calculate bilinear weighted pixel value and set output pixel + for (int c = 0; c < C; ++c, gOut_ptr_NCDHW += gOut_sC, gInp_ptr_NC += gInp_sC, inp_ptr_NC += inp_sC) { + scalar_t gOut = *gOut_ptr_NCDHW; + + // calculate and set grad_input + safe_add_3d(gInp_ptr_NC, iz_tnw, iy_tnw, ix_tnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tnw * gOut); + safe_add_3d(gInp_ptr_NC, iz_tne, iy_tne, ix_tne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tne * gOut); + safe_add_3d(gInp_ptr_NC, iz_tsw, iy_tsw, ix_tsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tsw * gOut); + safe_add_3d(gInp_ptr_NC, iz_tse, iy_tse, ix_tse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tse * gOut); + safe_add_3d(gInp_ptr_NC, iz_bnw, iy_bnw, ix_bnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bnw * gOut); + safe_add_3d(gInp_ptr_NC, iz_bne, iy_bne, ix_bne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bne * gOut); + safe_add_3d(gInp_ptr_NC, iz_bsw, iy_bsw, ix_bsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bsw * gOut); + safe_add_3d(gInp_ptr_NC, iz_bse, iy_bse, ix_bse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bse * gOut); + + // calculate grad_grid + if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) { + scalar_t tnw_val = inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW]; + gix -= tnw_val * (iy_bse - iy) * (iz_bse - iz) * gOut; + giy -= tnw_val * (ix_bse - ix) * (iz_bse - iz) * gOut; + giz -= tnw_val * (ix_bse - ix) * (iy_bse - iy) * gOut; + } + if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) { + scalar_t tne_val = inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW]; + gix += tne_val * (iy_bsw - iy) * (iz_bsw - iz) * gOut; + giy -= tne_val * (ix - ix_bsw) * (iz_bsw - iz) * gOut; + giz -= tne_val * (ix - ix_bsw) * (iy_bsw - iy) * gOut; + } + if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) { + scalar_t tsw_val = inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW]; + gix -= tsw_val * (iy - iy_bne) * (iz_bne - iz) * gOut; + giy += tsw_val * (ix_bne - ix) * (iz_bne - iz) * gOut; + giz -= tsw_val * (ix_bne - ix) * (iy - iy_bne) * gOut; + } + if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) { + scalar_t tse_val = inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW]; + gix += tse_val * (iy - iy_bnw) * (iz_bnw - iz) * gOut; + giy += tse_val * (ix - ix_bnw) * (iz_bnw - iz) * gOut; + giz -= tse_val * (ix - ix_bnw) * (iy - iy_bnw) * gOut; + } + if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) { + scalar_t bnw_val = inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW]; + gix -= bnw_val * (iy_tse - iy) * (iz - iz_tse) * gOut; + giy -= bnw_val * (ix_tse - ix) * (iz - iz_tse) * gOut; + giz += bnw_val * (ix_tse - ix) * (iy_tse - iy) * gOut; + } + if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) { + scalar_t bne_val = inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW]; + gix += bne_val * (iy_tsw - iy) * (iz - iz_tsw) * gOut; + giy -= bne_val * (ix - ix_tsw) * (iz - iz_tsw) * gOut; + giz += bne_val * (ix - ix_tsw) * (iy_tsw - iy) * gOut; + } + if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) { + scalar_t bsw_val = inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW]; + gix -= bsw_val * (iy - iy_tne) * (iz - iz_tne) * gOut; + giy += bsw_val * (ix_tne - ix) * (iz - iz_tne) * gOut; + giz += bsw_val * (ix_tne - ix) * (iy - iy_tne) * gOut; + } + if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) { + scalar_t bse_val = inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW]; + gix += bse_val * (iy - iy_tnw) * (iz - iz_tnw) * gOut; + giy += bse_val * (ix - ix_tnw) * (iz - iz_tnw) * gOut; + giz += bse_val * (ix - ix_tnw) * (iy - iy_tnw) * gOut; + } + } + + // assuming grad_grid is contiguous + gGrid_ptr_NDHW[0] = gix_mult * gix; + gGrid_ptr_NDHW[1] = giy_mult * giy; + gGrid_ptr_NDHW[2] = giz_mult * giz; + } else if (interpolation_mode == GridSamplerInterpolation::Nearest) { + int64_t ix_nearest = static_cast<int64_t>(std::round(ix)); + int64_t iy_nearest = static_cast<int64_t>(std::round(iy)); + int64_t iz_nearest = static_cast<int64_t>(std::round(iz)); + + // assign nearest neighor pixel value to output pixel + scalar_t *gOut_ptr_NCDHW = gOut_ptr + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW; + scalar_t *gInp_ptr_NC = gInp_ptr + n * gInp_sN; + for (int c = 0; c < C; ++c, gOut_ptr_NCDHW += gOut_sC, gInp_ptr_NC += gInp_sC) { + // calculate and set grad_input + safe_add_3d(gInp_ptr_NC, iz_nearest, iy_nearest, ix_nearest, + gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, *gOut_ptr_NCDHW); + } + } + } + } + } + } + return std::make_tuple(grad_input, grad_grid); + } + +} // namespace + +// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ]. +Tensor grid_sampler_2d_forward_cpu(const Tensor& input, const Tensor& grid, + int64_t interpolation_mode, int64_t padding_mode, + bool align_corners) { + return AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "grid_sampler_2d_forward_cpu", [&] { + return grid_sampler_2d_forward_cpu_impl<scalar_t>( + input, grid, static_cast<GridSamplerInterpolation>(interpolation_mode), + static_cast<GridSamplerPadding>(padding_mode), align_corners); + }); +} + +// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ]. +Tensor grid_sampler_3d_forward_cpu(const Tensor& input, const Tensor& grid, + int64_t interpolation_mode, int64_t padding_mode, + bool align_corners) { + return AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "grid_sampler_3d_forward_cpu", [&] { + return grid_sampler_3d_forward_cpu_impl<scalar_t>( + input, grid, static_cast<GridSamplerInterpolation>(interpolation_mode), + static_cast<GridSamplerPadding>(padding_mode), align_corners); + }); +} + +// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ]. +std::tuple<Tensor, Tensor> +grid_sampler_2d_backward_cpu(const Tensor& grad_output, const Tensor& input, const Tensor& grid, + int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { + return AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "grid_sampler_2d_backward_cpu", [&] { + return grid_sampler_2d_backward_cpu_impl<scalar_t>( + grad_output, input, grid, + static_cast<GridSamplerInterpolation>(interpolation_mode), + static_cast<GridSamplerPadding>(padding_mode), align_corners); + }); +} + +// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ]. +std::tuple<Tensor, Tensor> +grid_sampler_3d_backward_cpu(const Tensor& grad_output, const Tensor& input, const Tensor& grid, + int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { + return AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "grid_sampler_3d_backward_cpu", [&] { + return grid_sampler_3d_backward_cpu_impl<scalar_t>( + grad_output, input, grid, + static_cast<GridSamplerInterpolation>(interpolation_mode), + static_cast<GridSamplerPadding>(padding_mode), align_corners); + }); +} + +} // namespace mmdetection diff --git a/mmdet/ops/grid_sampler/src/cpu/grid_sampler_cpu.h b/mmdet/ops/grid_sampler/src/cpu/grid_sampler_cpu.h new file mode 100644 index 0000000000000000000000000000000000000000..3c9ae45063bf212d715abcf54c0bdccdb23958fc --- /dev/null +++ b/mmdet/ops/grid_sampler/src/cpu/grid_sampler_cpu.h @@ -0,0 +1,225 @@ +// Modified from https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/GridSampler.h + +#pragma once + +#include <ATen/ATen.h> +#include <ATen/NativeFunctions.h> + +namespace mmdetection { + +namespace detail { + + enum class GridSamplerInterpolation {Bilinear, Nearest}; + enum class GridSamplerPadding {Zeros, Border, Reflection}; + +} // namespace detail + +using detail::GridSamplerInterpolation; +using detail::GridSamplerPadding; + +// Unnormalizes a coordinate from the -1 to +1 scale to its pixel index value, +// where we view each pixel as an area between (idx - 0.5) and (idx + 0.5). +// if align_corners: -1 and +1 get sent to the centers of the corner pixels +// -1 --> 0 +// +1 --> (size - 1) +// scale_factor = (size - 1) / 2 +// if not align_corners: -1 and +1 get sent to the image edges +// -1 --> -0.5 +// +1 --> (size - 1) + 0.5 == size - 0.5 +// scale_factor = size / 2 +template <typename scalar_t> +static inline scalar_t grid_sampler_unnormalize(scalar_t coord, int64_t size, + bool align_corners) { + if (align_corners) { + // unnormalize coord from [-1, 1] to [0, size - 1] + return ((coord + 1) / 2) * (size - 1); + } else { + // unnormalize coord from [-1, 1] to [-0.5, size - 0.5] + return ((coord + 1) * size - 1) / 2; + } +} + +// grid_sampler_unnormalize_set_grad works the same as grid_sampler_unnormalize +// except that it also returns the `d output / d input` via pointer argument +// `grad_in`. +// This is useful in the backward pass of grid_sampler. +template <typename scalar_t> +static inline scalar_t grid_sampler_unnormalize_set_grad(scalar_t coord, int64_t size, + bool align_corners, scalar_t *grad_in) { + if (align_corners) { + // unnormalize coord from [-1, 1] to [0, size - 1] + *grad_in = static_cast<scalar_t>(size - 1) / 2; + return ((coord + 1) / 2) * (size - 1); + } else { + // unnormalize coord from [-1, 1] to [-0.5, size - 0.5] + *grad_in = static_cast<scalar_t>(size) / 2; + return ((coord + 1) * size - 1) / 2; + } +} + +// Clips coordinates to between 0 and clip_limit - 1 +template<typename scalar_t> +static inline scalar_t clip_coordinates(scalar_t in, int64_t clip_limit) { + return std::min(static_cast<scalar_t>(clip_limit - 1), std::max(in, static_cast<scalar_t>(0))); +} + +// clip_coordinates_set_grad works similarly to clip_coordinates except that +// it also returns the `d output / d input` via pointer argument `grad_in`. +// This is useful in the backward pass of grid_sampler. +template<typename scalar_t> +static inline scalar_t clip_coordinates_set_grad(scalar_t in, int64_t clip_limit, + scalar_t *grad_in) { + if (in < static_cast<scalar_t>(0)) { + *grad_in = static_cast<scalar_t>(0); + return static_cast<scalar_t>(0); + } else { + scalar_t max = static_cast<scalar_t>(clip_limit - 1); + if (in > max) { + *grad_in = static_cast<scalar_t>(0); + return max; + } else { + *grad_in = static_cast<scalar_t>(1); + return in; + } + } +} + +// Reflects coordinates until they fall between low and high (inclusive). +// The bounds are passed as twice their value so that half-integer values +// can be represented as ints. +template<typename scalar_t> +static inline scalar_t reflect_coordinates(scalar_t in, int64_t twice_low, + int64_t twice_high) { + if (twice_low == twice_high) { + return static_cast<scalar_t>(0); + } + scalar_t min = static_cast<scalar_t>(twice_low) / 2; + scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2; + in = std::fabs(in - min); + // `fmod` returns same sign as `in`, which is positive after the `fabs` above. + scalar_t extra = std::fmod(in, span); + int flips = static_cast<int>(std::floor(in / span)); + if (flips % 2 == 0) { + return extra + min; + } else { + return span - extra + min; + } +} + +// reflect_coordinates_set_grad works similarly to reflect_coordinates except +// that it also returns the `d output / d input` via pointer argument +// `grad_in`. +// This is useful in the backward pass of grid_sampler. +template<typename scalar_t> +static inline scalar_t reflect_coordinates_set_grad(scalar_t in, int64_t twice_low, + int64_t twice_high, scalar_t *grad_in) { + if (twice_low == twice_high) { + *grad_in = static_cast<scalar_t>(0); + return static_cast<scalar_t>(0); + } + int grad_in_mult_; + scalar_t min = static_cast<scalar_t>(twice_low) / 2; + scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2; + in = in - min; + if (in < static_cast<scalar_t>(0)) { + grad_in_mult_ = -1; + in = -in; + } else { + grad_in_mult_ = 1; + } + // `fmod` returns same sign as `in`, which is positive after the `if` above. + scalar_t extra = std::fmod(in, span); + int flips = static_cast<int>(std::floor(in / span)); + if (flips % 2 == 0) { + *grad_in = static_cast<scalar_t>(grad_in_mult_); + return extra + min; + } else { + *grad_in = static_cast<scalar_t>(-grad_in_mult_); + return span - extra + min; + } +} + +// Computes the pixel source index value for a grid coordinate +template <typename scalar_t> +static inline scalar_t grid_sampler_compute_source_index( + scalar_t coord, + int64_t size, + GridSamplerPadding padding_mode, + bool align_corners) { + coord = grid_sampler_unnormalize(coord, size, align_corners); + if (padding_mode == GridSamplerPadding::Border) { + // clip coordinates to image borders + coord = clip_coordinates(coord, size); + } else if (padding_mode == GridSamplerPadding::Reflection) { + // reflect coordinates by image borders + if (align_corners) { + coord = reflect_coordinates(coord, 0, 2*(size - 1)); + } else { + coord = reflect_coordinates(coord, -1, 2*size - 1); + // when align_corners=False, reflection does not auto clip coords + coord = clip_coordinates(coord, size); + } + } + return coord; +} + +// grid_sampler_compute_source_index_set_grad works similarly to +// grid_sampler_compute_source_index except that it also returns the +// `d output / d input` via pointer argument `grad_in`. +// This is useful in the backward pass of grid_sampler. +template <typename scalar_t> +static inline scalar_t grid_sampler_compute_source_index_set_grad( + scalar_t coord, + int64_t size, + GridSamplerPadding padding_mode, + bool align_corners, + scalar_t *grad_in) { + scalar_t grad_clip, grad_refl; + coord = grid_sampler_unnormalize_set_grad(coord, size, align_corners, grad_in); + if (padding_mode == GridSamplerPadding::Border) { + // clip coordinates to image borders + coord = clip_coordinates_set_grad(coord, size, &grad_clip); + *grad_in = (*grad_in) * grad_clip; + } else if (padding_mode == GridSamplerPadding::Reflection) { + // reflect coordinates by image borders + if (align_corners) { + coord = reflect_coordinates_set_grad(coord, 0, 2*(size - 1), &grad_refl); + *grad_in = (*grad_in) * grad_refl; + } else { + coord = reflect_coordinates_set_grad(coord, -1, 2*size - 1, &grad_refl); + // when align_corners=False, reflection does not auto clip coords + coord = clip_coordinates_set_grad(coord, size, &grad_clip); + *grad_in = (*grad_in) * grad_refl * grad_clip; + } + } + return coord; +} + +static inline bool within_bounds_2d(int64_t h, int64_t w, int64_t H, int64_t W) { + return h >= 0 && h < H && w >= 0 && w < W; +} + +static inline bool within_bounds_3d(int64_t d, int64_t h, int64_t w, int64_t D, int64_t H, int64_t W) { + return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W; +} + +template<typename scalar_t> +static inline void safe_add_2d(scalar_t *data, int64_t h, int64_t w, + int64_t sH, int64_t sW, int64_t H, int64_t W, + scalar_t delta) { + if (within_bounds_2d(h, w, H, W)) { + data[h * sH + w * sW] += delta; + } +} + +template<typename scalar_t> +static inline void safe_add_3d(scalar_t *data, int64_t d, int64_t h, int64_t w, + int64_t sD, int64_t sH, int64_t sW, + int64_t D, int64_t H, int64_t W, + scalar_t delta) { + if (within_bounds_3d(d, h, w, D, H, W)) { + data[d * sD + h * sH + w * sW] += delta; + } +} + +} // namespace mmdetection diff --git a/mmdet/ops/grid_sampler/src/cuda/grid_sampler_cuda.cu b/mmdet/ops/grid_sampler/src/cuda/grid_sampler_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..2d747a0b897dda1b0a29998a1825832b6b5eb99c --- /dev/null +++ b/mmdet/ops/grid_sampler/src/cuda/grid_sampler_cuda.cu @@ -0,0 +1,718 @@ +// Modified from https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/GridSampler.cu + +#include <ATen/ATen.h> +#include "grid_sampler_cuda.cuh" +#include <ATen/cuda/CUDAContext.h> +#include <ATen/cuda/CUDAApplyUtils.cuh> +#include <ATen/cuda/detail/TensorInfo.cuh> +#include <ATen/cuda/detail/IndexUtils.cuh> +#include <ATen/cuda/detail/KernelUtils.h> +#include <c10/macros/Macros.h> + +namespace mmdetection { + +using namespace at::cuda::detail; + +using mmdetection::detail::GridSamplerInterpolation; +using mmdetection::detail::GridSamplerPadding; + +namespace { + template <typename scalar_t> + C10_LAUNCH_BOUNDS_1(1024) + __global__ void grid_sampler_2d_forward_kernel_cuda( + const int nthreads, + TensorInfo<scalar_t, int> input, + TensorInfo<scalar_t, int> grid, + TensorInfo<scalar_t, int> output, + const GridSamplerInterpolation interpolation_mode, + const GridSamplerPadding padding_mode, + bool align_corners) { + + int C = input.sizes[1]; + int inp_H = input.sizes[2]; + int inp_W = input.sizes[3]; + int out_H = grid.sizes[1]; + int out_W = grid.sizes[2]; + int inp_sN = input.strides[0]; + int inp_sC = input.strides[1]; + int inp_sH = input.strides[2]; + int inp_sW = input.strides[3]; + int grid_sN = grid.strides[0]; + int grid_sH = grid.strides[1]; + int grid_sW = grid.strides[2]; + int grid_sCoor = grid.strides[3]; + int out_sN = output.strides[0]; + int out_sC = output.strides[1]; + int out_sH = output.strides[2]; + int out_sW = output.strides[3]; + + CUDA_KERNEL_LOOP(index, nthreads) { + const int w = index % out_W; + const int h = (index / out_W) % out_H; + const int n = index / (out_H * out_W); + const int grid_offset = n * grid_sN + h * grid_sH + w * grid_sW; + + // get the corresponding input x, y co-ordinates from grid + scalar_t ix = grid.data[grid_offset]; + scalar_t iy = grid.data[grid_offset + grid_sCoor]; + + ix = grid_sampler_compute_source_index(ix, inp_W, padding_mode, align_corners); + iy = grid_sampler_compute_source_index(iy, inp_H, padding_mode, align_corners); + + if (interpolation_mode == GridSamplerInterpolation::Bilinear) { + // get NE, NW, SE, SW pixel values from (x, y) + int ix_nw = static_cast<int>(::floor(ix)); + int iy_nw = static_cast<int>(::floor(iy)); + int ix_ne = ix_nw + 1; + int iy_ne = iy_nw; + int ix_sw = ix_nw; + int iy_sw = iy_nw + 1; + int ix_se = ix_nw + 1; + int iy_se = iy_nw + 1; + + // get surfaces to each neighbor: + scalar_t nw = (ix_se - ix) * (iy_se - iy); + scalar_t ne = (ix - ix_sw) * (iy_sw - iy); + scalar_t sw = (ix_ne - ix) * (iy - iy_ne); + scalar_t se = (ix - ix_nw) * (iy - iy_nw); + + // calculate bilinear weighted pixel value and set output pixel + auto inp_ptr_NC = input.data + n * inp_sN; + auto out_ptr_NCHW = output.data + n * out_sN + h * out_sH + w * out_sW; + for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) { + *out_ptr_NCHW = static_cast<scalar_t>(0); + if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) { + *out_ptr_NCHW += inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW] * nw; + } + if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) { + *out_ptr_NCHW += inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW] * ne; + } + if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) { + *out_ptr_NCHW += inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW] * sw; + } + if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) { + *out_ptr_NCHW += inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW] * se; + } + } + } else if (interpolation_mode == GridSamplerInterpolation::Nearest) { + int ix_nearest = static_cast<int>(::round(ix)); + int iy_nearest = static_cast<int>(::round(iy)); + + // assign nearest neighor pixel value to output pixel + auto inp_ptr_NC = input.data + n * inp_sN; + auto out_ptr_NCHW = output.data + n * out_sN + h * out_sH + w * out_sW; + for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) { + if (within_bounds_2d(iy_nearest, ix_nearest, inp_H, inp_W)) { + *out_ptr_NCHW = inp_ptr_NC[iy_nearest * inp_sH + ix_nearest * inp_sW]; + } else { + *out_ptr_NCHW = static_cast<scalar_t>(0); + } + } + } + } + } + + template <typename scalar_t> + C10_LAUNCH_BOUNDS_1(1024) + __global__ void grid_sampler_3d_forward_kernel_cuda( + const int nthreads, + TensorInfo<scalar_t, int> input, + TensorInfo<scalar_t, int> grid, + TensorInfo<scalar_t, int> output, + const GridSamplerInterpolation interpolation_mode, + const GridSamplerPadding padding_mode, + bool align_corners) { + + int C = input.sizes[1]; + int inp_D = input.sizes[2]; + int inp_H = input.sizes[3]; + int inp_W = input.sizes[4]; + int out_D = grid.sizes[1]; + int out_H = grid.sizes[2]; + int out_W = grid.sizes[3]; + int inp_sN = input.strides[0]; + int inp_sC = input.strides[1]; + int inp_sD = input.strides[2]; + int inp_sH = input.strides[3]; + int inp_sW = input.strides[4]; + int grid_sN = grid.strides[0]; + int grid_sD = grid.strides[1]; + int grid_sH = grid.strides[2]; + int grid_sW = grid.strides[3]; + int grid_sCoor = grid.strides[4]; + int out_sN = output.strides[0]; + int out_sC = output.strides[1]; + int out_sD = output.strides[2]; + int out_sH = output.strides[3]; + int out_sW = output.strides[4]; + + CUDA_KERNEL_LOOP(index, nthreads) { + const int w = index % out_W; + const int h = (index / out_W) % out_H; + const int d = (index / (out_H * out_W)) % out_D; + const int n = index / (out_D * out_H * out_W); + const int grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW; + + // get the corresponding input x, y, z co-ordinates from grid + scalar_t ix = grid.data[grid_offset]; + scalar_t iy = grid.data[grid_offset + grid_sCoor]; + scalar_t iz = grid.data[grid_offset + 2 * grid_sCoor]; + + ix = grid_sampler_compute_source_index(ix, inp_W, padding_mode, align_corners); + iy = grid_sampler_compute_source_index(iy, inp_H, padding_mode, align_corners); + iz = grid_sampler_compute_source_index(iz, inp_D, padding_mode, align_corners); + + if (interpolation_mode == GridSamplerInterpolation::Bilinear) { + // get corner pixel values from (x, y, z) + // for 4d, we used north-east-south-west + // for 5d, we add top-bottom + int ix_tnw = static_cast<int>(::floor(ix)); + int iy_tnw = static_cast<int>(::floor(iy)); + int iz_tnw = static_cast<int>(::floor(iz)); + + int ix_tne = ix_tnw + 1; + int iy_tne = iy_tnw; + int iz_tne = iz_tnw; + + int ix_tsw = ix_tnw; + int iy_tsw = iy_tnw + 1; + int iz_tsw = iz_tnw; + + int ix_tse = ix_tnw + 1; + int iy_tse = iy_tnw + 1; + int iz_tse = iz_tnw; + + int ix_bnw = ix_tnw; + int iy_bnw = iy_tnw; + int iz_bnw = iz_tnw + 1; + + int ix_bne = ix_tnw + 1; + int iy_bne = iy_tnw; + int iz_bne = iz_tnw + 1; + + int ix_bsw = ix_tnw; + int iy_bsw = iy_tnw + 1; + int iz_bsw = iz_tnw + 1; + + int ix_bse = ix_tnw + 1; + int iy_bse = iy_tnw + 1; + int iz_bse = iz_tnw + 1; + + // get surfaces to each neighbor: + scalar_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); + scalar_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); + scalar_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); + scalar_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); + scalar_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); + scalar_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); + scalar_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); + scalar_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); + + auto inp_ptr_NC = input.data + n * inp_sN; + auto out_ptr_NCDHW = output.data + n * out_sN + d * out_sD + h * out_sH + w * out_sW; + for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) { + // (c, iz_tnw, iy_tnw, ix_tnw) * tnw + (c, iz_tne, iy_tne, ix_tne) * tne + // + (c, iz_tsw, iy_tsw, ix_tsw) * tsw + (c, iz_tse, iy_tse, ix_tse) * tse + // + (c, iz_bnw, iy_bnw, ix_bnw) * bnw + (c, iz_bne, iy_bne, ix_bne) * bne + // + (c, iz_bsw, iy_bsw, ix_bsw) * bsw + (c, iz_bse, iy_bse, ix_bse) * bse + *out_ptr_NCDHW = static_cast<scalar_t>(0); + if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw; + } + if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne; + } + if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw; + } + if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse; + } + if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw; + } + if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne; + } + if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw; + } + if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse; + } + } + } else if (interpolation_mode == GridSamplerInterpolation::Nearest) { + int ix_nearest = static_cast<int>(::round(ix)); + int iy_nearest = static_cast<int>(::round(iy)); + int iz_nearest = static_cast<int>(::round(iz)); + + // assign nearest neighor pixel value to output pixel + auto inp_ptr_NC = input.data + n * inp_sN; + auto out_ptr_NCDHW = output.data + n * out_sN + d * out_sD + h * out_sH + w * out_sW; + for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) { + if (within_bounds_3d(iz_nearest, iy_nearest, ix_nearest, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW = inp_ptr_NC[iz_nearest * inp_sD + iy_nearest * inp_sH + ix_nearest * inp_sW]; + } else { + *out_ptr_NCDHW = static_cast<scalar_t>(0); + } + } + } + } + } + + template <typename scalar_t> + C10_LAUNCH_BOUNDS_1(1024) + __global__ void grid_sampler_2d_backward_kernel_cuda( + const int nthreads, + TensorInfo<scalar_t, int> grad_output, + TensorInfo<scalar_t, int> input, + TensorInfo<scalar_t, int> grid, + TensorInfo<scalar_t, int> grad_input, // initialized to zeros + TensorInfo<scalar_t, int> grad_grid, // initialized to empty + const GridSamplerInterpolation interpolation_mode, + const GridSamplerPadding padding_mode, + bool align_corners) { + + int C = input.sizes[1]; + int inp_H = input.sizes[2]; + int inp_W = input.sizes[3]; + int out_H = grid.sizes[1]; + int out_W = grid.sizes[2]; + int inp_sN = input.strides[0]; + int inp_sC = input.strides[1]; + int inp_sH = input.strides[2]; + int inp_sW = input.strides[3]; + int grid_sN = grid.strides[0]; + int grid_sH = grid.strides[1]; + int grid_sW = grid.strides[2]; + int grid_sCoor = grid.strides[3]; + int gOut_sN = grad_output.strides[0]; + int gOut_sC = grad_output.strides[1]; + int gOut_sH = grad_output.strides[2]; + int gOut_sW = grad_output.strides[3]; + int gInp_sN = grad_input.strides[0]; + int gInp_sC = grad_input.strides[1]; + int gInp_sH = grad_input.strides[2]; + int gInp_sW = grad_input.strides[3]; + int gGrid_sW = grad_grid.strides[2]; + + CUDA_KERNEL_LOOP(index, nthreads) { + const int w = index % out_W; + const int h = (index / out_W) % out_H; + const int n = index / (out_H * out_W); + const int grid_offset = n * grid_sN + h * grid_sH + w * grid_sW; + + // get the corresponding input x, y co-ordinates from grid + scalar_t ix = grid.data[grid_offset]; + scalar_t iy = grid.data[grid_offset + grid_sCoor]; + + // multipliers for gradients on ix and iy + scalar_t gix_mult, giy_mult; + ix = grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, align_corners, &gix_mult); + iy = grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, align_corners, &giy_mult); + + if (interpolation_mode == GridSamplerInterpolation::Bilinear) { + // get NE, NW, SE, SW pixel values from (x, y) + int ix_nw = static_cast<int>(::floor(ix)); + int iy_nw = static_cast<int>(::floor(iy)); + int ix_ne = ix_nw + 1; + int iy_ne = iy_nw; + int ix_sw = ix_nw; + int iy_sw = iy_nw + 1; + int ix_se = ix_nw + 1; + int iy_se = iy_nw + 1; + + // get surfaces to each neighbor: + scalar_t nw = (ix_se - ix) * (iy_se - iy); + scalar_t ne = (ix - ix_sw) * (iy_sw - iy); + scalar_t sw = (ix_ne - ix) * (iy - iy_ne); + scalar_t se = (ix - ix_nw) * (iy - iy_nw); + + scalar_t gix = static_cast<scalar_t>(0), giy = static_cast<scalar_t>(0); + scalar_t *gOut_ptr_NCHW = grad_output.data + n * gOut_sN + h * gOut_sH + w * gOut_sW; + scalar_t *gInp_ptr_NC = grad_input.data + n * gInp_sN; + scalar_t *inp_ptr_NC = input.data + n * inp_sN; + for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, gInp_ptr_NC += gInp_sC, gOut_ptr_NCHW += gOut_sC) { + scalar_t gOut = *gOut_ptr_NCHW; + + // calculate and set grad_input + safe_add_2d(gInp_ptr_NC, iy_nw, ix_nw, gInp_sH, gInp_sW, inp_H, inp_W, nw * gOut); + safe_add_2d(gInp_ptr_NC, iy_ne, ix_ne, gInp_sH, gInp_sW, inp_H, inp_W, ne * gOut); + safe_add_2d(gInp_ptr_NC, iy_sw, ix_sw, gInp_sH, gInp_sW, inp_H, inp_W, sw * gOut); + safe_add_2d(gInp_ptr_NC, iy_se, ix_se, gInp_sH, gInp_sW, inp_H, inp_W, se * gOut); + + // calculate grad_grid + if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) { + scalar_t nw_val = inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW]; + gix -= nw_val * (iy_se - iy) * gOut; + giy -= nw_val * (ix_se - ix) * gOut; + } + if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) { + scalar_t ne_val = inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW]; + gix += ne_val * (iy_sw - iy) * gOut; + giy -= ne_val * (ix - ix_sw) * gOut; + } + if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) { + scalar_t sw_val = inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW]; + gix -= sw_val * (iy - iy_ne) * gOut; + giy += sw_val * (ix_ne - ix) * gOut; + } + if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) { + scalar_t se_val = inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW]; + gix += se_val * (iy - iy_nw) * gOut; + giy += se_val * (ix - ix_nw) * gOut; + } + } + + // assuming grad_grid is contiguous + // thus we can + // 1. use index with gGrid_sW to directly compute gGrid_ptr_NHW + // 2. directly assign to gGrid_ptr_NHW[0], gGrid_ptr_NHW[1] + scalar_t *gGrid_ptr_NHW = grad_grid.data + index * gGrid_sW; + gGrid_ptr_NHW[0] = gix_mult * gix; + gGrid_ptr_NHW[1] = giy_mult * giy; + } else if (interpolation_mode == GridSamplerInterpolation::Nearest) { + int ix_nearest = static_cast<int>(::round(ix)); + int iy_nearest = static_cast<int>(::round(iy)); + + // assign nearest neighor pixel value to output pixel + scalar_t *gOut_ptr_NCHW = grad_output.data + n * gOut_sN + h * gOut_sH + w * gOut_sW; + scalar_t *gInp_ptr_NC = grad_input.data + n * gInp_sN; + for (int c = 0; c < C; ++c, gInp_ptr_NC += gInp_sC, gOut_ptr_NCHW += gOut_sC) { + // calculate and set grad_input + safe_add_2d(gInp_ptr_NC, iy_nearest, ix_nearest, gInp_sH, gInp_sW, inp_H, inp_W, *gOut_ptr_NCHW); + } + + // assuming grad_grid is contiguous + // thus we can + // 1. use index with gGrid_sW to directly compute gGrid_ptr_NHW + // 2. directly assign to gGrid_ptr_NHW[0], gGrid_ptr_NHW[1] + scalar_t *gGrid_ptr_NHW = grad_grid.data + index * gGrid_sW; + gGrid_ptr_NHW[0] = static_cast<scalar_t>(0); + gGrid_ptr_NHW[1] = static_cast<scalar_t>(0); + } + } + } + + template <typename scalar_t> + C10_LAUNCH_BOUNDS_1(1024) + __global__ void grid_sampler_3d_backward_kernel_cuda( + const int nthreads, + TensorInfo<scalar_t, int> grad_output, + TensorInfo<scalar_t, int> input, + TensorInfo<scalar_t, int> grid, + TensorInfo<scalar_t, int> grad_input, // initialized to zeros + TensorInfo<scalar_t, int> grad_grid, // initialized to empty + const GridSamplerInterpolation interpolation_mode, + const GridSamplerPadding padding_mode, + bool align_corners) { + + int C = input.sizes[1]; + int inp_D = input.sizes[2]; + int inp_H = input.sizes[3]; + int inp_W = input.sizes[4]; + int out_D = grid.sizes[1]; + int out_H = grid.sizes[2]; + int out_W = grid.sizes[3]; + int inp_sN = input.strides[0]; + int inp_sC = input.strides[1]; + int inp_sD = input.strides[2]; + int inp_sH = input.strides[3]; + int inp_sW = input.strides[4]; + int grid_sN = grid.strides[0]; + int grid_sD = grid.strides[1]; + int grid_sH = grid.strides[2]; + int grid_sW = grid.strides[3]; + int grid_sCoor = grid.strides[4]; + int gOut_sN = grad_output.strides[0]; + int gOut_sC = grad_output.strides[1]; + int gOut_sD = grad_output.strides[2]; + int gOut_sH = grad_output.strides[3]; + int gOut_sW = grad_output.strides[4]; + int gInp_sN = grad_input.strides[0]; + int gInp_sC = grad_input.strides[1]; + int gInp_sD = grad_input.strides[2]; + int gInp_sH = grad_input.strides[3]; + int gInp_sW = grad_input.strides[4]; + int gGrid_sW = grad_grid.strides[3]; + + CUDA_KERNEL_LOOP(index, nthreads) { + const int w = index % out_W; + const int h = (index / out_W) % out_H; + const int d = (index / (out_H * out_W)) % out_D; + const int n = index / (out_D * out_H * out_W); + const int grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW; + + // get the corresponding input x, y, z co-ordinates from grid + scalar_t ix = grid.data[grid_offset]; + scalar_t iy = grid.data[grid_offset + grid_sCoor]; + scalar_t iz = grid.data[grid_offset + 2 * grid_sCoor]; + + // multipliers for gradients on ix, iy, and iz + scalar_t gix_mult, giy_mult, giz_mult; + ix = grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, align_corners, &gix_mult); + iy = grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, align_corners, &giy_mult); + iz = grid_sampler_compute_source_index_set_grad(iz, inp_D, padding_mode, align_corners, &giz_mult); + + if (interpolation_mode == GridSamplerInterpolation::Bilinear) { + // get corner pixel values from (x, y, z) + // for 4d, we used north-east-south-west + // for 5d, we add top-bottom + int ix_tnw = static_cast<int>(::floor(ix)); + int iy_tnw = static_cast<int>(::floor(iy)); + int iz_tnw = static_cast<int>(::floor(iz)); + + int ix_tne = ix_tnw + 1; + int iy_tne = iy_tnw; + int iz_tne = iz_tnw; + + int ix_tsw = ix_tnw; + int iy_tsw = iy_tnw + 1; + int iz_tsw = iz_tnw; + + int ix_tse = ix_tnw + 1; + int iy_tse = iy_tnw + 1; + int iz_tse = iz_tnw; + + int ix_bnw = ix_tnw; + int iy_bnw = iy_tnw; + int iz_bnw = iz_tnw + 1; + + int ix_bne = ix_tnw + 1; + int iy_bne = iy_tnw; + int iz_bne = iz_tnw + 1; + + int ix_bsw = ix_tnw; + int iy_bsw = iy_tnw + 1; + int iz_bsw = iz_tnw + 1; + + int ix_bse = ix_tnw + 1; + int iy_bse = iy_tnw + 1; + int iz_bse = iz_tnw + 1; + + // get surfaces to each neighbor: + scalar_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); + scalar_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); + scalar_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); + scalar_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); + scalar_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); + scalar_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); + scalar_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); + scalar_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); + + scalar_t gix = static_cast<scalar_t>(0), giy = static_cast<scalar_t>(0), giz = static_cast<scalar_t>(0); + scalar_t *gOut_ptr_NCDHW = grad_output.data + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW; + scalar_t *gInp_ptr_NC = grad_input.data + n * gInp_sN; + scalar_t *inp_ptr_NC = input.data + n * inp_sN; + // calculate bilinear weighted pixel value and set output pixel + for (int c = 0; c < C; ++c, gOut_ptr_NCDHW += gOut_sC, gInp_ptr_NC += gInp_sC, inp_ptr_NC += inp_sC) { + scalar_t gOut = *gOut_ptr_NCDHW; + + // calculate and set grad_input + safe_add_3d(gInp_ptr_NC, iz_tnw, iy_tnw, ix_tnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tnw * gOut); + safe_add_3d(gInp_ptr_NC, iz_tne, iy_tne, ix_tne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tne * gOut); + safe_add_3d(gInp_ptr_NC, iz_tsw, iy_tsw, ix_tsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tsw * gOut); + safe_add_3d(gInp_ptr_NC, iz_tse, iy_tse, ix_tse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tse * gOut); + safe_add_3d(gInp_ptr_NC, iz_bnw, iy_bnw, ix_bnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bnw * gOut); + safe_add_3d(gInp_ptr_NC, iz_bne, iy_bne, ix_bne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bne * gOut); + safe_add_3d(gInp_ptr_NC, iz_bsw, iy_bsw, ix_bsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bsw * gOut); + safe_add_3d(gInp_ptr_NC, iz_bse, iy_bse, ix_bse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bse * gOut); + + // calculate grad_grid + if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) { + scalar_t tnw_val = inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW]; + gix -= tnw_val * (iy_bse - iy) * (iz_bse - iz) * gOut; + giy -= tnw_val * (ix_bse - ix) * (iz_bse - iz) * gOut; + giz -= tnw_val * (ix_bse - ix) * (iy_bse - iy) * gOut; + } + if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) { + scalar_t tne_val = inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW]; + gix += tne_val * (iy_bsw - iy) * (iz_bsw - iz) * gOut; + giy -= tne_val * (ix - ix_bsw) * (iz_bsw - iz) * gOut; + giz -= tne_val * (ix - ix_bsw) * (iy_bsw - iy) * gOut; + } + if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) { + scalar_t tsw_val = inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW]; + gix -= tsw_val * (iy - iy_bne) * (iz_bne - iz) * gOut; + giy += tsw_val * (ix_bne - ix) * (iz_bne - iz) * gOut; + giz -= tsw_val * (ix_bne - ix) * (iy - iy_bne) * gOut; + } + if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) { + scalar_t tse_val = inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW]; + gix += tse_val * (iy - iy_bnw) * (iz_bnw - iz) * gOut; + giy += tse_val * (ix - ix_bnw) * (iz_bnw - iz) * gOut; + giz -= tse_val * (ix - ix_bnw) * (iy - iy_bnw) * gOut; + } + if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) { + scalar_t bnw_val = inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW]; + gix -= bnw_val * (iy_tse - iy) * (iz - iz_tse) * gOut; + giy -= bnw_val * (ix_tse - ix) * (iz - iz_tse) * gOut; + giz += bnw_val * (ix_tse - ix) * (iy_tse - iy) * gOut; + } + if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) { + scalar_t bne_val = inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW]; + gix += bne_val * (iy_tsw - iy) * (iz - iz_tsw) * gOut; + giy -= bne_val * (ix - ix_tsw) * (iz - iz_tsw) * gOut; + giz += bne_val * (ix - ix_tsw) * (iy_tsw - iy) * gOut; + } + if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) { + scalar_t bsw_val = inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW]; + gix -= bsw_val * (iy - iy_tne) * (iz - iz_tne) * gOut; + giy += bsw_val * (ix_tne - ix) * (iz - iz_tne) * gOut; + giz += bsw_val * (ix_tne - ix) * (iy - iy_tne) * gOut; + } + if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) { + scalar_t bse_val = inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW]; + gix += bse_val * (iy - iy_tnw) * (iz - iz_tnw) * gOut; + giy += bse_val * (ix - ix_tnw) * (iz - iz_tnw) * gOut; + giz += bse_val * (ix - ix_tnw) * (iy - iy_tnw) * gOut; + } + } + + // assuming grad_grid is contiguous + // thus we can + // 1. use index with gGrid_sW to directly compute gGrid_ptr_NDHW + // 2. directly assign to gGrid_ptr_NDHW[0], gGrid_ptr_NDHW[1], gGrid_ptr_NDHW[2] + scalar_t *gGrid_ptr_NDHW = grad_grid.data + index * gGrid_sW; + gGrid_ptr_NDHW[0] = gix_mult * gix; + gGrid_ptr_NDHW[1] = giy_mult * giy; + gGrid_ptr_NDHW[2] = giz_mult * giz; + } else if (interpolation_mode == GridSamplerInterpolation::Nearest) { + int ix_nearest = static_cast<int>(::round(ix)); + int iy_nearest = static_cast<int>(::round(iy)); + int iz_nearest = static_cast<int>(::round(iz)); + + // assign nearest neighor pixel value to output pixel + scalar_t *gOut_ptr_NCDHW = grad_output.data + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW; + scalar_t *gInp_ptr_NC = grad_input.data + n * gInp_sN; + for (int c = 0; c < C; ++c, gOut_ptr_NCDHW += gOut_sC, gInp_ptr_NC += gInp_sC) { + // calculate and set grad_input + safe_add_3d(gInp_ptr_NC, iz_nearest, iy_nearest, ix_nearest, + gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, *gOut_ptr_NCDHW); + } + + // assuming grad_grid is contiguous + // thus we can + // 1. use index with gGrid_sW to directly compute gGrid_ptr_NDHW + // 2. directly assign to gGrid_ptr_NDHW[0], gGrid_ptr_NDHW[1], gGrid_ptr_NDHW[2] + scalar_t *gGrid_ptr_NDHW = grad_grid.data + index * gGrid_sW; + gGrid_ptr_NDHW[0] = static_cast<scalar_t>(0); + gGrid_ptr_NDHW[1] = static_cast<scalar_t>(0); + gGrid_ptr_NDHW[2] = static_cast<scalar_t>(0); + } + } + } +} // namespace + +using namespace at; +// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ]. +Tensor grid_sampler_2d_forward_cuda(const Tensor& input, const Tensor& grid, + int64_t interpolation_mode, int64_t padding_mode, + bool align_corners) { + auto N = input.size(0); + auto H = grid.size(1); + auto W = grid.size(2); + auto output = at::empty({N, input.size(1), H, W}, input.options()); + int count = static_cast<int>(N * H * W); + if (count > 0) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_2d_forward_cuda", [&] { + grid_sampler_2d_forward_kernel_cuda<scalar_t> + <<<GET_BLOCKS(count), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>( + count, + getTensorInfo<scalar_t, int>(input), + getTensorInfo<scalar_t, int>(grid), + getTensorInfo<scalar_t, int>(output), + static_cast<GridSamplerInterpolation>(interpolation_mode), + static_cast<GridSamplerPadding>(padding_mode), + align_corners); + }); + } + return output; +} + +// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ]. +Tensor grid_sampler_3d_forward_cuda(const Tensor& input, const Tensor& grid, + int64_t interpolation_mode, int64_t padding_mode, + bool align_corners) { + auto N = input.size(0); + auto D = grid.size(1); + auto H = grid.size(2); + auto W = grid.size(3); + auto output = at::empty({N, input.size(1), D, H, W}, input.options()); + int count = static_cast<int>(N * D * H * W); + if (count > 0) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_3d_forward_cuda", [&] { + grid_sampler_3d_forward_kernel_cuda<scalar_t> + <<<GET_BLOCKS(count), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>( + count, + getTensorInfo<scalar_t, int>(input), + getTensorInfo<scalar_t, int>(grid), + getTensorInfo<scalar_t, int>(output), + static_cast<GridSamplerInterpolation>(interpolation_mode), + static_cast<GridSamplerPadding>(padding_mode), + align_corners); + }); + } + return output; +} + +// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ]. +std::tuple<Tensor, Tensor> +grid_sampler_2d_backward_cuda(const Tensor& grad_output, const Tensor& input, + const Tensor& grid, int64_t interpolation_mode, + int64_t padding_mode, bool align_corners) { + auto N = input.size(0); + auto H = grid.size(1); + auto W = grid.size(2); + auto grad_input = at::zeros_like(input); + auto grad_grid = at::empty_like(grid); + int count = static_cast<int>(N * H * W); + if (count > 0) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_2d_backward_cuda", [&] { + grid_sampler_2d_backward_kernel_cuda<scalar_t> + <<<GET_BLOCKS(count), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>( + count, + getTensorInfo<scalar_t, int>(grad_output), + getTensorInfo<scalar_t, int>(input), + getTensorInfo<scalar_t, int>(grid), + getTensorInfo<scalar_t, int>(grad_input), + getTensorInfo<scalar_t, int>(grad_grid), + static_cast<GridSamplerInterpolation>(interpolation_mode), + static_cast<GridSamplerPadding>(padding_mode), + align_corners); + }); + } + return std::make_tuple(grad_input, grad_grid); +} + +// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ]. +std::tuple<Tensor, Tensor> +grid_sampler_3d_backward_cuda(const Tensor& grad_output, const Tensor& input, + const Tensor& grid, int64_t interpolation_mode, int64_t padding_mode, + bool align_corners) { + auto N = input.size(0); + auto D = grid.size(1); + auto H = grid.size(2); + auto W = grid.size(3); + auto grad_input = at::zeros_like(input); + auto grad_grid = at::empty_like(grid); + int count = static_cast<int>(N * D * H * W); + if (count > 0) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_3d_backward_cuda", [&] { + grid_sampler_3d_backward_kernel_cuda<scalar_t> + <<<GET_BLOCKS(count), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>( + count, + getTensorInfo<scalar_t, int>(grad_output), + getTensorInfo<scalar_t, int>(input), + getTensorInfo<scalar_t, int>(grid), + getTensorInfo<scalar_t, int>(grad_input), + getTensorInfo<scalar_t, int>(grad_grid), + static_cast<GridSamplerInterpolation>(interpolation_mode), + static_cast<GridSamplerPadding>(padding_mode), + align_corners); + }); + } + return std::make_tuple(grad_input, grad_grid); +} + +} // namespace mmdetection diff --git a/mmdet/ops/grid_sampler/src/cuda/grid_sampler_cuda.cuh b/mmdet/ops/grid_sampler/src/cuda/grid_sampler_cuda.cuh new file mode 100644 index 0000000000000000000000000000000000000000..a84fa7c076ecd8302aacddf6c350196cc5ce964e --- /dev/null +++ b/mmdet/ops/grid_sampler/src/cuda/grid_sampler_cuda.cuh @@ -0,0 +1,233 @@ +// Modified from https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/GridSampler.cuh + +#include <ATen/ATen.h> +#include <ATen/NativeFunctions.h> +#include <ATen/cuda/CUDAApplyUtils.cuh> +#include <THC/THCAtomics.cuh> + +namespace mmdetection { + +namespace detail { + + enum class GridSamplerInterpolation {Bilinear, Nearest}; + enum class GridSamplerPadding {Zeros, Border, Reflection}; + +} // namespace detail + +using detail::GridSamplerInterpolation; +using detail::GridSamplerPadding; + +// Unnormalizes a coordinate from the -1 to +1 scale to its pixel index value, +// where we view each pixel as an area between (idx - 0.5) and (idx + 0.5). +// if align_corners: -1 and +1 get sent to the centers of the corner pixels +// -1 --> 0 +// +1 --> (size - 1) +// scale_factor = (size - 1) / 2 +// if not align_corners: -1 and +1 get sent to the image edges +// -1 --> -0.5 +// +1 --> (size - 1) + 0.5 == size - 0.5 +// scale_factor = size / 2 +template <typename scalar_t> +static __forceinline__ __device__ +scalar_t grid_sampler_unnormalize(scalar_t coord, int size, bool align_corners) { + if (align_corners) { + // unnormalize coord from [-1, 1] to [0, size - 1] + return ((coord + 1.f) / 2) * (size - 1); + } else { + // unnormalize coord from [-1, 1] to [-0.5, size - 0.5] + return ((coord + 1.f) * size - 1) / 2; + } +} + +// grid_sampler_unnormalize_set_grad works the same as grid_sampler_unnormalize +// except that it also returns the `d output / d input` via pointer argument +// `grad_in`. +// This is useful in the backward pass of grid_sampler. +template <typename scalar_t> +static __forceinline__ __device__ +scalar_t grid_sampler_unnormalize_set_grad(scalar_t coord, int size, + bool align_corners, scalar_t *grad_in) { + if (align_corners) { + // unnormalize coord from [-1, 1] to [0, size - 1] + *grad_in = static_cast<scalar_t>(size - 1) / 2; + return ((coord + 1.f) / 2) * (size - 1); + } else { + // unnormalize coord from [-1, 1] to [-0.5, size - 0.5] + *grad_in = static_cast<scalar_t>(size) / 2; + return ((coord + 1.f) * size - 1) / 2; + } +} + +// Clips coordinates to between 0 and clip_limit - 1 +template <typename scalar_t> +static __forceinline__ __device__ +scalar_t clip_coordinates(scalar_t in, int clip_limit) { + return ::min(static_cast<scalar_t>(clip_limit - 1), ::max(in, static_cast<scalar_t>(0))); +} + +// clip_coordinates_set_grad works similarly to clip_coordinates except that +// it also returns the `d output / d input` via pointer argument `grad_in`. +// This is useful in the backward pass of grid_sampler. +template <typename scalar_t> +static __forceinline__ __device__ +scalar_t clip_coordinates_set_grad(scalar_t in, int clip_limit, scalar_t *grad_in) { + if (in < static_cast<scalar_t>(0)) { + *grad_in = static_cast<scalar_t>(0); + return static_cast<scalar_t>(0); + } else { + scalar_t max = static_cast<scalar_t>(clip_limit - 1); + if (in > max) { + *grad_in = static_cast<scalar_t>(0); + return max; + } else { + *grad_in = static_cast<scalar_t>(1); + return in; + } + } +} + +// Reflects coordinates until they fall between low and high (inclusive). +// The bounds are passed as twice their value so that half-integer values +// can be represented as ints. +template <typename scalar_t> +static __forceinline__ __device__ +scalar_t reflect_coordinates(scalar_t in, int twice_low, int twice_high) { + if (twice_low == twice_high) { + return static_cast<scalar_t>(0); + } + scalar_t min = static_cast<scalar_t>(twice_low) / 2; + scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2; + in = ::fabs(in - min); + // `fmod` returns same sign as `in`, which is positive after the `fabs` above. + scalar_t extra = ::fmod(in, span); + int flips = static_cast<int>(::floor(in / span)); + if (flips % 2 == 0) { + return extra + min; + } else { + return span - extra + min; + } +} + +// reflect_coordinates_set_grad works similarly to reflect_coordinates except +// that it also returns the `d output / d input` via pointer argument +// `grad_in`. +// This is useful in the backward pass of grid_sampler. +template <typename scalar_t> +static __forceinline__ __device__ +scalar_t reflect_coordinates_set_grad(scalar_t in, int twice_low, int twice_high, + scalar_t *grad_in) { + if (twice_low == twice_high) { + *grad_in = static_cast<scalar_t>(0); + return static_cast<scalar_t>(0); + } + int grad_in_mult_; + scalar_t min = static_cast<scalar_t>(twice_low) / 2; + scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2; + in = in - min; + if (in < static_cast<scalar_t>(0)) { + grad_in_mult_ = -1; + in = -in; + } else { + grad_in_mult_ = 1; + } + // `fmod` returns same sign as `in`, which is positive after the `if` above. + scalar_t extra = ::fmod(in, span); + int flips = static_cast<int>(::floor(in / span)); + if (flips % 2 == 0) { + *grad_in = static_cast<scalar_t>(grad_in_mult_); + return extra + min; + } else { + *grad_in = static_cast<scalar_t>(-grad_in_mult_); + return span - extra + min; + } +} + +// Computes the pixel source index value for a grid coordinate +template <typename scalar_t> +static __forceinline__ __device__ +scalar_t grid_sampler_compute_source_index( + scalar_t coord, + int size, + GridSamplerPadding padding_mode, + bool align_corners) { + coord = grid_sampler_unnormalize(coord, size, align_corners); + if (padding_mode == GridSamplerPadding::Border) { + // clip coordinates to image borders + coord = clip_coordinates(coord, size); + } else if (padding_mode == GridSamplerPadding::Reflection) { + // reflect coordinates by image borders + if (align_corners) { + coord = reflect_coordinates(coord, 0, 2*(size - 1)); + } else { + coord = reflect_coordinates(coord, -1, 2*size - 1); + // when align_corners=False, reflection does not auto clip coords + coord = clip_coordinates(coord, size); + } + } + return coord; +} + +// grid_sampler_compute_source_index_set_grad works similarly to +// grid_sampler_compute_source_index except that it also returns the +// `d output / d input` via pointer argument `grad_in`. +// This is useful in the backward pass of grid_sampler. +template <typename scalar_t> +static __forceinline__ __device__ +scalar_t grid_sampler_compute_source_index_set_grad( + scalar_t coord, + int size, + GridSamplerPadding padding_mode, + bool align_corners, + scalar_t *grad_in) { + scalar_t grad_clip, grad_refl; + coord = grid_sampler_unnormalize_set_grad(coord, size, align_corners, grad_in); + if (padding_mode == GridSamplerPadding::Border) { + // clip coordinates to image borders + coord = clip_coordinates_set_grad(coord, size, &grad_clip); + *grad_in = (*grad_in) * grad_clip; + } else if (padding_mode == GridSamplerPadding::Reflection) { + // reflect coordinates by image borders + if (align_corners) { + coord = reflect_coordinates_set_grad(coord, 0, 2*(size - 1), &grad_refl); + *grad_in = (*grad_in) * grad_refl; + } else { + coord = reflect_coordinates_set_grad(coord, -1, 2*size - 1, &grad_refl); + // when align_corners=False, reflection does not auto clip coords + coord = clip_coordinates_set_grad(coord, size, &grad_clip); + *grad_in = (*grad_in) * grad_refl * grad_clip; + } + } + return coord; +} + +static __forceinline__ __device__ +bool within_bounds_2d(int h, int w, int H, int W) { + return h >= 0 && h < H && w >= 0 && w < W; +} + +static __forceinline__ __device__ +bool within_bounds_3d(int d, int h, int w, int D, int H, int W) { + return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W; +} + +template<typename scalar_t> +static __forceinline__ __device__ +void safe_add_2d(scalar_t *data, int h, int w, + int sH, int sW, int H, int W, + scalar_t delta) { + if (within_bounds_2d(h, w, H, W)) { + atomicAdd(data + h * sH + w * sW, delta); + } +} + +template<typename scalar_t> +static __forceinline__ __device__ +void safe_add_3d(scalar_t *data, int d, int h, int w, + int sD, int sH, int sW, int D, int H, int W, + scalar_t delta) { + if (within_bounds_3d(d, h, w, D, H, W)) { + atomicAdd(data + d * sD + h * sH + w * sW, delta); + } +} + +} // namespace at::mmdetection diff --git a/mmdet/ops/grid_sampler/src/cudnn/grid_sampler_cudnn.cpp b/mmdet/ops/grid_sampler/src/cudnn/grid_sampler_cudnn.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f684abf12914ec3bdd02651013a3ccb8c3d32a48 --- /dev/null +++ b/mmdet/ops/grid_sampler/src/cudnn/grid_sampler_cudnn.cpp @@ -0,0 +1,148 @@ +#include <ATen/ATen.h> +#include <ATen/NativeFunctions.h> +#include <ATen/Config.h> +#include <ATen/cuda/CUDAConfig.h> + +#if !AT_CUDNN_ENABLED() + +namespace at { namespace native { + +// See Note [ATen preprocessor philosophy] + +Tensor cudnn_grid_sampler_forward( + const Tensor& input_t, const Tensor& grid_t) { + AT_ERROR("cudnn_grid_sampler_forward: ATen not compiled with cuDNN support"); +} + +std::tuple<Tensor, Tensor> cudnn_grid_sampler_backward( + const Tensor& input_t, const Tensor& grid_t, + const Tensor& grad_output_t) { + AT_ERROR("cudnn_grid_sampler_backward: ATen not compiled with cuDNN support"); +} + +}} + +#else // AT_CUDNN_ENABLED + +#include <ATen/cudnn/Descriptors.h> +#include <ATen/cudnn/Types.h> +#include <ATen/cudnn/Utils.h> +#include <ATen/cuda/Exceptions.h> + +#include <ATen/TensorUtils.h> + +// TODO: descriptor checking + + +namespace mmdetection { + +using namespace at; + +namespace { + +void setSamplerDescriptor(SpatialTransformerDescriptor& desc, cudnnDataType_t dataType, const at::Tensor& tensor) +{ + int inputSize[4] = {0}; + for (int i = 0; i < tensor.dim(); ++i) { + inputSize[i] = (int) tensor.size(i); + } + desc.set(dataType, 4, inputSize); +} + +void checkGridSize(CheckedFrom c, TensorArg grid, TensorArg input) +{ + // assert size of grid is n*h*w*2 + // FYI: grid is between [-1, 1], where -1 left most pixel, + // 1 represents right most pixel (and hence 0 is the center pixel) + // if grid has values >1 or <-1, those values are ignored + checkContiguous(c, grid); + checkDim(c, grid, 4); + // TODO: Maybe more user friendly to report where the expected size + // came from + checkSize(c, grid, 0, input->size(0)); + checkSize(c, grid, 3, 2); +} + +} // namespace + +Tensor cudnn_grid_sampler_forward( + const Tensor& input_t, const Tensor& grid_t) +{ + TensorArg input{ contiguousIfZeroInStrides(input_t), "input", 1 }, + grid{ grid_t.contiguous(), "grid", 2 }; + CheckedFrom c = "cudnn_grid_sampler_forward"; + checkAllSameGPU(c, {input, grid}); + checkAllSameType(c, {input, grid}); + checkGridSize(c, grid, input); + checkDim(c, input, 4); + + auto output_t = at::empty({0}, input->options()); + output_t.resize_({input->size(0), input->size(1), grid->size(1), grid->size(2)}); + + TensorDescriptor idesc{ *input }; // input descriptor + TensorDescriptor odesc{ output_t }; // output descriptor + SpatialTransformerDescriptor desc; // sampler descriptor + + auto handle = getCudnnHandle(); + auto dataType = getCudnnDataType(*input); + setSamplerDescriptor(desc, dataType, output_t); + + Constant one(dataType, 1); + Constant zero(dataType, 0); + AT_CUDNN_CHECK(cudnnSpatialTfSamplerForward( + handle, desc.desc(), + &one, idesc.desc(), input->data_ptr(), + grid->data_ptr(), + &zero, odesc.desc(), output_t.data_ptr() + )); + + return output_t; +} + +// NB: CuDNN does not support output mask; you always get both +// gradients. +std::tuple<Tensor, Tensor> cudnn_grid_sampler_backward( + const Tensor& input_t, const Tensor& grid_t, + const Tensor& grad_output_t) +{ + TensorArg input{ contiguousIfZeroInStrides(input_t), "input", 1 }, + grid{ grid_t.contiguous(), "grid", 2 }, + grad_output{ contiguousIfZeroInStrides(grad_output_t), "grad_output", 3 }; + CheckedFrom c = "cudnn_grid_sampler_backward"; + checkAllSameGPU(c, {input, grad_output, grid}); + checkGridSize(c, grid, input); + checkDim(c, input, 4); + checkDim(c, grad_output, 4); + + auto grad_input_t = at::empty({0}, input->options()); + grad_input_t.resize_(input->sizes()); + auto grad_grid_t = at::empty({0}, grid->options()); + grad_grid_t.resize_(grid->sizes()); + + TensorDescriptor idesc{ *input }; // input descriptor + TensorDescriptor odesc{ *grad_output }; // grad_output descriptor + TensorDescriptor gdesc{ grad_input_t }; // grad_input descriptor + SpatialTransformerDescriptor desc; // sampler descriptor + + auto handle = getCudnnHandle(); + auto dataType = getCudnnDataType(*input); + setSamplerDescriptor(desc, dataType, *grad_output); + + Constant one(dataType, 1); + Constant zero(dataType, 0); + AT_CUDNN_CHECK(cudnnSpatialTfSamplerBackward( + handle, desc.desc(), + &one, idesc.desc(), input->data_ptr(), + &zero, gdesc.desc(), grad_input_t.data_ptr(), + &one, odesc.desc(), grad_output->data_ptr(), + // intruigingly, the outputs don't need descriptors + grid->data_ptr(), + &zero, grad_grid_t.data_ptr() + )); + + return std::tuple<Tensor, Tensor>{ grad_input_t, grad_grid_t }; +} + +} // namespace mmdetection + +#endif diff --git a/mmdet/ops/grid_sampler/src/grid_sampler.cpp b/mmdet/ops/grid_sampler/src/grid_sampler.cpp new file mode 100644 index 0000000000000000000000000000000000000000..009675fe7ea2ec3d248f528bbe9fe3166785b411 --- /dev/null +++ b/mmdet/ops/grid_sampler/src/grid_sampler.cpp @@ -0,0 +1,66 @@ +#include <torch/extension.h> +#include <ATen/DeviceGuard.h> + +namespace mmdetection { + +using namespace at; + +// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ]. +Tensor grid_sampler_2d_forward_cpu(const Tensor& input, const Tensor& grid, + int64_t interpolation_mode, int64_t padding_mode, + bool align_corners); + +// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ]. +Tensor grid_sampler_3d_forward_cpu(const Tensor& input, const Tensor& grid, + int64_t interpolation_mode, int64_t padding_mode, + bool align_corners); + +// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ]. +std::tuple<Tensor, Tensor> +grid_sampler_2d_backward_cpu(const Tensor& grad_output, const Tensor& input, + const Tensor& grid, int64_t interpolation_mode, + int64_t padding_mode, bool align_corners); + +// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ]. +std::tuple<Tensor, Tensor> +grid_sampler_3d_backward_cpu(const Tensor& grad_output, const Tensor& input, + const Tensor& grid, int64_t interpolation_mode, int64_t padding_mode, + bool align_corners); + +// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ]. +Tensor grid_sampler_2d_forward_cuda(const Tensor& input, const Tensor& grid, + int64_t interpolation_mode, int64_t padding_mode, + bool align_corners); + +// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ]. +Tensor grid_sampler_3d_forward_cuda(const Tensor& input, const Tensor& grid, + int64_t interpolation_mode, int64_t padding_mode, + bool align_corners); + +// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ]. +std::tuple<Tensor, Tensor> +grid_sampler_2d_backward_cuda(const Tensor& grad_output, const Tensor& input, + const Tensor& grid, int64_t interpolation_mode, + int64_t padding_mode, bool align_corners); + +// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ]. +std::tuple<Tensor, Tensor> +grid_sampler_3d_backward_cuda(const Tensor& grad_output, const Tensor& input, + const Tensor& grid, int64_t interpolation_mode, int64_t padding_mode, + bool align_corners); + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + + m.def("grid_sampler_2d_forward_cpu", &grid_sampler_2d_forward_cpu, "grid_sampler_2d_forward (CPU)"); + m.def("grid_sampler_2d_backward_cpu", &grid_sampler_2d_backward_cpu, "grid_sampler_2d_backward (CPU)"); + m.def("grid_sampler_3d_forward_cpu", &grid_sampler_3d_forward_cpu, "grid_sampler_3d_forward (CPU)"); + m.def("grid_sampler_3d_backward_cpu", &grid_sampler_3d_backward_cpu, "grid_sampler_3d_backward (CPU)"); + + m.def("grid_sampler_2d_forward_cuda", &grid_sampler_2d_forward_cuda, "grid_sampler_2d_forward (CUDA)"); + m.def("grid_sampler_2d_backward_cuda", &grid_sampler_2d_backward_cuda, "grid_sampler_2d_backward (CUDA)"); + m.def("grid_sampler_3d_forward_cuda", &grid_sampler_3d_forward_cuda, "grid_sampler_3d_forward (CUDA)"); + m.def("grid_sampler_3d_backward_cuda", &grid_sampler_3d_backward_cuda, "grid_sampler_3d_backward (CUDA)"); +} + +} // namespace mmdetection diff --git a/setup.py b/setup.py index 89e79f95a544983a74e76d70134f2679ebafff00..1b3445c75230e35f642a37d80875fceeadffc0c8 100755 --- a/setup.py +++ b/setup.py @@ -270,6 +270,13 @@ if __name__ == '__main__': sources=[ 'src/masked_conv2d_cuda.cpp', 'src/masked_conv2d_kernel.cu' ]), + make_cuda_ext( + name='grid_sampler_cuda', + module='mmdet.ops.grid_sampler', + sources=[ + 'src/cpu/grid_sampler_cpu.cpp', + 'src/cuda/grid_sampler_cuda.cu', 'src/grid_sampler.cpp' + ]), make_cuda_ext( name='carafe_cuda', module='mmdet.ops.carafe',