diff --git a/mmdet/ops/__init__.py b/mmdet/ops/__init__.py index 97edf7322b5c5135bb4a2907144c851d071692cc..e47166f6fd220d257102a02fa4b26a7458bf2b00 100644 --- a/mmdet/ops/__init__.py +++ b/mmdet/ops/__init__.py @@ -1,5 +1,6 @@ from .context_block import ContextBlock from .conv_ws import ConvWS2d, conv_ws_2d +from .corner_pool import CornerPool from .dcn import (DeformConv, DeformConvPack, DeformRoIPooling, DeformRoIPoolingPack, ModulatedDeformConv, ModulatedDeformConvPack, ModulatedDeformRoIPoolingPack, @@ -16,13 +17,38 @@ from .utils import get_compiler_version, get_compiling_cuda_version from .wrappers import Conv2d, ConvTranspose2d, Linear, MaxPool2d __all__ = [ - 'nms', 'soft_nms', 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', - 'DeformConv', 'DeformConvPack', 'DeformRoIPooling', 'DeformRoIPoolingPack', - 'ModulatedDeformRoIPoolingPack', 'ModulatedDeformConv', - 'ModulatedDeformConvPack', 'deform_conv', 'modulated_deform_conv', - 'deform_roi_pooling', 'SigmoidFocalLoss', 'sigmoid_focal_loss', - 'MaskedConv2d', 'ContextBlock', 'GeneralizedAttention', 'NonLocal2D', - 'get_compiler_version', 'get_compiling_cuda_version', 'ConvWS2d', - 'conv_ws_2d', 'build_plugin_layer', 'batched_nms', 'Conv2d', - 'ConvTranspose2d', 'MaxPool2d', 'Linear', 'nms_match' + 'nms', + 'soft_nms', + 'RoIAlign', + 'roi_align', + 'RoIPool', + 'roi_pool', + 'DeformConv', + 'DeformConvPack', + 'DeformRoIPooling', + 'DeformRoIPoolingPack', + 'ModulatedDeformRoIPoolingPack', + 'ModulatedDeformConv', + 'ModulatedDeformConvPack', + 'deform_conv', + 'modulated_deform_conv', + 'deform_roi_pooling', + 'SigmoidFocalLoss', + 'sigmoid_focal_loss', + 'MaskedConv2d', + 'ContextBlock', + 'GeneralizedAttention', + 'NonLocal2D', + 'get_compiler_version', + 'get_compiling_cuda_version', + 'ConvWS2d', + 'conv_ws_2d', + 'build_plugin_layer', + 'batched_nms', + 'Conv2d', + 'ConvTranspose2d', + 'MaxPool2d', + 'Linear', + 'nms_match', + 'CornerPool', ] diff --git a/mmdet/ops/corner_pool/__init__.py b/mmdet/ops/corner_pool/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a5457db99f0d45cfb7c38e822294ecc0c353afd0 --- /dev/null +++ b/mmdet/ops/corner_pool/__init__.py @@ -0,0 +1,3 @@ +from .corner_pool import CornerPool + +__all__ = ['CornerPool'] diff --git a/mmdet/ops/corner_pool/corner_pool.py b/mmdet/ops/corner_pool/corner_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..00b6b774a4b667a888c61164874da03ad35e6f35 --- /dev/null +++ b/mmdet/ops/corner_pool/corner_pool.py @@ -0,0 +1,101 @@ +from torch import nn +from torch.autograd import Function + +from . import corner_pool_ext + + +class TopPoolFunction(Function): + + @staticmethod + def forward(ctx, input): + output = corner_pool_ext.top_pool_forward(input) + ctx.save_for_backward(input) + return output + + @staticmethod + def backward(ctx, grad_output): + input = ctx.saved_variables[0] + output = corner_pool_ext.top_pool_backward(input, grad_output) + return output + + +class BottomPoolFunction(Function): + + @staticmethod + def forward(ctx, input): + output = corner_pool_ext.bottom_pool_forward(input) + ctx.save_for_backward(input) + return output + + @staticmethod + def backward(ctx, grad_output): + input = ctx.saved_variables[0] + output = corner_pool_ext.bottom_pool_backward(input, grad_output) + return output + + +class LeftPoolFunction(Function): + + @staticmethod + def forward(ctx, input): + output = corner_pool_ext.left_pool_forward(input) + ctx.save_for_backward(input) + return output + + @staticmethod + def backward(ctx, grad_output): + input = ctx.saved_variables[0] + output = corner_pool_ext.left_pool_backward(input, grad_output) + return output + + +class RightPoolFunction(Function): + + @staticmethod + def forward(ctx, input): + output = corner_pool_ext.right_pool_forward(input) + ctx.save_for_backward(input) + return output + + @staticmethod + def backward(ctx, grad_output): + input = ctx.saved_variables[0] + output = corner_pool_ext.right_pool_backward(input, grad_output) + return output + + +class CornerPool(nn.Module): + """Corner Pooling. + + Corner Pooling is a new type of pooling layer that helps a + convolutional network better localize corners of bounding boxes. + + Please refer to https://arxiv.org/abs/1808.01244 for more details. + Code is modified from https://github.com/princeton-vl/CornerNet-Lite. + + Args: + mode(str): Pooling orientation for the pooling layer + + - 'bottom': Bottom Pooling + - 'left': Left Pooling + - 'right': Right Pooling + - 'top': Top Pooling + + Returns: + Feature map after pooling. + """ + + pool_functions = { + 'bottom': BottomPoolFunction, + 'left': LeftPoolFunction, + 'right': RightPoolFunction, + 'top': TopPoolFunction, + } + + def __init__(self, mode): + super(CornerPool, self).__init__() + assert mode in self.pool_functions + self.corner_pool = self.pool_functions[mode] + + def forward(self, x): + return self.corner_pool.apply(x) diff --git a/mmdet/ops/corner_pool/src/corner_pool.cpp b/mmdet/ops/corner_pool/src/corner_pool.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a1fde8078a81dd2f2c9fa57a4d412f9ecae27685 --- /dev/null +++ b/mmdet/ops/corner_pool/src/corner_pool.cpp @@ -0,0 +1,268 @@ +// Modified from +// https://github.com/princeton-vl/CornerNet-Lite/tree/master/core/models/py_utils/_cpools/src +#include <torch/torch.h> + +#include <vector> + +at::Tensor bottom_pool_forward(at::Tensor input) { + // Initialize output + at::Tensor output = at::zeros_like(input); + + // Get height + int64_t height = input.size(2); + + output.copy_(input); + + for (int64_t ind = 1; ind < height; ind <<= 1) { + at::Tensor max_temp = at::slice(output, 2, ind, height); + at::Tensor cur_temp = at::slice(output, 2, ind, height).clone(); + at::Tensor next_temp = at::slice(output, 2, 0, height - ind).clone(); + at::max_out(max_temp, cur_temp, next_temp); + } + + return output; +} + +at::Tensor bottom_pool_backward(at::Tensor input, at::Tensor grad_output) { + auto output = at::zeros_like(input); + + int32_t batch = input.size(0); + int32_t channel = input.size(1); + int32_t height = input.size(2); + int32_t width = input.size(3); + + auto max_val = torch::zeros({batch, channel, width}, + at::device(at::kCUDA).dtype(at::kFloat)); + auto max_ind = torch::zeros({batch, channel, width}, + at::device(at::kCUDA).dtype(at::kLong)); + + auto input_temp = input.select(2, 0); + max_val.copy_(input_temp); + + max_ind.fill_(0); + + auto output_temp = output.select(2, 0); + auto grad_output_temp = grad_output.select(2, 0); + output_temp.copy_(grad_output_temp); + + auto un_max_ind = max_ind.unsqueeze(2); + auto gt_mask = torch::zeros({batch, channel, width}, + at::device(at::kCUDA).dtype(at::kBool)); + auto max_temp = torch::zeros({batch, channel, width}, + at::device(at::kCUDA).dtype(at::kFloat)); + for (int32_t ind = 0; ind < height - 1; ++ind) { + input_temp = input.select(2, ind + 1); + at::gt_out(gt_mask, input_temp, max_val); + + at::masked_select_out(max_temp, input_temp, gt_mask); + max_val.masked_scatter_(gt_mask, max_temp); + max_ind.masked_fill_(gt_mask, ind + 1); + + grad_output_temp = grad_output.select(2, ind + 1).unsqueeze(2); + output.scatter_add_(2, un_max_ind, grad_output_temp); + } + + return output; +} + +at::Tensor left_pool_forward(at::Tensor input) { + // Initialize output + at::Tensor output = at::zeros_like(input); + + // Get width + int64_t width = input.size(3); + + output.copy_(input); + + for (int64_t ind = 1; ind < width; ind <<= 1) { + at::Tensor max_temp = at::slice(output, 3, 0, width - ind); + at::Tensor cur_temp = at::slice(output, 3, 0, width - ind).clone(); + at::Tensor next_temp = at::slice(output, 3, ind, width).clone(); + at::max_out(max_temp, cur_temp, next_temp); + } + + return output; +} + +at::Tensor left_pool_backward(at::Tensor input, at::Tensor grad_output) { + auto output = at::zeros_like(input); + + int32_t batch = input.size(0); + int32_t channel = input.size(1); + int32_t height = input.size(2); + int32_t width = input.size(3); + + auto max_val = torch::zeros({batch, channel, height}, + at::device(at::kCUDA).dtype(at::kFloat)); + auto max_ind = torch::zeros({batch, channel, height}, + at::device(at::kCUDA).dtype(at::kLong)); + + auto input_temp = input.select(3, width - 1); + max_val.copy_(input_temp); + + max_ind.fill_(width - 1); + + auto output_temp = output.select(3, width - 1); + auto grad_output_temp = grad_output.select(3, width - 1); + output_temp.copy_(grad_output_temp); + + auto un_max_ind = max_ind.unsqueeze(3); + auto gt_mask = torch::zeros({batch, channel, height}, + at::device(at::kCUDA).dtype(at::kBool)); + auto max_temp = torch::zeros({batch, channel, height}, + at::device(at::kCUDA).dtype(at::kFloat)); + for (int32_t ind = 1; ind < width; ++ind) { + input_temp = input.select(3, width - ind - 1); + at::gt_out(gt_mask, input_temp, max_val); + + at::masked_select_out(max_temp, input_temp, gt_mask); + max_val.masked_scatter_(gt_mask, max_temp); + max_ind.masked_fill_(gt_mask, width - ind - 1); + + grad_output_temp = grad_output.select(3, width - ind - 1).unsqueeze(3); + output.scatter_add_(3, un_max_ind, grad_output_temp); + } + + return output; +} + +at::Tensor right_pool_forward(at::Tensor input) { + // Initialize output + at::Tensor output = at::zeros_like(input); + + // Get width + int64_t width = input.size(3); + + output.copy_(input); + + for (int64_t ind = 1; ind < width; ind <<= 1) { + at::Tensor max_temp = at::slice(output, 3, ind, width); + at::Tensor cur_temp = at::slice(output, 3, ind, width).clone(); + at::Tensor next_temp = at::slice(output, 3, 0, width - ind).clone(); + at::max_out(max_temp, cur_temp, next_temp); + } + + return output; +} + +at::Tensor right_pool_backward(at::Tensor input, at::Tensor grad_output) { + at::Tensor output = at::zeros_like(input); + + int32_t batch = input.size(0); + int32_t channel = input.size(1); + int32_t height = input.size(2); + int32_t width = input.size(3); + + auto max_val = torch::zeros({batch, channel, height}, + at::device(at::kCUDA).dtype(at::kFloat)); + auto max_ind = torch::zeros({batch, channel, height}, + at::device(at::kCUDA).dtype(at::kLong)); + + auto input_temp = input.select(3, 0); + max_val.copy_(input_temp); + + max_ind.fill_(0); + + auto output_temp = output.select(3, 0); + auto grad_output_temp = grad_output.select(3, 0); + output_temp.copy_(grad_output_temp); + + auto un_max_ind = max_ind.unsqueeze(3); + auto gt_mask = torch::zeros({batch, channel, height}, + at::device(at::kCUDA).dtype(at::kBool)); + auto max_temp = torch::zeros({batch, channel, height}, + at::device(at::kCUDA).dtype(at::kFloat)); + for (int32_t ind = 0; ind < width - 1; ++ind) { + input_temp = input.select(3, ind + 1); + at::gt_out(gt_mask, input_temp, max_val); + + at::masked_select_out(max_temp, input_temp, gt_mask); + max_val.masked_scatter_(gt_mask, max_temp); + max_ind.masked_fill_(gt_mask, ind + 1); + + grad_output_temp = grad_output.select(3, ind + 1).unsqueeze(3); + output.scatter_add_(3, un_max_ind, grad_output_temp); + } + + return output; +} + +at::Tensor top_pool_forward(at::Tensor input) { + // Initialize output + at::Tensor output = at::zeros_like(input); + + // Get height + int64_t height = input.size(2); + + output.copy_(input); + + for (int64_t ind = 1; ind < height; ind <<= 1) { + at::Tensor max_temp = at::slice(output, 2, 0, height - ind); + at::Tensor cur_temp = at::slice(output, 2, 0, height - ind).clone(); + at::Tensor next_temp = at::slice(output, 2, ind, height).clone(); + at::max_out(max_temp, cur_temp, next_temp); + } + + return output; +} + +at::Tensor top_pool_backward(at::Tensor input, at::Tensor grad_output) { + auto output = at::zeros_like(input); + + int32_t batch = input.size(0); + int32_t channel = input.size(1); + int32_t height = input.size(2); + int32_t width = input.size(3); + + auto max_val = torch::zeros({batch, channel, width}, + at::device(at::kCUDA).dtype(at::kFloat)); + auto max_ind = torch::zeros({batch, channel, width}, + at::device(at::kCUDA).dtype(at::kLong)); + + auto input_temp = input.select(2, height - 1); + max_val.copy_(input_temp); + + max_ind.fill_(height - 1); + + auto output_temp = output.select(2, height - 1); + auto grad_output_temp = grad_output.select(2, height - 1); + output_temp.copy_(grad_output_temp); + + auto un_max_ind = max_ind.unsqueeze(2); + auto gt_mask = torch::zeros({batch, channel, width}, + at::device(at::kCUDA).dtype(at::kBool)); + auto max_temp = torch::zeros({batch, channel, width}, + at::device(at::kCUDA).dtype(at::kFloat)); + for (int32_t ind = 1; ind < height; ++ind) { + input_temp = input.select(2, height - ind - 1); + at::gt_out(gt_mask, input_temp, max_val); + + at::masked_select_out(max_temp, input_temp, gt_mask); + max_val.masked_scatter_(gt_mask, max_temp); + max_ind.masked_fill_(gt_mask, height - ind - 1); + + grad_output_temp = grad_output.select(2, height - ind - 1).unsqueeze(2); + output.scatter_add_(2, un_max_ind, grad_output_temp); + } + + return output; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("bottom_pool_forward", &bottom_pool_forward, "Bottom Pool Forward", + py::call_guard<py::gil_scoped_release>()); + m.def("bottom_pool_backward", &bottom_pool_backward, "Bottom Pool Backward", + py::call_guard<py::gil_scoped_release>()); + m.def("left_pool_forward", &left_pool_forward, "Left Pool Forward", + py::call_guard<py::gil_scoped_release>()); + m.def("left_pool_backward", &left_pool_backward, "Left Pool Backward", + py::call_guard<py::gil_scoped_release>()); + m.def("right_pool_forward", &right_pool_forward, "Right Pool Forward", + py::call_guard<py::gil_scoped_release>()); + m.def("right_pool_backward", &right_pool_backward, "Right Pool Backward", + py::call_guard<py::gil_scoped_release>()); + m.def("top_pool_forward", &top_pool_forward, "Top Pool Forward", + py::call_guard<py::gil_scoped_release>()); + m.def("top_pool_backward", &top_pool_backward, "Top Pool Backward", + py::call_guard<py::gil_scoped_release>()); +} diff --git a/setup.py b/setup.py index 80fa9f670a40b441ff9379af315049245eb179d3..57ccb55e27d32189a97c2ff1c43d69341a58d82a 100755 --- a/setup.py +++ b/setup.py @@ -294,7 +294,11 @@ if __name__ == '__main__': sources_cuda=[ 'src/cuda/carafe_naive_cuda.cpp', 'src/cuda/carafe_naive_cuda_kernel.cu' - ]) + ]), + make_cuda_ext( + name='corner_pool_ext', + module='mmdet.ops.corner_pool', + sources=['src/corner_pool.cpp']), ], cmdclass={'build_ext': BuildExtension}, zip_safe=False) diff --git a/tests/test_ops/test_corner_pool.py b/tests/test_ops/test_corner_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..cb84acf0d791a5f3880487d990f9bdef7843a7a0 --- /dev/null +++ b/tests/test_ops/test_corner_pool.py @@ -0,0 +1,58 @@ +""" +CommandLine: + pytest tests/test_corner_pool.py +""" +import pytest +import torch + +from mmdet.ops import CornerPool + + +def test_corner_pool_device_and_dtypes_cpu(): + """ + CommandLine: + xdoctest -m tests/test_corner_pool.py \ + test_corner_pool_device_and_dtypes_cpu + """ + with pytest.raises(AssertionError): + # pool mode must in ['bottom', 'left', 'right', 'top'] + pool = CornerPool('corner') + + lr_tensor = torch.tensor([[[[0, 0, 0, 0, 0], [2, 1, 3, 0, 2], + [5, 4, 1, 1, 6], [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]]]]) + tb_tensor = torch.tensor([[[[0, 3, 1, 0, 0], [0, 1, 1, 0, 0], + [0, 3, 4, 0, 0], [0, 2, 2, 0, 0], + [0, 0, 2, 0, 0]]]]) + # Left Pool + left_answer = torch.tensor([[[[0, 0, 0, 0, 0], [3, 3, 3, 2, 2], + [6, 6, 6, 6, 6], [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]]]]) + pool = CornerPool('left') + left_tensor = pool(lr_tensor) + assert left_tensor.type() == lr_tensor.type() + assert torch.equal(left_tensor, left_answer) + # Right Pool + right_answer = torch.tensor([[[[0, 0, 0, 0, 0], [2, 2, 3, 3, 3], + [5, 5, 5, 5, 6], [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]]]]) + pool = CornerPool('right') + right_tensor = pool(lr_tensor) + assert right_tensor.type() == lr_tensor.type() + assert torch.equal(right_tensor, right_answer) + # Top Pool + top_answer = torch.tensor([[[[0, 3, 4, 0, 0], [0, 3, 4, 0, 0], + [0, 3, 4, 0, 0], [0, 2, 2, 0, 0], + [0, 0, 2, 0, 0]]]]) + pool = CornerPool('top') + top_tensor = pool(tb_tensor) + assert top_tensor.type() == tb_tensor.type() + assert torch.equal(top_tensor, top_answer) + # Bottom Pool + bottom_answer = torch.tensor([[[[0, 3, 1, 0, 0], [0, 3, 1, 0, 0], + [0, 3, 4, 0, 0], [0, 3, 4, 0, 0], + [0, 3, 4, 0, 0]]]]) + pool = CornerPool('bottom') + bottom_tensor = pool(tb_tensor) + assert bottom_tensor.type() == tb_tensor.type() + assert torch.equal(bottom_tensor, bottom_answer) diff --git a/tests/test_merge_cells.py b/tests/test_ops/test_merge_cells.py similarity index 100% rename from tests/test_merge_cells.py rename to tests/test_ops/test_merge_cells.py diff --git a/tests/test_nms.py b/tests/test_ops/test_nms.py similarity index 100% rename from tests/test_nms.py rename to tests/test_ops/test_nms.py diff --git a/tests/test_soft_nms.py b/tests/test_ops/test_soft_nms.py similarity index 100% rename from tests/test_soft_nms.py rename to tests/test_ops/test_soft_nms.py diff --git a/tests/test_wrappers.py b/tests/test_ops/test_wrappers.py similarity index 100% rename from tests/test_wrappers.py rename to tests/test_ops/test_wrappers.py