Skip to content
Snippets Groups Projects
Unverified Commit d7a38be9 authored by Jerry Jiarui XU's avatar Jerry Jiarui XU Committed by GitHub
Browse files

add grid sampler (#2126)

* add grid sampler

* add doc

* remove comment
parent 639f934a
No related branches found
No related tags found
No related merge requests found
from .grid_sampler import grid_sample
__all__ = ['grid_sample']
import torch
import torch.nn.functional as F
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from . import grid_sampler_cuda
class _GridSampler(Function):
@staticmethod
def forward(ctx, input, grid, mode_enum, padding_mode_enum, align_corners):
ctx.save_for_backward(input, grid)
ctx.mode_enum = mode_enum
ctx.padding_mode_enum = padding_mode_enum
ctx.align_corners = align_corners
if input.is_cuda:
if input.dim() == 4:
func = grid_sampler_cuda.grid_sampler_2d_forward_cuda
else:
func = grid_sampler_cuda.grid_sampler_3d_forward_cuda
else:
if input.dim() == 4:
func = grid_sampler_cuda.grid_sampler_2d_forward_cpu
else:
func = grid_sampler_cuda.grid_sampler_3d_forward_cpu
output = func(input, grid, mode_enum, padding_mode_enum, align_corners)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
input, grid = ctx.saved_tensors
mode_enum = ctx.mode_enum
padding_mode_enum = ctx.padding_mode_enum
align_corners = ctx.align_corners
if input.is_cuda:
if input.dim() == 4:
func = grid_sampler_cuda.grid_sampler_2d_backward_cuda
else:
func = grid_sampler_cuda.grid_sampler_3d_backward_cuda
else:
if input.dim() == 4:
func = grid_sampler_cuda.grid_sampler_2d_backward_cpu
else:
func = grid_sampler_cuda.grid_sampler_3d_backward_cpu
grad_input, grad_grid = func(grad_output, input, grid, mode_enum,
padding_mode_enum, align_corners)
return grad_input, grad_grid, None, None, None
def grid_sample(input,
grid,
mode='bilinear',
padding_mode='zeros',
align_corners=False):
if torch.__version__ >= '1.3' or align_corners:
return F.grid_sample(input, grid, mode, padding_mode, align_corners)
else:
# use self-compiled grid_sampler to support align_corners=False
assert mode in ['bilinear', 'nearest'], \
'expected mode to be bilinear or nearest, but got: {}'.format(mode)
assert padding_mode in ['zeros', 'border', 'reflection'], \
'expected padding_mode to be zeros, border, or reflection, ' \
'but got: {}'.format(padding_mode)
if mode == 'bilinear':
mode_enum = 0
else:
mode_enum = 1
if padding_mode == 'zeros':
padding_mode_enum = 0
elif padding_mode == 'border':
padding_mode_enum = 1
else:
padding_mode_enum = 2
# shape check
assert input.device == grid.device, \
'expected input and grid to be on same device, ' \
'but input is on {} and grid is on {}'.format(
input.device, grid.device)
assert input.dtype == grid.dtype, \
'expected input and grid to have the same dtype, ' \
'but input has {} and grid has {}'.format(
input.dtype, grid.dtype)
assert input.dim() == 4 or input.dim() == 5, \
'expected 4D or 5D input and grid with same number of dimensions' \
'but got input with sizes {} and grid with sizes {}'.format(
input.size(), grid.size())
assert input.size(0) == grid.size(0), \
'expected input and grid to have the same batch size, ' \
'but got input with sizes {} and grid with sizes {}'.format(
input.size(), grid.size())
assert grid.size(-1) == input.dim() - 2, \
'expected grid to have size {} in last {} dimension, ' \
'but got grid with sizes '.format(
input.dim() - 2, grid.size())
for i in range(2, input.dim()):
assert input.size(i) > 0, \
'expected input to have non-empty spatial dimensions, ' \
'but input has sizes {} with dimension {} being empty'.format(
input.sizes(), i)
return _GridSampler.apply(input, grid, mode_enum, padding_mode_enum,
align_corners)
This diff is collapsed.
// Modified from https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/GridSampler.h
#pragma once
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
namespace mmdetection {
namespace detail {
enum class GridSamplerInterpolation {Bilinear, Nearest};
enum class GridSamplerPadding {Zeros, Border, Reflection};
} // namespace detail
using detail::GridSamplerInterpolation;
using detail::GridSamplerPadding;
// Unnormalizes a coordinate from the -1 to +1 scale to its pixel index value,
// where we view each pixel as an area between (idx - 0.5) and (idx + 0.5).
// if align_corners: -1 and +1 get sent to the centers of the corner pixels
// -1 --> 0
// +1 --> (size - 1)
// scale_factor = (size - 1) / 2
// if not align_corners: -1 and +1 get sent to the image edges
// -1 --> -0.5
// +1 --> (size - 1) + 0.5 == size - 0.5
// scale_factor = size / 2
template <typename scalar_t>
static inline scalar_t grid_sampler_unnormalize(scalar_t coord, int64_t size,
bool align_corners) {
if (align_corners) {
// unnormalize coord from [-1, 1] to [0, size - 1]
return ((coord + 1) / 2) * (size - 1);
} else {
// unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
return ((coord + 1) * size - 1) / 2;
}
}
// grid_sampler_unnormalize_set_grad works the same as grid_sampler_unnormalize
// except that it also returns the `d output / d input` via pointer argument
// `grad_in`.
// This is useful in the backward pass of grid_sampler.
template <typename scalar_t>
static inline scalar_t grid_sampler_unnormalize_set_grad(scalar_t coord, int64_t size,
bool align_corners, scalar_t *grad_in) {
if (align_corners) {
// unnormalize coord from [-1, 1] to [0, size - 1]
*grad_in = static_cast<scalar_t>(size - 1) / 2;
return ((coord + 1) / 2) * (size - 1);
} else {
// unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
*grad_in = static_cast<scalar_t>(size) / 2;
return ((coord + 1) * size - 1) / 2;
}
}
// Clips coordinates to between 0 and clip_limit - 1
template<typename scalar_t>
static inline scalar_t clip_coordinates(scalar_t in, int64_t clip_limit) {
return std::min(static_cast<scalar_t>(clip_limit - 1), std::max(in, static_cast<scalar_t>(0)));
}
// clip_coordinates_set_grad works similarly to clip_coordinates except that
// it also returns the `d output / d input` via pointer argument `grad_in`.
// This is useful in the backward pass of grid_sampler.
template<typename scalar_t>
static inline scalar_t clip_coordinates_set_grad(scalar_t in, int64_t clip_limit,
scalar_t *grad_in) {
if (in < static_cast<scalar_t>(0)) {
*grad_in = static_cast<scalar_t>(0);
return static_cast<scalar_t>(0);
} else {
scalar_t max = static_cast<scalar_t>(clip_limit - 1);
if (in > max) {
*grad_in = static_cast<scalar_t>(0);
return max;
} else {
*grad_in = static_cast<scalar_t>(1);
return in;
}
}
}
// Reflects coordinates until they fall between low and high (inclusive).
// The bounds are passed as twice their value so that half-integer values
// can be represented as ints.
template<typename scalar_t>
static inline scalar_t reflect_coordinates(scalar_t in, int64_t twice_low,
int64_t twice_high) {
if (twice_low == twice_high) {
return static_cast<scalar_t>(0);
}
scalar_t min = static_cast<scalar_t>(twice_low) / 2;
scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2;
in = std::fabs(in - min);
// `fmod` returns same sign as `in`, which is positive after the `fabs` above.
scalar_t extra = std::fmod(in, span);
int flips = static_cast<int>(std::floor(in / span));
if (flips % 2 == 0) {
return extra + min;
} else {
return span - extra + min;
}
}
// reflect_coordinates_set_grad works similarly to reflect_coordinates except
// that it also returns the `d output / d input` via pointer argument
// `grad_in`.
// This is useful in the backward pass of grid_sampler.
template<typename scalar_t>
static inline scalar_t reflect_coordinates_set_grad(scalar_t in, int64_t twice_low,
int64_t twice_high, scalar_t *grad_in) {
if (twice_low == twice_high) {
*grad_in = static_cast<scalar_t>(0);
return static_cast<scalar_t>(0);
}
int grad_in_mult_;
scalar_t min = static_cast<scalar_t>(twice_low) / 2;
scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2;
in = in - min;
if (in < static_cast<scalar_t>(0)) {
grad_in_mult_ = -1;
in = -in;
} else {
grad_in_mult_ = 1;
}
// `fmod` returns same sign as `in`, which is positive after the `if` above.
scalar_t extra = std::fmod(in, span);
int flips = static_cast<int>(std::floor(in / span));
if (flips % 2 == 0) {
*grad_in = static_cast<scalar_t>(grad_in_mult_);
return extra + min;
} else {
*grad_in = static_cast<scalar_t>(-grad_in_mult_);
return span - extra + min;
}
}
// Computes the pixel source index value for a grid coordinate
template <typename scalar_t>
static inline scalar_t grid_sampler_compute_source_index(
scalar_t coord,
int64_t size,
GridSamplerPadding padding_mode,
bool align_corners) {
coord = grid_sampler_unnormalize(coord, size, align_corners);
if (padding_mode == GridSamplerPadding::Border) {
// clip coordinates to image borders
coord = clip_coordinates(coord, size);
} else if (padding_mode == GridSamplerPadding::Reflection) {
// reflect coordinates by image borders
if (align_corners) {
coord = reflect_coordinates(coord, 0, 2*(size - 1));
} else {
coord = reflect_coordinates(coord, -1, 2*size - 1);
// when align_corners=False, reflection does not auto clip coords
coord = clip_coordinates(coord, size);
}
}
return coord;
}
// grid_sampler_compute_source_index_set_grad works similarly to
// grid_sampler_compute_source_index except that it also returns the
// `d output / d input` via pointer argument `grad_in`.
// This is useful in the backward pass of grid_sampler.
template <typename scalar_t>
static inline scalar_t grid_sampler_compute_source_index_set_grad(
scalar_t coord,
int64_t size,
GridSamplerPadding padding_mode,
bool align_corners,
scalar_t *grad_in) {
scalar_t grad_clip, grad_refl;
coord = grid_sampler_unnormalize_set_grad(coord, size, align_corners, grad_in);
if (padding_mode == GridSamplerPadding::Border) {
// clip coordinates to image borders
coord = clip_coordinates_set_grad(coord, size, &grad_clip);
*grad_in = (*grad_in) * grad_clip;
} else if (padding_mode == GridSamplerPadding::Reflection) {
// reflect coordinates by image borders
if (align_corners) {
coord = reflect_coordinates_set_grad(coord, 0, 2*(size - 1), &grad_refl);
*grad_in = (*grad_in) * grad_refl;
} else {
coord = reflect_coordinates_set_grad(coord, -1, 2*size - 1, &grad_refl);
// when align_corners=False, reflection does not auto clip coords
coord = clip_coordinates_set_grad(coord, size, &grad_clip);
*grad_in = (*grad_in) * grad_refl * grad_clip;
}
}
return coord;
}
static inline bool within_bounds_2d(int64_t h, int64_t w, int64_t H, int64_t W) {
return h >= 0 && h < H && w >= 0 && w < W;
}
static inline bool within_bounds_3d(int64_t d, int64_t h, int64_t w, int64_t D, int64_t H, int64_t W) {
return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W;
}
template<typename scalar_t>
static inline void safe_add_2d(scalar_t *data, int64_t h, int64_t w,
int64_t sH, int64_t sW, int64_t H, int64_t W,
scalar_t delta) {
if (within_bounds_2d(h, w, H, W)) {
data[h * sH + w * sW] += delta;
}
}
template<typename scalar_t>
static inline void safe_add_3d(scalar_t *data, int64_t d, int64_t h, int64_t w,
int64_t sD, int64_t sH, int64_t sW,
int64_t D, int64_t H, int64_t W,
scalar_t delta) {
if (within_bounds_3d(d, h, w, D, H, W)) {
data[d * sD + h * sH + w * sW] += delta;
}
}
} // namespace mmdetection
This diff is collapsed.
// Modified from https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/GridSampler.cuh
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <THC/THCAtomics.cuh>
namespace mmdetection {
namespace detail {
enum class GridSamplerInterpolation {Bilinear, Nearest};
enum class GridSamplerPadding {Zeros, Border, Reflection};
} // namespace detail
using detail::GridSamplerInterpolation;
using detail::GridSamplerPadding;
// Unnormalizes a coordinate from the -1 to +1 scale to its pixel index value,
// where we view each pixel as an area between (idx - 0.5) and (idx + 0.5).
// if align_corners: -1 and +1 get sent to the centers of the corner pixels
// -1 --> 0
// +1 --> (size - 1)
// scale_factor = (size - 1) / 2
// if not align_corners: -1 and +1 get sent to the image edges
// -1 --> -0.5
// +1 --> (size - 1) + 0.5 == size - 0.5
// scale_factor = size / 2
template <typename scalar_t>
static __forceinline__ __device__
scalar_t grid_sampler_unnormalize(scalar_t coord, int size, bool align_corners) {
if (align_corners) {
// unnormalize coord from [-1, 1] to [0, size - 1]
return ((coord + 1.f) / 2) * (size - 1);
} else {
// unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
return ((coord + 1.f) * size - 1) / 2;
}
}
// grid_sampler_unnormalize_set_grad works the same as grid_sampler_unnormalize
// except that it also returns the `d output / d input` via pointer argument
// `grad_in`.
// This is useful in the backward pass of grid_sampler.
template <typename scalar_t>
static __forceinline__ __device__
scalar_t grid_sampler_unnormalize_set_grad(scalar_t coord, int size,
bool align_corners, scalar_t *grad_in) {
if (align_corners) {
// unnormalize coord from [-1, 1] to [0, size - 1]
*grad_in = static_cast<scalar_t>(size - 1) / 2;
return ((coord + 1.f) / 2) * (size - 1);
} else {
// unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
*grad_in = static_cast<scalar_t>(size) / 2;
return ((coord + 1.f) * size - 1) / 2;
}
}
// Clips coordinates to between 0 and clip_limit - 1
template <typename scalar_t>
static __forceinline__ __device__
scalar_t clip_coordinates(scalar_t in, int clip_limit) {
return ::min(static_cast<scalar_t>(clip_limit - 1), ::max(in, static_cast<scalar_t>(0)));
}
// clip_coordinates_set_grad works similarly to clip_coordinates except that
// it also returns the `d output / d input` via pointer argument `grad_in`.
// This is useful in the backward pass of grid_sampler.
template <typename scalar_t>
static __forceinline__ __device__
scalar_t clip_coordinates_set_grad(scalar_t in, int clip_limit, scalar_t *grad_in) {
if (in < static_cast<scalar_t>(0)) {
*grad_in = static_cast<scalar_t>(0);
return static_cast<scalar_t>(0);
} else {
scalar_t max = static_cast<scalar_t>(clip_limit - 1);
if (in > max) {
*grad_in = static_cast<scalar_t>(0);
return max;
} else {
*grad_in = static_cast<scalar_t>(1);
return in;
}
}
}
// Reflects coordinates until they fall between low and high (inclusive).
// The bounds are passed as twice their value so that half-integer values
// can be represented as ints.
template <typename scalar_t>
static __forceinline__ __device__
scalar_t reflect_coordinates(scalar_t in, int twice_low, int twice_high) {
if (twice_low == twice_high) {
return static_cast<scalar_t>(0);
}
scalar_t min = static_cast<scalar_t>(twice_low) / 2;
scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2;
in = ::fabs(in - min);
// `fmod` returns same sign as `in`, which is positive after the `fabs` above.
scalar_t extra = ::fmod(in, span);
int flips = static_cast<int>(::floor(in / span));
if (flips % 2 == 0) {
return extra + min;
} else {
return span - extra + min;
}
}
// reflect_coordinates_set_grad works similarly to reflect_coordinates except
// that it also returns the `d output / d input` via pointer argument
// `grad_in`.
// This is useful in the backward pass of grid_sampler.
template <typename scalar_t>
static __forceinline__ __device__
scalar_t reflect_coordinates_set_grad(scalar_t in, int twice_low, int twice_high,
scalar_t *grad_in) {
if (twice_low == twice_high) {
*grad_in = static_cast<scalar_t>(0);
return static_cast<scalar_t>(0);
}
int grad_in_mult_;
scalar_t min = static_cast<scalar_t>(twice_low) / 2;
scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2;
in = in - min;
if (in < static_cast<scalar_t>(0)) {
grad_in_mult_ = -1;
in = -in;
} else {
grad_in_mult_ = 1;
}
// `fmod` returns same sign as `in`, which is positive after the `if` above.
scalar_t extra = ::fmod(in, span);
int flips = static_cast<int>(::floor(in / span));
if (flips % 2 == 0) {
*grad_in = static_cast<scalar_t>(grad_in_mult_);
return extra + min;
} else {
*grad_in = static_cast<scalar_t>(-grad_in_mult_);
return span - extra + min;
}
}
// Computes the pixel source index value for a grid coordinate
template <typename scalar_t>
static __forceinline__ __device__
scalar_t grid_sampler_compute_source_index(
scalar_t coord,
int size,
GridSamplerPadding padding_mode,
bool align_corners) {
coord = grid_sampler_unnormalize(coord, size, align_corners);
if (padding_mode == GridSamplerPadding::Border) {
// clip coordinates to image borders
coord = clip_coordinates(coord, size);
} else if (padding_mode == GridSamplerPadding::Reflection) {
// reflect coordinates by image borders
if (align_corners) {
coord = reflect_coordinates(coord, 0, 2*(size - 1));
} else {
coord = reflect_coordinates(coord, -1, 2*size - 1);
// when align_corners=False, reflection does not auto clip coords
coord = clip_coordinates(coord, size);
}
}
return coord;
}
// grid_sampler_compute_source_index_set_grad works similarly to
// grid_sampler_compute_source_index except that it also returns the
// `d output / d input` via pointer argument `grad_in`.
// This is useful in the backward pass of grid_sampler.
template <typename scalar_t>
static __forceinline__ __device__
scalar_t grid_sampler_compute_source_index_set_grad(
scalar_t coord,
int size,
GridSamplerPadding padding_mode,
bool align_corners,
scalar_t *grad_in) {
scalar_t grad_clip, grad_refl;
coord = grid_sampler_unnormalize_set_grad(coord, size, align_corners, grad_in);
if (padding_mode == GridSamplerPadding::Border) {
// clip coordinates to image borders
coord = clip_coordinates_set_grad(coord, size, &grad_clip);
*grad_in = (*grad_in) * grad_clip;
} else if (padding_mode == GridSamplerPadding::Reflection) {
// reflect coordinates by image borders
if (align_corners) {
coord = reflect_coordinates_set_grad(coord, 0, 2*(size - 1), &grad_refl);
*grad_in = (*grad_in) * grad_refl;
} else {
coord = reflect_coordinates_set_grad(coord, -1, 2*size - 1, &grad_refl);
// when align_corners=False, reflection does not auto clip coords
coord = clip_coordinates_set_grad(coord, size, &grad_clip);
*grad_in = (*grad_in) * grad_refl * grad_clip;
}
}
return coord;
}
static __forceinline__ __device__
bool within_bounds_2d(int h, int w, int H, int W) {
return h >= 0 && h < H && w >= 0 && w < W;
}
static __forceinline__ __device__
bool within_bounds_3d(int d, int h, int w, int D, int H, int W) {
return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W;
}
template<typename scalar_t>
static __forceinline__ __device__
void safe_add_2d(scalar_t *data, int h, int w,
int sH, int sW, int H, int W,
scalar_t delta) {
if (within_bounds_2d(h, w, H, W)) {
atomicAdd(data + h * sH + w * sW, delta);
}
}
template<typename scalar_t>
static __forceinline__ __device__
void safe_add_3d(scalar_t *data, int d, int h, int w,
int sD, int sH, int sW, int D, int H, int W,
scalar_t delta) {
if (within_bounds_3d(d, h, w, D, H, W)) {
atomicAdd(data + d * sD + h * sH + w * sW, delta);
}
}
} // namespace at::mmdetection
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Config.h>
#include <ATen/cuda/CUDAConfig.h>
#if !AT_CUDNN_ENABLED()
namespace at { namespace native {
// See Note [ATen preprocessor philosophy]
Tensor cudnn_grid_sampler_forward(
const Tensor& input_t, const Tensor& grid_t) {
AT_ERROR("cudnn_grid_sampler_forward: ATen not compiled with cuDNN support");
}
std::tuple<Tensor, Tensor> cudnn_grid_sampler_backward(
const Tensor& input_t, const Tensor& grid_t,
const Tensor& grad_output_t) {
AT_ERROR("cudnn_grid_sampler_backward: ATen not compiled with cuDNN support");
}
}}
#else // AT_CUDNN_ENABLED
#include <ATen/cudnn/Descriptors.h>
#include <ATen/cudnn/Types.h>
#include <ATen/cudnn/Utils.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/TensorUtils.h>
// TODO: descriptor checking
namespace mmdetection {
using namespace at;
namespace {
void setSamplerDescriptor(SpatialTransformerDescriptor& desc, cudnnDataType_t dataType, const at::Tensor& tensor)
{
int inputSize[4] = {0};
for (int i = 0; i < tensor.dim(); ++i) {
inputSize[i] = (int) tensor.size(i);
}
desc.set(dataType, 4, inputSize);
}
void checkGridSize(CheckedFrom c, TensorArg grid, TensorArg input)
{
// assert size of grid is n*h*w*2
// FYI: grid is between [-1, 1], where -1 left most pixel,
// 1 represents right most pixel (and hence 0 is the center pixel)
// if grid has values >1 or <-1, those values are ignored
checkContiguous(c, grid);
checkDim(c, grid, 4);
// TODO: Maybe more user friendly to report where the expected size
// came from
checkSize(c, grid, 0, input->size(0));
checkSize(c, grid, 3, 2);
}
} // namespace
Tensor cudnn_grid_sampler_forward(
const Tensor& input_t, const Tensor& grid_t)
{
TensorArg input{ contiguousIfZeroInStrides(input_t), "input", 1 },
grid{ grid_t.contiguous(), "grid", 2 };
CheckedFrom c = "cudnn_grid_sampler_forward";
checkAllSameGPU(c, {input, grid});
checkAllSameType(c, {input, grid});
checkGridSize(c, grid, input);
checkDim(c, input, 4);
auto output_t = at::empty({0}, input->options());
output_t.resize_({input->size(0), input->size(1), grid->size(1), grid->size(2)});
TensorDescriptor idesc{ *input }; // input descriptor
TensorDescriptor odesc{ output_t }; // output descriptor
SpatialTransformerDescriptor desc; // sampler descriptor
auto handle = getCudnnHandle();
auto dataType = getCudnnDataType(*input);
setSamplerDescriptor(desc, dataType, output_t);
Constant one(dataType, 1);
Constant zero(dataType, 0);
AT_CUDNN_CHECK(cudnnSpatialTfSamplerForward(
handle, desc.desc(),
&one, idesc.desc(), input->data_ptr(),
grid->data_ptr(),
&zero, odesc.desc(), output_t.data_ptr()
));
return output_t;
}
// NB: CuDNN does not support output mask; you always get both
// gradients.
std::tuple<Tensor, Tensor> cudnn_grid_sampler_backward(
const Tensor& input_t, const Tensor& grid_t,
const Tensor& grad_output_t)
{
TensorArg input{ contiguousIfZeroInStrides(input_t), "input", 1 },
grid{ grid_t.contiguous(), "grid", 2 },
grad_output{ contiguousIfZeroInStrides(grad_output_t), "grad_output", 3 };
CheckedFrom c = "cudnn_grid_sampler_backward";
checkAllSameGPU(c, {input, grad_output, grid});
checkGridSize(c, grid, input);
checkDim(c, input, 4);
checkDim(c, grad_output, 4);
auto grad_input_t = at::empty({0}, input->options());
grad_input_t.resize_(input->sizes());
auto grad_grid_t = at::empty({0}, grid->options());
grad_grid_t.resize_(grid->sizes());
TensorDescriptor idesc{ *input }; // input descriptor
TensorDescriptor odesc{ *grad_output }; // grad_output descriptor
TensorDescriptor gdesc{ grad_input_t }; // grad_input descriptor
SpatialTransformerDescriptor desc; // sampler descriptor
auto handle = getCudnnHandle();
auto dataType = getCudnnDataType(*input);
setSamplerDescriptor(desc, dataType, *grad_output);
Constant one(dataType, 1);
Constant zero(dataType, 0);
AT_CUDNN_CHECK(cudnnSpatialTfSamplerBackward(
handle, desc.desc(),
&one, idesc.desc(), input->data_ptr(),
&zero, gdesc.desc(), grad_input_t.data_ptr(),
&one, odesc.desc(), grad_output->data_ptr(),
// intruigingly, the outputs don't need descriptors
grid->data_ptr(),
&zero, grad_grid_t.data_ptr()
));
return std::tuple<Tensor, Tensor>{ grad_input_t, grad_grid_t };
}
} // namespace mmdetection
#endif
#include <torch/extension.h>
#include <ATen/DeviceGuard.h>
namespace mmdetection {
using namespace at;
// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
Tensor grid_sampler_2d_forward_cpu(const Tensor& input, const Tensor& grid,
int64_t interpolation_mode, int64_t padding_mode,
bool align_corners);
// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
Tensor grid_sampler_3d_forward_cpu(const Tensor& input, const Tensor& grid,
int64_t interpolation_mode, int64_t padding_mode,
bool align_corners);
// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
std::tuple<Tensor, Tensor>
grid_sampler_2d_backward_cpu(const Tensor& grad_output, const Tensor& input,
const Tensor& grid, int64_t interpolation_mode,
int64_t padding_mode, bool align_corners);
// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
std::tuple<Tensor, Tensor>
grid_sampler_3d_backward_cpu(const Tensor& grad_output, const Tensor& input,
const Tensor& grid, int64_t interpolation_mode, int64_t padding_mode,
bool align_corners);
// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
Tensor grid_sampler_2d_forward_cuda(const Tensor& input, const Tensor& grid,
int64_t interpolation_mode, int64_t padding_mode,
bool align_corners);
// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
Tensor grid_sampler_3d_forward_cuda(const Tensor& input, const Tensor& grid,
int64_t interpolation_mode, int64_t padding_mode,
bool align_corners);
// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
std::tuple<Tensor, Tensor>
grid_sampler_2d_backward_cuda(const Tensor& grad_output, const Tensor& input,
const Tensor& grid, int64_t interpolation_mode,
int64_t padding_mode, bool align_corners);
// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
std::tuple<Tensor, Tensor>
grid_sampler_3d_backward_cuda(const Tensor& grad_output, const Tensor& input,
const Tensor& grid, int64_t interpolation_mode, int64_t padding_mode,
bool align_corners);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("grid_sampler_2d_forward_cpu", &grid_sampler_2d_forward_cpu, "grid_sampler_2d_forward (CPU)");
m.def("grid_sampler_2d_backward_cpu", &grid_sampler_2d_backward_cpu, "grid_sampler_2d_backward (CPU)");
m.def("grid_sampler_3d_forward_cpu", &grid_sampler_3d_forward_cpu, "grid_sampler_3d_forward (CPU)");
m.def("grid_sampler_3d_backward_cpu", &grid_sampler_3d_backward_cpu, "grid_sampler_3d_backward (CPU)");
m.def("grid_sampler_2d_forward_cuda", &grid_sampler_2d_forward_cuda, "grid_sampler_2d_forward (CUDA)");
m.def("grid_sampler_2d_backward_cuda", &grid_sampler_2d_backward_cuda, "grid_sampler_2d_backward (CUDA)");
m.def("grid_sampler_3d_forward_cuda", &grid_sampler_3d_forward_cuda, "grid_sampler_3d_forward (CUDA)");
m.def("grid_sampler_3d_backward_cuda", &grid_sampler_3d_backward_cuda, "grid_sampler_3d_backward (CUDA)");
}
} // namespace mmdetection
......@@ -270,6 +270,13 @@ if __name__ == '__main__':
sources=[
'src/masked_conv2d_cuda.cpp', 'src/masked_conv2d_kernel.cu'
]),
make_cuda_ext(
name='grid_sampler_cuda',
module='mmdet.ops.grid_sampler',
sources=[
'src/cpu/grid_sampler_cpu.cpp',
'src/cuda/grid_sampler_cuda.cu', 'src/grid_sampler.cpp'
]),
make_cuda_ext(
name='carafe_cuda',
module='mmdet.ops.carafe',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment