diff --git a/mmdet/ops/__init__.py b/mmdet/ops/__init__.py
index 97edf7322b5c5135bb4a2907144c851d071692cc..e47166f6fd220d257102a02fa4b26a7458bf2b00 100644
--- a/mmdet/ops/__init__.py
+++ b/mmdet/ops/__init__.py
@@ -1,5 +1,6 @@
 from .context_block import ContextBlock
 from .conv_ws import ConvWS2d, conv_ws_2d
+from .corner_pool import CornerPool
 from .dcn import (DeformConv, DeformConvPack, DeformRoIPooling,
                   DeformRoIPoolingPack, ModulatedDeformConv,
                   ModulatedDeformConvPack, ModulatedDeformRoIPoolingPack,
@@ -16,13 +17,38 @@ from .utils import get_compiler_version, get_compiling_cuda_version
 from .wrappers import Conv2d, ConvTranspose2d, Linear, MaxPool2d
 
 __all__ = [
-    'nms', 'soft_nms', 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool',
-    'DeformConv', 'DeformConvPack', 'DeformRoIPooling', 'DeformRoIPoolingPack',
-    'ModulatedDeformRoIPoolingPack', 'ModulatedDeformConv',
-    'ModulatedDeformConvPack', 'deform_conv', 'modulated_deform_conv',
-    'deform_roi_pooling', 'SigmoidFocalLoss', 'sigmoid_focal_loss',
-    'MaskedConv2d', 'ContextBlock', 'GeneralizedAttention', 'NonLocal2D',
-    'get_compiler_version', 'get_compiling_cuda_version', 'ConvWS2d',
-    'conv_ws_2d', 'build_plugin_layer', 'batched_nms', 'Conv2d',
-    'ConvTranspose2d', 'MaxPool2d', 'Linear', 'nms_match'
+    'nms',
+    'soft_nms',
+    'RoIAlign',
+    'roi_align',
+    'RoIPool',
+    'roi_pool',
+    'DeformConv',
+    'DeformConvPack',
+    'DeformRoIPooling',
+    'DeformRoIPoolingPack',
+    'ModulatedDeformRoIPoolingPack',
+    'ModulatedDeformConv',
+    'ModulatedDeformConvPack',
+    'deform_conv',
+    'modulated_deform_conv',
+    'deform_roi_pooling',
+    'SigmoidFocalLoss',
+    'sigmoid_focal_loss',
+    'MaskedConv2d',
+    'ContextBlock',
+    'GeneralizedAttention',
+    'NonLocal2D',
+    'get_compiler_version',
+    'get_compiling_cuda_version',
+    'ConvWS2d',
+    'conv_ws_2d',
+    'build_plugin_layer',
+    'batched_nms',
+    'Conv2d',
+    'ConvTranspose2d',
+    'MaxPool2d',
+    'Linear',
+    'nms_match',
+    'CornerPool',
 ]
diff --git a/mmdet/ops/corner_pool/__init__.py b/mmdet/ops/corner_pool/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5457db99f0d45cfb7c38e822294ecc0c353afd0
--- /dev/null
+++ b/mmdet/ops/corner_pool/__init__.py
@@ -0,0 +1,3 @@
+from .corner_pool import CornerPool
+
+__all__ = ['CornerPool']
diff --git a/mmdet/ops/corner_pool/corner_pool.py b/mmdet/ops/corner_pool/corner_pool.py
new file mode 100644
index 0000000000000000000000000000000000000000..00b6b774a4b667a888c61164874da03ad35e6f35
--- /dev/null
+++ b/mmdet/ops/corner_pool/corner_pool.py
@@ -0,0 +1,101 @@
+from torch import nn
+from torch.autograd import Function
+
+from . import corner_pool_ext
+
+
+class TopPoolFunction(Function):
+
+    @staticmethod
+    def forward(ctx, input):
+        output = corner_pool_ext.top_pool_forward(input)
+        ctx.save_for_backward(input)
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        input = ctx.saved_variables[0]
+        output = corner_pool_ext.top_pool_backward(input, grad_output)
+        return output
+
+
+class BottomPoolFunction(Function):
+
+    @staticmethod
+    def forward(ctx, input):
+        output = corner_pool_ext.bottom_pool_forward(input)
+        ctx.save_for_backward(input)
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        input = ctx.saved_variables[0]
+        output = corner_pool_ext.bottom_pool_backward(input, grad_output)
+        return output
+
+
+class LeftPoolFunction(Function):
+
+    @staticmethod
+    def forward(ctx, input):
+        output = corner_pool_ext.left_pool_forward(input)
+        ctx.save_for_backward(input)
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        input = ctx.saved_variables[0]
+        output = corner_pool_ext.left_pool_backward(input, grad_output)
+        return output
+
+
+class RightPoolFunction(Function):
+
+    @staticmethod
+    def forward(ctx, input):
+        output = corner_pool_ext.right_pool_forward(input)
+        ctx.save_for_backward(input)
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        input = ctx.saved_variables[0]
+        output = corner_pool_ext.right_pool_backward(input, grad_output)
+        return output
+
+
+class CornerPool(nn.Module):
+    """Corner Pooling.
+
+    Corner Pooling is a new type of pooling layer that helps a
+    convolutional network better localize corners of bounding boxes.
+
+    Please refer to https://arxiv.org/abs/1808.01244 for more details.
+    Code is modified from https://github.com/princeton-vl/CornerNet-Lite.
+
+    Args:
+        mode(str): Pooling orientation for the pooling layer
+
+            - 'bottom': Bottom Pooling
+            - 'left': Left Pooling
+            - 'right': Right Pooling
+            - 'top': Top Pooling
+
+    Returns:
+        Feature map after pooling.
+    """
+
+    pool_functions = {
+        'bottom': BottomPoolFunction,
+        'left': LeftPoolFunction,
+        'right': RightPoolFunction,
+        'top': TopPoolFunction,
+    }
+
+    def __init__(self, mode):
+        super(CornerPool, self).__init__()
+        assert mode in self.pool_functions
+        self.corner_pool = self.pool_functions[mode]
+
+    def forward(self, x):
+        return self.corner_pool.apply(x)
diff --git a/mmdet/ops/corner_pool/src/corner_pool.cpp b/mmdet/ops/corner_pool/src/corner_pool.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..a1fde8078a81dd2f2c9fa57a4d412f9ecae27685
--- /dev/null
+++ b/mmdet/ops/corner_pool/src/corner_pool.cpp
@@ -0,0 +1,268 @@
+// Modified from
+// https://github.com/princeton-vl/CornerNet-Lite/tree/master/core/models/py_utils/_cpools/src
+#include <torch/torch.h>
+
+#include <vector>
+
+at::Tensor bottom_pool_forward(at::Tensor input) {
+  // Initialize output
+  at::Tensor output = at::zeros_like(input);
+
+  // Get height
+  int64_t height = input.size(2);
+
+  output.copy_(input);
+
+  for (int64_t ind = 1; ind < height; ind <<= 1) {
+    at::Tensor max_temp = at::slice(output, 2, ind, height);
+    at::Tensor cur_temp = at::slice(output, 2, ind, height).clone();
+    at::Tensor next_temp = at::slice(output, 2, 0, height - ind).clone();
+    at::max_out(max_temp, cur_temp, next_temp);
+  }
+
+  return output;
+}
+
+at::Tensor bottom_pool_backward(at::Tensor input, at::Tensor grad_output) {
+  auto output = at::zeros_like(input);
+
+  int32_t batch = input.size(0);
+  int32_t channel = input.size(1);
+  int32_t height = input.size(2);
+  int32_t width = input.size(3);
+
+  auto max_val = torch::zeros({batch, channel, width},
+                              at::device(at::kCUDA).dtype(at::kFloat));
+  auto max_ind = torch::zeros({batch, channel, width},
+                              at::device(at::kCUDA).dtype(at::kLong));
+
+  auto input_temp = input.select(2, 0);
+  max_val.copy_(input_temp);
+
+  max_ind.fill_(0);
+
+  auto output_temp = output.select(2, 0);
+  auto grad_output_temp = grad_output.select(2, 0);
+  output_temp.copy_(grad_output_temp);
+
+  auto un_max_ind = max_ind.unsqueeze(2);
+  auto gt_mask = torch::zeros({batch, channel, width},
+                              at::device(at::kCUDA).dtype(at::kBool));
+  auto max_temp = torch::zeros({batch, channel, width},
+                               at::device(at::kCUDA).dtype(at::kFloat));
+  for (int32_t ind = 0; ind < height - 1; ++ind) {
+    input_temp = input.select(2, ind + 1);
+    at::gt_out(gt_mask, input_temp, max_val);
+
+    at::masked_select_out(max_temp, input_temp, gt_mask);
+    max_val.masked_scatter_(gt_mask, max_temp);
+    max_ind.masked_fill_(gt_mask, ind + 1);
+
+    grad_output_temp = grad_output.select(2, ind + 1).unsqueeze(2);
+    output.scatter_add_(2, un_max_ind, grad_output_temp);
+  }
+
+  return output;
+}
+
+at::Tensor left_pool_forward(at::Tensor input) {
+  // Initialize output
+  at::Tensor output = at::zeros_like(input);
+
+  // Get width
+  int64_t width = input.size(3);
+
+  output.copy_(input);
+
+  for (int64_t ind = 1; ind < width; ind <<= 1) {
+    at::Tensor max_temp = at::slice(output, 3, 0, width - ind);
+    at::Tensor cur_temp = at::slice(output, 3, 0, width - ind).clone();
+    at::Tensor next_temp = at::slice(output, 3, ind, width).clone();
+    at::max_out(max_temp, cur_temp, next_temp);
+  }
+
+  return output;
+}
+
+at::Tensor left_pool_backward(at::Tensor input, at::Tensor grad_output) {
+  auto output = at::zeros_like(input);
+
+  int32_t batch = input.size(0);
+  int32_t channel = input.size(1);
+  int32_t height = input.size(2);
+  int32_t width = input.size(3);
+
+  auto max_val = torch::zeros({batch, channel, height},
+                              at::device(at::kCUDA).dtype(at::kFloat));
+  auto max_ind = torch::zeros({batch, channel, height},
+                              at::device(at::kCUDA).dtype(at::kLong));
+
+  auto input_temp = input.select(3, width - 1);
+  max_val.copy_(input_temp);
+
+  max_ind.fill_(width - 1);
+
+  auto output_temp = output.select(3, width - 1);
+  auto grad_output_temp = grad_output.select(3, width - 1);
+  output_temp.copy_(grad_output_temp);
+
+  auto un_max_ind = max_ind.unsqueeze(3);
+  auto gt_mask = torch::zeros({batch, channel, height},
+                              at::device(at::kCUDA).dtype(at::kBool));
+  auto max_temp = torch::zeros({batch, channel, height},
+                               at::device(at::kCUDA).dtype(at::kFloat));
+  for (int32_t ind = 1; ind < width; ++ind) {
+    input_temp = input.select(3, width - ind - 1);
+    at::gt_out(gt_mask, input_temp, max_val);
+
+    at::masked_select_out(max_temp, input_temp, gt_mask);
+    max_val.masked_scatter_(gt_mask, max_temp);
+    max_ind.masked_fill_(gt_mask, width - ind - 1);
+
+    grad_output_temp = grad_output.select(3, width - ind - 1).unsqueeze(3);
+    output.scatter_add_(3, un_max_ind, grad_output_temp);
+  }
+
+  return output;
+}
+
+at::Tensor right_pool_forward(at::Tensor input) {
+  // Initialize output
+  at::Tensor output = at::zeros_like(input);
+
+  // Get width
+  int64_t width = input.size(3);
+
+  output.copy_(input);
+
+  for (int64_t ind = 1; ind < width; ind <<= 1) {
+    at::Tensor max_temp = at::slice(output, 3, ind, width);
+    at::Tensor cur_temp = at::slice(output, 3, ind, width).clone();
+    at::Tensor next_temp = at::slice(output, 3, 0, width - ind).clone();
+    at::max_out(max_temp, cur_temp, next_temp);
+  }
+
+  return output;
+}
+
+at::Tensor right_pool_backward(at::Tensor input, at::Tensor grad_output) {
+  at::Tensor output = at::zeros_like(input);
+
+  int32_t batch = input.size(0);
+  int32_t channel = input.size(1);
+  int32_t height = input.size(2);
+  int32_t width = input.size(3);
+
+  auto max_val = torch::zeros({batch, channel, height},
+                              at::device(at::kCUDA).dtype(at::kFloat));
+  auto max_ind = torch::zeros({batch, channel, height},
+                              at::device(at::kCUDA).dtype(at::kLong));
+
+  auto input_temp = input.select(3, 0);
+  max_val.copy_(input_temp);
+
+  max_ind.fill_(0);
+
+  auto output_temp = output.select(3, 0);
+  auto grad_output_temp = grad_output.select(3, 0);
+  output_temp.copy_(grad_output_temp);
+
+  auto un_max_ind = max_ind.unsqueeze(3);
+  auto gt_mask = torch::zeros({batch, channel, height},
+                              at::device(at::kCUDA).dtype(at::kBool));
+  auto max_temp = torch::zeros({batch, channel, height},
+                               at::device(at::kCUDA).dtype(at::kFloat));
+  for (int32_t ind = 0; ind < width - 1; ++ind) {
+    input_temp = input.select(3, ind + 1);
+    at::gt_out(gt_mask, input_temp, max_val);
+
+    at::masked_select_out(max_temp, input_temp, gt_mask);
+    max_val.masked_scatter_(gt_mask, max_temp);
+    max_ind.masked_fill_(gt_mask, ind + 1);
+
+    grad_output_temp = grad_output.select(3, ind + 1).unsqueeze(3);
+    output.scatter_add_(3, un_max_ind, grad_output_temp);
+  }
+
+  return output;
+}
+
+at::Tensor top_pool_forward(at::Tensor input) {
+  // Initialize output
+  at::Tensor output = at::zeros_like(input);
+
+  // Get height
+  int64_t height = input.size(2);
+
+  output.copy_(input);
+
+  for (int64_t ind = 1; ind < height; ind <<= 1) {
+    at::Tensor max_temp = at::slice(output, 2, 0, height - ind);
+    at::Tensor cur_temp = at::slice(output, 2, 0, height - ind).clone();
+    at::Tensor next_temp = at::slice(output, 2, ind, height).clone();
+    at::max_out(max_temp, cur_temp, next_temp);
+  }
+
+  return output;
+}
+
+at::Tensor top_pool_backward(at::Tensor input, at::Tensor grad_output) {
+  auto output = at::zeros_like(input);
+
+  int32_t batch = input.size(0);
+  int32_t channel = input.size(1);
+  int32_t height = input.size(2);
+  int32_t width = input.size(3);
+
+  auto max_val = torch::zeros({batch, channel, width},
+                              at::device(at::kCUDA).dtype(at::kFloat));
+  auto max_ind = torch::zeros({batch, channel, width},
+                              at::device(at::kCUDA).dtype(at::kLong));
+
+  auto input_temp = input.select(2, height - 1);
+  max_val.copy_(input_temp);
+
+  max_ind.fill_(height - 1);
+
+  auto output_temp = output.select(2, height - 1);
+  auto grad_output_temp = grad_output.select(2, height - 1);
+  output_temp.copy_(grad_output_temp);
+
+  auto un_max_ind = max_ind.unsqueeze(2);
+  auto gt_mask = torch::zeros({batch, channel, width},
+                              at::device(at::kCUDA).dtype(at::kBool));
+  auto max_temp = torch::zeros({batch, channel, width},
+                               at::device(at::kCUDA).dtype(at::kFloat));
+  for (int32_t ind = 1; ind < height; ++ind) {
+    input_temp = input.select(2, height - ind - 1);
+    at::gt_out(gt_mask, input_temp, max_val);
+
+    at::masked_select_out(max_temp, input_temp, gt_mask);
+    max_val.masked_scatter_(gt_mask, max_temp);
+    max_ind.masked_fill_(gt_mask, height - ind - 1);
+
+    grad_output_temp = grad_output.select(2, height - ind - 1).unsqueeze(2);
+    output.scatter_add_(2, un_max_ind, grad_output_temp);
+  }
+
+  return output;
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+  m.def("bottom_pool_forward", &bottom_pool_forward, "Bottom Pool Forward",
+        py::call_guard<py::gil_scoped_release>());
+  m.def("bottom_pool_backward", &bottom_pool_backward, "Bottom Pool Backward",
+        py::call_guard<py::gil_scoped_release>());
+  m.def("left_pool_forward", &left_pool_forward, "Left Pool Forward",
+        py::call_guard<py::gil_scoped_release>());
+  m.def("left_pool_backward", &left_pool_backward, "Left Pool Backward",
+        py::call_guard<py::gil_scoped_release>());
+  m.def("right_pool_forward", &right_pool_forward, "Right Pool Forward",
+        py::call_guard<py::gil_scoped_release>());
+  m.def("right_pool_backward", &right_pool_backward, "Right Pool Backward",
+        py::call_guard<py::gil_scoped_release>());
+  m.def("top_pool_forward", &top_pool_forward, "Top Pool Forward",
+        py::call_guard<py::gil_scoped_release>());
+  m.def("top_pool_backward", &top_pool_backward, "Top Pool Backward",
+        py::call_guard<py::gil_scoped_release>());
+}
diff --git a/setup.py b/setup.py
index 80fa9f670a40b441ff9379af315049245eb179d3..57ccb55e27d32189a97c2ff1c43d69341a58d82a 100755
--- a/setup.py
+++ b/setup.py
@@ -294,7 +294,11 @@ if __name__ == '__main__':
                 sources_cuda=[
                     'src/cuda/carafe_naive_cuda.cpp',
                     'src/cuda/carafe_naive_cuda_kernel.cu'
-                ])
+                ]),
+            make_cuda_ext(
+                name='corner_pool_ext',
+                module='mmdet.ops.corner_pool',
+                sources=['src/corner_pool.cpp']),
         ],
         cmdclass={'build_ext': BuildExtension},
         zip_safe=False)
