Skip to content
Snippets Groups Projects
Unverified Commit d7a38be9 authored by Jerry Jiarui XU's avatar Jerry Jiarui XU Committed by GitHub
Browse files

add grid sampler (#2126)

* add grid sampler

* add doc

* remove comment
parent 639f934a
No related branches found
No related tags found
No related merge requests found
from .grid_sampler import grid_sample
__all__ = ['grid_sample']
import torch
import torch.nn.functional as F
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from . import grid_sampler_cuda
class _GridSampler(Function):
@staticmethod
def forward(ctx, input, grid, mode_enum, padding_mode_enum, align_corners):
ctx.save_for_backward(input, grid)
ctx.mode_enum = mode_enum
ctx.padding_mode_enum = padding_mode_enum
ctx.align_corners = align_corners
if input.is_cuda:
if input.dim() == 4:
func = grid_sampler_cuda.grid_sampler_2d_forward_cuda
else:
func = grid_sampler_cuda.grid_sampler_3d_forward_cuda
else:
if input.dim() == 4:
func = grid_sampler_cuda.grid_sampler_2d_forward_cpu
else:
func = grid_sampler_cuda.grid_sampler_3d_forward_cpu
output = func(input, grid, mode_enum, padding_mode_enum, align_corners)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
input, grid = ctx.saved_tensors
mode_enum = ctx.mode_enum
padding_mode_enum = ctx.padding_mode_enum
align_corners = ctx.align_corners
if input.is_cuda:
if input.dim() == 4:
func = grid_sampler_cuda.grid_sampler_2d_backward_cuda
else:
func = grid_sampler_cuda.grid_sampler_3d_backward_cuda
else:
if input.dim() == 4:
func = grid_sampler_cuda.grid_sampler_2d_backward_cpu
else:
func = grid_sampler_cuda.grid_sampler_3d_backward_cpu
grad_input, grad_grid = func(grad_output, input, grid, mode_enum,
padding_mode_enum, align_corners)
return grad_input, grad_grid, None, None, None
def grid_sample(input,
grid,
mode='bilinear',
padding_mode='zeros',
align_corners=False):
if torch.__version__ >= '1.3' or align_corners:
return F.grid_sample(input, grid, mode, padding_mode, align_corners)
else:
# use self-compiled grid_sampler to support align_corners=False
assert mode in ['bilinear', 'nearest'], \
'expected mode to be bilinear or nearest, but got: {}'.format(mode)
assert padding_mode in ['zeros', 'border', 'reflection'], \
'expected padding_mode to be zeros, border, or reflection, ' \
'but got: {}'.format(padding_mode)
if mode == 'bilinear':
mode_enum = 0
else:
mode_enum = 1
if padding_mode == 'zeros':
padding_mode_enum = 0
elif padding_mode == 'border':
padding_mode_enum = 1
else:
padding_mode_enum = 2
# shape check
assert input.device == grid.device, \
'expected input and grid to be on same device, ' \
'but input is on {} and grid is on {}'.format(
input.device, grid.device)
assert input.dtype == grid.dtype, \
'expected input and grid to have the same dtype, ' \
'but input has {} and grid has {}'.format(
input.dtype, grid.dtype)
assert input.dim() == 4 or input.dim() == 5, \
'expected 4D or 5D input and grid with same number of dimensions' \
'but got input with sizes {} and grid with sizes {}'.format(
input.size(), grid.size())
assert input.size(0) == grid.size(0), \
'expected input and grid to have the same batch size, ' \
'but got input with sizes {} and grid with sizes {}'.format(
input.size(), grid.size())
assert grid.size(-1) == input.dim() - 2, \
'expected grid to have size {} in last {} dimension, ' \
'but got grid with sizes '.format(
input.dim() - 2, grid.size())
for i in range(2, input.dim()):
assert input.size(i) > 0, \
'expected input to have non-empty spatial dimensions, ' \
'but input has sizes {} with dimension {} being empty'.format(
input.sizes(), i)
return _GridSampler.apply(input, grid, mode_enum, padding_mode_enum,
align_corners)
// Modified from https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/GridSampler.cpp
#include <torch/extension.h>
#include "grid_sampler_cpu.h"
#include <ATen/ATen.h>
#include <ATen/Device.h>
#include <ATen/NativeFunctions.h>
#include <c10/core/Layout.h>
#include <c10/util/Exception.h>
#ifdef _OPENMP
#include <omp.h>
#endif
namespace mmdetection {
using namespace at;
using mmdetection::detail::GridSamplerInterpolation;
using mmdetection::detail::GridSamplerPadding;
namespace {
template<typename scalar_t>
Tensor grid_sampler_2d_forward_cpu_impl(const Tensor& input, const Tensor& grid,
GridSamplerInterpolation interpolation_mode,
GridSamplerPadding padding_mode,
bool align_corners) {
int64_t N = input.size(0);
int64_t C = input.size(1);
int64_t inp_H = input.size(2);
int64_t inp_W = input.size(3);
int64_t out_H = grid.size(1);
int64_t out_W = grid.size(2);
auto output = at::empty({N, C, out_H, out_W}, input.options());
int64_t inp_sN = input.stride(0);
int64_t inp_sC = input.stride(1);
int64_t inp_sH = input.stride(2);
int64_t inp_sW = input.stride(3);
int64_t grid_sN = grid.stride(0);
int64_t grid_sH = grid.stride(1);
int64_t grid_sW = grid.stride(2);
int64_t grid_sCoor = grid.stride(3);
int64_t out_sN = output.stride(0);
int64_t out_sC = output.stride(1);
int64_t out_sH = output.stride(2);
int64_t out_sW = output.stride(3);
scalar_t *inp_ptr = input.data<scalar_t>();
scalar_t *out_ptr = output.data<scalar_t>();
scalar_t *grid_ptr = grid.data<scalar_t>();
// loop over each output pixel
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (int64_t n = 0; n < N; ++n) {
scalar_t *grid_ptr_N = grid_ptr + n * grid_sN;
scalar_t *inp_ptr_N = inp_ptr + n * inp_sN;
for (int64_t h = 0; h < out_H; ++h) {
for (int64_t w = 0; w < out_W; ++w) {
// get the corresponding input x, y, z co-ordinates from grid
scalar_t *grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW;
scalar_t ix = *grid_ptr_NHW;
scalar_t iy = grid_ptr_NHW[grid_sCoor];
ix = grid_sampler_compute_source_index(ix, inp_W, padding_mode, align_corners);
iy = grid_sampler_compute_source_index(iy, inp_H, padding_mode, align_corners);
if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
// get corner pixel values from (x, y, z)
// for 4d, we used north-east-south-west
// for 5d, we add top-bottom
int64_t ix_nw = static_cast<int64_t>(std::floor(ix));
int64_t iy_nw = static_cast<int64_t>(std::floor(iy));
int64_t ix_ne = ix_nw + 1;
int64_t iy_ne = iy_nw;
int64_t ix_sw = ix_nw;
int64_t iy_sw = iy_nw + 1;
int64_t ix_se = ix_nw + 1;
int64_t iy_se = iy_nw + 1;
// get surfaces to each neighbor:
scalar_t nw = (ix_se - ix) * (iy_se - iy) ;
scalar_t ne = (ix - ix_sw) * (iy_sw - iy) ;
scalar_t sw = (ix_ne - ix) * (iy - iy_ne);
scalar_t se = (ix - ix_nw) * (iy - iy_nw);
// calculate bilinear weighted pixel value and set output pixel
scalar_t *out_ptr_NCHW = out_ptr + n * out_sN + h * out_sH + w * out_sW;
scalar_t *inp_ptr_NC = inp_ptr_N;
for (int c = 0; c < C; ++c, out_ptr_NCHW += out_sC, inp_ptr_NC += inp_sC) {
// (c, iz_tnw, iy_tnw, ix_tnw) * tnw + (c, iz_tne, iy_tne, ix_tne) * tne
// + (c, iz_tsw, iy_tsw, ix_tsw) * tsw + (c, iz_tse, iy_tse, ix_tse) * tse
// + (c, iz_bnw, iy_bnw, ix_bnw) * bnw + (c, iz_bne, iy_bne, ix_bne) * bne
// + (c, iz_bsw, iy_bsw, ix_bsw) * bsw + (c, iz_bse, iy_bse, ix_bse) * bse
*out_ptr_NCHW = static_cast<scalar_t>(0);
if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) {
*out_ptr_NCHW += inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW] * nw;
}
if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) {
*out_ptr_NCHW += inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW] * ne;
}
if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) {
*out_ptr_NCHW += inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW] * sw;
}
if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) {
*out_ptr_NCHW += inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW] * se;
}
}
} else if (interpolation_mode == GridSamplerInterpolation::Nearest) {
int64_t ix_nearest = static_cast<int64_t>(std::round(ix));
int64_t iy_nearest = static_cast<int64_t>(std::round(iy));
// assign nearest neighor pixel value to output pixel
scalar_t *out_ptr_NCHW = out_ptr + n * out_sN + h * out_sH + w * out_sW;
scalar_t *inp_ptr_NC = inp_ptr_N;
for (int c = 0; c < C; ++c, out_ptr_NCHW += out_sC, inp_ptr_NC += inp_sC) {
if (within_bounds_2d(iy_nearest, ix_nearest, inp_H, inp_W)) {
*out_ptr_NCHW = inp_ptr_NC[iy_nearest * inp_sH + ix_nearest * inp_sW];
} else {
*out_ptr_NCHW = static_cast<scalar_t>(0);
}
}
}
}
}
}
return output;
}
template<typename scalar_t>
Tensor grid_sampler_3d_forward_cpu_impl(const Tensor& input, const Tensor& grid,
GridSamplerInterpolation interpolation_mode,
GridSamplerPadding padding_mode,
bool align_corners) {
int64_t N = input.size(0);
int64_t C = input.size(1);
int64_t inp_D = input.size(2);
int64_t inp_H = input.size(3);
int64_t inp_W = input.size(4);
int64_t out_D = grid.size(1);
int64_t out_H = grid.size(2);
int64_t out_W = grid.size(3);
auto output = at::empty({N, C, out_D, out_H, out_W}, input.options());
int64_t inp_sN = input.stride(0);
int64_t inp_sC = input.stride(1);
int64_t inp_sD = input.stride(2);
int64_t inp_sH = input.stride(3);
int64_t inp_sW = input.stride(4);
int64_t grid_sN = grid.stride(0);
int64_t grid_sD = grid.stride(1);
int64_t grid_sH = grid.stride(2);
int64_t grid_sW = grid.stride(3);
int64_t grid_sCoor = grid.stride(4);
int64_t out_sN = output.stride(0);
int64_t out_sC = output.stride(1);
int64_t out_sD = output.stride(2);
int64_t out_sH = output.stride(3);
int64_t out_sW = output.stride(4);
scalar_t *inp_ptr = input.data<scalar_t>();
scalar_t *out_ptr = output.data<scalar_t>();
scalar_t *grid_ptr = grid.data<scalar_t>();
// loop over each output pixel
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (int64_t n = 0; n < N; ++n) {
scalar_t *grid_ptr_N = grid_ptr + n * grid_sN;
scalar_t *inp_ptr_N = inp_ptr + n * inp_sN;
for (int64_t d = 0; d < out_D; ++d) {
for (int64_t h = 0; h < out_H; ++h) {
for (int64_t w = 0; w < out_W; ++w) {
// get the corresponding input x, y, z co-ordinates from grid
scalar_t *grid_ptr_NDHW = grid_ptr_N + d * grid_sD + h * grid_sH + w * grid_sW;
scalar_t ix = *grid_ptr_NDHW;
scalar_t iy = grid_ptr_NDHW[grid_sCoor];
scalar_t iz = grid_ptr_NDHW[2 * grid_sCoor];
ix = grid_sampler_compute_source_index(ix, inp_W, padding_mode, align_corners);
iy = grid_sampler_compute_source_index(iy, inp_H, padding_mode, align_corners);
iz = grid_sampler_compute_source_index(iz, inp_D, padding_mode, align_corners);
if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
// get corner pixel values from (x, y, z)
// for 4d, we used north-east-south-west
// for 5d, we add top-bottom
int64_t ix_tnw = static_cast<int64_t>(std::floor(ix));
int64_t iy_tnw = static_cast<int64_t>(std::floor(iy));
int64_t iz_tnw = static_cast<int64_t>(std::floor(iz));
int64_t ix_tne = ix_tnw + 1;
int64_t iy_tne = iy_tnw;
int64_t iz_tne = iz_tnw;
int64_t ix_tsw = ix_tnw;
int64_t iy_tsw = iy_tnw + 1;
int64_t iz_tsw = iz_tnw;
int64_t ix_tse = ix_tnw + 1;
int64_t iy_tse = iy_tnw + 1;
int64_t iz_tse = iz_tnw;
int64_t ix_bnw = ix_tnw;
int64_t iy_bnw = iy_tnw;
int64_t iz_bnw = iz_tnw + 1;
int64_t ix_bne = ix_tnw + 1;
int64_t iy_bne = iy_tnw;
int64_t iz_bne = iz_tnw + 1;
int64_t ix_bsw = ix_tnw;
int64_t iy_bsw = iy_tnw + 1;
int64_t iz_bsw = iz_tnw + 1;
int64_t ix_bse = ix_tnw + 1;
int64_t iy_bse = iy_tnw + 1;
int64_t iz_bse = iz_tnw + 1;
// get surfaces to each neighbor:
scalar_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);
scalar_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz);
scalar_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz);
scalar_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz);
scalar_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse);
scalar_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw);
scalar_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);
scalar_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);
// calculate bilinear weighted pixel value and set output pixel
scalar_t *out_ptr_NCDHW = out_ptr + n * out_sN + d * out_sD + h * out_sH + w * out_sW;
scalar_t *inp_ptr_NC = inp_ptr_N;
for (int c = 0; c < C; ++c, out_ptr_NCDHW += out_sC, inp_ptr_NC += inp_sC) {
// (c, iz_tnw, iy_tnw, ix_tnw) * tnw + (c, iz_tne, iy_tne, ix_tne) * tne
// + (c, iz_tsw, iy_tsw, ix_tsw) * tsw + (c, iz_tse, iy_tse, ix_tse) * tse
// + (c, iz_bnw, iy_bnw, ix_bnw) * bnw + (c, iz_bne, iy_bne, ix_bne) * bne
// + (c, iz_bsw, iy_bsw, ix_bsw) * bsw + (c, iz_bse, iy_bse, ix_bse) * bse
*out_ptr_NCDHW = static_cast<scalar_t>(0);
if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw;
}
if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne;
}
if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw;
}
if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse;
}
if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw;
}
if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne;
}
if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw;
}
if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse;
}
}
} else if (interpolation_mode == GridSamplerInterpolation::Nearest) {
int64_t ix_nearest = static_cast<int64_t>(std::round(ix));
int64_t iy_nearest = static_cast<int64_t>(std::round(iy));
int64_t iz_nearest = static_cast<int64_t>(std::round(iz));
// assign nearest neighor pixel value to output pixel
scalar_t *out_ptr_NCDHW = out_ptr + n * out_sN + d * out_sD + h * out_sH + w * out_sW;
scalar_t *inp_ptr_NC = inp_ptr_N;
for (int c = 0; c < C; ++c, out_ptr_NCDHW += out_sC, inp_ptr_NC += inp_sC) {
if (within_bounds_3d(iz_nearest, iy_nearest, ix_nearest, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW = inp_ptr_NC[iz_nearest * inp_sD + iy_nearest * inp_sH + ix_nearest * inp_sW];
} else {
*out_ptr_NCDHW = static_cast<scalar_t>(0);
}
}
}
}
}
}
}
return output;
}
template<typename scalar_t>
std::tuple<Tensor, Tensor>
grid_sampler_2d_backward_cpu_impl(const Tensor& grad_output,
const Tensor& input, const Tensor& grid,
GridSamplerInterpolation interpolation_mode,
GridSamplerPadding padding_mode,
bool align_corners) {
auto grad_input = at::zeros_like(input);
auto grad_grid = at::empty_like(grid);
// If interpolation mode is Nearest, then grad_grid is not filled in the
// loop below.
if (interpolation_mode == GridSamplerInterpolation::Nearest) {
grad_grid.zero_();
}
int64_t N = input.size(0);
int64_t C = input.size(1);
int64_t inp_H = input.size(2);
int64_t inp_W = input.size(3);
int64_t out_H = grid.size(1);
int64_t out_W = grid.size(2);
int64_t inp_sN = input.stride(0);
int64_t inp_sC = input.stride(1);
int64_t inp_sH = input.stride(2);
int64_t inp_sW = input.stride(3);
int64_t grid_sN = grid.stride(0);
int64_t grid_sH = grid.stride(1);
int64_t grid_sW = grid.stride(2);
int64_t grid_sCoor = grid.stride(3);
int64_t gOut_sN = grad_output.stride(0);
int64_t gOut_sC = grad_output.stride(1);
int64_t gOut_sH = grad_output.stride(2);
int64_t gOut_sW = grad_output.stride(3);
int64_t gInp_sN = grad_input.stride(0);
int64_t gInp_sC = grad_input.stride(1);
int64_t gInp_sH = grad_input.stride(2);
int64_t gInp_sW = grad_input.stride(3);
int64_t gGrid_sN = grad_grid.stride(0);
int64_t gGrid_sW = grad_grid.stride(2);
scalar_t *inp_ptr = input.data<scalar_t>();
scalar_t *grid_ptr = grid.data<scalar_t>();
scalar_t *gOut_ptr = grad_output.data<scalar_t>();
scalar_t *gInp_ptr = grad_input.data<scalar_t>();
scalar_t *gGrid_ptr = grad_grid.data<scalar_t>();
// loop over each output pixel
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (int64_t n = 0; n < N; ++n) {
scalar_t *grid_ptr_N = grid_ptr + n * grid_sN;
scalar_t *inp_ptr_N = inp_ptr + n * inp_sN;
scalar_t *gGrid_ptr_NHW = gGrid_ptr + n * gGrid_sN;
for (int64_t h = 0; h < out_H; ++h) {
for (int64_t w = 0; w < out_W; ++w, gGrid_ptr_NHW += gGrid_sW /* grad_grid is contiguous */ ) {
// get the corresponding input x, y, z co-ordinates from grid
scalar_t *grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW;
scalar_t ix = *grid_ptr_NHW;
scalar_t iy = grid_ptr_NHW[grid_sCoor];
// multipliers for gradients on ix, iy, and iz
scalar_t gix_mult, giy_mult;
ix = grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, align_corners, &gix_mult);
iy = grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, align_corners, &giy_mult);
if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
// get corner pixel values from (x, y, z)
// for 4d, we used north-east-south-west
// for 5d, we add top-bottom
int64_t ix_nw = static_cast<int64_t>(std::floor(ix));
int64_t iy_nw = static_cast<int64_t>(std::floor(iy));
int64_t ix_ne = ix_nw + 1;
int64_t iy_ne = iy_nw;
int64_t ix_sw = ix_nw;
int64_t iy_sw = iy_nw + 1;
int64_t ix_se = ix_nw + 1;
int64_t iy_se = iy_nw + 1;
// get surfaces to each neighbor:
scalar_t nw = (ix_se - ix) * (iy_se - iy) ;
scalar_t ne = (ix - ix_sw) * (iy_sw - iy) ;
scalar_t sw = (ix_ne - ix) * (iy - iy_ne);
scalar_t se = (ix - ix_nw) * (iy - iy_nw);
scalar_t gix = static_cast<scalar_t>(0), giy = static_cast<scalar_t>(0);
scalar_t *gOut_ptr_NCHW = gOut_ptr + n * gOut_sN + h * gOut_sH + w * gOut_sW;
scalar_t *gInp_ptr_NC = gInp_ptr + n * gInp_sN;
scalar_t *inp_ptr_NC = inp_ptr_N;
// calculate bilinear weighted pixel value and set output pixel
for (int c = 0; c < C; ++c, gOut_ptr_NCHW += gOut_sC, gInp_ptr_NC += gInp_sC, inp_ptr_NC += inp_sC) {
scalar_t gOut = *gOut_ptr_NCHW;
// calculate and set grad_input
safe_add_2d(gInp_ptr_NC, iy_nw, ix_nw, gInp_sH, gInp_sW, inp_H, inp_W, nw * gOut);
safe_add_2d(gInp_ptr_NC, iy_ne, ix_ne, gInp_sH, gInp_sW, inp_H, inp_W, ne * gOut);
safe_add_2d(gInp_ptr_NC, iy_sw, ix_sw, gInp_sH, gInp_sW, inp_H, inp_W, sw * gOut);
safe_add_2d(gInp_ptr_NC, iy_se, ix_se, gInp_sH, gInp_sW, inp_H, inp_W, se * gOut);
// calculate grad_grid
if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) {
scalar_t nw_val = inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW];
gix -= nw_val * (iy_se - iy) * gOut;
giy -= nw_val * (ix_se - ix) * gOut;
}
if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) {
scalar_t ne_val = inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW];
gix += ne_val * (iy_sw - iy) * gOut;
giy -= ne_val * (ix - ix_sw) * gOut;
}
if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) {
scalar_t sw_val = inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW];
gix -= sw_val * (iy - iy_ne) * gOut;
giy += sw_val * (ix_ne - ix) * gOut;
}
if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) {
scalar_t se_val = inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW];
gix += se_val * (iy - iy_nw) * gOut;
giy += se_val * (ix - ix_nw) * gOut;
}
}
// assuming grad_grid is contiguous
gGrid_ptr_NHW[0] = gix_mult * gix;
gGrid_ptr_NHW[1] = giy_mult * giy;
} else if (interpolation_mode == GridSamplerInterpolation::Nearest) {
int64_t ix_nearest = static_cast<int64_t>(std::round(ix));
int64_t iy_nearest = static_cast<int64_t>(std::round(iy));
// assign nearest neighor pixel value to output pixel
scalar_t *gOut_ptr_NCHW = gOut_ptr + n * gOut_sN + h * gOut_sH + w * gOut_sW;
scalar_t *gInp_ptr_NC = gInp_ptr + n * gInp_sN;
for (int c = 0; c < C; ++c, gOut_ptr_NCHW += gOut_sC, gInp_ptr_NC += gInp_sC) {
// calculate and set grad_input
safe_add_2d(gInp_ptr_NC, iy_nearest, ix_nearest,
gInp_sH, gInp_sW, inp_H, inp_W, *gOut_ptr_NCHW);
}
}
}
}
}
return std::make_tuple(grad_input, grad_grid);
}
template<typename scalar_t>
std::tuple<Tensor, Tensor>
grid_sampler_3d_backward_cpu_impl(const Tensor& grad_output,
const Tensor& input, const Tensor& grid,
GridSamplerInterpolation interpolation_mode,
GridSamplerPadding padding_mode,
bool align_corners) {
auto grad_input = at::zeros_like(input);
auto grad_grid = at::empty_like(grid);
// If interpolation mode is Nearest, then grad_grid is not filled in the
// loop below.
if (interpolation_mode == GridSamplerInterpolation::Nearest) {
grad_grid.zero_();
}
int64_t N = input.size(0);
int64_t C = input.size(1);
int64_t inp_D = input.size(2);
int64_t inp_H = input.size(3);
int64_t inp_W = input.size(4);
int64_t out_D = grid.size(1);
int64_t out_H = grid.size(2);
int64_t out_W = grid.size(3);
int64_t inp_sN = input.stride(0);
int64_t inp_sC = input.stride(1);
int64_t inp_sD = input.stride(2);
int64_t inp_sH = input.stride(3);
int64_t inp_sW = input.stride(4);
int64_t grid_sN = grid.stride(0);
int64_t grid_sD = grid.stride(1);
int64_t grid_sH = grid.stride(2);
int64_t grid_sW = grid.stride(3);
int64_t grid_sCoor = grid.stride(4);
int64_t gOut_sN = grad_output.stride(0);
int64_t gOut_sC = grad_output.stride(1);
int64_t gOut_sD = grad_output.stride(2);
int64_t gOut_sH = grad_output.stride(3);
int64_t gOut_sW = grad_output.stride(4);
int64_t gInp_sN = grad_input.stride(0);
int64_t gInp_sC = grad_input.stride(1);
int64_t gInp_sD = grad_input.stride(2);
int64_t gInp_sH = grad_input.stride(3);
int64_t gInp_sW = grad_input.stride(4);
int64_t gGrid_sN = grad_grid.stride(0);
int64_t gGrid_sW = grad_grid.stride(3);
scalar_t *inp_ptr = input.data<scalar_t>();
scalar_t *grid_ptr = grid.data<scalar_t>();
scalar_t *gOut_ptr = grad_output.data<scalar_t>();
scalar_t *gInp_ptr = grad_input.data<scalar_t>();
scalar_t *gGrid_ptr = grad_grid.data<scalar_t>();
// loop over each output pixel
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (int64_t n = 0; n < N; ++n) {
scalar_t *grid_ptr_N = grid_ptr + n * grid_sN;
scalar_t *inp_ptr_N = inp_ptr + n * inp_sN;
scalar_t *gGrid_ptr_NDHW = gGrid_ptr + n * gGrid_sN;
for (int64_t d = 0; d < out_D; ++d) {
for (int64_t h = 0; h < out_H; ++h) {
for (int64_t w = 0; w < out_W; ++w, gGrid_ptr_NDHW += gGrid_sW /* grad_grid is contiguous */ ) {
// get the corresponding input x, y, z co-ordinates from grid
scalar_t *grid_ptr_NDHW = grid_ptr_N + d * grid_sD + h * grid_sH + w * grid_sW;
scalar_t ix = *grid_ptr_NDHW;
scalar_t iy = grid_ptr_NDHW[grid_sCoor];
scalar_t iz = grid_ptr_NDHW[2 * grid_sCoor];
// multipliers for gradients on ix, iy, and iz
scalar_t gix_mult, giy_mult, giz_mult;
ix = grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, align_corners, &gix_mult);
iy = grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, align_corners, &giy_mult);
iz = grid_sampler_compute_source_index_set_grad(iz, inp_D, padding_mode, align_corners, &giz_mult);
if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
// get corner pixel values from (x, y, z)
// for 4d, we used north-east-south-west
// for 5d, we add top-bottom
int64_t ix_tnw = static_cast<int64_t>(std::floor(ix));
int64_t iy_tnw = static_cast<int64_t>(std::floor(iy));
int64_t iz_tnw = static_cast<int64_t>(std::floor(iz));
int64_t ix_tne = ix_tnw + 1;
int64_t iy_tne = iy_tnw;
int64_t iz_tne = iz_tnw;
int64_t ix_tsw = ix_tnw;
int64_t iy_tsw = iy_tnw + 1;
int64_t iz_tsw = iz_tnw;
int64_t ix_tse = ix_tnw + 1;
int64_t iy_tse = iy_tnw + 1;
int64_t iz_tse = iz_tnw;
int64_t ix_bnw = ix_tnw;
int64_t iy_bnw = iy_tnw;
int64_t iz_bnw = iz_tnw + 1;
int64_t ix_bne = ix_tnw + 1;
int64_t iy_bne = iy_tnw;
int64_t iz_bne = iz_tnw + 1;
int64_t ix_bsw = ix_tnw;
int64_t iy_bsw = iy_tnw + 1;
int64_t iz_bsw = iz_tnw + 1;
int64_t ix_bse = ix_tnw + 1;
int64_t iy_bse = iy_tnw + 1;
int64_t iz_bse = iz_tnw + 1;
// get surfaces to each neighbor:
scalar_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);
scalar_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz);
scalar_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz);
scalar_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz);
scalar_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse);
scalar_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw);
scalar_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);
scalar_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);
scalar_t gix = static_cast<scalar_t>(0), giy = static_cast<scalar_t>(0), giz = static_cast<scalar_t>(0);
scalar_t *gOut_ptr_NCDHW = gOut_ptr + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW;
scalar_t *gInp_ptr_NC = gInp_ptr + n * gInp_sN;
scalar_t *inp_ptr_NC = inp_ptr_N;
// calculate bilinear weighted pixel value and set output pixel
for (int c = 0; c < C; ++c, gOut_ptr_NCDHW += gOut_sC, gInp_ptr_NC += gInp_sC, inp_ptr_NC += inp_sC) {
scalar_t gOut = *gOut_ptr_NCDHW;
// calculate and set grad_input
safe_add_3d(gInp_ptr_NC, iz_tnw, iy_tnw, ix_tnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tnw * gOut);
safe_add_3d(gInp_ptr_NC, iz_tne, iy_tne, ix_tne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tne * gOut);
safe_add_3d(gInp_ptr_NC, iz_tsw, iy_tsw, ix_tsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tsw * gOut);
safe_add_3d(gInp_ptr_NC, iz_tse, iy_tse, ix_tse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tse * gOut);
safe_add_3d(gInp_ptr_NC, iz_bnw, iy_bnw, ix_bnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bnw * gOut);
safe_add_3d(gInp_ptr_NC, iz_bne, iy_bne, ix_bne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bne * gOut);
safe_add_3d(gInp_ptr_NC, iz_bsw, iy_bsw, ix_bsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bsw * gOut);
safe_add_3d(gInp_ptr_NC, iz_bse, iy_bse, ix_bse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bse * gOut);
// calculate grad_grid
if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {
scalar_t tnw_val = inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW];
gix -= tnw_val * (iy_bse - iy) * (iz_bse - iz) * gOut;
giy -= tnw_val * (ix_bse - ix) * (iz_bse - iz) * gOut;
giz -= tnw_val * (ix_bse - ix) * (iy_bse - iy) * gOut;
}
if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {
scalar_t tne_val = inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW];
gix += tne_val * (iy_bsw - iy) * (iz_bsw - iz) * gOut;
giy -= tne_val * (ix - ix_bsw) * (iz_bsw - iz) * gOut;
giz -= tne_val * (ix - ix_bsw) * (iy_bsw - iy) * gOut;
}
if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {
scalar_t tsw_val = inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW];
gix -= tsw_val * (iy - iy_bne) * (iz_bne - iz) * gOut;
giy += tsw_val * (ix_bne - ix) * (iz_bne - iz) * gOut;
giz -= tsw_val * (ix_bne - ix) * (iy - iy_bne) * gOut;
}
if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {
scalar_t tse_val = inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW];
gix += tse_val * (iy - iy_bnw) * (iz_bnw - iz) * gOut;
giy += tse_val * (ix - ix_bnw) * (iz_bnw - iz) * gOut;
giz -= tse_val * (ix - ix_bnw) * (iy - iy_bnw) * gOut;
}
if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {
scalar_t bnw_val = inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW];
gix -= bnw_val * (iy_tse - iy) * (iz - iz_tse) * gOut;
giy -= bnw_val * (ix_tse - ix) * (iz - iz_tse) * gOut;
giz += bnw_val * (ix_tse - ix) * (iy_tse - iy) * gOut;
}
if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {
scalar_t bne_val = inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW];
gix += bne_val * (iy_tsw - iy) * (iz - iz_tsw) * gOut;
giy -= bne_val * (ix - ix_tsw) * (iz - iz_tsw) * gOut;
giz += bne_val * (ix - ix_tsw) * (iy_tsw - iy) * gOut;
}
if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {
scalar_t bsw_val = inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW];
gix -= bsw_val * (iy - iy_tne) * (iz - iz_tne) * gOut;
giy += bsw_val * (ix_tne - ix) * (iz - iz_tne) * gOut;
giz += bsw_val * (ix_tne - ix) * (iy - iy_tne) * gOut;
}
if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {
scalar_t bse_val = inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW];
gix += bse_val * (iy - iy_tnw) * (iz - iz_tnw) * gOut;
giy += bse_val * (ix - ix_tnw) * (iz - iz_tnw) * gOut;
giz += bse_val * (ix - ix_tnw) * (iy - iy_tnw) * gOut;
}
}
// assuming grad_grid is contiguous
gGrid_ptr_NDHW[0] = gix_mult * gix;
gGrid_ptr_NDHW[1] = giy_mult * giy;
gGrid_ptr_NDHW[2] = giz_mult * giz;
} else if (interpolation_mode == GridSamplerInterpolation::Nearest) {
int64_t ix_nearest = static_cast<int64_t>(std::round(ix));
int64_t iy_nearest = static_cast<int64_t>(std::round(iy));
int64_t iz_nearest = static_cast<int64_t>(std::round(iz));
// assign nearest neighor pixel value to output pixel
scalar_t *gOut_ptr_NCDHW = gOut_ptr + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW;
scalar_t *gInp_ptr_NC = gInp_ptr + n * gInp_sN;
for (int c = 0; c < C; ++c, gOut_ptr_NCDHW += gOut_sC, gInp_ptr_NC += gInp_sC) {
// calculate and set grad_input
safe_add_3d(gInp_ptr_NC, iz_nearest, iy_nearest, ix_nearest,
gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, *gOut_ptr_NCDHW);
}
}
}
}
}
}
return std::make_tuple(grad_input, grad_grid);
}
} // namespace
// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
Tensor grid_sampler_2d_forward_cpu(const Tensor& input, const Tensor& grid,
int64_t interpolation_mode, int64_t padding_mode,
bool align_corners) {
return AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "grid_sampler_2d_forward_cpu", [&] {
return grid_sampler_2d_forward_cpu_impl<scalar_t>(
input, grid, static_cast<GridSamplerInterpolation>(interpolation_mode),
static_cast<GridSamplerPadding>(padding_mode), align_corners);
});
}
// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
Tensor grid_sampler_3d_forward_cpu(const Tensor& input, const Tensor& grid,
int64_t interpolation_mode, int64_t padding_mode,
bool align_corners) {
return AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "grid_sampler_3d_forward_cpu", [&] {
return grid_sampler_3d_forward_cpu_impl<scalar_t>(
input, grid, static_cast<GridSamplerInterpolation>(interpolation_mode),
static_cast<GridSamplerPadding>(padding_mode), align_corners);
});
}
// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
std::tuple<Tensor, Tensor>
grid_sampler_2d_backward_cpu(const Tensor& grad_output, const Tensor& input, const Tensor& grid,
int64_t interpolation_mode, int64_t padding_mode, bool align_corners) {
return AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "grid_sampler_2d_backward_cpu", [&] {
return grid_sampler_2d_backward_cpu_impl<scalar_t>(
grad_output, input, grid,
static_cast<GridSamplerInterpolation>(interpolation_mode),
static_cast<GridSamplerPadding>(padding_mode), align_corners);
});
}
// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
std::tuple<Tensor, Tensor>
grid_sampler_3d_backward_cpu(const Tensor& grad_output, const Tensor& input, const Tensor& grid,
int64_t interpolation_mode, int64_t padding_mode, bool align_corners) {
return AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "grid_sampler_3d_backward_cpu", [&] {
return grid_sampler_3d_backward_cpu_impl<scalar_t>(
grad_output, input, grid,
static_cast<GridSamplerInterpolation>(interpolation_mode),
static_cast<GridSamplerPadding>(padding_mode), align_corners);
});
}
} // namespace mmdetection
// Modified from https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/GridSampler.h
#pragma once
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
namespace mmdetection {
namespace detail {
enum class GridSamplerInterpolation {Bilinear, Nearest};
enum class GridSamplerPadding {Zeros, Border, Reflection};
} // namespace detail
using detail::GridSamplerInterpolation;
using detail::GridSamplerPadding;
// Unnormalizes a coordinate from the -1 to +1 scale to its pixel index value,
// where we view each pixel as an area between (idx - 0.5) and (idx + 0.5).
// if align_corners: -1 and +1 get sent to the centers of the corner pixels
// -1 --> 0
// +1 --> (size - 1)
// scale_factor = (size - 1) / 2
// if not align_corners: -1 and +1 get sent to the image edges
// -1 --> -0.5
// +1 --> (size - 1) + 0.5 == size - 0.5
// scale_factor = size / 2
template <typename scalar_t>
static inline scalar_t grid_sampler_unnormalize(scalar_t coord, int64_t size,
bool align_corners) {
if (align_corners) {
// unnormalize coord from [-1, 1] to [0, size - 1]
return ((coord + 1) / 2) * (size - 1);
} else {
// unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
return ((coord + 1) * size - 1) / 2;
}
}
// grid_sampler_unnormalize_set_grad works the same as grid_sampler_unnormalize
// except that it also returns the `d output / d input` via pointer argument
// `grad_in`.
// This is useful in the backward pass of grid_sampler.
template <typename scalar_t>
static inline scalar_t grid_sampler_unnormalize_set_grad(scalar_t coord, int64_t size,
bool align_corners, scalar_t *grad_in) {
if (align_corners) {
// unnormalize coord from [-1, 1] to [0, size - 1]
*grad_in = static_cast<scalar_t>(size - 1) / 2;
return ((coord + 1) / 2) * (size - 1);
} else {
// unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
*grad_in = static_cast<scalar_t>(size) / 2;
return ((coord + 1) * size - 1) / 2;
}
}
// Clips coordinates to between 0 and clip_limit - 1
template<typename scalar_t>
static inline scalar_t clip_coordinates(scalar_t in, int64_t clip_limit) {
return std::min(static_cast<scalar_t>(clip_limit - 1), std::max(in, static_cast<scalar_t>(0)));
}
// clip_coordinates_set_grad works similarly to clip_coordinates except that
// it also returns the `d output / d input` via pointer argument `grad_in`.
// This is useful in the backward pass of grid_sampler.
template<typename scalar_t>
static inline scalar_t clip_coordinates_set_grad(scalar_t in, int64_t clip_limit,
scalar_t *grad_in) {
if (in < static_cast<scalar_t>(0)) {
*grad_in = static_cast<scalar_t>(0);
return static_cast<scalar_t>(0);
} else {
scalar_t max = static_cast<scalar_t>(clip_limit - 1);
if (in > max) {
*grad_in = static_cast<scalar_t>(0);
return max;
} else {
*grad_in = static_cast<scalar_t>(1);
return in;
}
}
}
// Reflects coordinates until they fall between low and high (inclusive).
// The bounds are passed as twice their value so that half-integer values
// can be represented as ints.
template<typename scalar_t>
static inline scalar_t reflect_coordinates(scalar_t in, int64_t twice_low,
int64_t twice_high) {
if (twice_low == twice_high) {
return static_cast<scalar_t>(0);
}
scalar_t min = static_cast<scalar_t>(twice_low) / 2;
scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2;
in = std::fabs(in - min);
// `fmod` returns same sign as `in`, which is positive after the `fabs` above.
scalar_t extra = std::fmod(in, span);
int flips = static_cast<int>(std::floor(in / span));
if (flips % 2 == 0) {
return extra + min;
} else {
return span - extra + min;
}
}
// reflect_coordinates_set_grad works similarly to reflect_coordinates except
// that it also returns the `d output / d input` via pointer argument
// `grad_in`.
// This is useful in the backward pass of grid_sampler.
template<typename scalar_t>
static inline scalar_t reflect_coordinates_set_grad(scalar_t in, int64_t twice_low,
int64_t twice_high, scalar_t *grad_in) {
if (twice_low == twice_high) {
*grad_in = static_cast<scalar_t>(0);
return static_cast<scalar_t>(0);
}
int grad_in_mult_;
scalar_t min = static_cast<scalar_t>(twice_low) / 2;
scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2;
in = in - min;
if (in < static_cast<scalar_t>(0)) {
grad_in_mult_ = -1;
in = -in;
} else {
grad_in_mult_ = 1;
}
// `fmod` returns same sign as `in`, which is positive after the `if` above.
scalar_t extra = std::fmod(in, span);
int flips = static_cast<int>(std::floor(in / span));
if (flips % 2 == 0) {
*grad_in = static_cast<scalar_t>(grad_in_mult_);
return extra + min;
} else {
*grad_in = static_cast<scalar_t>(-grad_in_mult_);
return span - extra + min;
}
}
// Computes the pixel source index value for a grid coordinate
template <typename scalar_t>
static inline scalar_t grid_sampler_compute_source_index(
scalar_t coord,
int64_t size,
GridSamplerPadding padding_mode,
bool align_corners) {
coord = grid_sampler_unnormalize(coord, size, align_corners);
if (padding_mode == GridSamplerPadding::Border) {
// clip coordinates to image borders
coord = clip_coordinates(coord, size);
} else if (padding_mode == GridSamplerPadding::Reflection) {
// reflect coordinates by image borders
if (align_corners) {
coord = reflect_coordinates(coord, 0, 2*(size - 1));
} else {
coord = reflect_coordinates(coord, -1, 2*size - 1);
// when align_corners=False, reflection does not auto clip coords
coord = clip_coordinates(coord, size);
}
}
return coord;
}
// grid_sampler_compute_source_index_set_grad works similarly to
// grid_sampler_compute_source_index except that it also returns the
// `d output / d input` via pointer argument `grad_in`.
// This is useful in the backward pass of grid_sampler.
template <typename scalar_t>
static inline scalar_t grid_sampler_compute_source_index_set_grad(
scalar_t coord,
int64_t size,
GridSamplerPadding padding_mode,
bool align_corners,
scalar_t *grad_in) {
scalar_t grad_clip, grad_refl;
coord = grid_sampler_unnormalize_set_grad(coord, size, align_corners, grad_in);
if (padding_mode == GridSamplerPadding::Border) {
// clip coordinates to image borders
coord = clip_coordinates_set_grad(coord, size, &grad_clip);
*grad_in = (*grad_in) * grad_clip;
} else if (padding_mode == GridSamplerPadding::Reflection) {
// reflect coordinates by image borders
if (align_corners) {
coord = reflect_coordinates_set_grad(coord, 0, 2*(size - 1), &grad_refl);
*grad_in = (*grad_in) * grad_refl;
} else {
coord = reflect_coordinates_set_grad(coord, -1, 2*size - 1, &grad_refl);
// when align_corners=False, reflection does not auto clip coords
coord = clip_coordinates_set_grad(coord, size, &grad_clip);
*grad_in = (*grad_in) * grad_refl * grad_clip;
}
}
return coord;
}
static inline bool within_bounds_2d(int64_t h, int64_t w, int64_t H, int64_t W) {
return h >= 0 && h < H && w >= 0 && w < W;
}
static inline bool within_bounds_3d(int64_t d, int64_t h, int64_t w, int64_t D, int64_t H, int64_t W) {
return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W;
}
template<typename scalar_t>
static inline void safe_add_2d(scalar_t *data, int64_t h, int64_t w,
int64_t sH, int64_t sW, int64_t H, int64_t W,
scalar_t delta) {
if (within_bounds_2d(h, w, H, W)) {
data[h * sH + w * sW] += delta;
}
}
template<typename scalar_t>
static inline void safe_add_3d(scalar_t *data, int64_t d, int64_t h, int64_t w,
int64_t sD, int64_t sH, int64_t sW,
int64_t D, int64_t H, int64_t W,
scalar_t delta) {
if (within_bounds_3d(d, h, w, D, H, W)) {
data[d * sD + h * sH + w * sW] += delta;
}
}
} // namespace mmdetection
// Modified from https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/GridSampler.cu
#include <ATen/ATen.h>
#include "grid_sampler_cuda.cuh"
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/KernelUtils.h>
#include <c10/macros/Macros.h>
namespace mmdetection {
using namespace at::cuda::detail;
using mmdetection::detail::GridSamplerInterpolation;
using mmdetection::detail::GridSamplerPadding;
namespace {
template <typename scalar_t>
C10_LAUNCH_BOUNDS_1(1024)
__global__ void grid_sampler_2d_forward_kernel_cuda(
const int nthreads,
TensorInfo<scalar_t, int> input,
TensorInfo<scalar_t, int> grid,
TensorInfo<scalar_t, int> output,
const GridSamplerInterpolation interpolation_mode,
const GridSamplerPadding padding_mode,
bool align_corners) {
int C = input.sizes[1];
int inp_H = input.sizes[2];
int inp_W = input.sizes[3];
int out_H = grid.sizes[1];
int out_W = grid.sizes[2];
int inp_sN = input.strides[0];
int inp_sC = input.strides[1];
int inp_sH = input.strides[2];
int inp_sW = input.strides[3];
int grid_sN = grid.strides[0];
int grid_sH = grid.strides[1];
int grid_sW = grid.strides[2];
int grid_sCoor = grid.strides[3];
int out_sN = output.strides[0];
int out_sC = output.strides[1];
int out_sH = output.strides[2];
int out_sW = output.strides[3];
CUDA_KERNEL_LOOP(index, nthreads) {
const int w = index % out_W;
const int h = (index / out_W) % out_H;
const int n = index / (out_H * out_W);
const int grid_offset = n * grid_sN + h * grid_sH + w * grid_sW;
// get the corresponding input x, y co-ordinates from grid
scalar_t ix = grid.data[grid_offset];
scalar_t iy = grid.data[grid_offset + grid_sCoor];
ix = grid_sampler_compute_source_index(ix, inp_W, padding_mode, align_corners);
iy = grid_sampler_compute_source_index(iy, inp_H, padding_mode, align_corners);
if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
// get NE, NW, SE, SW pixel values from (x, y)
int ix_nw = static_cast<int>(::floor(ix));
int iy_nw = static_cast<int>(::floor(iy));
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
// get surfaces to each neighbor:
scalar_t nw = (ix_se - ix) * (iy_se - iy);
scalar_t ne = (ix - ix_sw) * (iy_sw - iy);
scalar_t sw = (ix_ne - ix) * (iy - iy_ne);
scalar_t se = (ix - ix_nw) * (iy - iy_nw);
// calculate bilinear weighted pixel value and set output pixel
auto inp_ptr_NC = input.data + n * inp_sN;
auto out_ptr_NCHW = output.data + n * out_sN + h * out_sH + w * out_sW;
for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) {
*out_ptr_NCHW = static_cast<scalar_t>(0);
if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) {
*out_ptr_NCHW += inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW] * nw;
}
if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) {
*out_ptr_NCHW += inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW] * ne;
}
if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) {
*out_ptr_NCHW += inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW] * sw;
}
if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) {
*out_ptr_NCHW += inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW] * se;
}
}
} else if (interpolation_mode == GridSamplerInterpolation::Nearest) {
int ix_nearest = static_cast<int>(::round(ix));
int iy_nearest = static_cast<int>(::round(iy));
// assign nearest neighor pixel value to output pixel
auto inp_ptr_NC = input.data + n * inp_sN;
auto out_ptr_NCHW = output.data + n * out_sN + h * out_sH + w * out_sW;
for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) {
if (within_bounds_2d(iy_nearest, ix_nearest, inp_H, inp_W)) {
*out_ptr_NCHW = inp_ptr_NC[iy_nearest * inp_sH + ix_nearest * inp_sW];
} else {
*out_ptr_NCHW = static_cast<scalar_t>(0);
}
}
}
}
}
template <typename scalar_t>
C10_LAUNCH_BOUNDS_1(1024)
__global__ void grid_sampler_3d_forward_kernel_cuda(
const int nthreads,
TensorInfo<scalar_t, int> input,
TensorInfo<scalar_t, int> grid,
TensorInfo<scalar_t, int> output,
const GridSamplerInterpolation interpolation_mode,
const GridSamplerPadding padding_mode,
bool align_corners) {
int C = input.sizes[1];
int inp_D = input.sizes[2];
int inp_H = input.sizes[3];
int inp_W = input.sizes[4];
int out_D = grid.sizes[1];
int out_H = grid.sizes[2];
int out_W = grid.sizes[3];
int inp_sN = input.strides[0];
int inp_sC = input.strides[1];
int inp_sD = input.strides[2];
int inp_sH = input.strides[3];
int inp_sW = input.strides[4];
int grid_sN = grid.strides[0];
int grid_sD = grid.strides[1];
int grid_sH = grid.strides[2];
int grid_sW = grid.strides[3];
int grid_sCoor = grid.strides[4];
int out_sN = output.strides[0];
int out_sC = output.strides[1];
int out_sD = output.strides[2];
int out_sH = output.strides[3];
int out_sW = output.strides[4];
CUDA_KERNEL_LOOP(index, nthreads) {
const int w = index % out_W;
const int h = (index / out_W) % out_H;
const int d = (index / (out_H * out_W)) % out_D;
const int n = index / (out_D * out_H * out_W);
const int grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW;
// get the corresponding input x, y, z co-ordinates from grid
scalar_t ix = grid.data[grid_offset];
scalar_t iy = grid.data[grid_offset + grid_sCoor];
scalar_t iz = grid.data[grid_offset + 2 * grid_sCoor];
ix = grid_sampler_compute_source_index(ix, inp_W, padding_mode, align_corners);
iy = grid_sampler_compute_source_index(iy, inp_H, padding_mode, align_corners);
iz = grid_sampler_compute_source_index(iz, inp_D, padding_mode, align_corners);
if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
// get corner pixel values from (x, y, z)
// for 4d, we used north-east-south-west
// for 5d, we add top-bottom
int ix_tnw = static_cast<int>(::floor(ix));
int iy_tnw = static_cast<int>(::floor(iy));
int iz_tnw = static_cast<int>(::floor(iz));
int ix_tne = ix_tnw + 1;
int iy_tne = iy_tnw;
int iz_tne = iz_tnw;
int ix_tsw = ix_tnw;
int iy_tsw = iy_tnw + 1;
int iz_tsw = iz_tnw;
int ix_tse = ix_tnw + 1;
int iy_tse = iy_tnw + 1;
int iz_tse = iz_tnw;
int ix_bnw = ix_tnw;
int iy_bnw = iy_tnw;
int iz_bnw = iz_tnw + 1;
int ix_bne = ix_tnw + 1;
int iy_bne = iy_tnw;
int iz_bne = iz_tnw + 1;
int ix_bsw = ix_tnw;
int iy_bsw = iy_tnw + 1;
int iz_bsw = iz_tnw + 1;
int ix_bse = ix_tnw + 1;
int iy_bse = iy_tnw + 1;
int iz_bse = iz_tnw + 1;
// get surfaces to each neighbor:
scalar_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);
scalar_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz);
scalar_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz);
scalar_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz);
scalar_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse);
scalar_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw);
scalar_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);
scalar_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);
auto inp_ptr_NC = input.data + n * inp_sN;
auto out_ptr_NCDHW = output.data + n * out_sN + d * out_sD + h * out_sH + w * out_sW;
for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) {
// (c, iz_tnw, iy_tnw, ix_tnw) * tnw + (c, iz_tne, iy_tne, ix_tne) * tne
// + (c, iz_tsw, iy_tsw, ix_tsw) * tsw + (c, iz_tse, iy_tse, ix_tse) * tse
// + (c, iz_bnw, iy_bnw, ix_bnw) * bnw + (c, iz_bne, iy_bne, ix_bne) * bne
// + (c, iz_bsw, iy_bsw, ix_bsw) * bsw + (c, iz_bse, iy_bse, ix_bse) * bse
*out_ptr_NCDHW = static_cast<scalar_t>(0);
if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw;
}
if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne;
}
if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw;
}
if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse;
}
if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw;
}
if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne;
}
if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw;
}
if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse;
}
}
} else if (interpolation_mode == GridSamplerInterpolation::Nearest) {
int ix_nearest = static_cast<int>(::round(ix));
int iy_nearest = static_cast<int>(::round(iy));
int iz_nearest = static_cast<int>(::round(iz));
// assign nearest neighor pixel value to output pixel
auto inp_ptr_NC = input.data + n * inp_sN;
auto out_ptr_NCDHW = output.data + n * out_sN + d * out_sD + h * out_sH + w * out_sW;
for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) {
if (within_bounds_3d(iz_nearest, iy_nearest, ix_nearest, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW = inp_ptr_NC[iz_nearest * inp_sD + iy_nearest * inp_sH + ix_nearest * inp_sW];
} else {
*out_ptr_NCDHW = static_cast<scalar_t>(0);
}
}
}
}
}
template <typename scalar_t>
C10_LAUNCH_BOUNDS_1(1024)
__global__ void grid_sampler_2d_backward_kernel_cuda(
const int nthreads,
TensorInfo<scalar_t, int> grad_output,
TensorInfo<scalar_t, int> input,
TensorInfo<scalar_t, int> grid,
TensorInfo<scalar_t, int> grad_input, // initialized to zeros
TensorInfo<scalar_t, int> grad_grid, // initialized to empty
const GridSamplerInterpolation interpolation_mode,
const GridSamplerPadding padding_mode,
bool align_corners) {
int C = input.sizes[1];
int inp_H = input.sizes[2];
int inp_W = input.sizes[3];
int out_H = grid.sizes[1];
int out_W = grid.sizes[2];
int inp_sN = input.strides[0];
int inp_sC = input.strides[1];
int inp_sH = input.strides[2];
int inp_sW = input.strides[3];
int grid_sN = grid.strides[0];
int grid_sH = grid.strides[1];
int grid_sW = grid.strides[2];
int grid_sCoor = grid.strides[3];
int gOut_sN = grad_output.strides[0];
int gOut_sC = grad_output.strides[1];
int gOut_sH = grad_output.strides[2];
int gOut_sW = grad_output.strides[3];
int gInp_sN = grad_input.strides[0];
int gInp_sC = grad_input.strides[1];
int gInp_sH = grad_input.strides[2];
int gInp_sW = grad_input.strides[3];
int gGrid_sW = grad_grid.strides[2];
CUDA_KERNEL_LOOP(index, nthreads) {
const int w = index % out_W;
const int h = (index / out_W) % out_H;
const int n = index / (out_H * out_W);
const int grid_offset = n * grid_sN + h * grid_sH + w * grid_sW;
// get the corresponding input x, y co-ordinates from grid
scalar_t ix = grid.data[grid_offset];
scalar_t iy = grid.data[grid_offset + grid_sCoor];
// multipliers for gradients on ix and iy
scalar_t gix_mult, giy_mult;
ix = grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, align_corners, &gix_mult);
iy = grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, align_corners, &giy_mult);
if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
// get NE, NW, SE, SW pixel values from (x, y)
int ix_nw = static_cast<int>(::floor(ix));
int iy_nw = static_cast<int>(::floor(iy));
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
// get surfaces to each neighbor:
scalar_t nw = (ix_se - ix) * (iy_se - iy);
scalar_t ne = (ix - ix_sw) * (iy_sw - iy);
scalar_t sw = (ix_ne - ix) * (iy - iy_ne);
scalar_t se = (ix - ix_nw) * (iy - iy_nw);
scalar_t gix = static_cast<scalar_t>(0), giy = static_cast<scalar_t>(0);
scalar_t *gOut_ptr_NCHW = grad_output.data + n * gOut_sN + h * gOut_sH + w * gOut_sW;
scalar_t *gInp_ptr_NC = grad_input.data + n * gInp_sN;
scalar_t *inp_ptr_NC = input.data + n * inp_sN;
for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, gInp_ptr_NC += gInp_sC, gOut_ptr_NCHW += gOut_sC) {
scalar_t gOut = *gOut_ptr_NCHW;
// calculate and set grad_input
safe_add_2d(gInp_ptr_NC, iy_nw, ix_nw, gInp_sH, gInp_sW, inp_H, inp_W, nw * gOut);
safe_add_2d(gInp_ptr_NC, iy_ne, ix_ne, gInp_sH, gInp_sW, inp_H, inp_W, ne * gOut);
safe_add_2d(gInp_ptr_NC, iy_sw, ix_sw, gInp_sH, gInp_sW, inp_H, inp_W, sw * gOut);
safe_add_2d(gInp_ptr_NC, iy_se, ix_se, gInp_sH, gInp_sW, inp_H, inp_W, se * gOut);
// calculate grad_grid
if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) {
scalar_t nw_val = inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW];
gix -= nw_val * (iy_se - iy) * gOut;
giy -= nw_val * (ix_se - ix) * gOut;
}
if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) {
scalar_t ne_val = inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW];
gix += ne_val * (iy_sw - iy) * gOut;
giy -= ne_val * (ix - ix_sw) * gOut;
}
if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) {
scalar_t sw_val = inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW];
gix -= sw_val * (iy - iy_ne) * gOut;
giy += sw_val * (ix_ne - ix) * gOut;
}
if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) {
scalar_t se_val = inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW];
gix += se_val * (iy - iy_nw) * gOut;
giy += se_val * (ix - ix_nw) * gOut;
}
}
// assuming grad_grid is contiguous
// thus we can
// 1. use index with gGrid_sW to directly compute gGrid_ptr_NHW
// 2. directly assign to gGrid_ptr_NHW[0], gGrid_ptr_NHW[1]
scalar_t *gGrid_ptr_NHW = grad_grid.data + index * gGrid_sW;
gGrid_ptr_NHW[0] = gix_mult * gix;
gGrid_ptr_NHW[1] = giy_mult * giy;
} else if (interpolation_mode == GridSamplerInterpolation::Nearest) {
int ix_nearest = static_cast<int>(::round(ix));
int iy_nearest = static_cast<int>(::round(iy));
// assign nearest neighor pixel value to output pixel
scalar_t *gOut_ptr_NCHW = grad_output.data + n * gOut_sN + h * gOut_sH + w * gOut_sW;
scalar_t *gInp_ptr_NC = grad_input.data + n * gInp_sN;
for (int c = 0; c < C; ++c, gInp_ptr_NC += gInp_sC, gOut_ptr_NCHW += gOut_sC) {
// calculate and set grad_input
safe_add_2d(gInp_ptr_NC, iy_nearest, ix_nearest, gInp_sH, gInp_sW, inp_H, inp_W, *gOut_ptr_NCHW);
}
// assuming grad_grid is contiguous
// thus we can
// 1. use index with gGrid_sW to directly compute gGrid_ptr_NHW
// 2. directly assign to gGrid_ptr_NHW[0], gGrid_ptr_NHW[1]
scalar_t *gGrid_ptr_NHW = grad_grid.data + index * gGrid_sW;
gGrid_ptr_NHW[0] = static_cast<scalar_t>(0);
gGrid_ptr_NHW[1] = static_cast<scalar_t>(0);
}
}
}
template <typename scalar_t>
C10_LAUNCH_BOUNDS_1(1024)
__global__ void grid_sampler_3d_backward_kernel_cuda(
const int nthreads,
TensorInfo<scalar_t, int> grad_output,
TensorInfo<scalar_t, int> input,
TensorInfo<scalar_t, int> grid,
TensorInfo<scalar_t, int> grad_input, // initialized to zeros
TensorInfo<scalar_t, int> grad_grid, // initialized to empty
const GridSamplerInterpolation interpolation_mode,
const GridSamplerPadding padding_mode,
bool align_corners) {
int C = input.sizes[1];
int inp_D = input.sizes[2];
int inp_H = input.sizes[3];
int inp_W = input.sizes[4];
int out_D = grid.sizes[1];
int out_H = grid.sizes[2];
int out_W = grid.sizes[3];
int inp_sN = input.strides[0];
int inp_sC = input.strides[1];
int inp_sD = input.strides[2];
int inp_sH = input.strides[3];
int inp_sW = input.strides[4];
int grid_sN = grid.strides[0];
int grid_sD = grid.strides[1];
int grid_sH = grid.strides[2];
int grid_sW = grid.strides[3];
int grid_sCoor = grid.strides[4];
int gOut_sN = grad_output.strides[0];
int gOut_sC = grad_output.strides[1];
int gOut_sD = grad_output.strides[2];
int gOut_sH = grad_output.strides[3];
int gOut_sW = grad_output.strides[4];
int gInp_sN = grad_input.strides[0];
int gInp_sC = grad_input.strides[1];
int gInp_sD = grad_input.strides[2];
int gInp_sH = grad_input.strides[3];
int gInp_sW = grad_input.strides[4];
int gGrid_sW = grad_grid.strides[3];
CUDA_KERNEL_LOOP(index, nthreads) {
const int w = index % out_W;
const int h = (index / out_W) % out_H;
const int d = (index / (out_H * out_W)) % out_D;
const int n = index / (out_D * out_H * out_W);
const int grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW;
// get the corresponding input x, y, z co-ordinates from grid
scalar_t ix = grid.data[grid_offset];
scalar_t iy = grid.data[grid_offset + grid_sCoor];
scalar_t iz = grid.data[grid_offset + 2 * grid_sCoor];
// multipliers for gradients on ix, iy, and iz
scalar_t gix_mult, giy_mult, giz_mult;
ix = grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, align_corners, &gix_mult);
iy = grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, align_corners, &giy_mult);
iz = grid_sampler_compute_source_index_set_grad(iz, inp_D, padding_mode, align_corners, &giz_mult);
if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
// get corner pixel values from (x, y, z)
// for 4d, we used north-east-south-west
// for 5d, we add top-bottom
int ix_tnw = static_cast<int>(::floor(ix));
int iy_tnw = static_cast<int>(::floor(iy));
int iz_tnw = static_cast<int>(::floor(iz));
int ix_tne = ix_tnw + 1;
int iy_tne = iy_tnw;
int iz_tne = iz_tnw;
int ix_tsw = ix_tnw;
int iy_tsw = iy_tnw + 1;
int iz_tsw = iz_tnw;
int ix_tse = ix_tnw + 1;
int iy_tse = iy_tnw + 1;
int iz_tse = iz_tnw;
int ix_bnw = ix_tnw;
int iy_bnw = iy_tnw;
int iz_bnw = iz_tnw + 1;
int ix_bne = ix_tnw + 1;
int iy_bne = iy_tnw;
int iz_bne = iz_tnw + 1;
int ix_bsw = ix_tnw;
int iy_bsw = iy_tnw + 1;
int iz_bsw = iz_tnw + 1;
int ix_bse = ix_tnw + 1;
int iy_bse = iy_tnw + 1;
int iz_bse = iz_tnw + 1;
// get surfaces to each neighbor:
scalar_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);
scalar_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz);
scalar_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz);
scalar_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz);
scalar_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse);
scalar_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw);
scalar_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);
scalar_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);
scalar_t gix = static_cast<scalar_t>(0), giy = static_cast<scalar_t>(0), giz = static_cast<scalar_t>(0);
scalar_t *gOut_ptr_NCDHW = grad_output.data + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW;
scalar_t *gInp_ptr_NC = grad_input.data + n * gInp_sN;
scalar_t *inp_ptr_NC = input.data + n * inp_sN;
// calculate bilinear weighted pixel value and set output pixel
for (int c = 0; c < C; ++c, gOut_ptr_NCDHW += gOut_sC, gInp_ptr_NC += gInp_sC, inp_ptr_NC += inp_sC) {
scalar_t gOut = *gOut_ptr_NCDHW;
// calculate and set grad_input
safe_add_3d(gInp_ptr_NC, iz_tnw, iy_tnw, ix_tnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tnw * gOut);
safe_add_3d(gInp_ptr_NC, iz_tne, iy_tne, ix_tne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tne * gOut);
safe_add_3d(gInp_ptr_NC, iz_tsw, iy_tsw, ix_tsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tsw * gOut);
safe_add_3d(gInp_ptr_NC, iz_tse, iy_tse, ix_tse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tse * gOut);
safe_add_3d(gInp_ptr_NC, iz_bnw, iy_bnw, ix_bnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bnw * gOut);
safe_add_3d(gInp_ptr_NC, iz_bne, iy_bne, ix_bne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bne * gOut);
safe_add_3d(gInp_ptr_NC, iz_bsw, iy_bsw, ix_bsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bsw * gOut);
safe_add_3d(gInp_ptr_NC, iz_bse, iy_bse, ix_bse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bse * gOut);
// calculate grad_grid
if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {
scalar_t tnw_val = inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW];
gix -= tnw_val * (iy_bse - iy) * (iz_bse - iz) * gOut;
giy -= tnw_val * (ix_bse - ix) * (iz_bse - iz) * gOut;
giz -= tnw_val * (ix_bse - ix) * (iy_bse - iy) * gOut;
}
if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {
scalar_t tne_val = inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW];
gix += tne_val * (iy_bsw - iy) * (iz_bsw - iz) * gOut;
giy -= tne_val * (ix - ix_bsw) * (iz_bsw - iz) * gOut;
giz -= tne_val * (ix - ix_bsw) * (iy_bsw - iy) * gOut;
}
if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {
scalar_t tsw_val = inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW];
gix -= tsw_val * (iy - iy_bne) * (iz_bne - iz) * gOut;
giy += tsw_val * (ix_bne - ix) * (iz_bne - iz) * gOut;
giz -= tsw_val * (ix_bne - ix) * (iy - iy_bne) * gOut;
}
if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {
scalar_t tse_val = inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW];
gix += tse_val * (iy - iy_bnw) * (iz_bnw - iz) * gOut;
giy += tse_val * (ix - ix_bnw) * (iz_bnw - iz) * gOut;
giz -= tse_val * (ix - ix_bnw) * (iy - iy_bnw) * gOut;
}
if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {
scalar_t bnw_val = inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW];
gix -= bnw_val * (iy_tse - iy) * (iz - iz_tse) * gOut;
giy -= bnw_val * (ix_tse - ix) * (iz - iz_tse) * gOut;
giz += bnw_val * (ix_tse - ix) * (iy_tse - iy) * gOut;
}
if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {
scalar_t bne_val = inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW];
gix += bne_val * (iy_tsw - iy) * (iz - iz_tsw) * gOut;
giy -= bne_val * (ix - ix_tsw) * (iz - iz_tsw) * gOut;
giz += bne_val * (ix - ix_tsw) * (iy_tsw - iy) * gOut;
}
if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {
scalar_t bsw_val = inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW];
gix -= bsw_val * (iy - iy_tne) * (iz - iz_tne) * gOut;
giy += bsw_val * (ix_tne - ix) * (iz - iz_tne) * gOut;
giz += bsw_val * (ix_tne - ix) * (iy - iy_tne) * gOut;
}
if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {
scalar_t bse_val = inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW];
gix += bse_val * (iy - iy_tnw) * (iz - iz_tnw) * gOut;
giy += bse_val * (ix - ix_tnw) * (iz - iz_tnw) * gOut;
giz += bse_val * (ix - ix_tnw) * (iy - iy_tnw) * gOut;
}
}
// assuming grad_grid is contiguous
// thus we can
// 1. use index with gGrid_sW to directly compute gGrid_ptr_NDHW
// 2. directly assign to gGrid_ptr_NDHW[0], gGrid_ptr_NDHW[1], gGrid_ptr_NDHW[2]
scalar_t *gGrid_ptr_NDHW = grad_grid.data + index * gGrid_sW;
gGrid_ptr_NDHW[0] = gix_mult * gix;
gGrid_ptr_NDHW[1] = giy_mult * giy;
gGrid_ptr_NDHW[2] = giz_mult * giz;
} else if (interpolation_mode == GridSamplerInterpolation::Nearest) {
int ix_nearest = static_cast<int>(::round(ix));
int iy_nearest = static_cast<int>(::round(iy));
int iz_nearest = static_cast<int>(::round(iz));
// assign nearest neighor pixel value to output pixel
scalar_t *gOut_ptr_NCDHW = grad_output.data + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW;
scalar_t *gInp_ptr_NC = grad_input.data + n * gInp_sN;
for (int c = 0; c < C; ++c, gOut_ptr_NCDHW += gOut_sC, gInp_ptr_NC += gInp_sC) {
// calculate and set grad_input
safe_add_3d(gInp_ptr_NC, iz_nearest, iy_nearest, ix_nearest,
gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, *gOut_ptr_NCDHW);
}
// assuming grad_grid is contiguous
// thus we can
// 1. use index with gGrid_sW to directly compute gGrid_ptr_NDHW
// 2. directly assign to gGrid_ptr_NDHW[0], gGrid_ptr_NDHW[1], gGrid_ptr_NDHW[2]
scalar_t *gGrid_ptr_NDHW = grad_grid.data + index * gGrid_sW;
gGrid_ptr_NDHW[0] = static_cast<scalar_t>(0);
gGrid_ptr_NDHW[1] = static_cast<scalar_t>(0);
gGrid_ptr_NDHW[2] = static_cast<scalar_t>(0);
}
}
}
} // namespace
using namespace at;
// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
Tensor grid_sampler_2d_forward_cuda(const Tensor& input, const Tensor& grid,
int64_t interpolation_mode, int64_t padding_mode,
bool align_corners) {
auto N = input.size(0);
auto H = grid.size(1);
auto W = grid.size(2);
auto output = at::empty({N, input.size(1), H, W}, input.options());
int count = static_cast<int>(N * H * W);
if (count > 0) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_2d_forward_cuda", [&] {
grid_sampler_2d_forward_kernel_cuda<scalar_t>
<<<GET_BLOCKS(count), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
count,
getTensorInfo<scalar_t, int>(input),
getTensorInfo<scalar_t, int>(grid),
getTensorInfo<scalar_t, int>(output),
static_cast<GridSamplerInterpolation>(interpolation_mode),
static_cast<GridSamplerPadding>(padding_mode),
align_corners);
});
}
return output;
}
// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
Tensor grid_sampler_3d_forward_cuda(const Tensor& input, const Tensor& grid,
int64_t interpolation_mode, int64_t padding_mode,
bool align_corners) {
auto N = input.size(0);
auto D = grid.size(1);
auto H = grid.size(2);
auto W = grid.size(3);
auto output = at::empty({N, input.size(1), D, H, W}, input.options());
int count = static_cast<int>(N * D * H * W);
if (count > 0) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_3d_forward_cuda", [&] {
grid_sampler_3d_forward_kernel_cuda<scalar_t>
<<<GET_BLOCKS(count), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
count,
getTensorInfo<scalar_t, int>(input),
getTensorInfo<scalar_t, int>(grid),
getTensorInfo<scalar_t, int>(output),
static_cast<GridSamplerInterpolation>(interpolation_mode),
static_cast<GridSamplerPadding>(padding_mode),
align_corners);
});
}
return output;
}
// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
std::tuple<Tensor, Tensor>
grid_sampler_2d_backward_cuda(const Tensor& grad_output, const Tensor& input,
const Tensor& grid, int64_t interpolation_mode,
int64_t padding_mode, bool align_corners) {
auto N = input.size(0);
auto H = grid.size(1);
auto W = grid.size(2);
auto grad_input = at::zeros_like(input);
auto grad_grid = at::empty_like(grid);
int count = static_cast<int>(N * H * W);
if (count > 0) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_2d_backward_cuda", [&] {
grid_sampler_2d_backward_kernel_cuda<scalar_t>
<<<GET_BLOCKS(count), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
count,
getTensorInfo<scalar_t, int>(grad_output),
getTensorInfo<scalar_t, int>(input),
getTensorInfo<scalar_t, int>(grid),
getTensorInfo<scalar_t, int>(grad_input),
getTensorInfo<scalar_t, int>(grad_grid),
static_cast<GridSamplerInterpolation>(interpolation_mode),
static_cast<GridSamplerPadding>(padding_mode),
align_corners);
});
}
return std::make_tuple(grad_input, grad_grid);
}
// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
std::tuple<Tensor, Tensor>
grid_sampler_3d_backward_cuda(const Tensor& grad_output, const Tensor& input,
const Tensor& grid, int64_t interpolation_mode, int64_t padding_mode,
bool align_corners) {
auto N = input.size(0);
auto D = grid.size(1);
auto H = grid.size(2);
auto W = grid.size(3);
auto grad_input = at::zeros_like(input);
auto grad_grid = at::empty_like(grid);
int count = static_cast<int>(N * D * H * W);
if (count > 0) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_3d_backward_cuda", [&] {
grid_sampler_3d_backward_kernel_cuda<scalar_t>
<<<GET_BLOCKS(count), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
count,
getTensorInfo<scalar_t, int>(grad_output),
getTensorInfo<scalar_t, int>(input),
getTensorInfo<scalar_t, int>(grid),
getTensorInfo<scalar_t, int>(grad_input),
getTensorInfo<scalar_t, int>(grad_grid),
static_cast<GridSamplerInterpolation>(interpolation_mode),
static_cast<GridSamplerPadding>(padding_mode),
align_corners);
});
}
return std::make_tuple(grad_input, grad_grid);
}
} // namespace mmdetection
// Modified from https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/GridSampler.cuh
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <THC/THCAtomics.cuh>
namespace mmdetection {
namespace detail {
enum class GridSamplerInterpolation {Bilinear, Nearest};
enum class GridSamplerPadding {Zeros, Border, Reflection};
} // namespace detail
using detail::GridSamplerInterpolation;
using detail::GridSamplerPadding;
// Unnormalizes a coordinate from the -1 to +1 scale to its pixel index value,
// where we view each pixel as an area between (idx - 0.5) and (idx + 0.5).
// if align_corners: -1 and +1 get sent to the centers of the corner pixels
// -1 --> 0
// +1 --> (size - 1)
// scale_factor = (size - 1) / 2
// if not align_corners: -1 and +1 get sent to the image edges
// -1 --> -0.5
// +1 --> (size - 1) + 0.5 == size - 0.5
// scale_factor = size / 2
template <typename scalar_t>
static __forceinline__ __device__
scalar_t grid_sampler_unnormalize(scalar_t coord, int size, bool align_corners) {
if (align_corners) {
// unnormalize coord from [-1, 1] to [0, size - 1]
return ((coord + 1.f) / 2) * (size - 1);
} else {
// unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
return ((coord + 1.f) * size - 1) / 2;
}
}
// grid_sampler_unnormalize_set_grad works the same as grid_sampler_unnormalize
// except that it also returns the `d output / d input` via pointer argument
// `grad_in`.
// This is useful in the backward pass of grid_sampler.
template <typename scalar_t>
static __forceinline__ __device__
scalar_t grid_sampler_unnormalize_set_grad(scalar_t coord, int size,
bool align_corners, scalar_t *grad_in) {
if (align_corners) {
// unnormalize coord from [-1, 1] to [0, size - 1]
*grad_in = static_cast<scalar_t>(size - 1) / 2;
return ((coord + 1.f) / 2) * (size - 1);
} else {
// unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
*grad_in = static_cast<scalar_t>(size) / 2;
return ((coord + 1.f) * size - 1) / 2;
}
}
// Clips coordinates to between 0 and clip_limit - 1
template <typename scalar_t>
static __forceinline__ __device__
scalar_t clip_coordinates(scalar_t in, int clip_limit) {
return ::min(static_cast<scalar_t>(clip_limit - 1), ::max(in, static_cast<scalar_t>(0)));
}
// clip_coordinates_set_grad works similarly to clip_coordinates except that
// it also returns the `d output / d input` via pointer argument `grad_in`.
// This is useful in the backward pass of grid_sampler.
template <typename scalar_t>
static __forceinline__ __device__
scalar_t clip_coordinates_set_grad(scalar_t in, int clip_limit, scalar_t *grad_in) {
if (in < static_cast<scalar_t>(0)) {
*grad_in = static_cast<scalar_t>(0);
return static_cast<scalar_t>(0);
} else {
scalar_t max = static_cast<scalar_t>(clip_limit - 1);
if (in > max) {
*grad_in = static_cast<scalar_t>(0);
return max;
} else {
*grad_in = static_cast<scalar_t>(1);
return in;
}
}
}
// Reflects coordinates until they fall between low and high (inclusive).
// The bounds are passed as twice their value so that half-integer values
// can be represented as ints.
template <typename scalar_t>
static __forceinline__ __device__
scalar_t reflect_coordinates(scalar_t in, int twice_low, int twice_high) {
if (twice_low == twice_high) {
return static_cast<scalar_t>(0);
}
scalar_t min = static_cast<scalar_t>(twice_low) / 2;
scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2;
in = ::fabs(in - min);
// `fmod` returns same sign as `in`, which is positive after the `fabs` above.
scalar_t extra = ::fmod(in, span);
int flips = static_cast<int>(::floor(in / span));
if (flips % 2 == 0) {
return extra + min;
} else {
return span - extra + min;
}
}
// reflect_coordinates_set_grad works similarly to reflect_coordinates except
// that it also returns the `d output / d input` via pointer argument
// `grad_in`.
// This is useful in the backward pass of grid_sampler.
template <typename scalar_t>
static __forceinline__ __device__
scalar_t reflect_coordinates_set_grad(scalar_t in, int twice_low, int twice_high,
scalar_t *grad_in) {
if (twice_low == twice_high) {
*grad_in = static_cast<scalar_t>(0);
return static_cast<scalar_t>(0);
}
int grad_in_mult_;
scalar_t min = static_cast<scalar_t>(twice_low) / 2;
scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2;
in = in - min;
if (in < static_cast<scalar_t>(0)) {
grad_in_mult_ = -1;
in = -in;
} else {
grad_in_mult_ = 1;
}
// `fmod` returns same sign as `in`, which is positive after the `if` above.
scalar_t extra = ::fmod(in, span);
int flips = static_cast<int>(::floor(in / span));
if (flips % 2 == 0) {
*grad_in = static_cast<scalar_t>(grad_in_mult_);
return extra + min;
} else {
*grad_in = static_cast<scalar_t>(-grad_in_mult_);
return span - extra + min;
}
}
// Computes the pixel source index value for a grid coordinate
template <typename scalar_t>
static __forceinline__ __device__
scalar_t grid_sampler_compute_source_index(
scalar_t coord,
int size,
GridSamplerPadding padding_mode,
bool align_corners) {
coord = grid_sampler_unnormalize(coord, size, align_corners);
if (padding_mode == GridSamplerPadding::Border) {
// clip coordinates to image borders
coord = clip_coordinates(coord, size);
} else if (padding_mode == GridSamplerPadding::Reflection) {
// reflect coordinates by image borders
if (align_corners) {
coord = reflect_coordinates(coord, 0, 2*(size - 1));
} else {
coord = reflect_coordinates(coord, -1, 2*size - 1);
// when align_corners=False, reflection does not auto clip coords
coord = clip_coordinates(coord, size);
}
}
return coord;
}
// grid_sampler_compute_source_index_set_grad works similarly to
// grid_sampler_compute_source_index except that it also returns the
// `d output / d input` via pointer argument `grad_in`.
// This is useful in the backward pass of grid_sampler.
template <typename scalar_t>
static __forceinline__ __device__
scalar_t grid_sampler_compute_source_index_set_grad(
scalar_t coord,
int size,
GridSamplerPadding padding_mode,
bool align_corners,
scalar_t *grad_in) {
scalar_t grad_clip, grad_refl;
coord = grid_sampler_unnormalize_set_grad(coord, size, align_corners, grad_in);
if (padding_mode == GridSamplerPadding::Border) {
// clip coordinates to image borders
coord = clip_coordinates_set_grad(coord, size, &grad_clip);
*grad_in = (*grad_in) * grad_clip;
} else if (padding_mode == GridSamplerPadding::Reflection) {
// reflect coordinates by image borders
if (align_corners) {
coord = reflect_coordinates_set_grad(coord, 0, 2*(size - 1), &grad_refl);
*grad_in = (*grad_in) * grad_refl;
} else {
coord = reflect_coordinates_set_grad(coord, -1, 2*size - 1, &grad_refl);
// when align_corners=False, reflection does not auto clip coords
coord = clip_coordinates_set_grad(coord, size, &grad_clip);
*grad_in = (*grad_in) * grad_refl * grad_clip;
}
}
return coord;
}
static __forceinline__ __device__
bool within_bounds_2d(int h, int w, int H, int W) {
return h >= 0 && h < H && w >= 0 && w < W;
}
static __forceinline__ __device__
bool within_bounds_3d(int d, int h, int w, int D, int H, int W) {
return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W;
}
template<typename scalar_t>
static __forceinline__ __device__
void safe_add_2d(scalar_t *data, int h, int w,
int sH, int sW, int H, int W,
scalar_t delta) {
if (within_bounds_2d(h, w, H, W)) {
atomicAdd(data + h * sH + w * sW, delta);
}
}
template<typename scalar_t>
static __forceinline__ __device__
void safe_add_3d(scalar_t *data, int d, int h, int w,
int sD, int sH, int sW, int D, int H, int W,
scalar_t delta) {
if (within_bounds_3d(d, h, w, D, H, W)) {
atomicAdd(data + d * sD + h * sH + w * sW, delta);
}
}
} // namespace at::mmdetection
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Config.h>
#include <ATen/cuda/CUDAConfig.h>
#if !AT_CUDNN_ENABLED()
namespace at { namespace native {
// See Note [ATen preprocessor philosophy]
Tensor cudnn_grid_sampler_forward(
const Tensor& input_t, const Tensor& grid_t) {
AT_ERROR("cudnn_grid_sampler_forward: ATen not compiled with cuDNN support");
}
std::tuple<Tensor, Tensor> cudnn_grid_sampler_backward(
const Tensor& input_t, const Tensor& grid_t,
const Tensor& grad_output_t) {
AT_ERROR("cudnn_grid_sampler_backward: ATen not compiled with cuDNN support");
}
}}
#else // AT_CUDNN_ENABLED
#include <ATen/cudnn/Descriptors.h>
#include <ATen/cudnn/Types.h>
#include <ATen/cudnn/Utils.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/TensorUtils.h>
// TODO: descriptor checking
namespace mmdetection {
using namespace at;
namespace {
void setSamplerDescriptor(SpatialTransformerDescriptor& desc, cudnnDataType_t dataType, const at::Tensor& tensor)
{
int inputSize[4] = {0};
for (int i = 0; i < tensor.dim(); ++i) {
inputSize[i] = (int) tensor.size(i);
}
desc.set(dataType, 4, inputSize);
}
void checkGridSize(CheckedFrom c, TensorArg grid, TensorArg input)
{
// assert size of grid is n*h*w*2
// FYI: grid is between [-1, 1], where -1 left most pixel,
// 1 represents right most pixel (and hence 0 is the center pixel)
// if grid has values >1 or <-1, those values are ignored
checkContiguous(c, grid);
checkDim(c, grid, 4);
// TODO: Maybe more user friendly to report where the expected size
// came from
checkSize(c, grid, 0, input->size(0));
checkSize(c, grid, 3, 2);
}
} // namespace
Tensor cudnn_grid_sampler_forward(
const Tensor& input_t, const Tensor& grid_t)
{
TensorArg input{ contiguousIfZeroInStrides(input_t), "input", 1 },
grid{ grid_t.contiguous(), "grid", 2 };
CheckedFrom c = "cudnn_grid_sampler_forward";
checkAllSameGPU(c, {input, grid});
checkAllSameType(c, {input, grid});
checkGridSize(c, grid, input);
checkDim(c, input, 4);
auto output_t = at::empty({0}, input->options());
output_t.resize_({input->size(0), input->size(1), grid->size(1), grid->size(2)});
TensorDescriptor idesc{ *input }; // input descriptor
TensorDescriptor odesc{ output_t }; // output descriptor
SpatialTransformerDescriptor desc; // sampler descriptor
auto handle = getCudnnHandle();
auto dataType = getCudnnDataType(*input);
setSamplerDescriptor(desc, dataType, output_t);
Constant one(dataType, 1);
Constant zero(dataType, 0);
AT_CUDNN_CHECK(cudnnSpatialTfSamplerForward(
handle, desc.desc(),
&one, idesc.desc(), input->data_ptr(),
grid->data_ptr(),
&zero, odesc.desc(), output_t.data_ptr()
));
return output_t;
}
// NB: CuDNN does not support output mask; you always get both
// gradients.
std::tuple<Tensor, Tensor> cudnn_grid_sampler_backward(
const Tensor& input_t, const Tensor& grid_t,
const Tensor& grad_output_t)
{
TensorArg input{ contiguousIfZeroInStrides(input_t), "input", 1 },
grid{ grid_t.contiguous(), "grid", 2 },
grad_output{ contiguousIfZeroInStrides(grad_output_t), "grad_output", 3 };
CheckedFrom c = "cudnn_grid_sampler_backward";
checkAllSameGPU(c, {input, grad_output, grid});
checkGridSize(c, grid, input);
checkDim(c, input, 4);
checkDim(c, grad_output, 4);
auto grad_input_t = at::empty({0}, input->options());
grad_input_t.resize_(input->sizes());
auto grad_grid_t = at::empty({0}, grid->options());
grad_grid_t.resize_(grid->sizes());
TensorDescriptor idesc{ *input }; // input descriptor
TensorDescriptor odesc{ *grad_output }; // grad_output descriptor
TensorDescriptor gdesc{ grad_input_t }; // grad_input descriptor
SpatialTransformerDescriptor desc; // sampler descriptor
auto handle = getCudnnHandle();
auto dataType = getCudnnDataType(*input);
setSamplerDescriptor(desc, dataType, *grad_output);
Constant one(dataType, 1);
Constant zero(dataType, 0);
AT_CUDNN_CHECK(cudnnSpatialTfSamplerBackward(
handle, desc.desc(),
&one, idesc.desc(), input->data_ptr(),
&zero, gdesc.desc(), grad_input_t.data_ptr(),
&one, odesc.desc(), grad_output->data_ptr(),
// intruigingly, the outputs don't need descriptors
grid->data_ptr(),
&zero, grad_grid_t.data_ptr()
));
return std::tuple<Tensor, Tensor>{ grad_input_t, grad_grid_t };
}
} // namespace mmdetection
#endif
#include <torch/extension.h>
#include <ATen/DeviceGuard.h>
namespace mmdetection {
using namespace at;
// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
Tensor grid_sampler_2d_forward_cpu(const Tensor& input, const Tensor& grid,
int64_t interpolation_mode, int64_t padding_mode,
bool align_corners);
// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
Tensor grid_sampler_3d_forward_cpu(const Tensor& input, const Tensor& grid,
int64_t interpolation_mode, int64_t padding_mode,
bool align_corners);
// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
std::tuple<Tensor, Tensor>
grid_sampler_2d_backward_cpu(const Tensor& grad_output, const Tensor& input,
const Tensor& grid, int64_t interpolation_mode,
int64_t padding_mode, bool align_corners);
// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
std::tuple<Tensor, Tensor>
grid_sampler_3d_backward_cpu(const Tensor& grad_output, const Tensor& input,
const Tensor& grid, int64_t interpolation_mode, int64_t padding_mode,
bool align_corners);
// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
Tensor grid_sampler_2d_forward_cuda(const Tensor& input, const Tensor& grid,
int64_t interpolation_mode, int64_t padding_mode,
bool align_corners);
// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
Tensor grid_sampler_3d_forward_cuda(const Tensor& input, const Tensor& grid,
int64_t interpolation_mode, int64_t padding_mode,
bool align_corners);
// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
std::tuple<Tensor, Tensor>
grid_sampler_2d_backward_cuda(const Tensor& grad_output, const Tensor& input,
const Tensor& grid, int64_t interpolation_mode,
int64_t padding_mode, bool align_corners);
// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
std::tuple<Tensor, Tensor>
grid_sampler_3d_backward_cuda(const Tensor& grad_output, const Tensor& input,
const Tensor& grid, int64_t interpolation_mode, int64_t padding_mode,
bool align_corners);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("grid_sampler_2d_forward_cpu", &grid_sampler_2d_forward_cpu, "grid_sampler_2d_forward (CPU)");
m.def("grid_sampler_2d_backward_cpu", &grid_sampler_2d_backward_cpu, "grid_sampler_2d_backward (CPU)");
m.def("grid_sampler_3d_forward_cpu", &grid_sampler_3d_forward_cpu, "grid_sampler_3d_forward (CPU)");
m.def("grid_sampler_3d_backward_cpu", &grid_sampler_3d_backward_cpu, "grid_sampler_3d_backward (CPU)");
m.def("grid_sampler_2d_forward_cuda", &grid_sampler_2d_forward_cuda, "grid_sampler_2d_forward (CUDA)");
m.def("grid_sampler_2d_backward_cuda", &grid_sampler_2d_backward_cuda, "grid_sampler_2d_backward (CUDA)");
m.def("grid_sampler_3d_forward_cuda", &grid_sampler_3d_forward_cuda, "grid_sampler_3d_forward (CUDA)");
m.def("grid_sampler_3d_backward_cuda", &grid_sampler_3d_backward_cuda, "grid_sampler_3d_backward (CUDA)");
}
} // namespace mmdetection
......@@ -270,6 +270,13 @@ if __name__ == '__main__':
sources=[
'src/masked_conv2d_cuda.cpp', 'src/masked_conv2d_kernel.cu'
]),
make_cuda_ext(
name='grid_sampler_cuda',
module='mmdet.ops.grid_sampler',
sources=[
'src/cpu/grid_sampler_cpu.cpp',
'src/cuda/grid_sampler_cuda.cu', 'src/grid_sampler.cpp'
]),
make_cuda_ext(
name='carafe_cuda',
module='mmdet.ops.carafe',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment