diff --git a/mmdet/ops/grid_sampler/__init__.py b/mmdet/ops/grid_sampler/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..868617a6b3f42049d8b78253bf639f07e47ec981
--- /dev/null
+++ b/mmdet/ops/grid_sampler/__init__.py
@@ -0,0 +1,3 @@
+from .grid_sampler import grid_sample
+
+__all__ = ['grid_sample']
diff --git a/mmdet/ops/grid_sampler/grid_sampler.py b/mmdet/ops/grid_sampler/grid_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..d359d8dfe953f5bd14cf6f5bb90cde856a11b30a
--- /dev/null
+++ b/mmdet/ops/grid_sampler/grid_sampler.py
@@ -0,0 +1,117 @@
+import torch
+import torch.nn.functional as F
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+
+from . import grid_sampler_cuda
+
+
+class _GridSampler(Function):
+
+    @staticmethod
+    def forward(ctx, input, grid, mode_enum, padding_mode_enum, align_corners):
+
+        ctx.save_for_backward(input, grid)
+        ctx.mode_enum = mode_enum
+        ctx.padding_mode_enum = padding_mode_enum
+        ctx.align_corners = align_corners
+
+        if input.is_cuda:
+            if input.dim() == 4:
+                func = grid_sampler_cuda.grid_sampler_2d_forward_cuda
+            else:
+                func = grid_sampler_cuda.grid_sampler_3d_forward_cuda
+        else:
+            if input.dim() == 4:
+                func = grid_sampler_cuda.grid_sampler_2d_forward_cpu
+            else:
+                func = grid_sampler_cuda.grid_sampler_3d_forward_cpu
+
+        output = func(input, grid, mode_enum, padding_mode_enum, align_corners)
+
+        return output
+
+    @staticmethod
+    @once_differentiable
+    def backward(ctx, grad_output):
+        input, grid = ctx.saved_tensors
+        mode_enum = ctx.mode_enum
+        padding_mode_enum = ctx.padding_mode_enum
+        align_corners = ctx.align_corners
+
+        if input.is_cuda:
+            if input.dim() == 4:
+                func = grid_sampler_cuda.grid_sampler_2d_backward_cuda
+            else:
+                func = grid_sampler_cuda.grid_sampler_3d_backward_cuda
+        else:
+            if input.dim() == 4:
+                func = grid_sampler_cuda.grid_sampler_2d_backward_cpu
+            else:
+                func = grid_sampler_cuda.grid_sampler_3d_backward_cpu
+
+        grad_input, grad_grid = func(grad_output, input, grid, mode_enum,
+                                     padding_mode_enum, align_corners)
+
+        return grad_input, grad_grid, None, None, None
+
+
+def grid_sample(input,
+                grid,
+                mode='bilinear',
+                padding_mode='zeros',
+                align_corners=False):
+    if torch.__version__ >= '1.3' or align_corners:
+        return F.grid_sample(input, grid, mode, padding_mode, align_corners)
+    else:
+
+        # use self-compiled grid_sampler to support align_corners=False
+
+        assert mode in ['bilinear', 'nearest'], \
+            'expected mode to be bilinear or nearest, but got: {}'.format(mode)
+
+        assert padding_mode in ['zeros', 'border', 'reflection'], \
+            'expected padding_mode to be zeros, border, or reflection, ' \
+            'but got: {}'.format(padding_mode)
+
+        if mode == 'bilinear':
+            mode_enum = 0
+        else:
+            mode_enum = 1
+
+        if padding_mode == 'zeros':
+            padding_mode_enum = 0
+        elif padding_mode == 'border':
+            padding_mode_enum = 1
+        else:
+            padding_mode_enum = 2
+
+        # shape check
+        assert input.device == grid.device, \
+            'expected input and grid to be on same device, ' \
+            'but input is on {} and grid is on {}'.format(
+                input.device, grid.device)
+        assert input.dtype == grid.dtype, \
+            'expected input and grid to have the same dtype, ' \
+            'but input has {} and grid has {}'.format(
+                input.dtype, grid.dtype)
+        assert input.dim() == 4 or input.dim() == 5, \
+            'expected 4D or 5D input and grid with same number of dimensions' \
+            'but got input with sizes {} and grid with sizes {}'.format(
+                input.size(), grid.size())
+        assert input.size(0) == grid.size(0), \
+            'expected input and grid to have the same batch size, ' \
+            'but got input with sizes {} and grid with sizes {}'.format(
+                input.size(), grid.size())
+        assert grid.size(-1) == input.dim() - 2, \
+            'expected grid to have size {} in last {} dimension, ' \
+            'but got grid with sizes '.format(
+                input.dim() - 2, grid.size())
+        for i in range(2, input.dim()):
+            assert input.size(i) > 0, \
+                'expected input to have non-empty spatial dimensions, ' \
+                'but input has sizes {} with dimension {} being empty'.format(
+                    input.sizes(), i)
+
+        return _GridSampler.apply(input, grid, mode_enum, padding_mode_enum,
+                                  align_corners)
diff --git a/mmdet/ops/grid_sampler/src/cpu/grid_sampler_cpu.cpp b/mmdet/ops/grid_sampler/src/cpu/grid_sampler_cpu.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..cf1776ed1d7de573c97dd3254ef08d2352ca4377
--- /dev/null
+++ b/mmdet/ops/grid_sampler/src/cpu/grid_sampler_cpu.cpp
@@ -0,0 +1,692 @@
+// Modified from https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/GridSampler.cpp
+
+#include <torch/extension.h>
+#include "grid_sampler_cpu.h"
+#include <ATen/ATen.h>
+#include <ATen/Device.h>
+#include <ATen/NativeFunctions.h>
+#include <c10/core/Layout.h>
+#include <c10/util/Exception.h>
+
+#ifdef _OPENMP
+#include <omp.h>
+#endif
+
+namespace mmdetection {
+
+using namespace at;
+using mmdetection::detail::GridSamplerInterpolation;
+using mmdetection::detail::GridSamplerPadding;
+
+namespace {
+
+  template<typename scalar_t>
+  Tensor grid_sampler_2d_forward_cpu_impl(const Tensor& input, const Tensor& grid,
+                                          GridSamplerInterpolation interpolation_mode,
+                                          GridSamplerPadding padding_mode,
+                                          bool align_corners) {
+    int64_t N = input.size(0);
+    int64_t C = input.size(1);
+    int64_t inp_H = input.size(2);
+    int64_t inp_W = input.size(3);
+    int64_t out_H = grid.size(1);
+    int64_t out_W = grid.size(2);
+    auto output = at::empty({N, C, out_H, out_W}, input.options());
+    int64_t inp_sN = input.stride(0);
+    int64_t inp_sC = input.stride(1);
+    int64_t inp_sH = input.stride(2);
+    int64_t inp_sW = input.stride(3);
+    int64_t grid_sN = grid.stride(0);
+    int64_t grid_sH = grid.stride(1);
+    int64_t grid_sW = grid.stride(2);
+    int64_t grid_sCoor = grid.stride(3);
+    int64_t out_sN = output.stride(0);
+    int64_t out_sC = output.stride(1);
+    int64_t out_sH = output.stride(2);
+    int64_t out_sW = output.stride(3);
+    scalar_t *inp_ptr = input.data<scalar_t>();
+    scalar_t *out_ptr = output.data<scalar_t>();
+    scalar_t *grid_ptr = grid.data<scalar_t>();
+    // loop over each output pixel
+    #ifdef _OPENMP
+    #pragma omp parallel for
+    #endif
+    for (int64_t n = 0; n < N; ++n) {
+        scalar_t *grid_ptr_N = grid_ptr + n * grid_sN;
+        scalar_t *inp_ptr_N = inp_ptr + n * inp_sN;
+          for (int64_t h = 0; h < out_H; ++h) {
+            for (int64_t w = 0; w < out_W; ++w) {
+              // get the corresponding input x, y, z co-ordinates from grid
+              scalar_t *grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW;
+              scalar_t ix = *grid_ptr_NHW;
+              scalar_t iy = grid_ptr_NHW[grid_sCoor];
+
+              ix = grid_sampler_compute_source_index(ix, inp_W, padding_mode, align_corners);
+              iy = grid_sampler_compute_source_index(iy, inp_H, padding_mode, align_corners);
+
+              if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
+                // get corner pixel values from (x, y, z)
+                // for 4d, we used north-east-south-west
+                // for 5d, we add top-bottom
+                int64_t ix_nw = static_cast<int64_t>(std::floor(ix));
+                int64_t iy_nw = static_cast<int64_t>(std::floor(iy));
+
+                int64_t ix_ne = ix_nw + 1;
+                int64_t iy_ne = iy_nw;
+
+                int64_t ix_sw = ix_nw;
+                int64_t iy_sw = iy_nw + 1;
+
+                int64_t ix_se = ix_nw + 1;
+                int64_t iy_se = iy_nw + 1;
+
+                // get surfaces to each neighbor:
+                scalar_t nw = (ix_se - ix)    * (iy_se - iy)   ;
+                scalar_t ne = (ix    - ix_sw) * (iy_sw - iy)   ;
+                scalar_t sw = (ix_ne - ix)    * (iy    - iy_ne);
+                scalar_t se = (ix    - ix_nw) * (iy    - iy_nw);
+
+                // calculate bilinear weighted pixel value and set output pixel
+                scalar_t *out_ptr_NCHW = out_ptr + n * out_sN + h * out_sH + w * out_sW;
+                scalar_t *inp_ptr_NC = inp_ptr_N;
+                for (int c = 0; c < C; ++c, out_ptr_NCHW += out_sC, inp_ptr_NC += inp_sC) {
+                  //   (c, iz_tnw, iy_tnw, ix_tnw) * tnw + (c, iz_tne, iy_tne, ix_tne) * tne
+                  // + (c, iz_tsw, iy_tsw, ix_tsw) * tsw + (c, iz_tse, iy_tse, ix_tse) * tse
+                  // + (c, iz_bnw, iy_bnw, ix_bnw) * bnw + (c, iz_bne, iy_bne, ix_bne) * bne
+                  // + (c, iz_bsw, iy_bsw, ix_bsw) * bsw + (c, iz_bse, iy_bse, ix_bse) * bse
+                  *out_ptr_NCHW = static_cast<scalar_t>(0);
+                  if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) {
+                    *out_ptr_NCHW += inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW] * nw;
+                  }
+                  if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) {
+                    *out_ptr_NCHW += inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW] * ne;
+                  }
+                  if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) {
+                    *out_ptr_NCHW += inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW] * sw;
+                  }
+                  if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) {
+                    *out_ptr_NCHW += inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW] * se;
+                  }
+                }
+              } else if (interpolation_mode == GridSamplerInterpolation::Nearest) {
+                int64_t ix_nearest = static_cast<int64_t>(std::round(ix));
+                int64_t iy_nearest = static_cast<int64_t>(std::round(iy));
+
+                // assign nearest neighor pixel value to output pixel
+                scalar_t *out_ptr_NCHW = out_ptr + n * out_sN + h * out_sH + w * out_sW;
+                scalar_t *inp_ptr_NC = inp_ptr_N;
+                for (int c = 0; c < C; ++c, out_ptr_NCHW += out_sC, inp_ptr_NC += inp_sC) {
+                  if (within_bounds_2d(iy_nearest, ix_nearest, inp_H, inp_W)) {
+                    *out_ptr_NCHW = inp_ptr_NC[iy_nearest * inp_sH + ix_nearest * inp_sW];
+                  } else {
+                    *out_ptr_NCHW = static_cast<scalar_t>(0);
+                  }
+                }
+              }
+            }
+          }
+      }
+
+    return output;
+  }
+
+  template<typename scalar_t>
+  Tensor grid_sampler_3d_forward_cpu_impl(const Tensor& input, const Tensor& grid,
+                                  GridSamplerInterpolation interpolation_mode,
+                                  GridSamplerPadding padding_mode,
+                                  bool align_corners) {
+    int64_t N = input.size(0);
+    int64_t C = input.size(1);
+    int64_t inp_D = input.size(2);
+    int64_t inp_H = input.size(3);
+    int64_t inp_W = input.size(4);
+    int64_t out_D = grid.size(1);
+    int64_t out_H = grid.size(2);
+    int64_t out_W = grid.size(3);
+    auto output = at::empty({N, C, out_D, out_H, out_W}, input.options());
+    int64_t inp_sN = input.stride(0);
+    int64_t inp_sC = input.stride(1);
+    int64_t inp_sD = input.stride(2);
+    int64_t inp_sH = input.stride(3);
+    int64_t inp_sW = input.stride(4);
+    int64_t grid_sN = grid.stride(0);
+    int64_t grid_sD = grid.stride(1);
+    int64_t grid_sH = grid.stride(2);
+    int64_t grid_sW = grid.stride(3);
+    int64_t grid_sCoor = grid.stride(4);
+    int64_t out_sN = output.stride(0);
+    int64_t out_sC = output.stride(1);
+    int64_t out_sD = output.stride(2);
+    int64_t out_sH = output.stride(3);
+    int64_t out_sW = output.stride(4);
+    scalar_t *inp_ptr = input.data<scalar_t>();
+    scalar_t *out_ptr = output.data<scalar_t>();
+    scalar_t *grid_ptr = grid.data<scalar_t>();
+    // loop over each output pixel
+    #ifdef _OPENMP
+    #pragma omp parallel for
+    #endif
+    for (int64_t n = 0; n < N; ++n) {
+        scalar_t *grid_ptr_N = grid_ptr + n * grid_sN;
+        scalar_t *inp_ptr_N = inp_ptr + n * inp_sN;
+        for (int64_t d = 0; d < out_D; ++d) {
+          for (int64_t h = 0; h < out_H; ++h) {
+            for (int64_t w = 0; w < out_W; ++w) {
+              // get the corresponding input x, y, z co-ordinates from grid
+              scalar_t *grid_ptr_NDHW = grid_ptr_N + d * grid_sD + h * grid_sH + w * grid_sW;
+              scalar_t ix = *grid_ptr_NDHW;
+              scalar_t iy = grid_ptr_NDHW[grid_sCoor];
+              scalar_t iz = grid_ptr_NDHW[2 * grid_sCoor];
+
+              ix = grid_sampler_compute_source_index(ix, inp_W, padding_mode, align_corners);
+              iy = grid_sampler_compute_source_index(iy, inp_H, padding_mode, align_corners);
+              iz = grid_sampler_compute_source_index(iz, inp_D, padding_mode, align_corners);
+
+              if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
+                // get corner pixel values from (x, y, z)
+                // for 4d, we used north-east-south-west
+                // for 5d, we add top-bottom
+                int64_t ix_tnw = static_cast<int64_t>(std::floor(ix));
+                int64_t iy_tnw = static_cast<int64_t>(std::floor(iy));
+                int64_t iz_tnw = static_cast<int64_t>(std::floor(iz));
+
+                int64_t ix_tne = ix_tnw + 1;
+                int64_t iy_tne = iy_tnw;
+                int64_t iz_tne = iz_tnw;
+
+                int64_t ix_tsw = ix_tnw;
+                int64_t iy_tsw = iy_tnw + 1;
+                int64_t iz_tsw = iz_tnw;
+
+                int64_t ix_tse = ix_tnw + 1;
+                int64_t iy_tse = iy_tnw + 1;
+                int64_t iz_tse = iz_tnw;
+
+                int64_t ix_bnw = ix_tnw;
+                int64_t iy_bnw = iy_tnw;
+                int64_t iz_bnw = iz_tnw + 1;
+
+                int64_t ix_bne = ix_tnw + 1;
+                int64_t iy_bne = iy_tnw;
+                int64_t iz_bne = iz_tnw + 1;
+
+                int64_t ix_bsw = ix_tnw;
+                int64_t iy_bsw = iy_tnw + 1;
+                int64_t iz_bsw = iz_tnw + 1;
+
+                int64_t ix_bse = ix_tnw + 1;
+                int64_t iy_bse = iy_tnw + 1;
+                int64_t iz_bse = iz_tnw + 1;
+
+                // get surfaces to each neighbor:
+                scalar_t tnw = (ix_bse - ix)    * (iy_bse - iy)    * (iz_bse - iz);
+                scalar_t tne = (ix    - ix_bsw) * (iy_bsw - iy)    * (iz_bsw - iz);
+                scalar_t tsw = (ix_bne - ix)    * (iy    - iy_bne) * (iz_bne - iz);
+                scalar_t tse = (ix    - ix_bnw) * (iy    - iy_bnw) * (iz_bnw - iz);
+                scalar_t bnw = (ix_tse - ix)    * (iy_tse - iy)    * (iz - iz_tse);
+                scalar_t bne = (ix    - ix_tsw) * (iy_tsw - iy)    * (iz - iz_tsw);
+                scalar_t bsw = (ix_tne - ix)    * (iy    - iy_tne) * (iz - iz_tne);
+                scalar_t bse = (ix    - ix_tnw) * (iy    - iy_tnw) * (iz - iz_tnw);
+
+                // calculate bilinear weighted pixel value and set output pixel
+                scalar_t *out_ptr_NCDHW = out_ptr + n * out_sN + d * out_sD + h * out_sH + w * out_sW;
+                scalar_t *inp_ptr_NC = inp_ptr_N;
+                for (int c = 0; c < C; ++c, out_ptr_NCDHW += out_sC, inp_ptr_NC += inp_sC) {
+                  //   (c, iz_tnw, iy_tnw, ix_tnw) * tnw + (c, iz_tne, iy_tne, ix_tne) * tne
+                  // + (c, iz_tsw, iy_tsw, ix_tsw) * tsw + (c, iz_tse, iy_tse, ix_tse) * tse
+                  // + (c, iz_bnw, iy_bnw, ix_bnw) * bnw + (c, iz_bne, iy_bne, ix_bne) * bne
+                  // + (c, iz_bsw, iy_bsw, ix_bsw) * bsw + (c, iz_bse, iy_bse, ix_bse) * bse
+                  *out_ptr_NCDHW = static_cast<scalar_t>(0);
+                  if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {
+                    *out_ptr_NCDHW += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw;
+                  }
+                  if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {
+                    *out_ptr_NCDHW += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne;
+                  }
+                  if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {
+                    *out_ptr_NCDHW += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw;
+                  }
+                  if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {
+                    *out_ptr_NCDHW += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse;
+                  }
+                  if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {
+                    *out_ptr_NCDHW += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw;
+                  }
+                  if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {
+                    *out_ptr_NCDHW += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne;
+                  }
+                  if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {
+                    *out_ptr_NCDHW += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw;
+                  }
+                  if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {
+                    *out_ptr_NCDHW += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse;
+                  }
+                }
+              } else if (interpolation_mode == GridSamplerInterpolation::Nearest) {
+                int64_t ix_nearest = static_cast<int64_t>(std::round(ix));
+                int64_t iy_nearest = static_cast<int64_t>(std::round(iy));
+                int64_t iz_nearest = static_cast<int64_t>(std::round(iz));
+
+                // assign nearest neighor pixel value to output pixel
+                scalar_t *out_ptr_NCDHW = out_ptr + n * out_sN + d * out_sD + h * out_sH + w * out_sW;
+                scalar_t *inp_ptr_NC = inp_ptr_N;
+                for (int c = 0; c < C; ++c, out_ptr_NCDHW += out_sC, inp_ptr_NC += inp_sC) {
+                  if (within_bounds_3d(iz_nearest, iy_nearest, ix_nearest, inp_D, inp_H, inp_W)) {
+                    *out_ptr_NCDHW = inp_ptr_NC[iz_nearest * inp_sD + iy_nearest * inp_sH + ix_nearest * inp_sW];
+                  } else {
+                    *out_ptr_NCDHW = static_cast<scalar_t>(0);
+                  }
+                }
+              }
+            }
+          }
+        }
+      }
+    return output;
+  }
+
+  template<typename scalar_t>
+  std::tuple<Tensor, Tensor>
+  grid_sampler_2d_backward_cpu_impl(const Tensor& grad_output,
+                                    const Tensor& input, const Tensor& grid,
+                                    GridSamplerInterpolation interpolation_mode,
+                                    GridSamplerPadding padding_mode,
+                                    bool align_corners) {
+    auto grad_input = at::zeros_like(input);
+    auto grad_grid = at::empty_like(grid);
+    // If interpolation mode is Nearest, then grad_grid is not filled in the
+    // loop below.
+    if (interpolation_mode == GridSamplerInterpolation::Nearest) {
+      grad_grid.zero_();
+    }
+    int64_t N = input.size(0);
+    int64_t C = input.size(1);
+    int64_t inp_H = input.size(2);
+    int64_t inp_W = input.size(3);
+    int64_t out_H = grid.size(1);
+    int64_t out_W = grid.size(2);
+    int64_t inp_sN = input.stride(0);
+    int64_t inp_sC = input.stride(1);
+    int64_t inp_sH = input.stride(2);
+    int64_t inp_sW = input.stride(3);
+    int64_t grid_sN = grid.stride(0);
+    int64_t grid_sH = grid.stride(1);
+    int64_t grid_sW = grid.stride(2);
+    int64_t grid_sCoor = grid.stride(3);
+    int64_t gOut_sN = grad_output.stride(0);
+    int64_t gOut_sC = grad_output.stride(1);
+    int64_t gOut_sH = grad_output.stride(2);
+    int64_t gOut_sW = grad_output.stride(3);
+    int64_t gInp_sN = grad_input.stride(0);
+    int64_t gInp_sC = grad_input.stride(1);
+    int64_t gInp_sH = grad_input.stride(2);
+    int64_t gInp_sW = grad_input.stride(3);
+    int64_t gGrid_sN = grad_grid.stride(0);
+    int64_t gGrid_sW = grad_grid.stride(2);
+    scalar_t *inp_ptr = input.data<scalar_t>();
+    scalar_t *grid_ptr = grid.data<scalar_t>();
+    scalar_t *gOut_ptr = grad_output.data<scalar_t>();
+    scalar_t *gInp_ptr = grad_input.data<scalar_t>();
+    scalar_t *gGrid_ptr = grad_grid.data<scalar_t>();
+    // loop over each output pixel
+    #ifdef _OPENMP
+    #pragma omp parallel for
+    #endif
+    for (int64_t n = 0; n < N; ++n) {
+        scalar_t *grid_ptr_N = grid_ptr + n * grid_sN;
+        scalar_t *inp_ptr_N = inp_ptr + n * inp_sN;
+        scalar_t *gGrid_ptr_NHW = gGrid_ptr + n * gGrid_sN;
+          for (int64_t h = 0; h < out_H; ++h) {
+            for (int64_t w = 0; w < out_W; ++w, gGrid_ptr_NHW += gGrid_sW /* grad_grid is contiguous */ ) {
+              // get the corresponding input x, y, z co-ordinates from grid
+              scalar_t *grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW;
+              scalar_t ix = *grid_ptr_NHW;
+              scalar_t iy = grid_ptr_NHW[grid_sCoor];
+
+              // multipliers for gradients on ix, iy, and iz
+              scalar_t gix_mult, giy_mult;
+              ix = grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, align_corners, &gix_mult);
+              iy = grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, align_corners, &giy_mult);
+
+              if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
+                // get corner pixel values from (x, y, z)
+                // for 4d, we used north-east-south-west
+                // for 5d, we add top-bottom
+                int64_t ix_nw = static_cast<int64_t>(std::floor(ix));
+                int64_t iy_nw = static_cast<int64_t>(std::floor(iy));
+
+                int64_t ix_ne = ix_nw + 1;
+                int64_t iy_ne = iy_nw;
+
+                int64_t ix_sw = ix_nw;
+                int64_t iy_sw = iy_nw + 1;
+
+                int64_t ix_se = ix_nw + 1;
+                int64_t iy_se = iy_nw + 1;
+
+                // get surfaces to each neighbor:
+                scalar_t nw = (ix_se - ix)    * (iy_se - iy)   ;
+                scalar_t ne = (ix    - ix_sw) * (iy_sw - iy)   ;
+                scalar_t sw = (ix_ne - ix)    * (iy    - iy_ne);
+                scalar_t se = (ix    - ix_nw) * (iy    - iy_nw);
+
+                scalar_t gix = static_cast<scalar_t>(0), giy = static_cast<scalar_t>(0);
+                scalar_t *gOut_ptr_NCHW = gOut_ptr + n * gOut_sN + h * gOut_sH + w * gOut_sW;
+                scalar_t *gInp_ptr_NC = gInp_ptr + n * gInp_sN;
+                scalar_t *inp_ptr_NC = inp_ptr_N;
+                // calculate bilinear weighted pixel value and set output pixel
+                for (int c = 0; c < C; ++c, gOut_ptr_NCHW += gOut_sC, gInp_ptr_NC += gInp_sC, inp_ptr_NC += inp_sC) {
+                  scalar_t gOut = *gOut_ptr_NCHW;
+
+                  // calculate and set grad_input
+                  safe_add_2d(gInp_ptr_NC, iy_nw, ix_nw, gInp_sH, gInp_sW, inp_H, inp_W, nw * gOut);
+                  safe_add_2d(gInp_ptr_NC, iy_ne, ix_ne, gInp_sH, gInp_sW, inp_H, inp_W, ne * gOut);
+                  safe_add_2d(gInp_ptr_NC, iy_sw, ix_sw, gInp_sH, gInp_sW, inp_H, inp_W, sw * gOut);
+                  safe_add_2d(gInp_ptr_NC, iy_se, ix_se, gInp_sH, gInp_sW, inp_H, inp_W, se * gOut);
+
+                  // calculate grad_grid
+                  if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) {
+                    scalar_t nw_val = inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW];
+                    gix -= nw_val * (iy_se - iy) * gOut;
+                    giy -= nw_val * (ix_se - ix) * gOut;
+                  }
+                  if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) {
+                    scalar_t ne_val = inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW];
+                    gix += ne_val * (iy_sw - iy)    * gOut;
+                    giy -= ne_val * (ix    - ix_sw) * gOut;
+                  }
+                  if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) {
+                    scalar_t sw_val = inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW];
+                    gix -= sw_val * (iy - iy_ne) * gOut;
+                    giy += sw_val * (ix_ne - ix) * gOut;
+                  }
+                  if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) {
+                    scalar_t se_val = inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW];
+                    gix += se_val * (iy - iy_nw)    * gOut;
+                    giy += se_val * (ix    - ix_nw) * gOut;
+                  }
+                }
+
+                // assuming grad_grid is contiguous
+                gGrid_ptr_NHW[0] = gix_mult * gix;
+                gGrid_ptr_NHW[1] = giy_mult * giy;
+              } else if (interpolation_mode == GridSamplerInterpolation::Nearest) {
+                int64_t ix_nearest = static_cast<int64_t>(std::round(ix));
+                int64_t iy_nearest = static_cast<int64_t>(std::round(iy));
+
+                // assign nearest neighor pixel value to output pixel
+                scalar_t *gOut_ptr_NCHW = gOut_ptr + n * gOut_sN + h * gOut_sH + w * gOut_sW;
+                scalar_t *gInp_ptr_NC = gInp_ptr + n * gInp_sN;
+                for (int c = 0; c < C; ++c, gOut_ptr_NCHW += gOut_sC, gInp_ptr_NC += gInp_sC) {
+                  // calculate and set grad_input
+                  safe_add_2d(gInp_ptr_NC, iy_nearest, ix_nearest,
+                              gInp_sH, gInp_sW, inp_H, inp_W, *gOut_ptr_NCHW);
+                }
+              }
+            }
+          }
+      }
+    return std::make_tuple(grad_input, grad_grid);
+  }
+
+  template<typename scalar_t>
+  std::tuple<Tensor, Tensor>
+  grid_sampler_3d_backward_cpu_impl(const Tensor& grad_output,
+                                    const Tensor& input, const Tensor& grid,
+                                    GridSamplerInterpolation interpolation_mode,
+                                    GridSamplerPadding padding_mode,
+                                    bool align_corners) {
+    auto grad_input = at::zeros_like(input);
+    auto grad_grid = at::empty_like(grid);
+    // If interpolation mode is Nearest, then grad_grid is not filled in the
+    // loop below.
+    if (interpolation_mode == GridSamplerInterpolation::Nearest) {
+      grad_grid.zero_();
+    }
+    int64_t N = input.size(0);
+    int64_t C = input.size(1);
+    int64_t inp_D = input.size(2);
+    int64_t inp_H = input.size(3);
+    int64_t inp_W = input.size(4);
+    int64_t out_D = grid.size(1);
+    int64_t out_H = grid.size(2);
+    int64_t out_W = grid.size(3);
+    int64_t inp_sN = input.stride(0);
+    int64_t inp_sC = input.stride(1);
+    int64_t inp_sD = input.stride(2);
+    int64_t inp_sH = input.stride(3);
+    int64_t inp_sW = input.stride(4);
+    int64_t grid_sN = grid.stride(0);
+    int64_t grid_sD = grid.stride(1);
+    int64_t grid_sH = grid.stride(2);
+    int64_t grid_sW = grid.stride(3);
+    int64_t grid_sCoor = grid.stride(4);
+    int64_t gOut_sN = grad_output.stride(0);
+    int64_t gOut_sC = grad_output.stride(1);
+    int64_t gOut_sD = grad_output.stride(2);
+    int64_t gOut_sH = grad_output.stride(3);
+    int64_t gOut_sW = grad_output.stride(4);
+    int64_t gInp_sN = grad_input.stride(0);
+    int64_t gInp_sC = grad_input.stride(1);
+    int64_t gInp_sD = grad_input.stride(2);
+    int64_t gInp_sH = grad_input.stride(3);
+    int64_t gInp_sW = grad_input.stride(4);
+    int64_t gGrid_sN = grad_grid.stride(0);
+    int64_t gGrid_sW = grad_grid.stride(3);
+    scalar_t *inp_ptr = input.data<scalar_t>();
+    scalar_t *grid_ptr = grid.data<scalar_t>();
+    scalar_t *gOut_ptr = grad_output.data<scalar_t>();
+    scalar_t *gInp_ptr = grad_input.data<scalar_t>();
+    scalar_t *gGrid_ptr = grad_grid.data<scalar_t>();
+    // loop over each output pixel
+    #ifdef _OPENMP
+    #pragma omp parallel for
+    #endif
+    for (int64_t n = 0; n < N; ++n) {
+        scalar_t *grid_ptr_N = grid_ptr + n * grid_sN;
+        scalar_t *inp_ptr_N = inp_ptr + n * inp_sN;
+        scalar_t *gGrid_ptr_NDHW = gGrid_ptr + n * gGrid_sN;
+        for (int64_t d = 0; d < out_D; ++d) {
+          for (int64_t h = 0; h < out_H; ++h) {
+            for (int64_t w = 0; w < out_W; ++w, gGrid_ptr_NDHW += gGrid_sW /* grad_grid is contiguous */ ) {
+              // get the corresponding input x, y, z co-ordinates from grid
+              scalar_t *grid_ptr_NDHW = grid_ptr_N + d * grid_sD + h * grid_sH + w * grid_sW;
+              scalar_t ix = *grid_ptr_NDHW;
+              scalar_t iy = grid_ptr_NDHW[grid_sCoor];
+              scalar_t iz = grid_ptr_NDHW[2 * grid_sCoor];
+
+              // multipliers for gradients on ix, iy, and iz
+              scalar_t gix_mult, giy_mult, giz_mult;
+              ix = grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, align_corners, &gix_mult);
+              iy = grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, align_corners, &giy_mult);
+              iz = grid_sampler_compute_source_index_set_grad(iz, inp_D, padding_mode, align_corners, &giz_mult);
+
+              if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
+                // get corner pixel values from (x, y, z)
+                // for 4d, we used north-east-south-west
+                // for 5d, we add top-bottom
+                int64_t ix_tnw = static_cast<int64_t>(std::floor(ix));
+                int64_t iy_tnw = static_cast<int64_t>(std::floor(iy));
+                int64_t iz_tnw = static_cast<int64_t>(std::floor(iz));
+
+                int64_t ix_tne = ix_tnw + 1;
+                int64_t iy_tne = iy_tnw;
+                int64_t iz_tne = iz_tnw;
+
+                int64_t ix_tsw = ix_tnw;
+                int64_t iy_tsw = iy_tnw + 1;
+                int64_t iz_tsw = iz_tnw;
+
+                int64_t ix_tse = ix_tnw + 1;
+                int64_t iy_tse = iy_tnw + 1;
+                int64_t iz_tse = iz_tnw;
+
+                int64_t ix_bnw = ix_tnw;
+                int64_t iy_bnw = iy_tnw;
+                int64_t iz_bnw = iz_tnw + 1;
+
+                int64_t ix_bne = ix_tnw + 1;
+                int64_t iy_bne = iy_tnw;
+                int64_t iz_bne = iz_tnw + 1;
+
+                int64_t ix_bsw = ix_tnw;
+                int64_t iy_bsw = iy_tnw + 1;
+                int64_t iz_bsw = iz_tnw + 1;
+
+                int64_t ix_bse = ix_tnw + 1;
+                int64_t iy_bse = iy_tnw + 1;
+                int64_t iz_bse = iz_tnw + 1;
+
+                // get surfaces to each neighbor:
+                scalar_t tnw = (ix_bse - ix)    * (iy_bse - iy)    * (iz_bse - iz);
+                scalar_t tne = (ix    - ix_bsw) * (iy_bsw - iy)    * (iz_bsw - iz);
+                scalar_t tsw = (ix_bne - ix)    * (iy    - iy_bne) * (iz_bne - iz);
+                scalar_t tse = (ix    - ix_bnw) * (iy    - iy_bnw) * (iz_bnw - iz);
+                scalar_t bnw = (ix_tse - ix)    * (iy_tse - iy)    * (iz - iz_tse);
+                scalar_t bne = (ix    - ix_tsw) * (iy_tsw - iy)    * (iz - iz_tsw);
+                scalar_t bsw = (ix_tne - ix)    * (iy    - iy_tne) * (iz - iz_tne);
+                scalar_t bse = (ix    - ix_tnw) * (iy    - iy_tnw) * (iz - iz_tnw);
+
+                scalar_t gix = static_cast<scalar_t>(0), giy = static_cast<scalar_t>(0), giz = static_cast<scalar_t>(0);
+                scalar_t *gOut_ptr_NCDHW = gOut_ptr + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW;
+                scalar_t *gInp_ptr_NC = gInp_ptr + n * gInp_sN;
+                scalar_t *inp_ptr_NC = inp_ptr_N;
+                // calculate bilinear weighted pixel value and set output pixel
+                for (int c = 0; c < C; ++c, gOut_ptr_NCDHW += gOut_sC, gInp_ptr_NC += gInp_sC, inp_ptr_NC += inp_sC) {
+                  scalar_t gOut = *gOut_ptr_NCDHW;
+
+                  // calculate and set grad_input
+                  safe_add_3d(gInp_ptr_NC, iz_tnw, iy_tnw, ix_tnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tnw * gOut);
+                  safe_add_3d(gInp_ptr_NC, iz_tne, iy_tne, ix_tne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tne * gOut);
+                  safe_add_3d(gInp_ptr_NC, iz_tsw, iy_tsw, ix_tsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tsw * gOut);
+                  safe_add_3d(gInp_ptr_NC, iz_tse, iy_tse, ix_tse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tse * gOut);
+                  safe_add_3d(gInp_ptr_NC, iz_bnw, iy_bnw, ix_bnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bnw * gOut);
+                  safe_add_3d(gInp_ptr_NC, iz_bne, iy_bne, ix_bne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bne * gOut);
+                  safe_add_3d(gInp_ptr_NC, iz_bsw, iy_bsw, ix_bsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bsw * gOut);
+                  safe_add_3d(gInp_ptr_NC, iz_bse, iy_bse, ix_bse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bse * gOut);
+
+                  // calculate grad_grid
+                  if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {
+                    scalar_t tnw_val = inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW];
+                    gix -= tnw_val * (iy_bse - iy)    * (iz_bse - iz)    * gOut;
+                    giy -= tnw_val * (ix_bse - ix)    * (iz_bse - iz)    * gOut;
+                    giz -= tnw_val * (ix_bse - ix)    * (iy_bse - iy)    * gOut;
+                  }
+                  if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {
+                    scalar_t tne_val = inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW];
+                    gix += tne_val * (iy_bsw - iy)    * (iz_bsw - iz)    * gOut;
+                    giy -= tne_val * (ix    - ix_bsw) * (iz_bsw - iz)    * gOut;
+                    giz -= tne_val * (ix    - ix_bsw) * (iy_bsw - iy)    * gOut;
+                  }
+                  if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {
+                    scalar_t tsw_val = inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW];
+                    gix -= tsw_val * (iy - iy_bne)    * (iz_bne - iz)    * gOut;
+                    giy += tsw_val * (ix_bne - ix)    * (iz_bne - iz)    * gOut;
+                    giz -= tsw_val * (ix_bne - ix)    * (iy    - iy_bne) * gOut;
+                  }
+                  if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {
+                    scalar_t tse_val = inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW];
+                    gix += tse_val * (iy - iy_bnw)    * (iz_bnw - iz)    * gOut;
+                    giy += tse_val * (ix    - ix_bnw) * (iz_bnw - iz)    * gOut;
+                    giz -= tse_val * (ix    - ix_bnw) * (iy    - iy_bnw) * gOut;
+                  }
+                  if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {
+                    scalar_t bnw_val = inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW];
+                    gix -= bnw_val * (iy_tse - iy)    * (iz - iz_tse)    * gOut;
+                    giy -= bnw_val * (ix_tse - ix)    * (iz - iz_tse)    * gOut;
+                    giz += bnw_val * (ix_tse - ix)    * (iy_tse - iy)    * gOut;
+                  }
+                  if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {
+                    scalar_t bne_val = inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW];
+                    gix += bne_val * (iy_tsw - iy)    * (iz - iz_tsw)    * gOut;
+                    giy -= bne_val * (ix    - ix_tsw) * (iz - iz_tsw)    * gOut;
+                    giz += bne_val * (ix    - ix_tsw) * (iy_tsw - iy)    * gOut;
+                  }
+                  if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {
+                    scalar_t bsw_val = inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW];
+                    gix -= bsw_val * (iy - iy_tne)    * (iz - iz_tne)    * gOut;
+                    giy += bsw_val * (ix_tne - ix)    * (iz - iz_tne)    * gOut;
+                    giz += bsw_val * (ix_tne - ix)    * (iy    - iy_tne) * gOut;
+                  }
+                  if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {
+                    scalar_t bse_val = inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW];
+                    gix += bse_val * (iy - iy_tnw)    * (iz - iz_tnw)    * gOut;
+                    giy += bse_val * (ix    - ix_tnw) * (iz - iz_tnw)    * gOut;
+                    giz += bse_val * (ix    - ix_tnw) * (iy    - iy_tnw) * gOut;
+                  }
+                }
+
+                // assuming grad_grid is contiguous
+                gGrid_ptr_NDHW[0] = gix_mult * gix;
+                gGrid_ptr_NDHW[1] = giy_mult * giy;
+                gGrid_ptr_NDHW[2] = giz_mult * giz;
+              } else if (interpolation_mode == GridSamplerInterpolation::Nearest) {
+                int64_t ix_nearest = static_cast<int64_t>(std::round(ix));
+                int64_t iy_nearest = static_cast<int64_t>(std::round(iy));
+                int64_t iz_nearest = static_cast<int64_t>(std::round(iz));
+
+                // assign nearest neighor pixel value to output pixel
+                scalar_t *gOut_ptr_NCDHW = gOut_ptr + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW;
+                scalar_t *gInp_ptr_NC = gInp_ptr + n * gInp_sN;
+                for (int c = 0; c < C; ++c, gOut_ptr_NCDHW += gOut_sC, gInp_ptr_NC += gInp_sC) {
+                  // calculate and set grad_input
+                  safe_add_3d(gInp_ptr_NC, iz_nearest, iy_nearest, ix_nearest,
+                              gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, *gOut_ptr_NCDHW);
+                }
+              }
+            }
+          }
+        }
+      }
+    return std::make_tuple(grad_input, grad_grid);
+  }
+
+}  // namespace
+
+// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
+Tensor grid_sampler_2d_forward_cpu(const Tensor& input, const Tensor& grid,
+                                   int64_t interpolation_mode, int64_t padding_mode,
+                                   bool align_corners) {
+    return AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "grid_sampler_2d_forward_cpu", [&] {
+        return grid_sampler_2d_forward_cpu_impl<scalar_t>(
+                input, grid, static_cast<GridSamplerInterpolation>(interpolation_mode),
+                static_cast<GridSamplerPadding>(padding_mode), align_corners);
+    });
+}
+
+// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
+Tensor grid_sampler_3d_forward_cpu(const Tensor& input, const Tensor& grid,
+                           int64_t interpolation_mode, int64_t padding_mode,
+                           bool align_corners) {
+  return AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "grid_sampler_3d_forward_cpu", [&] {
+    return grid_sampler_3d_forward_cpu_impl<scalar_t>(
+      input, grid, static_cast<GridSamplerInterpolation>(interpolation_mode),
+      static_cast<GridSamplerPadding>(padding_mode), align_corners);
+  });
+}
+
+// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
+std::tuple<Tensor, Tensor>
+grid_sampler_2d_backward_cpu(const Tensor& grad_output, const Tensor& input, const Tensor& grid,
+                             int64_t interpolation_mode, int64_t padding_mode, bool align_corners) {
+    return AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "grid_sampler_2d_backward_cpu", [&] {
+        return grid_sampler_2d_backward_cpu_impl<scalar_t>(
+                grad_output, input, grid,
+                static_cast<GridSamplerInterpolation>(interpolation_mode),
+                static_cast<GridSamplerPadding>(padding_mode), align_corners);
+    });
+}
+
+// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
+std::tuple<Tensor, Tensor>
+grid_sampler_3d_backward_cpu(const Tensor& grad_output, const Tensor& input, const Tensor& grid,
+                             int64_t interpolation_mode, int64_t padding_mode, bool align_corners) {
+  return AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "grid_sampler_3d_backward_cpu", [&] {
+    return grid_sampler_3d_backward_cpu_impl<scalar_t>(
+      grad_output, input, grid,
+      static_cast<GridSamplerInterpolation>(interpolation_mode),
+      static_cast<GridSamplerPadding>(padding_mode), align_corners);
+  });
+}
+
+}  // namespace mmdetection
diff --git a/mmdet/ops/grid_sampler/src/cpu/grid_sampler_cpu.h b/mmdet/ops/grid_sampler/src/cpu/grid_sampler_cpu.h
new file mode 100644
index 0000000000000000000000000000000000000000..3c9ae45063bf212d715abcf54c0bdccdb23958fc
--- /dev/null
+++ b/mmdet/ops/grid_sampler/src/cpu/grid_sampler_cpu.h
@@ -0,0 +1,225 @@
+// Modified from https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/GridSampler.h
+
+#pragma once
+
+#include <ATen/ATen.h>
+#include <ATen/NativeFunctions.h>
+
+namespace mmdetection {
+
+namespace detail {
+
+  enum class GridSamplerInterpolation {Bilinear, Nearest};
+  enum class GridSamplerPadding {Zeros, Border, Reflection};
+
+}  // namespace detail
+
+using detail::GridSamplerInterpolation;
+using detail::GridSamplerPadding;
+
+// Unnormalizes a coordinate from the -1 to +1 scale to its pixel index value,
+// where we view each pixel as an area between (idx - 0.5) and (idx + 0.5).
+// if align_corners: -1 and +1 get sent to the centers of the corner pixels
+//     -1 --> 0
+//     +1 --> (size - 1)
+//     scale_factor = (size - 1) / 2
+// if not align_corners: -1 and +1 get sent to the image edges
+//     -1 --> -0.5
+//     +1 --> (size - 1) + 0.5 == size - 0.5
+//     scale_factor = size / 2
+template <typename scalar_t>
+static inline scalar_t grid_sampler_unnormalize(scalar_t coord, int64_t size,
+                                                bool align_corners) {
+  if (align_corners) {
+    // unnormalize coord from [-1, 1] to [0, size - 1]
+    return ((coord + 1) / 2) * (size - 1);
+  } else {
+    // unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
+    return ((coord + 1) * size - 1) / 2;
+  }
+}
+
+// grid_sampler_unnormalize_set_grad works the same as grid_sampler_unnormalize
+// except that it also returns the `d output / d input` via pointer argument
+// `grad_in`.
+// This is useful in the backward pass of grid_sampler.
+template <typename scalar_t>
+static inline scalar_t grid_sampler_unnormalize_set_grad(scalar_t coord, int64_t size,
+                                                         bool align_corners, scalar_t *grad_in) {
+  if (align_corners) {
+    // unnormalize coord from [-1, 1] to [0, size - 1]
+    *grad_in = static_cast<scalar_t>(size - 1) / 2;
+    return ((coord + 1) / 2) * (size - 1);
+  } else {
+    // unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
+    *grad_in = static_cast<scalar_t>(size) / 2;
+    return ((coord + 1) * size - 1) / 2;
+  }
+}
+
+// Clips coordinates to between 0 and clip_limit - 1
+template<typename scalar_t>
+static inline scalar_t clip_coordinates(scalar_t in, int64_t clip_limit) {
+  return std::min(static_cast<scalar_t>(clip_limit - 1), std::max(in, static_cast<scalar_t>(0)));
+}
+
+// clip_coordinates_set_grad works similarly to clip_coordinates except that
+// it also returns the `d output / d input` via pointer argument `grad_in`.
+// This is useful in the backward pass of grid_sampler.
+template<typename scalar_t>
+static inline scalar_t clip_coordinates_set_grad(scalar_t in, int64_t clip_limit,
+                                                 scalar_t *grad_in) {
+  if (in < static_cast<scalar_t>(0)) {
+    *grad_in = static_cast<scalar_t>(0);
+    return static_cast<scalar_t>(0);
+  } else {
+    scalar_t max = static_cast<scalar_t>(clip_limit - 1);
+    if (in > max) {
+      *grad_in = static_cast<scalar_t>(0);
+      return max;
+    } else {
+      *grad_in = static_cast<scalar_t>(1);
+      return in;
+    }
+  }
+}
+
+// Reflects coordinates until they fall between low and high (inclusive).
+// The bounds are passed as twice their value so that half-integer values
+// can be represented as ints.
+template<typename scalar_t>
+static inline scalar_t reflect_coordinates(scalar_t in, int64_t twice_low,
+                                           int64_t twice_high) {
+  if (twice_low == twice_high) {
+    return static_cast<scalar_t>(0);
+  }
+  scalar_t min = static_cast<scalar_t>(twice_low) / 2;
+  scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2;
+  in = std::fabs(in - min);
+  // `fmod` returns same sign as `in`, which is positive after the `fabs` above.
+  scalar_t extra = std::fmod(in, span);
+  int flips = static_cast<int>(std::floor(in / span));
+  if (flips % 2 == 0) {
+    return extra + min;
+  } else {
+    return span - extra + min;
+  }
+}
+
+// reflect_coordinates_set_grad works similarly to reflect_coordinates except
+// that it also returns the `d output / d input` via pointer argument
+// `grad_in`.
+// This is useful in the backward pass of grid_sampler.
+template<typename scalar_t>
+static inline scalar_t reflect_coordinates_set_grad(scalar_t in, int64_t twice_low,
+                                                    int64_t twice_high, scalar_t *grad_in) {
+  if (twice_low == twice_high) {
+    *grad_in = static_cast<scalar_t>(0);
+    return static_cast<scalar_t>(0);
+  }
+  int grad_in_mult_;
+  scalar_t min = static_cast<scalar_t>(twice_low) / 2;
+  scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2;
+  in = in - min;
+  if (in < static_cast<scalar_t>(0)) {
+    grad_in_mult_ = -1;
+    in = -in;
+  } else {
+    grad_in_mult_ = 1;
+  }
+  // `fmod` returns same sign as `in`, which is positive after the `if` above.
+  scalar_t extra = std::fmod(in, span);
+  int flips = static_cast<int>(std::floor(in / span));
+  if (flips % 2 == 0) {
+    *grad_in = static_cast<scalar_t>(grad_in_mult_);
+    return extra + min;
+  } else {
+    *grad_in = static_cast<scalar_t>(-grad_in_mult_);
+    return span - extra + min;
+  }
+}
+
+// Computes the pixel source index value for a grid coordinate
+template <typename scalar_t>
+static inline scalar_t grid_sampler_compute_source_index(
+    scalar_t coord,
+    int64_t size,
+    GridSamplerPadding padding_mode,
+    bool align_corners) {
+  coord = grid_sampler_unnormalize(coord, size, align_corners);
+  if (padding_mode == GridSamplerPadding::Border) {
+    // clip coordinates to image borders
+    coord = clip_coordinates(coord, size);
+  } else if (padding_mode == GridSamplerPadding::Reflection) {
+    // reflect coordinates by image borders
+    if (align_corners) {
+      coord = reflect_coordinates(coord, 0, 2*(size - 1));
+    } else {
+      coord = reflect_coordinates(coord, -1, 2*size - 1);
+      // when align_corners=False, reflection does not auto clip coords
+      coord = clip_coordinates(coord, size);
+    }
+  }
+  return coord;
+}
+
+// grid_sampler_compute_source_index_set_grad works similarly to
+// grid_sampler_compute_source_index except that it also returns the
+// `d output / d input` via pointer argument `grad_in`.
+// This is useful in the backward pass of grid_sampler.
+template <typename scalar_t>
+static inline scalar_t grid_sampler_compute_source_index_set_grad(
+    scalar_t coord,
+    int64_t size,
+    GridSamplerPadding padding_mode,
+    bool align_corners,
+    scalar_t *grad_in) {
+  scalar_t grad_clip, grad_refl;
+  coord = grid_sampler_unnormalize_set_grad(coord, size, align_corners, grad_in);
+  if (padding_mode == GridSamplerPadding::Border) {
+    // clip coordinates to image borders
+    coord = clip_coordinates_set_grad(coord, size, &grad_clip);
+    *grad_in = (*grad_in) * grad_clip;
+  } else if (padding_mode == GridSamplerPadding::Reflection) {
+    // reflect coordinates by image borders
+    if (align_corners) {
+      coord = reflect_coordinates_set_grad(coord, 0, 2*(size - 1), &grad_refl);
+      *grad_in = (*grad_in) * grad_refl;
+    } else {
+      coord = reflect_coordinates_set_grad(coord, -1, 2*size - 1, &grad_refl);
+      // when align_corners=False, reflection does not auto clip coords
+      coord = clip_coordinates_set_grad(coord, size, &grad_clip);
+      *grad_in = (*grad_in) * grad_refl * grad_clip;
+    }
+  }
+  return coord;
+}
+
+static inline bool within_bounds_2d(int64_t h, int64_t w, int64_t H, int64_t W) {
+  return h >= 0 && h < H && w >= 0 && w < W;
+}
+
+static inline bool within_bounds_3d(int64_t d, int64_t h, int64_t w, int64_t D, int64_t H, int64_t W) {
+  return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W;
+}
+
+template<typename scalar_t>
+static inline void safe_add_2d(scalar_t *data, int64_t h, int64_t w,
+                               int64_t sH, int64_t sW, int64_t H, int64_t W,
+                               scalar_t delta) {
+  if (within_bounds_2d(h, w, H, W)) {
+    data[h * sH + w * sW] += delta;
+  }
+}
+
+template<typename scalar_t>
+static inline void safe_add_3d(scalar_t *data, int64_t d, int64_t h, int64_t w,
+                               int64_t sD, int64_t sH, int64_t sW,
+                               int64_t D, int64_t H, int64_t W,
+                               scalar_t delta) {
+  if (within_bounds_3d(d, h, w, D, H, W)) {
+    data[d * sD + h * sH + w * sW] += delta;
+  }
+}
+
+}  // namespace mmdetection
diff --git a/mmdet/ops/grid_sampler/src/cuda/grid_sampler_cuda.cu b/mmdet/ops/grid_sampler/src/cuda/grid_sampler_cuda.cu
new file mode 100644
index 0000000000000000000000000000000000000000..2d747a0b897dda1b0a29998a1825832b6b5eb99c
--- /dev/null
+++ b/mmdet/ops/grid_sampler/src/cuda/grid_sampler_cuda.cu
@@ -0,0 +1,718 @@
+// Modified from https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/GridSampler.cu
+
+#include <ATen/ATen.h>
+#include "grid_sampler_cuda.cuh"
+#include <ATen/cuda/CUDAContext.h>
+#include <ATen/cuda/CUDAApplyUtils.cuh>
+#include <ATen/cuda/detail/TensorInfo.cuh>
+#include <ATen/cuda/detail/IndexUtils.cuh>
+#include <ATen/cuda/detail/KernelUtils.h>
+#include <c10/macros/Macros.h>
+
+namespace mmdetection {
+
+using namespace at::cuda::detail;
+
+using mmdetection::detail::GridSamplerInterpolation;
+using mmdetection::detail::GridSamplerPadding;
+
+namespace {
+  template <typename scalar_t>
+  C10_LAUNCH_BOUNDS_1(1024)
+  __global__ void grid_sampler_2d_forward_kernel_cuda(
+      const int nthreads,
+      TensorInfo<scalar_t, int> input,
+      TensorInfo<scalar_t, int> grid,
+      TensorInfo<scalar_t, int> output,
+      const GridSamplerInterpolation interpolation_mode,
+      const GridSamplerPadding padding_mode,
+      bool align_corners) {
+
+    int C = input.sizes[1];
+    int inp_H = input.sizes[2];
+    int inp_W = input.sizes[3];
+    int out_H = grid.sizes[1];
+    int out_W = grid.sizes[2];
+    int inp_sN = input.strides[0];
+    int inp_sC = input.strides[1];
+    int inp_sH = input.strides[2];
+    int inp_sW = input.strides[3];
+    int grid_sN = grid.strides[0];
+    int grid_sH = grid.strides[1];
+    int grid_sW = grid.strides[2];
+    int grid_sCoor = grid.strides[3];
+    int out_sN = output.strides[0];
+    int out_sC = output.strides[1];
+    int out_sH = output.strides[2];
+    int out_sW = output.strides[3];
+
+    CUDA_KERNEL_LOOP(index, nthreads) {
+      const int w = index % out_W;
+      const int h = (index / out_W) % out_H;
+      const int n = index / (out_H * out_W);
+      const int grid_offset = n * grid_sN + h * grid_sH + w * grid_sW;
+
+      // get the corresponding input x, y co-ordinates from grid
+      scalar_t ix = grid.data[grid_offset];
+      scalar_t iy = grid.data[grid_offset + grid_sCoor];
+
+      ix = grid_sampler_compute_source_index(ix, inp_W, padding_mode, align_corners);
+      iy = grid_sampler_compute_source_index(iy, inp_H, padding_mode, align_corners);
+
+      if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
+        // get NE, NW, SE, SW pixel values from (x, y)
+        int ix_nw = static_cast<int>(::floor(ix));
+        int iy_nw = static_cast<int>(::floor(iy));
+        int ix_ne = ix_nw + 1;
+        int iy_ne = iy_nw;
+        int ix_sw = ix_nw;
+        int iy_sw = iy_nw + 1;
+        int ix_se = ix_nw + 1;
+        int iy_se = iy_nw + 1;
+
+        // get surfaces to each neighbor:
+        scalar_t nw = (ix_se - ix)    * (iy_se - iy);
+        scalar_t ne = (ix    - ix_sw) * (iy_sw - iy);
+        scalar_t sw = (ix_ne - ix)    * (iy    - iy_ne);
+        scalar_t se = (ix    - ix_nw) * (iy    - iy_nw);
+
+        // calculate bilinear weighted pixel value and set output pixel
+        auto inp_ptr_NC = input.data + n * inp_sN;
+        auto out_ptr_NCHW = output.data + n * out_sN + h * out_sH + w * out_sW;
+        for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) {
+          *out_ptr_NCHW = static_cast<scalar_t>(0);
+          if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) {
+            *out_ptr_NCHW += inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW] * nw;
+          }
+          if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) {
+            *out_ptr_NCHW += inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW] * ne;
+          }
+          if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) {
+            *out_ptr_NCHW += inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW] * sw;
+          }
+          if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) {
+            *out_ptr_NCHW += inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW] * se;
+          }
+        }
+      } else if (interpolation_mode == GridSamplerInterpolation::Nearest) {
+        int ix_nearest = static_cast<int>(::round(ix));
+        int iy_nearest = static_cast<int>(::round(iy));
+
+        // assign nearest neighor pixel value to output pixel
+        auto inp_ptr_NC = input.data + n * inp_sN;
+        auto out_ptr_NCHW = output.data + n * out_sN + h * out_sH + w * out_sW;
+        for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) {
+          if (within_bounds_2d(iy_nearest, ix_nearest, inp_H, inp_W)) {
+            *out_ptr_NCHW = inp_ptr_NC[iy_nearest * inp_sH + ix_nearest * inp_sW];
+          } else {
+            *out_ptr_NCHW = static_cast<scalar_t>(0);
+          }
+        }
+      }
+    }
+  }
+
+  template <typename scalar_t>
+  C10_LAUNCH_BOUNDS_1(1024)
+  __global__ void grid_sampler_3d_forward_kernel_cuda(
+      const int nthreads,
+      TensorInfo<scalar_t, int> input,
+      TensorInfo<scalar_t, int> grid,
+      TensorInfo<scalar_t, int> output,
+      const GridSamplerInterpolation interpolation_mode,
+      const GridSamplerPadding padding_mode,
+      bool align_corners) {
+
+    int C = input.sizes[1];
+    int inp_D = input.sizes[2];
+    int inp_H = input.sizes[3];
+    int inp_W = input.sizes[4];
+    int out_D = grid.sizes[1];
+    int out_H = grid.sizes[2];
+    int out_W = grid.sizes[3];
+    int inp_sN = input.strides[0];
+    int inp_sC = input.strides[1];
+    int inp_sD = input.strides[2];
+    int inp_sH = input.strides[3];
+    int inp_sW = input.strides[4];
+    int grid_sN = grid.strides[0];
+    int grid_sD = grid.strides[1];
+    int grid_sH = grid.strides[2];
+    int grid_sW = grid.strides[3];
+    int grid_sCoor = grid.strides[4];
+    int out_sN = output.strides[0];
+    int out_sC = output.strides[1];
+    int out_sD = output.strides[2];
+    int out_sH = output.strides[3];
+    int out_sW = output.strides[4];
+
+    CUDA_KERNEL_LOOP(index, nthreads) {
+      const int w = index % out_W;
+      const int h = (index / out_W) % out_H;
+      const int d = (index / (out_H * out_W)) % out_D;
+      const int n = index / (out_D * out_H * out_W);
+      const int grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW;
+
+      // get the corresponding input x, y, z co-ordinates from grid
+      scalar_t ix = grid.data[grid_offset];
+      scalar_t iy = grid.data[grid_offset + grid_sCoor];
+      scalar_t iz = grid.data[grid_offset + 2 * grid_sCoor];
+
+      ix = grid_sampler_compute_source_index(ix, inp_W, padding_mode, align_corners);
+      iy = grid_sampler_compute_source_index(iy, inp_H, padding_mode, align_corners);
+      iz = grid_sampler_compute_source_index(iz, inp_D, padding_mode, align_corners);
+
+      if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
+        // get corner pixel values from (x, y, z)
+        // for 4d, we used north-east-south-west
+        // for 5d, we add top-bottom
+        int ix_tnw = static_cast<int>(::floor(ix));
+        int iy_tnw = static_cast<int>(::floor(iy));
+        int iz_tnw = static_cast<int>(::floor(iz));
+
+        int ix_tne = ix_tnw + 1;
+        int iy_tne = iy_tnw;
+        int iz_tne = iz_tnw;
+
+        int ix_tsw = ix_tnw;
+        int iy_tsw = iy_tnw + 1;
+        int iz_tsw = iz_tnw;
+
+        int ix_tse = ix_tnw + 1;
+        int iy_tse = iy_tnw + 1;
+        int iz_tse = iz_tnw;
+
+        int ix_bnw = ix_tnw;
+        int iy_bnw = iy_tnw;
+        int iz_bnw = iz_tnw + 1;
+
+        int ix_bne = ix_tnw + 1;
+        int iy_bne = iy_tnw;
+        int iz_bne = iz_tnw + 1;
+
+        int ix_bsw = ix_tnw;
+        int iy_bsw = iy_tnw + 1;
+        int iz_bsw = iz_tnw + 1;
+
+        int ix_bse = ix_tnw + 1;
+        int iy_bse = iy_tnw + 1;
+        int iz_bse = iz_tnw + 1;
+
+        // get surfaces to each neighbor:
+        scalar_t tnw = (ix_bse - ix)    * (iy_bse - iy)    * (iz_bse - iz);
+        scalar_t tne = (ix    - ix_bsw) * (iy_bsw - iy)    * (iz_bsw - iz);
+        scalar_t tsw = (ix_bne - ix)    * (iy    - iy_bne) * (iz_bne - iz);
+        scalar_t tse = (ix    - ix_bnw) * (iy    - iy_bnw) * (iz_bnw - iz);
+        scalar_t bnw = (ix_tse - ix)    * (iy_tse - iy)    * (iz - iz_tse);
+        scalar_t bne = (ix    - ix_tsw) * (iy_tsw - iy)    * (iz - iz_tsw);
+        scalar_t bsw = (ix_tne - ix)    * (iy    - iy_tne) * (iz - iz_tne);
+        scalar_t bse = (ix    - ix_tnw) * (iy    - iy_tnw) * (iz - iz_tnw);
+
+        auto inp_ptr_NC = input.data + n * inp_sN;
+        auto out_ptr_NCDHW = output.data + n * out_sN + d * out_sD + h * out_sH + w * out_sW;
+        for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) {
+          //   (c, iz_tnw, iy_tnw, ix_tnw) * tnw + (c, iz_tne, iy_tne, ix_tne) * tne
+          // + (c, iz_tsw, iy_tsw, ix_tsw) * tsw + (c, iz_tse, iy_tse, ix_tse) * tse
+          // + (c, iz_bnw, iy_bnw, ix_bnw) * bnw + (c, iz_bne, iy_bne, ix_bne) * bne
+          // + (c, iz_bsw, iy_bsw, ix_bsw) * bsw + (c, iz_bse, iy_bse, ix_bse) * bse
+          *out_ptr_NCDHW = static_cast<scalar_t>(0);
+          if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {
+            *out_ptr_NCDHW += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw;
+          }
+          if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {
+            *out_ptr_NCDHW += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne;
+          }
+          if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {
+            *out_ptr_NCDHW += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw;
+          }
+          if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {
+            *out_ptr_NCDHW += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse;
+          }
+          if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {
+            *out_ptr_NCDHW += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw;
+          }
+          if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {
+            *out_ptr_NCDHW += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne;
+          }
+          if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {
+            *out_ptr_NCDHW += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw;
+          }
+          if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {
+            *out_ptr_NCDHW += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse;
+          }
+        }
+      } else if (interpolation_mode == GridSamplerInterpolation::Nearest) {
+        int ix_nearest = static_cast<int>(::round(ix));
+        int iy_nearest = static_cast<int>(::round(iy));
+        int iz_nearest = static_cast<int>(::round(iz));
+
+        // assign nearest neighor pixel value to output pixel
+        auto inp_ptr_NC = input.data + n * inp_sN;
+        auto out_ptr_NCDHW = output.data + n * out_sN + d * out_sD + h * out_sH + w * out_sW;
+        for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) {
+          if (within_bounds_3d(iz_nearest, iy_nearest, ix_nearest, inp_D, inp_H, inp_W)) {
+            *out_ptr_NCDHW = inp_ptr_NC[iz_nearest * inp_sD + iy_nearest * inp_sH + ix_nearest * inp_sW];
+          } else {
+            *out_ptr_NCDHW = static_cast<scalar_t>(0);
+          }
+        }
+      }
+    }
+  }
+
+  template <typename scalar_t>
+  C10_LAUNCH_BOUNDS_1(1024)
+  __global__ void grid_sampler_2d_backward_kernel_cuda(
+      const int nthreads,
+      TensorInfo<scalar_t, int> grad_output,
+      TensorInfo<scalar_t, int> input,
+      TensorInfo<scalar_t, int> grid,
+      TensorInfo<scalar_t, int> grad_input,  // initialized to zeros
+      TensorInfo<scalar_t, int> grad_grid,   // initialized to empty
+      const GridSamplerInterpolation interpolation_mode,
+      const GridSamplerPadding padding_mode,
+      bool align_corners) {
+
+    int C = input.sizes[1];
+    int inp_H = input.sizes[2];
+    int inp_W = input.sizes[3];
+    int out_H = grid.sizes[1];
+    int out_W = grid.sizes[2];
+    int inp_sN = input.strides[0];
+    int inp_sC = input.strides[1];
+    int inp_sH = input.strides[2];
+    int inp_sW = input.strides[3];
+    int grid_sN = grid.strides[0];
+    int grid_sH = grid.strides[1];
+    int grid_sW = grid.strides[2];
+    int grid_sCoor = grid.strides[3];
+    int gOut_sN = grad_output.strides[0];
+    int gOut_sC = grad_output.strides[1];
+    int gOut_sH = grad_output.strides[2];
+    int gOut_sW = grad_output.strides[3];
+    int gInp_sN = grad_input.strides[0];
+    int gInp_sC = grad_input.strides[1];
+    int gInp_sH = grad_input.strides[2];
+    int gInp_sW = grad_input.strides[3];
+    int gGrid_sW = grad_grid.strides[2];
+
+    CUDA_KERNEL_LOOP(index, nthreads) {
+      const int w = index % out_W;
+      const int h = (index / out_W) % out_H;
+      const int n = index / (out_H * out_W);
+      const int grid_offset = n * grid_sN + h * grid_sH + w * grid_sW;
+
+      // get the corresponding input x, y co-ordinates from grid
+      scalar_t ix = grid.data[grid_offset];
+      scalar_t iy = grid.data[grid_offset + grid_sCoor];
+
+      // multipliers for gradients on ix and iy
+      scalar_t gix_mult, giy_mult;
+      ix = grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, align_corners, &gix_mult);
+      iy = grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, align_corners, &giy_mult);
+
+      if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
+        // get NE, NW, SE, SW pixel values from (x, y)
+        int ix_nw = static_cast<int>(::floor(ix));
+        int iy_nw = static_cast<int>(::floor(iy));
+        int ix_ne = ix_nw + 1;
+        int iy_ne = iy_nw;
+        int ix_sw = ix_nw;
+        int iy_sw = iy_nw + 1;
+        int ix_se = ix_nw + 1;
+        int iy_se = iy_nw + 1;
+
+        // get surfaces to each neighbor:
+        scalar_t nw = (ix_se - ix)    * (iy_se - iy);
+        scalar_t ne = (ix    - ix_sw) * (iy_sw - iy);
+        scalar_t sw = (ix_ne - ix)    * (iy    - iy_ne);
+        scalar_t se = (ix    - ix_nw) * (iy    - iy_nw);
+
+        scalar_t gix = static_cast<scalar_t>(0), giy = static_cast<scalar_t>(0);
+        scalar_t *gOut_ptr_NCHW = grad_output.data + n * gOut_sN + h * gOut_sH + w * gOut_sW;
+        scalar_t *gInp_ptr_NC = grad_input.data + n * gInp_sN;
+        scalar_t *inp_ptr_NC = input.data + n * inp_sN;
+        for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, gInp_ptr_NC += gInp_sC, gOut_ptr_NCHW += gOut_sC) {
+          scalar_t gOut = *gOut_ptr_NCHW;
+
+          // calculate and set grad_input
+          safe_add_2d(gInp_ptr_NC, iy_nw, ix_nw, gInp_sH, gInp_sW, inp_H, inp_W, nw * gOut);
+          safe_add_2d(gInp_ptr_NC, iy_ne, ix_ne, gInp_sH, gInp_sW, inp_H, inp_W, ne * gOut);
+          safe_add_2d(gInp_ptr_NC, iy_sw, ix_sw, gInp_sH, gInp_sW, inp_H, inp_W, sw * gOut);
+          safe_add_2d(gInp_ptr_NC, iy_se, ix_se, gInp_sH, gInp_sW, inp_H, inp_W, se * gOut);
+
+          // calculate grad_grid
+          if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) {
+            scalar_t nw_val = inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW];
+            gix -= nw_val * (iy_se - iy) * gOut;
+            giy -= nw_val * (ix_se - ix) * gOut;
+          }
+          if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) {
+            scalar_t ne_val = inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW];
+            gix += ne_val * (iy_sw - iy) * gOut;
+            giy -= ne_val * (ix - ix_sw) * gOut;
+          }
+          if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) {
+            scalar_t sw_val = inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW];
+            gix -= sw_val * (iy - iy_ne) * gOut;
+            giy += sw_val * (ix_ne - ix) * gOut;
+          }
+          if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) {
+            scalar_t se_val = inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW];
+            gix += se_val * (iy - iy_nw) * gOut;
+            giy += se_val * (ix - ix_nw) * gOut;
+          }
+        }
+
+        // assuming grad_grid is contiguous
+        // thus we can
+        //   1. use index with gGrid_sW to directly compute gGrid_ptr_NHW
+        //   2. directly assign to gGrid_ptr_NHW[0], gGrid_ptr_NHW[1]
+        scalar_t *gGrid_ptr_NHW = grad_grid.data + index * gGrid_sW;
+        gGrid_ptr_NHW[0] = gix_mult * gix;
+        gGrid_ptr_NHW[1] = giy_mult * giy;
+      } else if (interpolation_mode == GridSamplerInterpolation::Nearest) {
+        int ix_nearest = static_cast<int>(::round(ix));
+        int iy_nearest = static_cast<int>(::round(iy));
+
+        // assign nearest neighor pixel value to output pixel
+        scalar_t *gOut_ptr_NCHW = grad_output.data + n * gOut_sN + h * gOut_sH + w * gOut_sW;
+        scalar_t *gInp_ptr_NC = grad_input.data + n * gInp_sN;
+        for (int c = 0; c < C; ++c, gInp_ptr_NC += gInp_sC, gOut_ptr_NCHW += gOut_sC) {
+          // calculate and set grad_input
+          safe_add_2d(gInp_ptr_NC, iy_nearest, ix_nearest, gInp_sH, gInp_sW, inp_H, inp_W, *gOut_ptr_NCHW);
+        }
+
+        // assuming grad_grid is contiguous
+        // thus we can
+        //   1. use index with gGrid_sW to directly compute gGrid_ptr_NHW
+        //   2. directly assign to gGrid_ptr_NHW[0], gGrid_ptr_NHW[1]
+        scalar_t *gGrid_ptr_NHW = grad_grid.data + index * gGrid_sW;
+        gGrid_ptr_NHW[0] = static_cast<scalar_t>(0);
+        gGrid_ptr_NHW[1] = static_cast<scalar_t>(0);
+      }
+    }
+  }
+
+  template <typename scalar_t>
+  C10_LAUNCH_BOUNDS_1(1024)
+  __global__ void grid_sampler_3d_backward_kernel_cuda(
+      const int nthreads,
+      TensorInfo<scalar_t, int> grad_output,
+      TensorInfo<scalar_t, int> input,
+      TensorInfo<scalar_t, int> grid,
+      TensorInfo<scalar_t, int> grad_input,  // initialized to zeros
+      TensorInfo<scalar_t, int> grad_grid,   // initialized to empty
+      const GridSamplerInterpolation interpolation_mode,
+      const GridSamplerPadding padding_mode,
+      bool align_corners) {
+
+    int C = input.sizes[1];
+    int inp_D = input.sizes[2];
+    int inp_H = input.sizes[3];
+    int inp_W = input.sizes[4];
+    int out_D = grid.sizes[1];
+    int out_H = grid.sizes[2];
+    int out_W = grid.sizes[3];
+    int inp_sN = input.strides[0];
+    int inp_sC = input.strides[1];
+    int inp_sD = input.strides[2];
+    int inp_sH = input.strides[3];
+    int inp_sW = input.strides[4];
+    int grid_sN = grid.strides[0];
+    int grid_sD = grid.strides[1];
+    int grid_sH = grid.strides[2];
+    int grid_sW = grid.strides[3];
+    int grid_sCoor = grid.strides[4];
+    int gOut_sN = grad_output.strides[0];
+    int gOut_sC = grad_output.strides[1];
+    int gOut_sD = grad_output.strides[2];
+    int gOut_sH = grad_output.strides[3];
+    int gOut_sW = grad_output.strides[4];
+    int gInp_sN = grad_input.strides[0];
+    int gInp_sC = grad_input.strides[1];
+    int gInp_sD = grad_input.strides[2];
+    int gInp_sH = grad_input.strides[3];
+    int gInp_sW = grad_input.strides[4];
+    int gGrid_sW = grad_grid.strides[3];
+
+    CUDA_KERNEL_LOOP(index, nthreads) {
+      const int w = index % out_W;
+      const int h = (index / out_W) % out_H;
+      const int d = (index / (out_H * out_W)) % out_D;
+      const int n = index / (out_D * out_H * out_W);
+      const int grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW;
+
+      // get the corresponding input x, y, z co-ordinates from grid
+      scalar_t ix = grid.data[grid_offset];
+      scalar_t iy = grid.data[grid_offset + grid_sCoor];
+      scalar_t iz = grid.data[grid_offset + 2 * grid_sCoor];
+
+      // multipliers for gradients on ix, iy, and iz
+      scalar_t gix_mult, giy_mult, giz_mult;
+      ix = grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, align_corners, &gix_mult);
+      iy = grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, align_corners, &giy_mult);
+      iz = grid_sampler_compute_source_index_set_grad(iz, inp_D, padding_mode, align_corners, &giz_mult);
+
+      if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
+        // get corner pixel values from (x, y, z)
+        // for 4d, we used north-east-south-west
+        // for 5d, we add top-bottom
+        int ix_tnw = static_cast<int>(::floor(ix));
+        int iy_tnw = static_cast<int>(::floor(iy));
+        int iz_tnw = static_cast<int>(::floor(iz));
+
+        int ix_tne = ix_tnw + 1;
+        int iy_tne = iy_tnw;
+        int iz_tne = iz_tnw;
+
+        int ix_tsw = ix_tnw;
+        int iy_tsw = iy_tnw + 1;
+        int iz_tsw = iz_tnw;
+
+        int ix_tse = ix_tnw + 1;
+        int iy_tse = iy_tnw + 1;
+        int iz_tse = iz_tnw;
+
+        int ix_bnw = ix_tnw;
+        int iy_bnw = iy_tnw;
+        int iz_bnw = iz_tnw + 1;
+
+        int ix_bne = ix_tnw + 1;
+        int iy_bne = iy_tnw;
+        int iz_bne = iz_tnw + 1;
+
+        int ix_bsw = ix_tnw;
+        int iy_bsw = iy_tnw + 1;
+        int iz_bsw = iz_tnw + 1;
+
+        int ix_bse = ix_tnw + 1;
+        int iy_bse = iy_tnw + 1;
+        int iz_bse = iz_tnw + 1;
+
+        // get surfaces to each neighbor:
+        scalar_t tnw = (ix_bse - ix)    * (iy_bse - iy)    * (iz_bse - iz);
+        scalar_t tne = (ix    - ix_bsw) * (iy_bsw - iy)    * (iz_bsw - iz);
+        scalar_t tsw = (ix_bne - ix)    * (iy    - iy_bne) * (iz_bne - iz);
+        scalar_t tse = (ix    - ix_bnw) * (iy    - iy_bnw) * (iz_bnw - iz);
+        scalar_t bnw = (ix_tse - ix)    * (iy_tse - iy)    * (iz - iz_tse);
+        scalar_t bne = (ix    - ix_tsw) * (iy_tsw - iy)    * (iz - iz_tsw);
+        scalar_t bsw = (ix_tne - ix)    * (iy    - iy_tne) * (iz - iz_tne);
+        scalar_t bse = (ix    - ix_tnw) * (iy    - iy_tnw) * (iz - iz_tnw);
+
+        scalar_t gix = static_cast<scalar_t>(0), giy = static_cast<scalar_t>(0), giz = static_cast<scalar_t>(0);
+        scalar_t *gOut_ptr_NCDHW = grad_output.data + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW;
+        scalar_t *gInp_ptr_NC = grad_input.data + n * gInp_sN;
+        scalar_t *inp_ptr_NC = input.data + n * inp_sN;
+        // calculate bilinear weighted pixel value and set output pixel
+        for (int c = 0; c < C; ++c, gOut_ptr_NCDHW += gOut_sC, gInp_ptr_NC += gInp_sC, inp_ptr_NC += inp_sC) {
+          scalar_t gOut = *gOut_ptr_NCDHW;
+
+          // calculate and set grad_input
+          safe_add_3d(gInp_ptr_NC, iz_tnw, iy_tnw, ix_tnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tnw * gOut);
+          safe_add_3d(gInp_ptr_NC, iz_tne, iy_tne, ix_tne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tne * gOut);
+          safe_add_3d(gInp_ptr_NC, iz_tsw, iy_tsw, ix_tsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tsw * gOut);
+          safe_add_3d(gInp_ptr_NC, iz_tse, iy_tse, ix_tse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tse * gOut);
+          safe_add_3d(gInp_ptr_NC, iz_bnw, iy_bnw, ix_bnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bnw * gOut);
+          safe_add_3d(gInp_ptr_NC, iz_bne, iy_bne, ix_bne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bne * gOut);
+          safe_add_3d(gInp_ptr_NC, iz_bsw, iy_bsw, ix_bsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bsw * gOut);
+          safe_add_3d(gInp_ptr_NC, iz_bse, iy_bse, ix_bse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bse * gOut);
+
+          // calculate grad_grid
+          if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {
+            scalar_t tnw_val = inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW];
+            gix -= tnw_val * (iy_bse - iy)    * (iz_bse - iz)    * gOut;
+            giy -= tnw_val * (ix_bse - ix)    * (iz_bse - iz)    * gOut;
+            giz -= tnw_val * (ix_bse - ix)    * (iy_bse - iy)    * gOut;
+          }
+          if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {
+            scalar_t tne_val = inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW];
+            gix += tne_val * (iy_bsw - iy)    * (iz_bsw - iz)    * gOut;
+            giy -= tne_val * (ix    - ix_bsw) * (iz_bsw - iz)    * gOut;
+            giz -= tne_val * (ix    - ix_bsw) * (iy_bsw - iy)    * gOut;
+          }
+          if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {
+            scalar_t tsw_val = inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW];
+            gix -= tsw_val * (iy - iy_bne)    * (iz_bne - iz)    * gOut;
+            giy += tsw_val * (ix_bne - ix)    * (iz_bne - iz)    * gOut;
+            giz -= tsw_val * (ix_bne - ix)    * (iy    - iy_bne) * gOut;
+          }
+          if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {
+            scalar_t tse_val = inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW];
+            gix += tse_val * (iy - iy_bnw)    * (iz_bnw - iz)    * gOut;
+            giy += tse_val * (ix    - ix_bnw) * (iz_bnw - iz)    * gOut;
+            giz -= tse_val * (ix    - ix_bnw) * (iy    - iy_bnw) * gOut;
+          }
+          if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {
+            scalar_t bnw_val = inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW];
+            gix -= bnw_val * (iy_tse - iy)    * (iz - iz_tse)    * gOut;
+            giy -= bnw_val * (ix_tse - ix)    * (iz - iz_tse)    * gOut;
+            giz += bnw_val * (ix_tse - ix)    * (iy_tse - iy)    * gOut;
+          }
+          if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {
+            scalar_t bne_val = inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW];
+            gix += bne_val * (iy_tsw - iy)    * (iz - iz_tsw)    * gOut;
+            giy -= bne_val * (ix    - ix_tsw) * (iz - iz_tsw)    * gOut;
+            giz += bne_val * (ix    - ix_tsw) * (iy_tsw - iy)    * gOut;
+          }
+          if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {
+            scalar_t bsw_val = inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW];
+            gix -= bsw_val * (iy - iy_tne)    * (iz - iz_tne)    * gOut;
+            giy += bsw_val * (ix_tne - ix)    * (iz - iz_tne)    * gOut;
+            giz += bsw_val * (ix_tne - ix)    * (iy    - iy_tne) * gOut;
+          }
+          if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {
+            scalar_t bse_val = inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW];
+            gix += bse_val * (iy - iy_tnw)    * (iz - iz_tnw)    * gOut;
+            giy += bse_val * (ix    - ix_tnw) * (iz - iz_tnw)    * gOut;
+            giz += bse_val * (ix    - ix_tnw) * (iy    - iy_tnw) * gOut;
+          }
+        }
+
+        // assuming grad_grid is contiguous
+        // thus we can
+        //   1. use index with gGrid_sW to directly compute gGrid_ptr_NDHW
+        //   2. directly assign to gGrid_ptr_NDHW[0], gGrid_ptr_NDHW[1], gGrid_ptr_NDHW[2]
+        scalar_t *gGrid_ptr_NDHW = grad_grid.data + index * gGrid_sW;
+        gGrid_ptr_NDHW[0] = gix_mult * gix;
+        gGrid_ptr_NDHW[1] = giy_mult * giy;
+        gGrid_ptr_NDHW[2] = giz_mult * giz;
+      } else if (interpolation_mode == GridSamplerInterpolation::Nearest) {
+        int ix_nearest = static_cast<int>(::round(ix));
+        int iy_nearest = static_cast<int>(::round(iy));
+        int iz_nearest = static_cast<int>(::round(iz));
+
+        // assign nearest neighor pixel value to output pixel
+        scalar_t *gOut_ptr_NCDHW = grad_output.data + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW;
+        scalar_t *gInp_ptr_NC = grad_input.data + n * gInp_sN;
+        for (int c = 0; c < C; ++c, gOut_ptr_NCDHW += gOut_sC, gInp_ptr_NC += gInp_sC) {
+          // calculate and set grad_input
+          safe_add_3d(gInp_ptr_NC, iz_nearest, iy_nearest, ix_nearest,
+                      gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, *gOut_ptr_NCDHW);
+        }
+
+        // assuming grad_grid is contiguous
+        // thus we can
+        //   1. use index with gGrid_sW to directly compute gGrid_ptr_NDHW
+        //   2. directly assign to gGrid_ptr_NDHW[0], gGrid_ptr_NDHW[1], gGrid_ptr_NDHW[2]
+        scalar_t *gGrid_ptr_NDHW = grad_grid.data + index * gGrid_sW;
+        gGrid_ptr_NDHW[0] = static_cast<scalar_t>(0);
+        gGrid_ptr_NDHW[1] = static_cast<scalar_t>(0);
+        gGrid_ptr_NDHW[2] = static_cast<scalar_t>(0);
+      }
+    }
+  }
+}  // namespace
+
+using namespace at;
+// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
+Tensor grid_sampler_2d_forward_cuda(const Tensor& input, const Tensor& grid,
+                            int64_t interpolation_mode, int64_t padding_mode,
+                            bool align_corners) {
+  auto N = input.size(0);
+  auto H = grid.size(1);
+  auto W = grid.size(2);
+  auto output = at::empty({N, input.size(1), H, W}, input.options());
+  int count = static_cast<int>(N * H * W);
+  if (count > 0) {
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_2d_forward_cuda", [&] {
+      grid_sampler_2d_forward_kernel_cuda<scalar_t>
+        <<<GET_BLOCKS(count), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
+          count,
+          getTensorInfo<scalar_t, int>(input),
+          getTensorInfo<scalar_t, int>(grid),
+          getTensorInfo<scalar_t, int>(output),
+          static_cast<GridSamplerInterpolation>(interpolation_mode),
+          static_cast<GridSamplerPadding>(padding_mode),
+          align_corners);
+    });
+  }
+  return output;
+}
+
+// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
+Tensor grid_sampler_3d_forward_cuda(const Tensor& input, const Tensor& grid,
+                            int64_t interpolation_mode, int64_t padding_mode,
+                            bool align_corners) {
+  auto N = input.size(0);
+  auto D = grid.size(1);
+  auto H = grid.size(2);
+  auto W = grid.size(3);
+  auto output = at::empty({N, input.size(1), D, H, W}, input.options());
+  int count = static_cast<int>(N * D * H * W);
+  if (count > 0) {
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_3d_forward_cuda", [&] {
+      grid_sampler_3d_forward_kernel_cuda<scalar_t>
+        <<<GET_BLOCKS(count), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
+          count,
+          getTensorInfo<scalar_t, int>(input),
+          getTensorInfo<scalar_t, int>(grid),
+          getTensorInfo<scalar_t, int>(output),
+          static_cast<GridSamplerInterpolation>(interpolation_mode),
+          static_cast<GridSamplerPadding>(padding_mode),
+          align_corners);
+    });
+  }
+  return output;
+}
+
+// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
+std::tuple<Tensor, Tensor>
+grid_sampler_2d_backward_cuda(const Tensor& grad_output, const Tensor& input,
+                              const Tensor& grid, int64_t interpolation_mode,
+                              int64_t padding_mode, bool align_corners) {
+  auto N = input.size(0);
+  auto H = grid.size(1);
+  auto W = grid.size(2);
+  auto grad_input = at::zeros_like(input);
+  auto grad_grid = at::empty_like(grid);
+  int count = static_cast<int>(N * H * W);
+  if (count > 0) {
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_2d_backward_cuda", [&] {
+      grid_sampler_2d_backward_kernel_cuda<scalar_t>
+        <<<GET_BLOCKS(count), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
+          count,
+          getTensorInfo<scalar_t, int>(grad_output),
+          getTensorInfo<scalar_t, int>(input),
+          getTensorInfo<scalar_t, int>(grid),
+          getTensorInfo<scalar_t, int>(grad_input),
+          getTensorInfo<scalar_t, int>(grad_grid),
+          static_cast<GridSamplerInterpolation>(interpolation_mode),
+          static_cast<GridSamplerPadding>(padding_mode),
+          align_corners);
+    });
+  }
+  return std::make_tuple(grad_input, grad_grid);
+}
+
+// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
+std::tuple<Tensor, Tensor>
+grid_sampler_3d_backward_cuda(const Tensor& grad_output, const Tensor& input,
+                              const Tensor& grid, int64_t interpolation_mode, int64_t padding_mode,
+                              bool align_corners) {
+  auto N = input.size(0);
+  auto D = grid.size(1);
+  auto H = grid.size(2);
+  auto W = grid.size(3);
+  auto grad_input = at::zeros_like(input);
+  auto grad_grid = at::empty_like(grid);
+  int count = static_cast<int>(N * D * H * W);
+  if (count > 0) {
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_3d_backward_cuda", [&] {
+      grid_sampler_3d_backward_kernel_cuda<scalar_t>
+        <<<GET_BLOCKS(count), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
+          count,
+          getTensorInfo<scalar_t, int>(grad_output),
+          getTensorInfo<scalar_t, int>(input),
+          getTensorInfo<scalar_t, int>(grid),
+          getTensorInfo<scalar_t, int>(grad_input),
+          getTensorInfo<scalar_t, int>(grad_grid),
+          static_cast<GridSamplerInterpolation>(interpolation_mode),
+          static_cast<GridSamplerPadding>(padding_mode),
+          align_corners);
+    });
+  }
+  return std::make_tuple(grad_input, grad_grid);
+}
+
+}  // namespace mmdetection
diff --git a/mmdet/ops/grid_sampler/src/cuda/grid_sampler_cuda.cuh b/mmdet/ops/grid_sampler/src/cuda/grid_sampler_cuda.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..a84fa7c076ecd8302aacddf6c350196cc5ce964e
--- /dev/null
+++ b/mmdet/ops/grid_sampler/src/cuda/grid_sampler_cuda.cuh
@@ -0,0 +1,233 @@
+// Modified from https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/GridSampler.cuh
+
+#include <ATen/ATen.h>
+#include <ATen/NativeFunctions.h>
+#include <ATen/cuda/CUDAApplyUtils.cuh>
+#include <THC/THCAtomics.cuh>
+
+namespace mmdetection {
+
+namespace detail {
+
+  enum class GridSamplerInterpolation {Bilinear, Nearest};
+  enum class GridSamplerPadding {Zeros, Border, Reflection};
+
+}  // namespace detail
+
+using detail::GridSamplerInterpolation;
+using detail::GridSamplerPadding;
+
+// Unnormalizes a coordinate from the -1 to +1 scale to its pixel index value,
+// where we view each pixel as an area between (idx - 0.5) and (idx + 0.5).
+// if align_corners: -1 and +1 get sent to the centers of the corner pixels
+//     -1 --> 0
+//     +1 --> (size - 1)
+//     scale_factor = (size - 1) / 2
+// if not align_corners: -1 and +1 get sent to the image edges
+//     -1 --> -0.5
+//     +1 --> (size - 1) + 0.5 == size - 0.5
+//     scale_factor = size / 2
+template <typename scalar_t>
+static __forceinline__ __device__
+scalar_t grid_sampler_unnormalize(scalar_t coord, int size, bool align_corners) {
+  if (align_corners) {
+    // unnormalize coord from [-1, 1] to [0, size - 1]
+    return ((coord + 1.f) / 2) * (size - 1);
+  } else {
+    // unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
+    return ((coord + 1.f) * size - 1) / 2;
+  }
+}
+
+// grid_sampler_unnormalize_set_grad works the same as grid_sampler_unnormalize
+// except that it also returns the `d output / d input` via pointer argument
+// `grad_in`.
+// This is useful in the backward pass of grid_sampler.
+template <typename scalar_t>
+static __forceinline__ __device__
+scalar_t grid_sampler_unnormalize_set_grad(scalar_t coord, int size,
+                                           bool align_corners, scalar_t *grad_in) {
+  if (align_corners) {
+    // unnormalize coord from [-1, 1] to [0, size - 1]
+    *grad_in = static_cast<scalar_t>(size - 1) / 2;
+    return ((coord + 1.f) / 2) * (size - 1);
+  } else {
+    // unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
+    *grad_in = static_cast<scalar_t>(size) / 2;
+    return ((coord + 1.f) * size - 1) / 2;
+  }
+}
+
+// Clips coordinates to between 0 and clip_limit - 1
+template <typename scalar_t>
+static __forceinline__ __device__
+scalar_t clip_coordinates(scalar_t in, int clip_limit) {
+  return ::min(static_cast<scalar_t>(clip_limit - 1), ::max(in, static_cast<scalar_t>(0)));
+}
+
+// clip_coordinates_set_grad works similarly to clip_coordinates except that
+// it also returns the `d output / d input` via pointer argument `grad_in`.
+// This is useful in the backward pass of grid_sampler.
+template <typename scalar_t>
+static __forceinline__ __device__
+scalar_t clip_coordinates_set_grad(scalar_t in, int clip_limit, scalar_t *grad_in) {
+  if (in < static_cast<scalar_t>(0)) {
+    *grad_in = static_cast<scalar_t>(0);
+    return static_cast<scalar_t>(0);
+  } else {
+    scalar_t max = static_cast<scalar_t>(clip_limit - 1);
+    if (in > max) {
+      *grad_in = static_cast<scalar_t>(0);
+      return max;
+    } else {
+      *grad_in = static_cast<scalar_t>(1);
+      return in;
+    }
+  }
+}
+
+// Reflects coordinates until they fall between low and high (inclusive).
+// The bounds are passed as twice their value so that half-integer values
+// can be represented as ints.
+template <typename scalar_t>
+static __forceinline__ __device__
+scalar_t reflect_coordinates(scalar_t in, int twice_low, int twice_high) {
+  if (twice_low == twice_high) {
+    return static_cast<scalar_t>(0);
+  }
+  scalar_t min = static_cast<scalar_t>(twice_low) / 2;
+  scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2;
+  in = ::fabs(in - min);
+  // `fmod` returns same sign as `in`, which is positive after the `fabs` above.
+  scalar_t extra = ::fmod(in, span);
+  int flips = static_cast<int>(::floor(in / span));
+  if (flips % 2 == 0) {
+    return extra + min;
+  } else {
+    return span - extra + min;
+  }
+}
+
+// reflect_coordinates_set_grad works similarly to reflect_coordinates except
+// that it also returns the `d output / d input` via pointer argument
+// `grad_in`.
+// This is useful in the backward pass of grid_sampler.
+template <typename scalar_t>
+static __forceinline__ __device__
+scalar_t reflect_coordinates_set_grad(scalar_t in, int twice_low, int twice_high,
+                                      scalar_t *grad_in) {
+  if (twice_low == twice_high) {
+    *grad_in = static_cast<scalar_t>(0);
+    return static_cast<scalar_t>(0);
+  }
+  int grad_in_mult_;
+  scalar_t min = static_cast<scalar_t>(twice_low) / 2;
+  scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2;
+  in = in - min;
+  if (in < static_cast<scalar_t>(0)) {
+    grad_in_mult_ = -1;
+    in = -in;
+  } else {
+    grad_in_mult_ = 1;
+  }
+  // `fmod` returns same sign as `in`, which is positive after the `if` above.
+  scalar_t extra = ::fmod(in, span);
+  int flips = static_cast<int>(::floor(in / span));
+  if (flips % 2 == 0) {
+    *grad_in = static_cast<scalar_t>(grad_in_mult_);
+    return extra + min;
+  } else {
+    *grad_in = static_cast<scalar_t>(-grad_in_mult_);
+    return span - extra + min;
+  }
+}
+
+// Computes the pixel source index value for a grid coordinate
+template <typename scalar_t>
+static __forceinline__ __device__
+scalar_t grid_sampler_compute_source_index(
+    scalar_t coord,
+    int size,
+    GridSamplerPadding padding_mode,
+    bool align_corners) {
+  coord = grid_sampler_unnormalize(coord, size, align_corners);
+  if (padding_mode == GridSamplerPadding::Border) {
+    // clip coordinates to image borders
+    coord = clip_coordinates(coord, size);
+  } else if (padding_mode == GridSamplerPadding::Reflection) {
+    // reflect coordinates by image borders
+    if (align_corners) {
+      coord = reflect_coordinates(coord, 0, 2*(size - 1));
+    } else {
+      coord = reflect_coordinates(coord, -1, 2*size - 1);
+      // when align_corners=False, reflection does not auto clip coords
+      coord = clip_coordinates(coord, size);
+    }
+  }
+  return coord;
+}
+
+// grid_sampler_compute_source_index_set_grad works similarly to
+// grid_sampler_compute_source_index except that it also returns the
+// `d output / d input` via pointer argument `grad_in`.
+// This is useful in the backward pass of grid_sampler.
+template <typename scalar_t>
+static __forceinline__ __device__
+scalar_t grid_sampler_compute_source_index_set_grad(
+    scalar_t coord,
+    int size,
+    GridSamplerPadding padding_mode,
+    bool align_corners,
+    scalar_t *grad_in) {
+  scalar_t grad_clip, grad_refl;
+  coord = grid_sampler_unnormalize_set_grad(coord, size, align_corners, grad_in);
+  if (padding_mode == GridSamplerPadding::Border) {
+    // clip coordinates to image borders
+    coord = clip_coordinates_set_grad(coord, size, &grad_clip);
+    *grad_in = (*grad_in) * grad_clip;
+  } else if (padding_mode == GridSamplerPadding::Reflection) {
+    // reflect coordinates by image borders
+    if (align_corners) {
+      coord = reflect_coordinates_set_grad(coord, 0, 2*(size - 1), &grad_refl);
+      *grad_in = (*grad_in) * grad_refl;
+    } else {
+      coord = reflect_coordinates_set_grad(coord, -1, 2*size - 1, &grad_refl);
+      // when align_corners=False, reflection does not auto clip coords
+      coord = clip_coordinates_set_grad(coord, size, &grad_clip);
+      *grad_in = (*grad_in) * grad_refl * grad_clip;
+    }
+  }
+  return coord;
+}
+
+static __forceinline__ __device__
+bool within_bounds_2d(int h, int w, int H, int W) {
+  return h >= 0 && h < H && w >= 0 && w < W;
+}
+
+static __forceinline__ __device__
+bool within_bounds_3d(int d, int h, int w, int D, int H, int W) {
+  return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W;
+}
+
+template<typename scalar_t>
+static __forceinline__ __device__
+void safe_add_2d(scalar_t *data, int h, int w,
+                 int sH, int sW, int H, int W,
+                 scalar_t delta) {
+  if (within_bounds_2d(h, w, H, W)) {
+    atomicAdd(data + h * sH + w * sW, delta);
+  }
+}
+
+template<typename scalar_t>
+static __forceinline__ __device__
+void safe_add_3d(scalar_t *data, int d, int h, int w,
+                 int sD, int sH, int sW, int D, int H, int W,
+                 scalar_t delta) {
+  if (within_bounds_3d(d, h, w, D, H, W)) {
+    atomicAdd(data + d * sD + h * sH + w * sW, delta);
+  }
+}
+
+}  // namespace at::mmdetection
diff --git a/mmdet/ops/grid_sampler/src/cudnn/grid_sampler_cudnn.cpp b/mmdet/ops/grid_sampler/src/cudnn/grid_sampler_cudnn.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..f684abf12914ec3bdd02651013a3ccb8c3d32a48
--- /dev/null
+++ b/mmdet/ops/grid_sampler/src/cudnn/grid_sampler_cudnn.cpp
@@ -0,0 +1,148 @@
+#include <ATen/ATen.h>
+#include <ATen/NativeFunctions.h>
+#include <ATen/Config.h>
+#include <ATen/cuda/CUDAConfig.h>
+
+#if !AT_CUDNN_ENABLED()
+
+namespace at { namespace native {
+
+// See Note [ATen preprocessor philosophy]
+
+Tensor cudnn_grid_sampler_forward(
+    const Tensor& input_t, const Tensor& grid_t) {
+  AT_ERROR("cudnn_grid_sampler_forward: ATen not compiled with cuDNN support");
+}
+
+std::tuple<Tensor, Tensor> cudnn_grid_sampler_backward(
+    const Tensor& input_t, const Tensor& grid_t,
+    const Tensor& grad_output_t) {
+  AT_ERROR("cudnn_grid_sampler_backward: ATen not compiled with cuDNN support");
+}
+
+}}
+
+#else // AT_CUDNN_ENABLED
+
+#include <ATen/cudnn/Descriptors.h>
+#include <ATen/cudnn/Types.h>
+#include <ATen/cudnn/Utils.h>
+#include <ATen/cuda/Exceptions.h>
+
+#include <ATen/TensorUtils.h>
+
+// TODO: descriptor checking
+
+
+namespace mmdetection {
+
+using namespace at;
+
+namespace {
+
+void setSamplerDescriptor(SpatialTransformerDescriptor& desc, cudnnDataType_t dataType, const at::Tensor& tensor)
+{
+  int inputSize[4] = {0};
+  for (int i = 0; i < tensor.dim(); ++i) {
+    inputSize[i] = (int) tensor.size(i);
+  }
+  desc.set(dataType, 4, inputSize);
+}
+
+void checkGridSize(CheckedFrom c, TensorArg grid, TensorArg input)
+{
+  // assert size of grid is n*h*w*2
+  // FYI: grid is between [-1, 1], where -1 left most pixel,
+  // 1 represents right most pixel (and hence 0 is the center pixel)
+  // if grid has values >1 or <-1, those values are ignored
+  checkContiguous(c, grid);
+  checkDim(c, grid, 4);
+  // TODO: Maybe more user friendly to report where the expected size
+  // came from
+  checkSize(c, grid, 0, input->size(0));
+  checkSize(c, grid, 3, 2);
+}
+
+}  // namespace
+
+Tensor cudnn_grid_sampler_forward(
+    const Tensor& input_t, const Tensor& grid_t)
+{
+  TensorArg input{ contiguousIfZeroInStrides(input_t), "input", 1 },
+            grid{ grid_t.contiguous(), "grid", 2 };
+  CheckedFrom c = "cudnn_grid_sampler_forward";
+  checkAllSameGPU(c, {input, grid});
+  checkAllSameType(c, {input, grid});
+  checkGridSize(c, grid, input);
+  checkDim(c, input, 4);
+
+  auto output_t = at::empty({0}, input->options());
+  output_t.resize_({input->size(0), input->size(1), grid->size(1), grid->size(2)});
+
+  TensorDescriptor idesc{ *input };  // input descriptor
+  TensorDescriptor odesc{ output_t };  // output descriptor
+  SpatialTransformerDescriptor desc; // sampler descriptor
+
+  auto handle = getCudnnHandle();
+  auto dataType = getCudnnDataType(*input);
+  setSamplerDescriptor(desc, dataType, output_t);
+
+  Constant one(dataType, 1);
+  Constant zero(dataType, 0);
+  AT_CUDNN_CHECK(cudnnSpatialTfSamplerForward(
+      handle, desc.desc(),
+      &one, idesc.desc(), input->data_ptr(),
+      grid->data_ptr(),
+      &zero, odesc.desc(), output_t.data_ptr()
+  ));
+
+  return output_t;
+}
+
+// NB: CuDNN does not support output mask; you always get both
+// gradients.
+std::tuple<Tensor, Tensor> cudnn_grid_sampler_backward(
+    const Tensor& input_t, const Tensor& grid_t,
+    const Tensor& grad_output_t)
+{
+  TensorArg input{ contiguousIfZeroInStrides(input_t), "input", 1 },
+            grid{ grid_t.contiguous(), "grid", 2 },
+            grad_output{ contiguousIfZeroInStrides(grad_output_t), "grad_output", 3 };
+  CheckedFrom c = "cudnn_grid_sampler_backward";
+  checkAllSameGPU(c, {input, grad_output, grid});
+  checkGridSize(c, grid, input);
+  checkDim(c, input, 4);
+  checkDim(c, grad_output, 4);
+
+  auto grad_input_t = at::empty({0}, input->options());
+  grad_input_t.resize_(input->sizes());
+  auto grad_grid_t = at::empty({0}, grid->options());
+  grad_grid_t.resize_(grid->sizes());
+
+  TensorDescriptor idesc{ *input };  // input descriptor
+  TensorDescriptor odesc{ *grad_output };  // grad_output descriptor
+  TensorDescriptor gdesc{ grad_input_t };  // grad_input descriptor
+  SpatialTransformerDescriptor desc; // sampler descriptor
+
+  auto handle = getCudnnHandle();
+  auto dataType = getCudnnDataType(*input);
+  setSamplerDescriptor(desc, dataType, *grad_output);
+
+  Constant one(dataType, 1);
+  Constant zero(dataType, 0);
+  AT_CUDNN_CHECK(cudnnSpatialTfSamplerBackward(
+    handle, desc.desc(),
+    &one, idesc.desc(), input->data_ptr(),
+    &zero, gdesc.desc(), grad_input_t.data_ptr(),
+    &one, odesc.desc(), grad_output->data_ptr(),
+    // intruigingly, the outputs don't need descriptors
+    grid->data_ptr(),
+    &zero, grad_grid_t.data_ptr()
+  ));
+
+  return std::tuple<Tensor, Tensor>{ grad_input_t, grad_grid_t };
+}
+
+}  // namespace mmdetection
+
+#endif
diff --git a/mmdet/ops/grid_sampler/src/grid_sampler.cpp b/mmdet/ops/grid_sampler/src/grid_sampler.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..009675fe7ea2ec3d248f528bbe9fe3166785b411
--- /dev/null
+++ b/mmdet/ops/grid_sampler/src/grid_sampler.cpp
@@ -0,0 +1,66 @@
+#include <torch/extension.h>
+#include <ATen/DeviceGuard.h>
+
+namespace mmdetection {
+
+using namespace at;
+
+// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
+Tensor grid_sampler_2d_forward_cpu(const Tensor& input, const Tensor& grid,
+                                    int64_t interpolation_mode, int64_t padding_mode,
+                                    bool align_corners);
+
+// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
+Tensor grid_sampler_3d_forward_cpu(const Tensor& input, const Tensor& grid,
+                                    int64_t interpolation_mode, int64_t padding_mode,
+                                    bool align_corners);
+
+// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
+std::tuple<Tensor, Tensor>
+grid_sampler_2d_backward_cpu(const Tensor& grad_output, const Tensor& input,
+                              const Tensor& grid, int64_t interpolation_mode,
+                              int64_t padding_mode, bool align_corners);
+
+// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
+std::tuple<Tensor, Tensor>
+grid_sampler_3d_backward_cpu(const Tensor& grad_output, const Tensor& input,
+                              const Tensor& grid, int64_t interpolation_mode, int64_t padding_mode,
+                              bool align_corners);
+
+// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
+Tensor grid_sampler_2d_forward_cuda(const Tensor& input, const Tensor& grid,
+                            int64_t interpolation_mode, int64_t padding_mode,
+                            bool align_corners);
+
+// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
+Tensor grid_sampler_3d_forward_cuda(const Tensor& input, const Tensor& grid,
+                            int64_t interpolation_mode, int64_t padding_mode,
+                            bool align_corners);
+
+// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
+std::tuple<Tensor, Tensor>
+grid_sampler_2d_backward_cuda(const Tensor& grad_output, const Tensor& input,
+                              const Tensor& grid, int64_t interpolation_mode,
+                              int64_t padding_mode, bool align_corners);
+
+// No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
+std::tuple<Tensor, Tensor>
+grid_sampler_3d_backward_cuda(const Tensor& grad_output, const Tensor& input,
+                              const Tensor& grid, int64_t interpolation_mode, int64_t padding_mode,
+                              bool align_corners);
+
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+
+  m.def("grid_sampler_2d_forward_cpu", &grid_sampler_2d_forward_cpu, "grid_sampler_2d_forward (CPU)");
+  m.def("grid_sampler_2d_backward_cpu", &grid_sampler_2d_backward_cpu, "grid_sampler_2d_backward (CPU)");
+  m.def("grid_sampler_3d_forward_cpu", &grid_sampler_3d_forward_cpu, "grid_sampler_3d_forward (CPU)");
+  m.def("grid_sampler_3d_backward_cpu", &grid_sampler_3d_backward_cpu, "grid_sampler_3d_backward (CPU)");
+
+  m.def("grid_sampler_2d_forward_cuda", &grid_sampler_2d_forward_cuda, "grid_sampler_2d_forward (CUDA)");
+  m.def("grid_sampler_2d_backward_cuda", &grid_sampler_2d_backward_cuda, "grid_sampler_2d_backward (CUDA)");
+  m.def("grid_sampler_3d_forward_cuda", &grid_sampler_3d_forward_cuda, "grid_sampler_3d_forward (CUDA)");
+  m.def("grid_sampler_3d_backward_cuda", &grid_sampler_3d_backward_cuda, "grid_sampler_3d_backward (CUDA)");
+}
+
+}  // namespace mmdetection
diff --git a/setup.py b/setup.py
index 89e79f95a544983a74e76d70134f2679ebafff00..1b3445c75230e35f642a37d80875fceeadffc0c8 100755
--- a/setup.py
+++ b/setup.py
@@ -270,6 +270,13 @@ if __name__ == '__main__':
                 sources=[
                     'src/masked_conv2d_cuda.cpp', 'src/masked_conv2d_kernel.cu'
                 ]),
+            make_cuda_ext(
+                name='grid_sampler_cuda',
+                module='mmdet.ops.grid_sampler',
+                sources=[
+                    'src/cpu/grid_sampler_cpu.cpp',
+                    'src/cuda/grid_sampler_cuda.cu', 'src/grid_sampler.cpp'
+                ]),
             make_cuda_ext(
                 name='carafe_cuda',
                 module='mmdet.ops.carafe',