diff --git a/tests/test_ops/test_corner_pool.py b/tests/test_ops/test_corner_pool.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb84acf0d791a5f3880487d990f9bdef7843a7a0
--- /dev/null
+++ b/tests/test_ops/test_corner_pool.py
@@ -0,0 +1,58 @@
+"""
+CommandLine:
+    pytest tests/test_corner_pool.py
+"""
+import pytest
+import torch
+
+from mmdet.ops import CornerPool
+
+
+def test_corner_pool_device_and_dtypes_cpu():
+    """
+    CommandLine:
+        xdoctest -m tests/test_corner_pool.py \
+            test_corner_pool_device_and_dtypes_cpu
+    """
+    with pytest.raises(AssertionError):
+        # pool mode must in ['bottom', 'left', 'right', 'top']
+        pool = CornerPool('corner')
+
+    lr_tensor = torch.tensor([[[[0, 0, 0, 0, 0], [2, 1, 3, 0, 2],
+                                [5, 4, 1, 1, 6], [0, 0, 0, 0, 0],
+                                [0, 0, 0, 0, 0]]]])
+    tb_tensor = torch.tensor([[[[0, 3, 1, 0, 0], [0, 1, 1, 0, 0],
+                                [0, 3, 4, 0, 0], [0, 2, 2, 0, 0],
+                                [0, 0, 2, 0, 0]]]])
+    # Left Pool
+    left_answer = torch.tensor([[[[0, 0, 0, 0, 0], [3, 3, 3, 2, 2],
+                                  [6, 6, 6, 6, 6], [0, 0, 0, 0, 0],
+                                  [0, 0, 0, 0, 0]]]])
+    pool = CornerPool('left')
+    left_tensor = pool(lr_tensor)
+    assert left_tensor.type() == lr_tensor.type()
+    assert torch.equal(left_tensor, left_answer)
+    # Right Pool
+    right_answer = torch.tensor([[[[0, 0, 0, 0, 0], [2, 2, 3, 3, 3],
+                                   [5, 5, 5, 5, 6], [0, 0, 0, 0, 0],
+                                   [0, 0, 0, 0, 0]]]])
+    pool = CornerPool('right')
+    right_tensor = pool(lr_tensor)
+    assert right_tensor.type() == lr_tensor.type()
+    assert torch.equal(right_tensor, right_answer)
+    # Top Pool
+    top_answer = torch.tensor([[[[0, 3, 4, 0, 0], [0, 3, 4, 0, 0],
+                                 [0, 3, 4, 0, 0], [0, 2, 2, 0, 0],
+                                 [0, 0, 2, 0, 0]]]])
+    pool = CornerPool('top')
+    top_tensor = pool(tb_tensor)
+    assert top_tensor.type() == tb_tensor.type()
+    assert torch.equal(top_tensor, top_answer)
+    # Bottom Pool
+    bottom_answer = torch.tensor([[[[0, 3, 1, 0, 0], [0, 3, 1, 0, 0],
+                                    [0, 3, 4, 0, 0], [0, 3, 4, 0, 0],
+                                    [0, 3, 4, 0, 0]]]])
+    pool = CornerPool('bottom')
+    bottom_tensor = pool(tb_tensor)
+    assert bottom_tensor.type() == tb_tensor.type()
+    assert torch.equal(bottom_tensor, bottom_answer)
diff --git a/tests/test_merge_cells.py b/tests/test_ops/test_merge_cells.py
similarity index 100%
rename from tests/test_merge_cells.py
rename to tests/test_ops/test_merge_cells.py
diff --git a/tests/test_nms.py b/tests/test_ops/test_nms.py
similarity index 100%
rename from tests/test_nms.py
rename to tests/test_ops/test_nms.py
diff --git a/tests/test_soft_nms.py b/tests/test_ops/test_soft_nms.py
similarity index 100%
rename from tests/test_soft_nms.py
rename to tests/test_ops/test_soft_nms.py
diff --git a/tests/test_wrappers.py b/tests/test_ops/test_wrappers.py
similarity index 100%
rename from tests/test_wrappers.py
rename to tests/test_ops/test_wrappers.py