Skip to content
Snippets Groups Projects
Unverified Commit 51df8a9b authored by Jerry Jiarui XU's avatar Jerry Jiarui XU Committed by GitHub
Browse files

add affine_grid (#2180)

* add affine_grid

* missing setup

* remove import

* reformat

* rename and reformat

* reformat cpp
parent 9f6cddfd
No related branches found
No related tags found
No related merge requests found
from .affine_grid import affine_grid
__all__ = ['affine_grid']
import torch
import torch.nn.functional as F
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from . import affine_grid_cuda
class _AffineGridGenerator(Function):
@staticmethod
def forward(ctx, theta, size, align_corners):
ctx.save_for_backward(theta)
ctx.size = size
ctx.align_corners = align_corners
func = affine_grid_cuda.affine_grid_generator_forward
output = func(theta, size, align_corners)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
theta = ctx.saved_tensors
size = ctx.size
align_corners = ctx.align_corners
func = affine_grid_cuda.affine_grid_generator_backward
grad_input = func(grad_output, theta, size, align_corners)
return grad_input, None, None
def affine_grid(theta, size, align_corners=False):
if torch.__version__ >= '1.3':
return F.affine_grid(theta, size, align_corners)
elif align_corners:
return F.affine_grid(theta, size)
else:
# enforce floating point dtype on theta
if not theta.is_floating_point():
raise ValueError(
'Expected theta to have floating point type, but got {}'.
format(theta.dtype))
# check that shapes and sizes match
if len(size) == 4:
if theta.dim() != 3 or theta.size(-2) != 2 or theta.size(-1) != 3:
raise ValueError(
'Expected a batch of 2D affine matrices of shape Nx2x3 '
'for size {}. Got {}.'.format(size, theta.shape))
elif len(size) == 5:
if theta.dim() != 3 or theta.size(-2) != 3 or theta.size(-1) != 4:
raise ValueError(
'Expected a batch of 3D affine matrices of shape Nx3x4 '
'for size {}. Got {}.'.format(size, theta.shape))
else:
raise NotImplementedError(
'affine_grid only supports 4D and 5D sizes, '
'for 2D and 3D affine transforms, respectively. '
'Got size {}.'.format(size))
if min(size) <= 0:
raise ValueError(
'Expected non-zero, positive output size. Got {}'.format(size))
return _AffineGridGenerator.apply(theta, size, align_corners)
// Modified from https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AffineGridGenerator.cpp
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <torch/extension.h>
namespace mmdetection {
using namespace at;
at::Tensor linspace_from_neg_one(const Tensor& grid, int64_t num_steps,
bool align_corners) {
if (num_steps <= 1) {
return at::tensor(0, grid.options());
}
auto range = at::linspace(-1, 1, num_steps, grid.options());
if (!align_corners) {
range = range * (num_steps - 1) / num_steps;
}
return range;
}
Tensor make_base_grid_4D(const Tensor& theta, int64_t N, int64_t C, int64_t H,
int64_t W, bool align_corners) {
auto base_grid = at::empty({N, H, W, 3}, theta.options());
base_grid.select(-1, 0).copy_(linspace_from_neg_one(theta, W, align_corners));
base_grid.select(-1, 1).copy_(
linspace_from_neg_one(theta, H, align_corners).unsqueeze_(-1));
base_grid.select(-1, 2).fill_(1);
return base_grid;
}
Tensor make_base_grid_5D(const Tensor& theta, int64_t N, int64_t C, int64_t D,
int64_t H, int64_t W, bool align_corners) {
auto base_grid = at::empty({N, D, H, W, 4}, theta.options());
base_grid.select(-1, 0).copy_(linspace_from_neg_one(theta, W, align_corners));
base_grid.select(-1, 1).copy_(
linspace_from_neg_one(theta, H, align_corners).unsqueeze_(-1));
base_grid.select(-1, 2).copy_(linspace_from_neg_one(theta, D, align_corners)
.unsqueeze_(-1)
.unsqueeze_(-1));
base_grid.select(-1, 3).fill_(1);
return base_grid;
}
Tensor affine_grid_generator_4D_forward(const Tensor& theta, int64_t N,
int64_t C, int64_t H, int64_t W,
bool align_corners) {
Tensor base_grid = make_base_grid_4D(theta, N, C, H, W, align_corners);
auto grid = base_grid.view({N, H * W, 3}).bmm(theta.transpose(1, 2));
return grid.view({N, H, W, 2});
}
Tensor affine_grid_generator_5D_forward(const Tensor& theta, int64_t N,
int64_t C, int64_t D, int64_t H,
int64_t W, bool align_corners) {
Tensor base_grid = make_base_grid_5D(theta, N, C, D, H, W, align_corners);
auto grid = base_grid.view({N, D * H * W, 4}).bmm(theta.transpose(1, 2));
return grid.view({N, D, H, W, 3});
}
Tensor affine_grid_generator_forward(const Tensor& theta, IntArrayRef size,
bool align_corners) {
if (size.size() == 4) {
return affine_grid_generator_4D_forward(theta, size[0], size[1], size[2],
size[3], align_corners);
} else {
return affine_grid_generator_5D_forward(theta, size[0], size[1], size[2],
size[3], size[4], align_corners);
}
}
Tensor affine_grid_generator_4D_backward(const Tensor& grad_grid, int64_t N,
int64_t C, int64_t H, int64_t W,
bool align_corners) {
auto base_grid = make_base_grid_4D(grad_grid, N, C, H, W, align_corners);
AT_ASSERT(grad_grid.sizes() == IntArrayRef({N, H, W, 2}));
auto grad_theta = base_grid.view({N, H * W, 3})
.transpose(1, 2)
.bmm(grad_grid.view({N, H * W, 2}));
return grad_theta.transpose(1, 2);
}
Tensor affine_grid_generator_5D_backward(const Tensor& grad_grid, int64_t N,
int64_t C, int64_t D, int64_t H,
int64_t W, bool align_corners) {
auto base_grid = make_base_grid_5D(grad_grid, N, C, D, H, W, align_corners);
AT_ASSERT(grad_grid.sizes() == IntArrayRef({N, D, H, W, 3}));
auto grad_theta = base_grid.view({N, D * H * W, 4})
.transpose(1, 2)
.bmm(grad_grid.view({N, D * H * W, 3}));
return grad_theta.transpose(1, 2);
}
Tensor affine_grid_generator_backward(const Tensor& grad, IntArrayRef size,
bool align_corners) {
if (size.size() == 4) {
return affine_grid_generator_4D_backward(grad, size[0], size[1], size[2],
size[3], align_corners);
} else {
return affine_grid_generator_5D_backward(grad, size[0], size[1], size[2],
size[3], size[4], align_corners);
}
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("affine_grid_generator_forward", &affine_grid_generator_forward,
"affine_grid_generator_forward");
m.def("affine_grid_generator_backward", &affine_grid_generator_backward,
"affine_grid_generator_backward");
}
} // namespace mmdetection
......@@ -61,8 +61,10 @@ def grid_sample(input,
mode='bilinear',
padding_mode='zeros',
align_corners=False):
if torch.__version__ >= '1.3' or align_corners:
if torch.__version__ >= '1.3':
return F.grid_sample(input, grid, mode, padding_mode, align_corners)
elif align_corners:
return F.grid_sample(input, grid, mode, padding_mode)
else:
# use self-compiled grid_sampler to support align_corners=False
......
......@@ -269,6 +269,10 @@ if __name__ == '__main__':
sources=[
'src/masked_conv2d_cuda.cpp', 'src/masked_conv2d_kernel.cu'
]),
make_cuda_ext(
name='affine_grid_cuda',
module='mmdet.ops.affine_grid',
sources=['src/affine_grid_cuda.cpp']),
make_cuda_ext(
name='grid_sampler_cuda',
module='mmdet.ops.grid_sampler',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment