From 51df8a9b7ad5f25ebd75cf8e0969c3b728bde08d Mon Sep 17 00:00:00 2001
From: Jerry Jiarui XU <xvjiarui0826@gmail.com>
Date: Mon, 2 Mar 2020 10:42:11 +0800
Subject: [PATCH] add affine_grid (#2180)

* add affine_grid

* missing setup

* remove import

* reformat

* rename and reformat

* reformat cpp
---
 mmdet/ops/affine_grid/__init__.py             |   3 +
 mmdet/ops/affine_grid/affine_grid.py          |  68 +++++++++++
 .../ops/affine_grid/src/affine_grid_cuda.cpp  | 115 ++++++++++++++++++
 mmdet/ops/grid_sampler/grid_sampler.py        |   4 +-
 setup.py                                      |   4 +
 5 files changed, 193 insertions(+), 1 deletion(-)
 create mode 100644 mmdet/ops/affine_grid/__init__.py
 create mode 100644 mmdet/ops/affine_grid/affine_grid.py
 create mode 100644 mmdet/ops/affine_grid/src/affine_grid_cuda.cpp

diff --git a/mmdet/ops/affine_grid/__init__.py b/mmdet/ops/affine_grid/__init__.py
new file mode 100644
index 00000000..8530ade3
--- /dev/null
+++ b/mmdet/ops/affine_grid/__init__.py
@@ -0,0 +1,3 @@
+from .affine_grid import affine_grid
+
+__all__ = ['affine_grid']
diff --git a/mmdet/ops/affine_grid/affine_grid.py b/mmdet/ops/affine_grid/affine_grid.py
new file mode 100644
index 00000000..94bacb5e
--- /dev/null
+++ b/mmdet/ops/affine_grid/affine_grid.py
@@ -0,0 +1,68 @@
+import torch
+import torch.nn.functional as F
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+
+from . import affine_grid_cuda
+
+
+class _AffineGridGenerator(Function):
+
+    @staticmethod
+    def forward(ctx, theta, size, align_corners):
+
+        ctx.save_for_backward(theta)
+        ctx.size = size
+        ctx.align_corners = align_corners
+
+        func = affine_grid_cuda.affine_grid_generator_forward
+
+        output = func(theta, size, align_corners)
+
+        return output
+
+    @staticmethod
+    @once_differentiable
+    def backward(ctx, grad_output):
+        theta = ctx.saved_tensors
+        size = ctx.size
+        align_corners = ctx.align_corners
+
+        func = affine_grid_cuda.affine_grid_generator_backward
+
+        grad_input = func(grad_output, theta, size, align_corners)
+
+        return grad_input, None, None
+
+
+def affine_grid(theta, size, align_corners=False):
+    if torch.__version__ >= '1.3':
+        return F.affine_grid(theta, size, align_corners)
+    elif align_corners:
+        return F.affine_grid(theta, size)
+    else:
+        # enforce floating point dtype on theta
+        if not theta.is_floating_point():
+            raise ValueError(
+                'Expected theta to have floating point type, but got {}'.
+                format(theta.dtype))
+        # check that shapes and sizes match
+        if len(size) == 4:
+            if theta.dim() != 3 or theta.size(-2) != 2 or theta.size(-1) != 3:
+                raise ValueError(
+                    'Expected a batch of 2D affine matrices of shape Nx2x3 '
+                    'for size {}. Got {}.'.format(size, theta.shape))
+        elif len(size) == 5:
+            if theta.dim() != 3 or theta.size(-2) != 3 or theta.size(-1) != 4:
+                raise ValueError(
+                    'Expected a batch of 3D affine matrices of shape Nx3x4 '
+                    'for size {}. Got {}.'.format(size, theta.shape))
+        else:
+            raise NotImplementedError(
+                'affine_grid only supports 4D and 5D sizes, '
+                'for 2D and 3D affine transforms, respectively. '
+                'Got size {}.'.format(size))
+        if min(size) <= 0:
+            raise ValueError(
+                'Expected non-zero, positive output size. Got {}'.format(size))
+        return _AffineGridGenerator.apply(theta, size, align_corners)
diff --git a/mmdet/ops/affine_grid/src/affine_grid_cuda.cpp b/mmdet/ops/affine_grid/src/affine_grid_cuda.cpp
new file mode 100644
index 00000000..3874128c
--- /dev/null
+++ b/mmdet/ops/affine_grid/src/affine_grid_cuda.cpp
@@ -0,0 +1,115 @@
+// Modified from https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AffineGridGenerator.cpp
+#include <ATen/ATen.h>
+#include <ATen/NativeFunctions.h>
+#include <torch/extension.h>
+
+namespace mmdetection {
+
+using namespace at;
+
+at::Tensor linspace_from_neg_one(const Tensor& grid, int64_t num_steps,
+                                 bool align_corners) {
+  if (num_steps <= 1) {
+    return at::tensor(0, grid.options());
+  }
+  auto range = at::linspace(-1, 1, num_steps, grid.options());
+  if (!align_corners) {
+    range = range * (num_steps - 1) / num_steps;
+  }
+  return range;
+}
+
+Tensor make_base_grid_4D(const Tensor& theta, int64_t N, int64_t C, int64_t H,
+                         int64_t W, bool align_corners) {
+  auto base_grid = at::empty({N, H, W, 3}, theta.options());
+
+  base_grid.select(-1, 0).copy_(linspace_from_neg_one(theta, W, align_corners));
+  base_grid.select(-1, 1).copy_(
+      linspace_from_neg_one(theta, H, align_corners).unsqueeze_(-1));
+  base_grid.select(-1, 2).fill_(1);
+
+  return base_grid;
+}
+
+Tensor make_base_grid_5D(const Tensor& theta, int64_t N, int64_t C, int64_t D,
+                         int64_t H, int64_t W, bool align_corners) {
+  auto base_grid = at::empty({N, D, H, W, 4}, theta.options());
+
+  base_grid.select(-1, 0).copy_(linspace_from_neg_one(theta, W, align_corners));
+  base_grid.select(-1, 1).copy_(
+      linspace_from_neg_one(theta, H, align_corners).unsqueeze_(-1));
+  base_grid.select(-1, 2).copy_(linspace_from_neg_one(theta, D, align_corners)
+                                    .unsqueeze_(-1)
+                                    .unsqueeze_(-1));
+  base_grid.select(-1, 3).fill_(1);
+
+  return base_grid;
+}
+
+Tensor affine_grid_generator_4D_forward(const Tensor& theta, int64_t N,
+                                        int64_t C, int64_t H, int64_t W,
+                                        bool align_corners) {
+  Tensor base_grid = make_base_grid_4D(theta, N, C, H, W, align_corners);
+  auto grid = base_grid.view({N, H * W, 3}).bmm(theta.transpose(1, 2));
+  return grid.view({N, H, W, 2});
+}
+
+Tensor affine_grid_generator_5D_forward(const Tensor& theta, int64_t N,
+                                        int64_t C, int64_t D, int64_t H,
+                                        int64_t W, bool align_corners) {
+  Tensor base_grid = make_base_grid_5D(theta, N, C, D, H, W, align_corners);
+  auto grid = base_grid.view({N, D * H * W, 4}).bmm(theta.transpose(1, 2));
+  return grid.view({N, D, H, W, 3});
+}
+
+Tensor affine_grid_generator_forward(const Tensor& theta, IntArrayRef size,
+                                     bool align_corners) {
+  if (size.size() == 4) {
+    return affine_grid_generator_4D_forward(theta, size[0], size[1], size[2],
+                                            size[3], align_corners);
+  } else {
+    return affine_grid_generator_5D_forward(theta, size[0], size[1], size[2],
+                                            size[3], size[4], align_corners);
+  }
+}
+
+Tensor affine_grid_generator_4D_backward(const Tensor& grad_grid, int64_t N,
+                                         int64_t C, int64_t H, int64_t W,
+                                         bool align_corners) {
+  auto base_grid = make_base_grid_4D(grad_grid, N, C, H, W, align_corners);
+  AT_ASSERT(grad_grid.sizes() == IntArrayRef({N, H, W, 2}));
+  auto grad_theta = base_grid.view({N, H * W, 3})
+                        .transpose(1, 2)
+                        .bmm(grad_grid.view({N, H * W, 2}));
+  return grad_theta.transpose(1, 2);
+}
+
+Tensor affine_grid_generator_5D_backward(const Tensor& grad_grid, int64_t N,
+                                         int64_t C, int64_t D, int64_t H,
+                                         int64_t W, bool align_corners) {
+  auto base_grid = make_base_grid_5D(grad_grid, N, C, D, H, W, align_corners);
+  AT_ASSERT(grad_grid.sizes() == IntArrayRef({N, D, H, W, 3}));
+  auto grad_theta = base_grid.view({N, D * H * W, 4})
+                        .transpose(1, 2)
+                        .bmm(grad_grid.view({N, D * H * W, 3}));
+  return grad_theta.transpose(1, 2);
+}
+
+Tensor affine_grid_generator_backward(const Tensor& grad, IntArrayRef size,
+                                      bool align_corners) {
+  if (size.size() == 4) {
+    return affine_grid_generator_4D_backward(grad, size[0], size[1], size[2],
+                                             size[3], align_corners);
+  } else {
+    return affine_grid_generator_5D_backward(grad, size[0], size[1], size[2],
+                                             size[3], size[4], align_corners);
+  }
+}
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+  m.def("affine_grid_generator_forward", &affine_grid_generator_forward,
+        "affine_grid_generator_forward");
+  m.def("affine_grid_generator_backward", &affine_grid_generator_backward,
+        "affine_grid_generator_backward");
+}
+
+}  // namespace mmdetection
diff --git a/mmdet/ops/grid_sampler/grid_sampler.py b/mmdet/ops/grid_sampler/grid_sampler.py
index d359d8df..a112fa82 100644
--- a/mmdet/ops/grid_sampler/grid_sampler.py
+++ b/mmdet/ops/grid_sampler/grid_sampler.py
@@ -61,8 +61,10 @@ def grid_sample(input,
                 mode='bilinear',
                 padding_mode='zeros',
                 align_corners=False):
-    if torch.__version__ >= '1.3' or align_corners:
+    if torch.__version__ >= '1.3':
         return F.grid_sample(input, grid, mode, padding_mode, align_corners)
+    elif align_corners:
+        return F.grid_sample(input, grid, mode, padding_mode)
     else:
 
         # use self-compiled grid_sampler to support align_corners=False
diff --git a/setup.py b/setup.py
index 9704b4d6..a11263f2 100755
--- a/setup.py
+++ b/setup.py
@@ -269,6 +269,10 @@ if __name__ == '__main__':
                 sources=[
                     'src/masked_conv2d_cuda.cpp', 'src/masked_conv2d_kernel.cu'
                 ]),
+            make_cuda_ext(
+                name='affine_grid_cuda',
+                module='mmdet.ops.affine_grid',
+                sources=['src/affine_grid_cuda.cpp']),
             make_cuda_ext(
                 name='grid_sampler_cuda',
                 module='mmdet.ops.grid_sampler',
-- 
GitLab