From 7ed8d51edaf373a7ace08750685377c75ee9578e Mon Sep 17 00:00:00 2001
From: Wenwei Zhang <40779233+ZwwWayne@users.noreply.github.com>
Date: Mon, 6 Apr 2020 18:39:54 +0800
Subject: [PATCH] Change V2.0 coors (#2380)

* Refactor (all): change coordinate system

* Fix (mask_head): fix cat -1 bug in mask_paste

* Fix (unittest)
: modify unittest and pass CI

* reformat to pass CI

* Fix round coordinates bugs

* clean file

* Fix (test): use cpu version of aligned roi_align in tests

* Refactor (mask): clean np.stack

* Refactor (head): reformat code and fix missing -1

* Reformat: reformat and add doc strings

* Refactor (mask_head): more clea docstring
---
 mmdet/core/anchor/anchor_generator.py        |  27 +-
 mmdet/core/anchor/guided_anchor_target.py    |  12 +-
 mmdet/core/bbox/geometry.py                  |  24 +-
 mmdet/core/bbox/transforms.py                |  46 +--
 mmdet/core/evaluation/bbox_overlaps.py       |  10 +-
 mmdet/core/evaluation/mean_ap.py             |  24 +-
 mmdet/core/mask/mask_target.py               |  12 +-
 mmdet/core/mask/structures.py                |  62 +--
 mmdet/datasets/cityscapes.py                 |   2 +-
 mmdet/datasets/coco.py                       |   8 +-
 mmdet/datasets/pipelines/instaboost.py       |  14 +-
 mmdet/datasets/pipelines/transforms.py       |  16 +-
 mmdet/datasets/xml_style.py                  |   1 +
 mmdet/models/anchor_heads/fcos_head.py       |   4 +-
 mmdet/models/anchor_heads/ga_rpn_head.py     |   4 +-
 mmdet/models/anchor_heads/rpn_head.py        |   4 +-
 mmdet/models/anchor_heads/ssd_head.py        |   2 +-
 mmdet/models/bbox_heads/bbox_head.py         |   4 +-
 mmdet/models/losses/iou_loss.py              |  16 +-
 mmdet/models/mask_heads/fcn_mask_head.py     | 169 ++++++--
 mmdet/models/mask_heads/grid_head.py         |   4 +-
 mmdet/models/roi_extractors/single_level.py  |  14 +-
 mmdet/ops/nms/nms_wrapper.py                 |   6 +-
 mmdet/ops/nms/src/cpu/nms_cpu.cpp            |  65 +--
 mmdet/ops/nms/src/cuda/nms_kernel.cu         |   6 +-
 mmdet/ops/roi_align/roi_align.py             |  28 +-
 mmdet/ops/roi_align/src/cpu/roi_align_v2.cpp | 404 +++++++++++++++++++
 mmdet/ops/roi_align/src/roi_align_ext.cpp    |  23 +-
 setup.py                                     |   5 +-
 tests/test_assigner.py                       |   2 +-
 tests/test_forward.py                        |   9 -
 tests/test_masks.py                          |  27 +-
 tests/test_nms.py                            |  16 +-
 tests/test_sampler.py                        |   4 +-
 34 files changed, 801 insertions(+), 273 deletions(-)
 create mode 100644 mmdet/ops/roi_align/src/cpu/roi_align_v2.cpp

diff --git a/mmdet/core/anchor/anchor_generator.py b/mmdet/core/anchor/anchor_generator.py
index cd227ad0..e7926858 100644
--- a/mmdet/core/anchor/anchor_generator.py
+++ b/mmdet/core/anchor/anchor_generator.py
@@ -8,10 +8,10 @@ class AnchorGenerator(object):
         >>> self = AnchorGenerator(9, [1.], [1.])
         >>> all_anchors = self.grid_anchors((2, 2), device='cpu')
         >>> print(all_anchors)
-        tensor([[ 0.,  0.,  8.,  8.],
-                [16.,  0., 24.,  8.],
-                [ 0., 16.,  8., 24.],
-                [16., 16., 24., 24.]])
+        tensor([[-4.5000, -4.5000,  4.5000,  4.5000],
+                [11.5000, -4.5000, 20.5000,  4.5000],
+                [-4.5000, 11.5000,  4.5000, 20.5000],
+                [11.5000, 11.5000, 20.5000, 20.5000]])
     """
 
     def __init__(self, base_size, scales, ratios, scale_major=True, ctr=None):
@@ -30,8 +30,8 @@ class AnchorGenerator(object):
         w = self.base_size
         h = self.base_size
         if self.ctr is None:
-            x_ctr = 0.5 * (w - 1)
-            y_ctr = 0.5 * (h - 1)
+            x_ctr = 0.
+            y_ctr = 0.
         else:
             x_ctr, y_ctr = self.ctr
 
@@ -44,14 +44,13 @@ class AnchorGenerator(object):
             ws = (w * self.scales[:, None] * w_ratios[None, :]).view(-1)
             hs = (h * self.scales[:, None] * h_ratios[None, :]).view(-1)
 
-        # yapf: disable
-        base_anchors = torch.stack(
-            [
-                x_ctr - 0.5 * (ws - 1), y_ctr - 0.5 * (hs - 1),
-                x_ctr + 0.5 * (ws - 1), y_ctr + 0.5 * (hs - 1)
-            ],
-            dim=-1).round()
-        # yapf: enable
+        # use float anchor and the anchor's center is aligned with the
+        # pixel center
+        base_anchors = [
+            x_ctr - 0.5 * ws, y_ctr - 0.5 * hs, x_ctr + 0.5 * ws,
+            y_ctr + 0.5 * hs
+        ]
+        base_anchors = torch.stack(base_anchors, dim=-1)
 
         return base_anchors
 
diff --git a/mmdet/core/anchor/guided_anchor_target.py b/mmdet/core/anchor/guided_anchor_target.py
index 21162eb9..45d0177b 100644
--- a/mmdet/core/anchor/guided_anchor_target.py
+++ b/mmdet/core/anchor/guided_anchor_target.py
@@ -22,10 +22,10 @@ def calc_region(bbox, ratio, featmap_size=None):
     x2 = torch.round(ratio * bbox[0] + (1 - ratio) * bbox[2]).long()
     y2 = torch.round(ratio * bbox[1] + (1 - ratio) * bbox[3]).long()
     if featmap_size is not None:
-        x1 = x1.clamp(min=0, max=featmap_size[1] - 1)
-        y1 = y1.clamp(min=0, max=featmap_size[0] - 1)
-        x2 = x2.clamp(min=0, max=featmap_size[1] - 1)
-        y2 = y2.clamp(min=0, max=featmap_size[0] - 1)
+        x1 = x1.clamp(min=0, max=featmap_size[1])
+        y1 = y1.clamp(min=0, max=featmap_size[0])
+        x2 = x2.clamp(min=0, max=featmap_size[1])
+        y2 = y2.clamp(min=0, max=featmap_size[0])
     return (x1, y1, x2, y2)
 
 
@@ -76,8 +76,8 @@ def ga_loc_target(gt_bboxes_list,
         all_ignore_map.append(ignore_map)
     for img_id in range(img_per_gpu):
         gt_bboxes = gt_bboxes_list[img_id]
-        scale = torch.sqrt((gt_bboxes[:, 2] - gt_bboxes[:, 0] + 1) *
-                           (gt_bboxes[:, 3] - gt_bboxes[:, 1] + 1))
+        scale = torch.sqrt((gt_bboxes[:, 2] - gt_bboxes[:, 0]) *
+                           (gt_bboxes[:, 3] - gt_bboxes[:, 1]))
         min_anchor_size = scale.new_full(
             (1, ), float(anchor_scale * anchor_strides[0]))
         # assign gt bboxes to different feature levels w.r.t. their scales
diff --git a/mmdet/core/bbox/geometry.py b/mmdet/core/bbox/geometry.py
index ff7c5d4f..6fd791ed 100644
--- a/mmdet/core/bbox/geometry.py
+++ b/mmdet/core/bbox/geometry.py
@@ -30,8 +30,8 @@ def bbox_overlaps(bboxes1, bboxes2, mode='iou', is_aligned=False):
         >>>     [10, 10, 20, 20],
         >>> ])
         >>> bbox_overlaps(bboxes1, bboxes2)
-        tensor([[0.5238, 0.0500, 0.0041],
-                [0.0323, 0.0452, 1.0000],
+        tensor([[0.5000, 0.0000, 0.0000],
+                [0.0000, 0.0000, 1.0000],
                 [0.0000, 0.0000, 0.0000]])
 
     Example:
@@ -58,14 +58,14 @@ def bbox_overlaps(bboxes1, bboxes2, mode='iou', is_aligned=False):
         lt = torch.max(bboxes1[:, :2], bboxes2[:, :2])  # [rows, 2]
         rb = torch.min(bboxes1[:, 2:], bboxes2[:, 2:])  # [rows, 2]
 
-        wh = (rb - lt + 1).clamp(min=0)  # [rows, 2]
+        wh = (rb - lt).clamp(min=0)  # [rows, 2]
         overlap = wh[:, 0] * wh[:, 1]
-        area1 = (bboxes1[:, 2] - bboxes1[:, 0] + 1) * (
-            bboxes1[:, 3] - bboxes1[:, 1] + 1)
+        area1 = (bboxes1[:, 2] - bboxes1[:, 0]) * (
+            bboxes1[:, 3] - bboxes1[:, 1])
 
         if mode == 'iou':
-            area2 = (bboxes2[:, 2] - bboxes2[:, 0] + 1) * (
-                bboxes2[:, 3] - bboxes2[:, 1] + 1)
+            area2 = (bboxes2[:, 2] - bboxes2[:, 0]) * (
+                bboxes2[:, 3] - bboxes2[:, 1])
             ious = overlap / (area1 + area2 - overlap)
         else:
             ious = overlap / area1
@@ -73,14 +73,14 @@ def bbox_overlaps(bboxes1, bboxes2, mode='iou', is_aligned=False):
         lt = torch.max(bboxes1[:, None, :2], bboxes2[:, :2])  # [rows, cols, 2]
         rb = torch.min(bboxes1[:, None, 2:], bboxes2[:, 2:])  # [rows, cols, 2]
 
-        wh = (rb - lt + 1).clamp(min=0)  # [rows, cols, 2]
+        wh = (rb - lt).clamp(min=0)  # [rows, cols, 2]
         overlap = wh[:, :, 0] * wh[:, :, 1]
-        area1 = (bboxes1[:, 2] - bboxes1[:, 0] + 1) * (
-            bboxes1[:, 3] - bboxes1[:, 1] + 1)
+        area1 = (bboxes1[:, 2] - bboxes1[:, 0]) * (
+            bboxes1[:, 3] - bboxes1[:, 1])
 
         if mode == 'iou':
-            area2 = (bboxes2[:, 2] - bboxes2[:, 0] + 1) * (
-                bboxes2[:, 3] - bboxes2[:, 1] + 1)
+            area2 = (bboxes2[:, 2] - bboxes2[:, 0]) * (
+                bboxes2[:, 3] - bboxes2[:, 1])
             ious = overlap / (area1[:, None] + area2 - overlap)
         else:
             ious = overlap / (area1[:, None])
diff --git a/mmdet/core/bbox/transforms.py b/mmdet/core/bbox/transforms.py
index 58aa13ad..288d9e79 100644
--- a/mmdet/core/bbox/transforms.py
+++ b/mmdet/core/bbox/transforms.py
@@ -10,13 +10,13 @@ def bbox2delta(proposals, gt, means=[0, 0, 0, 0], stds=[1, 1, 1, 1]):
     gt = gt.float()
     px = (proposals[..., 0] + proposals[..., 2]) * 0.5
     py = (proposals[..., 1] + proposals[..., 3]) * 0.5
-    pw = proposals[..., 2] - proposals[..., 0] + 1.0
-    ph = proposals[..., 3] - proposals[..., 1] + 1.0
+    pw = proposals[..., 2] - proposals[..., 0]
+    ph = proposals[..., 3] - proposals[..., 1]
 
     gx = (gt[..., 0] + gt[..., 2]) * 0.5
     gy = (gt[..., 1] + gt[..., 3]) * 0.5
-    gw = gt[..., 2] - gt[..., 0] + 1.0
-    gh = gt[..., 3] - gt[..., 1] + 1.0
+    gw = gt[..., 2] - gt[..., 0]
+    gh = gt[..., 3] - gt[..., 1]
 
     dx = (gx - px) / pw
     dy = (gy - py) / ph
@@ -71,9 +71,9 @@ def delta2bbox(rois,
         >>>                        [ 0.7, -1.9, -0.5,  0.3]])
         >>> delta2bbox(rois, deltas, max_shape=(32, 32))
         tensor([[0.0000, 0.0000, 1.0000, 1.0000],
-                [0.2817, 0.2817, 4.7183, 4.7183],
-                [0.0000, 0.6321, 7.3891, 0.3679],
-                [5.8967, 2.9251, 5.5033, 3.2749]])
+                [0.1409, 0.1409, 2.8591, 2.8591],
+                [0.0000, 0.3161, 4.1945, 0.6839],
+                [5.0000, 5.0000, 5.0000, 5.0000]])
     """
     means = deltas.new_tensor(means).repeat(1, deltas.size(1) // 4)
     stds = deltas.new_tensor(stds).repeat(1, deltas.size(1) // 4)
@@ -89,8 +89,8 @@ def delta2bbox(rois,
     px = ((rois[:, 0] + rois[:, 2]) * 0.5).unsqueeze(1).expand_as(dx)
     py = ((rois[:, 1] + rois[:, 3]) * 0.5).unsqueeze(1).expand_as(dy)
     # Compute width/height of each roi
-    pw = (rois[:, 2] - rois[:, 0] + 1.0).unsqueeze(1).expand_as(dw)
-    ph = (rois[:, 3] - rois[:, 1] + 1.0).unsqueeze(1).expand_as(dh)
+    pw = (rois[:, 2] - rois[:, 0]).unsqueeze(1).expand_as(dw)
+    ph = (rois[:, 3] - rois[:, 1]).unsqueeze(1).expand_as(dh)
     # Use exp(network energy) to enlarge/shrink each roi
     gw = pw * dw.exp()
     gh = ph * dh.exp()
@@ -98,15 +98,15 @@ def delta2bbox(rois,
     gx = torch.addcmul(px, 1, pw, dx)  # gx = px + pw * dx
     gy = torch.addcmul(py, 1, ph, dy)  # gy = py + ph * dy
     # Convert center-xy/width/height to top-left, bottom-right
-    x1 = gx - gw * 0.5 + 0.5
-    y1 = gy - gh * 0.5 + 0.5
-    x2 = gx + gw * 0.5 - 0.5
-    y2 = gy + gh * 0.5 - 0.5
+    x1 = gx - gw * 0.5
+    y1 = gy - gh * 0.5
+    x2 = gx + gw * 0.5
+    y2 = gy + gh * 0.5
     if max_shape is not None:
-        x1 = x1.clamp(min=0, max=max_shape[1] - 1)
-        y1 = y1.clamp(min=0, max=max_shape[0] - 1)
-        x2 = x2.clamp(min=0, max=max_shape[1] - 1)
-        y2 = y2.clamp(min=0, max=max_shape[0] - 1)
+        x1 = x1.clamp(min=0, max=max_shape[1])
+        y1 = y1.clamp(min=0, max=max_shape[0])
+        x2 = x2.clamp(min=0, max=max_shape[1])
+        y2 = y2.clamp(min=0, max=max_shape[0])
     bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view_as(deltas)
     return bboxes
 
@@ -124,8 +124,8 @@ def bbox_flip(bboxes, img_shape):
     if isinstance(bboxes, torch.Tensor):
         assert bboxes.shape[-1] % 4 == 0
         flipped = bboxes.clone()
-        flipped[:, 0::4] = img_shape[1] - bboxes[:, 2::4] - 1
-        flipped[:, 2::4] = img_shape[1] - bboxes[:, 0::4] - 1
+        flipped[:, 0::4] = img_shape[1] - bboxes[:, 2::4]
+        flipped[:, 2::4] = img_shape[1] - bboxes[:, 0::4]
         return flipped
     elif isinstance(bboxes, np.ndarray):
         return mmcv.bbox_flip(bboxes, img_shape)
@@ -216,8 +216,8 @@ def distance2bbox(points, distance, max_shape=None):
     x2 = points[:, 0] + distance[:, 2]
     y2 = points[:, 1] + distance[:, 3]
     if max_shape is not None:
-        x1 = x1.clamp(min=0, max=max_shape[1] - 1)
-        y1 = y1.clamp(min=0, max=max_shape[0] - 1)
-        x2 = x2.clamp(min=0, max=max_shape[1] - 1)
-        y2 = y2.clamp(min=0, max=max_shape[0] - 1)
+        x1 = x1.clamp(min=0, max=max_shape[1])
+        y1 = y1.clamp(min=0, max=max_shape[0])
+        x2 = x2.clamp(min=0, max=max_shape[1])
+        y2 = y2.clamp(min=0, max=max_shape[0])
     return torch.stack([x1, y1, x2, y2], -1)
diff --git a/mmdet/core/evaluation/bbox_overlaps.py b/mmdet/core/evaluation/bbox_overlaps.py
index ad4c7052..5507e88c 100644
--- a/mmdet/core/evaluation/bbox_overlaps.py
+++ b/mmdet/core/evaluation/bbox_overlaps.py
@@ -28,17 +28,15 @@ def bbox_overlaps(bboxes1, bboxes2, mode='iou'):
         bboxes1, bboxes2 = bboxes2, bboxes1
         ious = np.zeros((cols, rows), dtype=np.float32)
         exchange = True
-    area1 = (bboxes1[:, 2] - bboxes1[:, 0] + 1) * (
-        bboxes1[:, 3] - bboxes1[:, 1] + 1)
-    area2 = (bboxes2[:, 2] - bboxes2[:, 0] + 1) * (
-        bboxes2[:, 3] - bboxes2[:, 1] + 1)
+    area1 = (bboxes1[:, 2] - bboxes1[:, 0]) * (bboxes1[:, 3] - bboxes1[:, 1])
+    area2 = (bboxes2[:, 2] - bboxes2[:, 0]) * (bboxes2[:, 3] - bboxes2[:, 1])
     for i in range(bboxes1.shape[0]):
         x_start = np.maximum(bboxes1[i, 0], bboxes2[:, 0])
         y_start = np.maximum(bboxes1[i, 1], bboxes2[:, 1])
         x_end = np.minimum(bboxes1[i, 2], bboxes2[:, 2])
         y_end = np.minimum(bboxes1[i, 3], bboxes2[:, 3])
-        overlap = np.maximum(x_end - x_start + 1, 0) * np.maximum(
-            y_end - y_start + 1, 0)
+        overlap = np.maximum(x_end - x_start, 0) * np.maximum(
+            y_end - y_start, 0)
         if mode == 'iou':
             union = area1[i] + area2 - overlap
         else:
diff --git a/mmdet/core/evaluation/mean_ap.py b/mmdet/core/evaluation/mean_ap.py
index fedc51d6..b9843a78 100644
--- a/mmdet/core/evaluation/mean_ap.py
+++ b/mmdet/core/evaluation/mean_ap.py
@@ -98,14 +98,14 @@ def tpfp_imagenet(det_bboxes,
         if area_ranges == [(None, None)]:
             fp[...] = 1
         else:
-            det_areas = (det_bboxes[:, 2] - det_bboxes[:, 0] + 1) * (
-                det_bboxes[:, 3] - det_bboxes[:, 1] + 1)
+            det_areas = (det_bboxes[:, 2] - det_bboxes[:, 0]) * (
+                det_bboxes[:, 3] - det_bboxes[:, 1])
             for i, (min_area, max_area) in enumerate(area_ranges):
                 fp[i, (det_areas >= min_area) & (det_areas < max_area)] = 1
         return tp, fp
     ious = bbox_overlaps(det_bboxes, gt_bboxes - 1)
-    gt_w = gt_bboxes[:, 2] - gt_bboxes[:, 0] + 1
-    gt_h = gt_bboxes[:, 3] - gt_bboxes[:, 1] + 1
+    gt_w = gt_bboxes[:, 2] - gt_bboxes[:, 0]
+    gt_h = gt_bboxes[:, 3] - gt_bboxes[:, 1]
     iou_thrs = np.minimum((gt_w * gt_h) / ((gt_w + 10.0) * (gt_h + 10.0)),
                           default_iou_thr)
     # sort all detections by scores in descending order
@@ -144,7 +144,7 @@ def tpfp_imagenet(det_bboxes,
                 fp[k, i] = 1
             else:
                 bbox = det_bboxes[i, :4]
-                area = (bbox[2] - bbox[0] + 1) * (bbox[3] - bbox[1] + 1)
+                area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
                 if area >= min_area and area < max_area:
                     fp[k, i] = 1
     return tp, fp
@@ -194,8 +194,8 @@ def tpfp_default(det_bboxes,
         if area_ranges == [(None, None)]:
             fp[...] = 1
         else:
-            det_areas = (det_bboxes[:, 2] - det_bboxes[:, 0] + 1) * (
-                det_bboxes[:, 3] - det_bboxes[:, 1] + 1)
+            det_areas = (det_bboxes[:, 2] - det_bboxes[:, 0]) * (
+                det_bboxes[:, 3] - det_bboxes[:, 1])
             for i, (min_area, max_area) in enumerate(area_ranges):
                 fp[i, (det_areas >= min_area) & (det_areas < max_area)] = 1
         return tp, fp
@@ -213,8 +213,8 @@ def tpfp_default(det_bboxes,
         if min_area is None:
             gt_area_ignore = np.zeros_like(gt_ignore_inds, dtype=bool)
         else:
-            gt_areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0] + 1) * (
-                gt_bboxes[:, 3] - gt_bboxes[:, 1] + 1)
+            gt_areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * (
+                gt_bboxes[:, 3] - gt_bboxes[:, 1])
             gt_area_ignore = (gt_areas < min_area) | (gt_areas >= max_area)
         for i in sort_inds:
             if ious_max[i] >= iou_thr:
@@ -231,7 +231,7 @@ def tpfp_default(det_bboxes,
                 fp[k, i] = 1
             else:
                 bbox = det_bboxes[i, :4]
-                area = (bbox[2] - bbox[0] + 1) * (bbox[3] - bbox[1] + 1)
+                area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
                 if area >= min_area and area < max_area:
                     fp[k, i] = 1
     return tp, fp
@@ -332,8 +332,8 @@ def eval_map(det_results,
             if area_ranges is None:
                 num_gts[0] += bbox.shape[0]
             else:
-                gt_areas = (bbox[:, 2] - bbox[:, 0] + 1) * (
-                    bbox[:, 3] - bbox[:, 1] + 1)
+                gt_areas = (bbox[:, 2] - bbox[:, 0]) * (
+                    bbox[:, 3] - bbox[:, 1])
                 for k, (min_area, max_area) in enumerate(area_ranges):
                     num_gts[k] += np.sum((gt_areas >= min_area)
                                          & (gt_areas < max_area))
diff --git a/mmdet/core/mask/mask_target.py b/mmdet/core/mask/mask_target.py
index bfc15459..68014889 100644
--- a/mmdet/core/mask/mask_target.py
+++ b/mmdet/core/mask/mask_target.py
@@ -13,21 +13,21 @@ def mask_target(pos_proposals_list, pos_assigned_gt_inds_list, gt_masks_list,
 
 
 def mask_target_single(pos_proposals, pos_assigned_gt_inds, gt_masks, cfg):
+    device = pos_proposals.device
     mask_size = _pair(cfg.mask_size)
     num_pos = pos_proposals.size(0)
     if num_pos > 0:
         proposals_np = pos_proposals.cpu().numpy()
         maxh, maxw = gt_masks.height, gt_masks.width
-        proposals_np[:, [0, 2]] = np.clip(proposals_np[:, [0, 2]], 0, maxw - 1)
-        proposals_np[:, [1, 3]] = np.clip(proposals_np[:, [1, 3]], 0, maxh - 1)
-        proposals_np = proposals_np.astype(np.int32)
+        proposals_np[:, [0, 2]] = np.clip(proposals_np[:, [0, 2]], 0, maxw)
+        proposals_np[:, [1, 3]] = np.clip(proposals_np[:, [1, 3]], 0, maxh)
         pos_assigned_gt_inds = pos_assigned_gt_inds.cpu().numpy()
 
         mask_targets = gt_masks.crop_and_resize(
-            proposals_np, mask_size, inds=pos_assigned_gt_inds).to_ndarray()
+            proposals_np, mask_size, device=device,
+            inds=pos_assigned_gt_inds).to_ndarray()
 
-        mask_targets = torch.from_numpy(np.stack(mask_targets)).float().to(
-            pos_proposals.device)
+        mask_targets = torch.from_numpy(mask_targets).float().to(device)
     else:
         mask_targets = pos_proposals.new_zeros((0, ) + mask_size)
 
diff --git a/mmdet/core/mask/structures.py b/mmdet/core/mask/structures.py
index 7975f502..350d3090 100644
--- a/mmdet/core/mask/structures.py
+++ b/mmdet/core/mask/structures.py
@@ -5,6 +5,8 @@ import numpy as np
 import pycocotools.mask as maskUtils
 import torch
 
+from mmdet.ops.roi_align import roi_align
+
 
 class BaseInstanceMasks(metaclass=ABCMeta):
 
@@ -185,11 +187,11 @@ class BitmapMasks(BaseInstanceMasks):
 
         # clip the boundary
         bbox = bbox.copy()
-        bbox[0::2] = np.clip(bbox[0::2], 0, self.width - 1)
-        bbox[1::2] = np.clip(bbox[1::2], 0, self.height - 1)
+        bbox[0::2] = np.clip(bbox[0::2], 0, self.width)
+        bbox[1::2] = np.clip(bbox[1::2], 0, self.height)
         x1, y1, x2, y2 = bbox
-        w = np.maximum(x2 - x1 + 1, 1)
-        h = np.maximum(y2 - y1 + 1, 1)
+        w = np.maximum(x2 - x1, 1)
+        h = np.maximum(y2 - y1, 1)
 
         if len(self.masks) == 0:
             cropped_masks = np.empty((0, h, w), dtype=np.uint8)
@@ -201,6 +203,7 @@ class BitmapMasks(BaseInstanceMasks):
                         bboxes,
                         out_shape,
                         inds,
+                        device='cpu',
                         interpolation='bilinear'):
         """Crop and resize masks by the given bboxes.
 
@@ -209,9 +212,10 @@ class BitmapMasks(BaseInstanceMasks):
         assigned bbox and resize to the size of (mask_h, mask_w)
 
         Args:
-            bboxes (ndarray): bboxes in format [x1, y1, x2, y2], shape (N, 4)
+            bboxes (Tensor): bboxes in format [x1, y1, x2, y2], shape (N, 4)
             out_shape (tuple[int]): target (h, w) of resized mask
             inds (ndarray): indexes to assign masks to each bbox
+            device (str): device of bboxes
             interpolation (str): see `mmcv.imresize`
 
         Return:
@@ -221,19 +225,26 @@ class BitmapMasks(BaseInstanceMasks):
             empty_masks = np.empty((0, *out_shape), dtype=np.uint8)
             return BitmapMasks(empty_masks, *out_shape)
 
-        resized_masks = []
-        for i in range(len(bboxes)):
-            mask = self.masks[inds[i]]
-            bbox = bboxes[i, :].astype(np.int32)
-            x1, y1, x2, y2 = bbox
-            w = np.maximum(x2 - x1 + 1, 1)
-            h = np.maximum(y2 - y1 + 1, 1)
-            resized_masks.append(
-                mmcv.imresize(
-                    mask[y1:y1 + h, x1:x1 + w],
-                    out_shape,
-                    interpolation=interpolation))
-        return BitmapMasks(np.stack(resized_masks), *out_shape)
+        # convert bboxes to tensor
+        if isinstance(bboxes, np.ndarray):
+            bboxes = torch.from_numpy(bboxes).to(device=device)
+        if isinstance(inds, np.ndarray):
+            inds = torch.from_numpy(inds).to(device=device)
+
+        num_bbox = bboxes.shape[0]
+        fake_inds = torch.arange(
+            num_bbox, device=device).to(dtype=bboxes.dtype)[:, None]
+        rois = torch.cat([fake_inds, bboxes], dim=1)  # Nx5
+        rois = rois.to(device=device)
+        if num_bbox > 0:
+            gt_masks_th = torch.from_numpy(self.masks).to(device).index_select(
+                0, inds).to(dtype=rois.dtype)
+            targets = roi_align(gt_masks_th[:, None, :, :], rois, out_shape,
+                                1.0, 0, True).squeeze(1)
+            resized_masks = (targets >= 0.5).cpu().numpy()
+        else:
+            resized_masks = []
+        return BitmapMasks(resized_masks, *out_shape)
 
     def expand(self, expanded_h, expanded_w, top, left):
         """see `transforms.Expand`."""
@@ -355,7 +366,7 @@ class PolygonMasks(BaseInstanceMasks):
                 flipped_poly_per_obj = []
                 for p in poly_per_obj:
                     p = p.copy()
-                    p[idx::2] = dim - p[idx::2] - 1
+                    p[idx::2] = dim - p[idx::2]
                     flipped_poly_per_obj.append(p)
                 flipped_masks.append(flipped_poly_per_obj)
             flipped_masks = PolygonMasks(flipped_masks, self.height,
@@ -369,11 +380,11 @@ class PolygonMasks(BaseInstanceMasks):
 
         # clip the boundary
         bbox = bbox.copy()
-        bbox[0::2] = np.clip(bbox[0::2], 0, self.width - 1)
-        bbox[1::2] = np.clip(bbox[1::2], 0, self.height - 1)
+        bbox[0::2] = np.clip(bbox[0::2], 0, self.width)
+        bbox[1::2] = np.clip(bbox[1::2], 0, self.height)
         x1, y1, x2, y2 = bbox
-        w = np.maximum(x2 - x1 + 1, 1)
-        h = np.maximum(y2 - y1 + 1, 1)
+        w = np.maximum(x2 - x1, 1)
+        h = np.maximum(y2 - y1, 1)
 
         if len(self.masks) == 0:
             cropped_masks = PolygonMasks([], h, w)
@@ -402,6 +413,7 @@ class PolygonMasks(BaseInstanceMasks):
                         bboxes,
                         out_shape,
                         inds,
+                        device='cpu',
                         interpolation='bilinear'):
         """see BitmapMasks.crop_and_resize"""
         out_h, out_w = out_shape
@@ -413,8 +425,8 @@ class PolygonMasks(BaseInstanceMasks):
             mask = self.masks[inds[i]]
             bbox = bboxes[i, :].astype(np.int32)
             x1, y1, x2, y2 = bbox
-            w = np.maximum(x2 - x1 + 1, 1)
-            h = np.maximum(y2 - y1 + 1, 1)
+            w = np.maximum(x2 - x1, 1)
+            h = np.maximum(y2 - y1, 1)
             h_scale = out_h / h
             w_scale = out_w / w
 
diff --git a/mmdet/datasets/cityscapes.py b/mmdet/datasets/cityscapes.py
index 56bf2881..abd03541 100644
--- a/mmdet/datasets/cityscapes.py
+++ b/mmdet/datasets/cityscapes.py
@@ -60,7 +60,7 @@ class CityscapesDataset(CocoDataset):
             x1, y1, w, h = ann['bbox']
             if ann['area'] <= 0 or w < 1 or h < 1:
                 continue
-            bbox = [x1, y1, x1 + w - 1, y1 + h - 1]
+            bbox = [x1, y1, x1 + w, y1 + h]
             if ann.get('iscrowd', False):
                 gt_bboxes_ignore.append(bbox)
             else:
diff --git a/mmdet/datasets/coco.py b/mmdet/datasets/coco.py
index efd5b57d..e45c8c3a 100644
--- a/mmdet/datasets/coco.py
+++ b/mmdet/datasets/coco.py
@@ -86,7 +86,7 @@ class CocoDataset(CustomDataset):
             x1, y1, w, h = ann['bbox']
             if ann['area'] <= 0 or w < 1 or h < 1:
                 continue
-            bbox = [x1, y1, x1 + w - 1, y1 + h - 1]
+            bbox = [x1, y1, x1 + w, y1 + h]
             if ann.get('iscrowd', False):
                 gt_bboxes_ignore.append(bbox)
             else:
@@ -122,8 +122,8 @@ class CocoDataset(CustomDataset):
         return [
             _bbox[0],
             _bbox[1],
-            _bbox[2] - _bbox[0] + 1,
-            _bbox[3] - _bbox[1] + 1,
+            _bbox[2] - _bbox[0],
+            _bbox[3] - _bbox[1],
         ]
 
     def _proposal2json(self, results):
@@ -249,7 +249,7 @@ class CocoDataset(CustomDataset):
                 if ann.get('ignore', False) or ann['iscrowd']:
                     continue
                 x1, y1, w, h = ann['bbox']
-                bboxes.append([x1, y1, x1 + w - 1, y1 + h - 1])
+                bboxes.append([x1, y1, x1 + w, y1 + h])
             bboxes = np.array(bboxes, dtype=np.float32)
             if bboxes.shape[0] == 0:
                 bboxes = np.zeros((0, 4))
diff --git a/mmdet/datasets/pipelines/instaboost.py b/mmdet/datasets/pipelines/instaboost.py
index 6777d442..3b747955 100644
--- a/mmdet/datasets/pipelines/instaboost.py
+++ b/mmdet/datasets/pipelines/instaboost.py
@@ -44,7 +44,8 @@ class InstaBoost(object):
             bbox = bboxes[i]
             mask = masks[i]
             x1, y1, x2, y2 = bbox
-            bbox = [x1, y1, x2 - x1 + 1, y2 - y1 + 1]
+            # assert (x2 - x1) >= 1 and (y2 - y1) >= 1
+            bbox = [x1, y1, x2 - x1, y2 - y1]
             anns.append({
                 'category_id': label,
                 'segmentation': mask,
@@ -59,7 +60,10 @@ class InstaBoost(object):
         gt_masks_ann = []
         for ann in anns:
             x1, y1, w, h = ann['bbox']
-            bbox = [x1, y1, x1 + w - 1, y1 + h - 1]
+            # TODO: more essential bug need to be fixed in instaboost
+            if w <= 0 or h <= 0:
+                continue
+            bbox = [x1, y1, x1 + w, y1 + h]
             gt_bboxes.append(bbox)
             gt_labels.append(ann['category_id'])
             gt_masks_ann.append(ann['segmentation'])
@@ -73,6 +77,7 @@ class InstaBoost(object):
 
     def __call__(self, results):
         img = results['img']
+        orig_type = img.dtype
         anns = self._load_anns(results)
         if np.random.choice([0, 1], p=[1 - self.aug_ratio, self.aug_ratio]):
             try:
@@ -81,8 +86,9 @@ class InstaBoost(object):
                 raise ImportError('Please run "pip install instaboostfast" '
                                   'to install instaboostfast first.')
             anns, img = instaboost.get_new_data(
-                anns, img, self.cfg, background=None)
-        results = self._parse_anns(results, anns, img)
+                anns, img.astype(np.uint8), self.cfg, background=None)
+
+        results = self._parse_anns(results, anns, img.astype(orig_type))
         return results
 
     def __repr__(self):
diff --git a/mmdet/datasets/pipelines/transforms.py b/mmdet/datasets/pipelines/transforms.py
index f7ef65b0..81099015 100644
--- a/mmdet/datasets/pipelines/transforms.py
+++ b/mmdet/datasets/pipelines/transforms.py
@@ -143,8 +143,8 @@ class Resize(object):
         img_shape = results['img_shape']
         for key in results.get('bbox_fields', []):
             bboxes = results[key] * results['scale_factor']
-            bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1] - 1)
-            bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0] - 1)
+            bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1])
+            bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0])
             results[key] = bboxes
 
     def _resize_masks(self, results):
@@ -215,12 +215,12 @@ class RandomFlip(object):
         flipped = bboxes.copy()
         if direction == 'horizontal':
             w = img_shape[1]
-            flipped[..., 0::4] = w - bboxes[..., 2::4] - 1
-            flipped[..., 2::4] = w - bboxes[..., 0::4] - 1
+            flipped[..., 0::4] = w - bboxes[..., 2::4]
+            flipped[..., 2::4] = w - bboxes[..., 0::4]
         elif direction == 'vertical':
             h = img_shape[0]
-            flipped[..., 1::4] = h - bboxes[..., 3::4] - 1
-            flipped[..., 3::4] = h - bboxes[..., 1::4] - 1
+            flipped[..., 1::4] = h - bboxes[..., 3::4]
+            flipped[..., 3::4] = h - bboxes[..., 1::4]
         else:
             raise ValueError(
                 'Invalid flipping direction "{}"'.format(direction))
@@ -372,8 +372,8 @@ class RandomCrop(object):
             bbox_offset = np.array([offset_w, offset_h, offset_w, offset_h],
                                    dtype=np.float32)
             bboxes = results[key] - bbox_offset
-            bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1] - 1)
-            bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0] - 1)
+            bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1])
+            bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0])
             results[key] = bboxes
 
         # crop semantic seg
diff --git a/mmdet/datasets/xml_style.py b/mmdet/datasets/xml_style.py
index b99ca751..41abbd51 100644
--- a/mmdet/datasets/xml_style.py
+++ b/mmdet/datasets/xml_style.py
@@ -47,6 +47,7 @@ class XMLDataset(CustomDataset):
             label = self.cat2label[name]
             difficult = int(obj.find('difficult').text)
             bnd_box = obj.find('bndbox')
+            # TODO: check whether it is necessary to use int
             # Coordinates may be float type
             bbox = [
                 int(float(bnd_box.find('xmin').text)),
diff --git a/mmdet/models/anchor_heads/fcos_head.py b/mmdet/models/anchor_heads/fcos_head.py
index 58f27f77..1f4a69f7 100644
--- a/mmdet/models/anchor_heads/fcos_head.py
+++ b/mmdet/models/anchor_heads/fcos_head.py
@@ -369,8 +369,8 @@ class FCOSHead(nn.Module):
             return gt_labels.new_zeros(num_points), \
                    gt_bboxes.new_zeros((num_points, 4))
 
-        areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0] + 1) * (
-            gt_bboxes[:, 3] - gt_bboxes[:, 1] + 1)
+        areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * (
+            gt_bboxes[:, 3] - gt_bboxes[:, 1])
         # TODO: figure out why these two are different
         # areas = areas[None].expand(num_points, num_gts)
         areas = areas[None].repeat(num_points, 1)
diff --git a/mmdet/models/anchor_heads/ga_rpn_head.py b/mmdet/models/anchor_heads/ga_rpn_head.py
index 11512ffc..d6e3b447 100644
--- a/mmdet/models/anchor_heads/ga_rpn_head.py
+++ b/mmdet/models/anchor_heads/ga_rpn_head.py
@@ -103,8 +103,8 @@ class GARPNHead(GuidedAnchorHead):
                                    self.target_stds, img_shape)
             # filter out too small bboxes
             if cfg.min_bbox_size > 0:
-                w = proposals[:, 2] - proposals[:, 0] + 1
-                h = proposals[:, 3] - proposals[:, 1] + 1
+                w = proposals[:, 2] - proposals[:, 0]
+                h = proposals[:, 3] - proposals[:, 1]
                 valid_inds = torch.nonzero((w >= cfg.min_bbox_size) &
                                            (h >= cfg.min_bbox_size)).squeeze()
                 proposals = proposals[valid_inds, :]
diff --git a/mmdet/models/anchor_heads/rpn_head.py b/mmdet/models/anchor_heads/rpn_head.py
index f88b949c..adcc7d1c 100644
--- a/mmdet/models/anchor_heads/rpn_head.py
+++ b/mmdet/models/anchor_heads/rpn_head.py
@@ -82,8 +82,8 @@ class RPNHead(AnchorHead):
             proposals = delta2bbox(anchors, rpn_bbox_pred, self.target_means,
                                    self.target_stds, img_shape)
             if cfg.min_bbox_size > 0:
-                w = proposals[:, 2] - proposals[:, 0] + 1
-                h = proposals[:, 3] - proposals[:, 1] + 1
+                w = proposals[:, 2] - proposals[:, 0]
+                h = proposals[:, 3] - proposals[:, 1]
                 valid_inds = torch.nonzero((w >= cfg.min_bbox_size) &
                                            (h >= cfg.min_bbox_size)).squeeze()
                 proposals = proposals[valid_inds, :]
diff --git a/mmdet/models/anchor_heads/ssd_head.py b/mmdet/models/anchor_heads/ssd_head.py
index 57113679..dd21a79e 100644
--- a/mmdet/models/anchor_heads/ssd_head.py
+++ b/mmdet/models/anchor_heads/ssd_head.py
@@ -75,7 +75,7 @@ class SSDHead(AnchorHead):
         for k in range(len(anchor_strides)):
             base_size = min_sizes[k]
             stride = anchor_strides[k]
-            ctr = ((stride - 1) / 2., (stride - 1) / 2.)
+            ctr = ((stride) / 2., (stride) / 2.)
             scales = [1., np.sqrt(max_sizes[k] / min_sizes[k])]
             ratios = [1.]
             for r in anchor_ratios[k]:
diff --git a/mmdet/models/bbox_heads/bbox_head.py b/mmdet/models/bbox_heads/bbox_head.py
index 36cc74f2..4867f36c 100644
--- a/mmdet/models/bbox_heads/bbox_head.py
+++ b/mmdet/models/bbox_heads/bbox_head.py
@@ -154,8 +154,8 @@ class BBoxHead(nn.Module):
         else:
             bboxes = rois[:, 1:].clone()
             if img_shape is not None:
-                bboxes[:, [0, 2]].clamp_(min=0, max=img_shape[1] - 1)
-                bboxes[:, [1, 3]].clamp_(min=0, max=img_shape[0] - 1)
+                bboxes[:, [0, 2]].clamp_(min=0, max=img_shape[1])
+                bboxes[:, [1, 3]].clamp_(min=0, max=img_shape[0])
 
         if rescale:
             if isinstance(scale_factor, float):
diff --git a/mmdet/models/losses/iou_loss.py b/mmdet/models/losses/iou_loss.py
index c19c1d1d..26c042ae 100644
--- a/mmdet/models/losses/iou_loss.py
+++ b/mmdet/models/losses/iou_loss.py
@@ -40,13 +40,13 @@ def bounded_iou_loss(pred, target, beta=0.2, eps=1e-3):
     """
     pred_ctrx = (pred[:, 0] + pred[:, 2]) * 0.5
     pred_ctry = (pred[:, 1] + pred[:, 3]) * 0.5
-    pred_w = pred[:, 2] - pred[:, 0] + 1
-    pred_h = pred[:, 3] - pred[:, 1] + 1
+    pred_w = pred[:, 2] - pred[:, 0]
+    pred_h = pred[:, 3] - pred[:, 1]
     with torch.no_grad():
         target_ctrx = (target[:, 0] + target[:, 2]) * 0.5
         target_ctry = (target[:, 1] + target[:, 3]) * 0.5
-        target_w = target[:, 2] - target[:, 0] + 1
-        target_h = target[:, 3] - target[:, 1] + 1
+        target_w = target[:, 2] - target[:, 0]
+        target_h = target[:, 3] - target[:, 1]
 
     dx = target_ctrx - pred_ctrx
     dy = target_ctry - pred_ctry
@@ -91,12 +91,12 @@ def giou_loss(pred, target, eps=1e-7):
     # overlap
     lt = torch.max(pred[:, :2], target[:, :2])
     rb = torch.min(pred[:, 2:], target[:, 2:])
-    wh = (rb - lt + 1).clamp(min=0)
+    wh = (rb - lt).clamp(min=0)
     overlap = wh[:, 0] * wh[:, 1]
 
     # union
-    ap = (pred[:, 2] - pred[:, 0] + 1) * (pred[:, 3] - pred[:, 1] + 1)
-    ag = (target[:, 2] - target[:, 0] + 1) * (target[:, 3] - target[:, 1] + 1)
+    ap = (pred[:, 2] - pred[:, 0]) * (pred[:, 3] - pred[:, 1])
+    ag = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1])
     union = ap + ag - overlap + eps
 
     # IoU
@@ -105,7 +105,7 @@ def giou_loss(pred, target, eps=1e-7):
     # enclose area
     enclose_x1y1 = torch.min(pred[:, :2], target[:, :2])
     enclose_x2y2 = torch.max(pred[:, 2:], target[:, 2:])
-    enclose_wh = (enclose_x2y2 - enclose_x1y1 + 1).clamp(min=0)
+    enclose_wh = (enclose_x2y2 - enclose_x1y1).clamp(min=0)
     enclose_area = enclose_wh[:, 0] * enclose_wh[:, 1] + eps
 
     # GIoU
diff --git a/mmdet/models/mask_heads/fcn_mask_head.py b/mmdet/models/mask_heads/fcn_mask_head.py
index e94b209e..9353cf0f 100644
--- a/mmdet/models/mask_heads/fcn_mask_head.py
+++ b/mmdet/models/mask_heads/fcn_mask_head.py
@@ -1,4 +1,3 @@
-import mmcv
 import numpy as np
 import pycocotools.mask as mask_util
 import torch
@@ -8,9 +7,15 @@ from torch.nn.modules.utils import _pair
 from mmdet.core import auto_fp16, force_fp32, mask_target
 from mmdet.ops import ConvModule, build_upsample_layer
 from mmdet.ops.carafe import CARAFEPack
+from mmdet.ops.grid_sampler import grid_sample
 from ..builder import build_loss
 from ..registry import HEADS
 
+BYTES_PER_FLOAT = 4
+# TODO: This memory limit may be too much or too little. It would be better to
+# determine it based on available resources.
+GPU_MEM_LIMIT = 1024**3  # 1 GB memory limit
+
 
 @HEADS.register_module
 class FCNMaskHead(nn.Module):
@@ -144,7 +149,7 @@ class FCNMaskHead(nn.Module):
         """Get segmentation masks from mask_pred and bboxes.
 
         Args:
-            mask_pred (Tensor or ndarray): shape (n, #class+1, h, w).
+            mask_pred (Tensor or ndarray): shape (n, #class, h, w).
                 For single-scale testing, mask_pred is the direct output of
                 model, whose type is Tensor, while for multi-scale testing,
                 it will be converted to numpy array outside of this method.
@@ -158,15 +163,15 @@ class FCNMaskHead(nn.Module):
             list[list]: encoded masks
         """
         if isinstance(mask_pred, torch.Tensor):
-            mask_pred = mask_pred.sigmoid().cpu().numpy()
-        assert isinstance(mask_pred, np.ndarray)
-        # when enabling mixed precision training, mask_pred may be float16
-        # numpy array
-        mask_pred = mask_pred.astype(np.float32)
+            mask_pred = mask_pred.sigmoid()
+        else:
+            mask_pred = det_bboxes.new_tensor(mask_pred)
 
-        cls_segms = [[] for _ in range(self.num_classes - 1)]
-        bboxes = det_bboxes.cpu().numpy()[:, :4]
-        labels = det_labels.cpu().numpy() + 1
+        device = mask_pred.device
+        cls_segms = [[] for _ in range(self.num_classes)
+                     ]  # BG is not included in num_classes
+        bboxes = det_bboxes[:, :4]
+        labels = det_labels + 1  # TODO: remove + 1 in cat -1
 
         if rescale:
             img_h, img_w = ori_shape[:2]
@@ -175,34 +180,130 @@ class FCNMaskHead(nn.Module):
             img_w = np.round(ori_shape[1] * scale_factor).astype(np.int32)
             scale_factor = 1.0
 
-        for i in range(bboxes.shape[0]):
-            if not isinstance(scale_factor, (float, np.ndarray)):
-                scale_factor = scale_factor.cpu().numpy()
-            bbox = (bboxes[i, :] / scale_factor).astype(np.int32)
-            label = labels[i]
-            w = max(bbox[2] - bbox[0] + 1, 1)
-            h = max(bbox[3] - bbox[1] + 1, 1)
+        if not isinstance(scale_factor, (float, torch.Tensor)):
+            scale_factor = bboxes.new_tensor(scale_factor)
+        bboxes = bboxes / scale_factor
 
-            if not self.class_agnostic:
-                mask_pred_ = mask_pred[i, label, :, :]
-            else:
-                mask_pred_ = mask_pred[i, 0, :, :]
+        N = len(mask_pred)
+        # The actual implementation split the input into chunks,
+        # and paste them chunk by chunk.
+        if device.type == 'cpu':
+            # CPU is most efficient when they are pasted one by one with
+            # skip_empty=True, so that it performs minimal number of
+            # operations.
+            num_chunks = N
+        else:
+            # GPU benefits from parallelism for larger chunks,
+            # but may have memory issue
+            num_chunks = int(
+                np.ceil(N * img_h * img_w * BYTES_PER_FLOAT / GPU_MEM_LIMIT))
+            assert (num_chunks <=
+                    N), 'Default GPU_MEM_LIMIT is too small; try increasing it'
+        chunks = torch.chunk(torch.arange(N, device=device), num_chunks)
 
-            bbox_mask = mmcv.imresize(mask_pred_, (w, h))
-            bbox_mask = (bbox_mask > rcnn_test_cfg.mask_thr_binary).astype(
-                np.uint8)
+        threshold = rcnn_test_cfg.mask_thr_binary
+        im_mask = torch.zeros(
+            N,
+            img_h,
+            img_w,
+            device=device,
+            dtype=torch.bool if threshold >= 0 else torch.uint8)
 
-            if rcnn_test_cfg.get('crop_mask', False):
-                im_mask = bbox_mask
-            else:
-                im_mask = np.zeros((img_h, img_w), dtype=np.uint8)
-                im_mask[bbox[1]:bbox[1] + h, bbox[0]:bbox[0] + w] = bbox_mask
+        if not self.class_agnostic:
+            mask_pred = mask_pred[range(N), labels][:, None]
 
-            if rcnn_test_cfg.get('rle_mask_encode', True):
-                rle = mask_util.encode(
-                    np.array(im_mask[:, :, np.newaxis], order='F'))[0]
-                cls_segms[label - 1].append(rle)
+        for inds in chunks:
+            masks_chunk, spatial_inds = _do_paste_mask(
+                mask_pred[inds],
+                bboxes[inds],
+                img_h,
+                img_w,
+                skip_empty=device.type == 'cpu')
+
+            if threshold >= 0:
+                masks_chunk = (masks_chunk >= threshold).to(dtype=torch.bool)
             else:
-                cls_segms[label - 1].append(im_mask)
+                # for visualization and debugging
+                masks_chunk = (masks_chunk * 255).to(dtype=torch.uint8)
+
+            im_mask[(inds, ) + spatial_inds] = masks_chunk
 
+        for i in range(N):
+            rle = mask_util.encode(
+                np.array(
+                    im_mask[i][:, :, None].cpu().numpy(),
+                    order='F',
+                    dtype='uint8'))[0]
+            cls_segms[labels[i] - 1].append(rle)  # TODO: remove -1 in cat -1
         return cls_segms
+
+
+def _do_paste_mask(masks, boxes, img_h, img_w, skip_empty=True):
+    """Paste instance masks acoording to boxes.
+
+    This implementation is modified from
+    https://github.com/facebookresearch/detectron2/
+
+    Args:
+        masks (Tensor): N, 1, H, W
+        boxes (Tensor): N, 4
+        img_h (int): Height of the image to be pasted.
+        img_w (int): Width of the image to be pasted.
+        skip_empty (bool): Only paste masks within the region that
+            tightly bound all boxes, and returns the results this region only.
+            An important optimization for CPU.
+
+    Returns:
+        tuple: (Tensor, tuple). The first item is mask tensor, the second one
+            is the slice object.
+        If skip_empty == False, the whole image will be pasted. It will
+            return a mask of shape (N, img_h, img_w) and an empty tuple.
+        If skip_empty == True, only area around the mask will be pasted.
+            A mask of shape (N, h', w') and its start and end coordinates
+            in the original image will be returned.
+    """
+    # On GPU, paste all masks together (up to chunk size)
+    # by using the entire image to sample the masks
+    # Compared to pasting them one by one,
+    # this has more operations but is faster on COCO-scale dataset.
+    device = masks.device
+    if skip_empty:
+        x0_int, y0_int = torch.clamp(
+            boxes.min(dim=0).values.floor()[:2] - 1,
+            min=0).to(dtype=torch.int32)
+        x1_int = torch.clamp(
+            boxes[:, 2].max().ceil() + 1, max=img_w).to(dtype=torch.int32)
+        y1_int = torch.clamp(
+            boxes[:, 3].max().ceil() + 1, max=img_h).to(dtype=torch.int32)
+    else:
+        x0_int, y0_int = 0, 0
+        x1_int, y1_int = img_w, img_h
+    x0, y0, x1, y1 = torch.split(boxes, 1, dim=1)  # each is Nx1
+
+    N = masks.shape[0]
+
+    img_y = torch.arange(
+        y0_int, y1_int, device=device, dtype=torch.float32) + 0.5
+    img_x = torch.arange(
+        x0_int, x1_int, device=device, dtype=torch.float32) + 0.5
+    img_y = (img_y - y0) / (y1 - y0) * 2 - 1
+    img_x = (img_x - x0) / (x1 - x0) * 2 - 1
+    # img_x, img_y have shapes (N, w), (N, h)
+    if torch.isinf(img_x).any():
+        inds = torch.where(torch.isinf(img_x))
+        img_x[inds] = 0
+    if torch.isinf(img_y).any():
+        inds = torch.where(torch.isinf(img_y))
+        img_y[inds] = 0
+
+    gx = img_x[:, None, :].expand(N, img_y.size(1), img_x.size(1))
+    gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1))
+    grid = torch.stack([gx, gy], dim=3)
+
+    img_masks = grid_sample(
+        masks.to(dtype=torch.float32), grid, align_corners=False)
+
+    if skip_empty:
+        return img_masks[:, 0], (slice(y0_int, y1_int), slice(x0_int, x1_int))
+    else:
+        return img_masks[:, 0], ()
diff --git a/mmdet/models/mask_heads/grid_head.py b/mmdet/models/mask_heads/grid_head.py
index ec669cb2..3449bca1 100644
--- a/mmdet/models/mask_heads/grid_head.py
+++ b/mmdet/models/mask_heads/grid_head.py
@@ -355,7 +355,7 @@ class GridHead(nn.Module):
 
         bbox_res = torch.cat(
             [bboxes_x1, bboxes_y1, bboxes_x2, bboxes_y2, cls_scores], dim=1)
-        bbox_res[:, [0, 2]].clamp_(min=0, max=img_metas[0]['img_shape'][1] - 1)
-        bbox_res[:, [1, 3]].clamp_(min=0, max=img_metas[0]['img_shape'][0] - 1)
+        bbox_res[:, [0, 2]].clamp_(min=0, max=img_metas[0]['img_shape'][1])
+        bbox_res[:, [1, 3]].clamp_(min=0, max=img_metas[0]['img_shape'][0])
 
         return bbox_res
diff --git a/mmdet/models/roi_extractors/single_level.py b/mmdet/models/roi_extractors/single_level.py
index 6620d1d8..54b25cc8 100644
--- a/mmdet/models/roi_extractors/single_level.py
+++ b/mmdet/models/roi_extractors/single_level.py
@@ -67,7 +67,7 @@ class SingleRoIExtractor(nn.Module):
             Tensor: Level index (0-based) of each RoI, shape (k, )
         """
         scale = torch.sqrt(
-            (rois[:, 3] - rois[:, 1] + 1) * (rois[:, 4] - rois[:, 2] + 1))
+            (rois[:, 3] - rois[:, 1]) * (rois[:, 4] - rois[:, 2]))
         target_lvls = torch.floor(torch.log2(scale / self.finest_scale + 1e-6))
         target_lvls = target_lvls.clamp(min=0, max=num_levels - 1).long()
         return target_lvls
@@ -75,14 +75,14 @@ class SingleRoIExtractor(nn.Module):
     def roi_rescale(self, rois, scale_factor):
         cx = (rois[:, 1] + rois[:, 3]) * 0.5
         cy = (rois[:, 2] + rois[:, 4]) * 0.5
-        w = rois[:, 3] - rois[:, 1] + 1
-        h = rois[:, 4] - rois[:, 2] + 1
+        w = rois[:, 3] - rois[:, 1]
+        h = rois[:, 4] - rois[:, 2]
         new_w = w * scale_factor
         new_h = h * scale_factor
-        x1 = cx - new_w * 0.5 + 0.5
-        x2 = cx + new_w * 0.5 - 0.5
-        y1 = cy - new_h * 0.5 + 0.5
-        y2 = cy + new_h * 0.5 - 0.5
+        x1 = cx - new_w * 0.5
+        x2 = cx + new_w * 0.5
+        y1 = cy - new_h * 0.5
+        y2 = cy + new_h * 0.5
         new_rois = torch.stack((rois[:, 0], x1, y1, x2, y2), dim=-1)
         return new_rois
 
diff --git a/mmdet/ops/nms/nms_wrapper.py b/mmdet/ops/nms/nms_wrapper.py
index a9ebac22..9f1a84c4 100644
--- a/mmdet/ops/nms/nms_wrapper.py
+++ b/mmdet/ops/nms/nms_wrapper.py
@@ -29,7 +29,7 @@ def nms(dets, iou_thr, device_id=None):
         >>>                  [35.6, 11.8, 39.3, 14.2, 0.5],
         >>>                  [35.3, 11.5, 39.9, 14.5, 0.4],
         >>>                  [35.2, 11.7, 39.7, 15.7, 0.3]], dtype=np.float32)
-        >>> iou_thr = 0.7
+        >>> iou_thr = 0.6
         >>> suppressed, inds = nms(dets, iou_thr)
         >>> assert len(inds) == len(suppressed) == 3
     """
@@ -84,9 +84,9 @@ def soft_nms(dets, iou_thr, method='linear', sigma=0.5, min_score=1e-3):
         >>>                  [3., 1., 3., 1., 0.5],
         >>>                  [3., 1., 3., 1., 0.4],
         >>>                  [3., 1., 3., 1., 0.0]], dtype=np.float32)
-        >>> iou_thr = 0.7
+        >>> iou_thr = 0.6
         >>> new_dets, inds = soft_nms(dets, iou_thr, sigma=0.5)
-        >>> assert len(inds) == len(new_dets) == 3
+        >>> assert len(inds) == len(new_dets) == 5
     """
     # convert dets (tensor or numpy array) to tensor
     if isinstance(dets, torch.Tensor):
diff --git a/mmdet/ops/nms/src/cpu/nms_cpu.cpp b/mmdet/ops/nms/src/cpu/nms_cpu.cpp
index 1fa589dc..4d11abec 100644
--- a/mmdet/ops/nms/src/cpu/nms_cpu.cpp
+++ b/mmdet/ops/nms/src/cpu/nms_cpu.cpp
@@ -1,4 +1,6 @@
-// Modified from https://github.com/bharatsingh430/soft-nms/blob/master/lib/nms/cpu_nms.pyx, Soft-NMS is added
+// Soft-NMS is added by MMDetection.
+// Modified from
+// https://github.com/bharatsingh430/soft-nms/blob/master/lib/nms/cpu_nms.pyx.
 // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
 #include <torch/extension.h>
 
@@ -16,7 +18,7 @@ at::Tensor nms_cpu_kernel(const at::Tensor& dets, const float threshold) {
   auto y2_t = dets.select(1, 3).contiguous();
   auto scores = dets.select(1, 4).contiguous();
 
-  at::Tensor areas_t = (x2_t - x1_t + 1) * (y2_t - y1_t + 1);
+  at::Tensor areas_t = (x2_t - x1_t) * (y2_t - y1_t);
 
   auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
 
@@ -49,8 +51,8 @@ at::Tensor nms_cpu_kernel(const at::Tensor& dets, const float threshold) {
       auto xx2 = std::min(ix2, x2[j]);
       auto yy2 = std::min(iy2, y2[j]);
 
-      auto w = std::max(static_cast<scalar_t>(0), xx2 - xx1 + 1);
-      auto h = std::max(static_cast<scalar_t>(0), yy2 - yy1 + 1);
+      auto w = std::max(static_cast<scalar_t>(0), xx2 - xx1);
+      auto h = std::max(static_cast<scalar_t>(0), yy2 - yy1);
       auto inter = w * h;
       auto ovr = inter / (iarea + areas[j] - inter);
       if (ovr >= threshold) suppressed[j] = 1;
@@ -69,7 +71,8 @@ at::Tensor nms_cpu(const at::Tensor& dets, const float threshold) {
 
 template <typename scalar_t>
 at::Tensor soft_nms_cpu_kernel(const at::Tensor& dets, const float threshold,
-                               const unsigned char method, const float sigma, const float min_score) {
+                               const unsigned char method, const float sigma,
+                               const float min_score) {
   AT_ASSERTM(!dets.type().is_cuda(), "dets must be a CPU tensor");
 
   if (dets.numel() == 0) {
@@ -82,7 +85,7 @@ at::Tensor soft_nms_cpu_kernel(const at::Tensor& dets, const float threshold,
   auto y2_t = dets.select(1, 3).contiguous();
   auto scores_t = dets.select(1, 4).contiguous();
 
-  at::Tensor areas_t = (x2_t - x1_t + 1) * (y2_t - y1_t + 1);
+  at::Tensor areas_t = (x2_t - x1_t) * (y2_t - y1_t);
 
   auto ndets = dets.size(0);
   auto x1 = x1_t.data<scalar_t>();
@@ -110,12 +113,12 @@ at::Tensor soft_nms_cpu_kernel(const at::Tensor& dets, const float threshold,
 
     pos = i + 1;
     // get max box
-    while (pos < ndets){
-        if (max_score < scores[pos]) {
-            max_score = scores[pos];
-            max_pos = pos;
-        }
-        pos = pos + 1;
+    while (pos < ndets) {
+      if (max_score < scores[pos]) {
+        max_score = scores[pos];
+        max_pos = pos;
+      }
+      pos = pos + 1;
     }
     // add max box as a detection
     x1[i] = x1[max_pos];
@@ -127,10 +130,10 @@ at::Tensor soft_nms_cpu_kernel(const at::Tensor& dets, const float threshold,
     inds[i] = inds[max_pos];
 
     // swap ith box with position of max box
-    x1[max_pos] =  ix1;
-    y1[max_pos] =  iy1;
-    x2[max_pos] =  ix2;
-    y2[max_pos] =  iy2;
+    x1[max_pos] = ix1;
+    y1[max_pos] = iy1;
+    x2[max_pos] = ix2;
+    y2[max_pos] = iy2;
     scores[max_pos] = iscore;
     areas[max_pos] = iarea;
     inds[max_pos] = iind;
@@ -143,32 +146,30 @@ at::Tensor soft_nms_cpu_kernel(const at::Tensor& dets, const float threshold,
     iarea = areas[i];
 
     pos = i + 1;
-    // NMS iterations, note that N changes if detection boxes fall below threshold
+    // NMS iterations, note that N changes if detection boxes fall below
+    // threshold
     while (pos < ndets) {
       auto xx1 = std::max(ix1, x1[pos]);
       auto yy1 = std::max(iy1, y1[pos]);
       auto xx2 = std::min(ix2, x2[pos]);
       auto yy2 = std::min(iy2, y2[pos]);
 
-      auto w = std::max(static_cast<scalar_t>(0), xx2 - xx1 + 1);
-      auto h = std::max(static_cast<scalar_t>(0), yy2 - yy1 + 1);
+      auto w = std::max(static_cast<scalar_t>(0), xx2 - xx1);
+      auto h = std::max(static_cast<scalar_t>(0), yy2 - yy1);
       auto inter = w * h;
       auto ovr = inter / (iarea + areas[pos] - inter);
 
       scalar_t weight = 1.;
       if (method == 1) {
         if (ovr > threshold) weight = 1 - ovr;
-      }
-      else if (method == 2) {
+      } else if (method == 2) {
         weight = std::exp(-(ovr * ovr) / sigma);
-      }
-      else {
+      } else {
         // original NMS
         if (ovr > threshold) {
-            weight = 0;
-        }
-        else {
-            weight = 1;
+          weight = 0;
+        } else {
+          weight = 1;
         }
       }
       scores[pos] = weight * scores[pos];
@@ -182,7 +183,7 @@ at::Tensor soft_nms_cpu_kernel(const at::Tensor& dets, const float threshold,
         scores[pos] = scores[ndets - 1];
         areas[pos] = areas[ndets - 1];
         inds[pos] = inds[ndets - 1];
-        ndets = ndets -1;
+        ndets = ndets - 1;
         pos = pos - 1;
       }
       pos = pos + 1;
@@ -196,15 +197,17 @@ at::Tensor soft_nms_cpu_kernel(const at::Tensor& dets, const float threshold,
   result[4] = scores_t.slice(0, 0, ndets);
   result[5] = inds_t.slice(0, 0, ndets);
 
-  result =result.t().contiguous();
+  result = result.t().contiguous();
   return result;
 }
 
 at::Tensor soft_nms_cpu(const at::Tensor& dets, const float threshold,
-                    const unsigned char method, const float sigma, const float min_score) {
+                        const unsigned char method, const float sigma,
+                        const float min_score) {
   at::Tensor result;
   AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "soft_nms", [&] {
-    result = soft_nms_cpu_kernel<scalar_t>(dets, threshold, method, sigma, min_score);
+    result = soft_nms_cpu_kernel<scalar_t>(dets, threshold, method, sigma,
+                                           min_score);
   });
   return result;
 }
diff --git a/mmdet/ops/nms/src/cuda/nms_kernel.cu b/mmdet/ops/nms/src/cuda/nms_kernel.cu
index 8dc98be1..0c084f7a 100644
--- a/mmdet/ops/nms/src/cuda/nms_kernel.cu
+++ b/mmdet/ops/nms/src/cuda/nms_kernel.cu
@@ -14,10 +14,10 @@ int const threadsPerBlock = sizeof(unsigned long long) * 8;
 __device__ inline float devIoU(float const * const a, float const * const b) {
   float left = max(a[0], b[0]), right = min(a[2], b[2]);
   float top = max(a[1], b[1]), bottom = min(a[3], b[3]);
-  float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f);
+  float width = max(right - left, 0.f), height = max(bottom - top, 0.f);
   float interS = width * height;
-  float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1);
-  float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1);
+  float Sa = (a[2] - a[0]) * (a[3] - a[1]);
+  float Sb = (b[2] - b[0]) * (b[3] - b[1]);
   return interS / (Sa + Sb - interS);
 }
 
diff --git a/mmdet/ops/roi_align/roi_align.py b/mmdet/ops/roi_align/roi_align.py
index 203c1152..4f792fa9 100644
--- a/mmdet/ops/roi_align/roi_align.py
+++ b/mmdet/ops/roi_align/roi_align.py
@@ -24,20 +24,18 @@ class RoIAlignFunction(Function):
         ctx.feature_size = features.size()
         ctx.aligned = aligned
 
-        if features.is_cuda:
-            if not aligned:
-                (batch_size, num_channels, data_height,
-                 data_width) = features.size()
-                num_rois = rois.size(0)
-
-                output = features.new_zeros(num_rois, num_channels, out_h,
-                                            out_w)
-                roi_align_ext.forward_v1(features, rois, out_h, out_w,
-                                         spatial_scale, sample_num, output)
-            else:
-                output = roi_align_ext.forward_v2(features, rois,
-                                                  spatial_scale, out_h, out_w,
-                                                  sample_num, aligned)
+        if aligned:
+            output = roi_align_ext.forward_v2(features, rois, spatial_scale,
+                                              out_h, out_w, sample_num,
+                                              aligned)
+        elif features.is_cuda:
+            (batch_size, num_channels, data_height,
+             data_width) = features.size()
+            num_rois = rois.size(0)
+
+            output = features.new_zeros(num_rois, num_channels, out_h, out_w)
+            roi_align_ext.forward_v1(features, rois, out_h, out_w,
+                                     spatial_scale, sample_num, output)
         else:
             raise NotImplementedError
 
@@ -85,7 +83,7 @@ class RoIAlign(nn.Module):
                  spatial_scale,
                  sample_num=0,
                  use_torchvision=False,
-                 aligned=False):
+                 aligned=True):
         """
         Args:
             out_size (tuple): h, w
diff --git a/mmdet/ops/roi_align/src/cpu/roi_align_v2.cpp b/mmdet/ops/roi_align/src/cpu/roi_align_v2.cpp
new file mode 100644
index 00000000..2c6b557d
--- /dev/null
+++ b/mmdet/ops/roi_align/src/cpu/roi_align_v2.cpp
@@ -0,0 +1,404 @@
+// Modified from
+// https://github.com/facebookresearch/detectron2/tree/master/detectron2/layers/csrc/ROIAlign
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+#include <ATen/ATen.h>
+#include <ATen/TensorUtils.h>
+
+// implementation taken from Caffe2
+template <typename T>
+struct PreCalc {
+  int pos1;
+  int pos2;
+  int pos3;
+  int pos4;
+  T w1;
+  T w2;
+  T w3;
+  T w4;
+};
+
+template <typename T>
+void pre_calc_for_bilinear_interpolate(
+    const int height, const int width, const int pooled_height,
+    const int pooled_width, const int iy_upper, const int ix_upper,
+    T roi_start_h, T roi_start_w, T bin_size_h, T bin_size_w,
+    int roi_bin_grid_h, int roi_bin_grid_w, std::vector<PreCalc<T>>& pre_calc) {
+  int pre_calc_index = 0;
+  for (int ph = 0; ph < pooled_height; ph++) {
+    for (int pw = 0; pw < pooled_width; pw++) {
+      for (int iy = 0; iy < iy_upper; iy++) {
+        const T yy = roi_start_h + ph * bin_size_h +
+                     static_cast<T>(iy + .5f) * bin_size_h /
+                         static_cast<T>(roi_bin_grid_h);  // e.g., 0.5, 1.5
+        for (int ix = 0; ix < ix_upper; ix++) {
+          const T xx = roi_start_w + pw * bin_size_w +
+                       static_cast<T>(ix + .5f) * bin_size_w /
+                           static_cast<T>(roi_bin_grid_w);
+
+          T x = xx;
+          T y = yy;
+          // deal with: inverse elements are out of feature map boundary
+          if (y < -1.0 || y > height || x < -1.0 || x > width) {
+            // empty
+            PreCalc<T> pc;
+            pc.pos1 = 0;
+            pc.pos2 = 0;
+            pc.pos3 = 0;
+            pc.pos4 = 0;
+            pc.w1 = 0;
+            pc.w2 = 0;
+            pc.w3 = 0;
+            pc.w4 = 0;
+            pre_calc[pre_calc_index] = pc;
+            pre_calc_index += 1;
+            continue;
+          }
+
+          if (y <= 0) {
+            y = 0;
+          }
+          if (x <= 0) {
+            x = 0;
+          }
+
+          int y_low = (int)y;
+          int x_low = (int)x;
+          int y_high;
+          int x_high;
+
+          if (y_low >= height - 1) {
+            y_high = y_low = height - 1;
+            y = (T)y_low;
+          } else {
+            y_high = y_low + 1;
+          }
+
+          if (x_low >= width - 1) {
+            x_high = x_low = width - 1;
+            x = (T)x_low;
+          } else {
+            x_high = x_low + 1;
+          }
+
+          T ly = y - y_low;
+          T lx = x - x_low;
+          T hy = 1. - ly, hx = 1. - lx;
+          T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
+
+          // save weights and indices
+          PreCalc<T> pc;
+          pc.pos1 = y_low * width + x_low;
+          pc.pos2 = y_low * width + x_high;
+          pc.pos3 = y_high * width + x_low;
+          pc.pos4 = y_high * width + x_high;
+          pc.w1 = w1;
+          pc.w2 = w2;
+          pc.w3 = w3;
+          pc.w4 = w4;
+          pre_calc[pre_calc_index] = pc;
+
+          pre_calc_index += 1;
+        }
+      }
+    }
+  }
+}
+
+template <typename T>
+void ROIAlignForward(const int nthreads, const T* input, const T& spatial_scale,
+                     const int channels, const int height, const int width,
+                     const int pooled_height, const int pooled_width,
+                     const int sampling_ratio, const T* rois, T* output,
+                     bool aligned) {
+  int n_rois = nthreads / channels / pooled_width / pooled_height;
+  // (n, c, ph, pw) is an element in the pooled output
+  // can be parallelized using omp
+  // #pragma omp parallel for num_threads(32)
+  for (int n = 0; n < n_rois; n++) {
+    int index_n = n * channels * pooled_width * pooled_height;
+
+    const T* offset_rois = rois + n * 5;
+    int roi_batch_ind = offset_rois[0];
+
+    // Do not use rounding; this implementation detail is critical
+    T offset = aligned ? (T)0.5 : (T)0.0;
+    T roi_start_w = offset_rois[1] * spatial_scale - offset;
+    T roi_start_h = offset_rois[2] * spatial_scale - offset;
+    T roi_end_w = offset_rois[3] * spatial_scale - offset;
+    T roi_end_h = offset_rois[4] * spatial_scale - offset;
+
+    T roi_width = roi_end_w - roi_start_w;
+    T roi_height = roi_end_h - roi_start_h;
+    if (aligned) {
+      AT_ASSERTM(roi_width >= 0 && roi_height >= 0,
+                 "ROIs in ROIAlign cannot have non-negative size!");
+    } else {  // for backward-compatibility only
+      roi_width = std::max(roi_width, (T)1.);
+      roi_height = std::max(roi_height, (T)1.);
+    }
+    T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
+    T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
+
+    // We use roi_bin_grid to sample the grid and mimic integral
+    int roi_bin_grid_h = (sampling_ratio > 0)
+                             ? sampling_ratio
+                             : ceil(roi_height / pooled_height);  // e.g., = 2
+    int roi_bin_grid_w =
+        (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
+
+    // We do average (integral) pooling inside a bin
+    // When the grid is empty, output zeros == 0/1, instead of NaN.
+    const T count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1);  // e.g. = 4
+
+    // we want to precalculate indices and weights shared by all channels,
+    // this is the key point of optimization
+    std::vector<PreCalc<T>> pre_calc(roi_bin_grid_h * roi_bin_grid_w *
+                                     pooled_width * pooled_height);
+    pre_calc_for_bilinear_interpolate(
+        height, width, pooled_height, pooled_width, roi_bin_grid_h,
+        roi_bin_grid_w, roi_start_h, roi_start_w, bin_size_h, bin_size_w,
+        roi_bin_grid_h, roi_bin_grid_w, pre_calc);
+
+    for (int c = 0; c < channels; c++) {
+      int index_n_c = index_n + c * pooled_width * pooled_height;
+      const T* offset_input =
+          input + (roi_batch_ind * channels + c) * height * width;
+      int pre_calc_index = 0;
+
+      for (int ph = 0; ph < pooled_height; ph++) {
+        for (int pw = 0; pw < pooled_width; pw++) {
+          int index = index_n_c + ph * pooled_width + pw;
+
+          T output_val = 0.;
+          for (int iy = 0; iy < roi_bin_grid_h; iy++) {
+            for (int ix = 0; ix < roi_bin_grid_w; ix++) {
+              PreCalc<T> pc = pre_calc[pre_calc_index];
+              output_val += pc.w1 * offset_input[pc.pos1] +
+                            pc.w2 * offset_input[pc.pos2] +
+                            pc.w3 * offset_input[pc.pos3] +
+                            pc.w4 * offset_input[pc.pos4];
+
+              pre_calc_index += 1;
+            }
+          }
+          output_val /= count;
+
+          output[index] = output_val;
+        }  // for pw
+      }    // for ph
+    }      // for c
+  }        // for n
+}
+
+template <typename T>
+void bilinear_interpolate_gradient(const int height, const int width, T y, T x,
+                                   T& w1, T& w2, T& w3, T& w4, int& x_low,
+                                   int& x_high, int& y_low, int& y_high,
+                                   const int index /* index for debug only*/) {
+  // deal with cases that inverse elements are out of feature map boundary
+  if (y < -1.0 || y > height || x < -1.0 || x > width) {
+    // empty
+    w1 = w2 = w3 = w4 = 0.;
+    x_low = x_high = y_low = y_high = -1;
+    return;
+  }
+
+  if (y <= 0) y = 0;
+  if (x <= 0) x = 0;
+
+  y_low = (int)y;
+  x_low = (int)x;
+
+  if (y_low >= height - 1) {
+    y_high = y_low = height - 1;
+    y = (T)y_low;
+  } else {
+    y_high = y_low + 1;
+  }
+
+  if (x_low >= width - 1) {
+    x_high = x_low = width - 1;
+    x = (T)x_low;
+  } else {
+    x_high = x_low + 1;
+  }
+
+  T ly = y - y_low;
+  T lx = x - x_low;
+  T hy = 1. - ly, hx = 1. - lx;
+
+  // reference in forward
+  // T v1 = input[y_low * width + x_low];
+  // T v2 = input[y_low * width + x_high];
+  // T v3 = input[y_high * width + x_low];
+  // T v4 = input[y_high * width + x_high];
+  // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+
+  w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
+
+  return;
+}
+
+template <class T>
+inline void add(T* address, const T& val) {
+  *address += val;
+}
+
+template <typename T>
+void ROIAlignBackward(const int nthreads, const T* grad_output,
+                      const T& spatial_scale, const int channels,
+                      const int height, const int width,
+                      const int pooled_height, const int pooled_width,
+                      const int sampling_ratio, T* grad_input, const T* rois,
+                      const int n_stride, const int c_stride,
+                      const int h_stride, const int w_stride, bool aligned) {
+  for (int index = 0; index < nthreads; index++) {
+    // (n, c, ph, pw) is an element in the pooled output
+    int pw = index % pooled_width;
+    int ph = (index / pooled_width) % pooled_height;
+    int c = (index / pooled_width / pooled_height) % channels;
+    int n = index / pooled_width / pooled_height / channels;
+
+    const T* offset_rois = rois + n * 5;
+    int roi_batch_ind = offset_rois[0];
+
+    // Do not use rounding; this implementation detail is critical
+    T offset = aligned ? (T)0.5 : (T)0.0;
+    T roi_start_w = offset_rois[1] * spatial_scale - offset;
+    T roi_start_h = offset_rois[2] * spatial_scale - offset;
+    T roi_end_w = offset_rois[3] * spatial_scale - offset;
+    T roi_end_h = offset_rois[4] * spatial_scale - offset;
+
+    T roi_width = roi_end_w - roi_start_w;
+    T roi_height = roi_end_h - roi_start_h;
+    if (aligned) {
+      AT_ASSERTM(roi_width >= 0 && roi_height >= 0,
+                 "ROIs in ROIAlign do not have non-negative size!");
+    } else {  // for backward-compatibility only
+      roi_width = std::max(roi_width, (T)1.);
+      roi_height = std::max(roi_height, (T)1.);
+    }
+    T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
+    T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
+
+    T* offset_grad_input =
+        grad_input + ((roi_batch_ind * channels + c) * height * width);
+
+    int output_offset = n * n_stride + c * c_stride;
+    const T* offset_grad_output = grad_output + output_offset;
+    const T grad_output_this_bin =
+        offset_grad_output[ph * h_stride + pw * w_stride];
+
+    // We use roi_bin_grid to sample the grid and mimic integral
+    int roi_bin_grid_h = (sampling_ratio > 0)
+                             ? sampling_ratio
+                             : ceil(roi_height / pooled_height);  // e.g., = 2
+    int roi_bin_grid_w =
+        (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
+
+    // We do average (integral) pooling inside a bin
+    const T count = roi_bin_grid_h * roi_bin_grid_w;  // e.g. = 4
+
+    for (int iy = 0; iy < roi_bin_grid_h; iy++) {
+      const T y = roi_start_h + ph * bin_size_h +
+                  static_cast<T>(iy + .5f) * bin_size_h /
+                      static_cast<T>(roi_bin_grid_h);  // e.g., 0.5, 1.5
+      for (int ix = 0; ix < roi_bin_grid_w; ix++) {
+        const T x = roi_start_w + pw * bin_size_w +
+                    static_cast<T>(ix + .5f) * bin_size_w /
+                        static_cast<T>(roi_bin_grid_w);
+
+        T w1, w2, w3, w4;
+        int x_low, x_high, y_low, y_high;
+
+        bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4,
+                                      x_low, x_high, y_low, y_high, index);
+
+        T g1 = grad_output_this_bin * w1 / count;
+        T g2 = grad_output_this_bin * w2 / count;
+        T g3 = grad_output_this_bin * w3 / count;
+        T g4 = grad_output_this_bin * w4 / count;
+
+        if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
+          // atomic add is not needed for now since it is single threaded
+          add(offset_grad_input + y_low * width + x_low, static_cast<T>(g1));
+          add(offset_grad_input + y_low * width + x_high, static_cast<T>(g2));
+          add(offset_grad_input + y_high * width + x_low, static_cast<T>(g3));
+          add(offset_grad_input + y_high * width + x_high, static_cast<T>(g4));
+        }  // if
+      }    // ix
+    }      // iy
+  }        // for
+}  // ROIAlignBackward
+
+at::Tensor ROIAlignForwardV2CPULaucher(const at::Tensor& input,
+                                       const at::Tensor& rois,
+                                       const float spatial_scale,
+                                       const int pooled_height,
+                                       const int pooled_width,
+                                       const int sampling_ratio, bool aligned) {
+  AT_ASSERTM(input.device().is_cpu(), "input must be a CPU tensor");
+  AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor");
+
+  at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
+
+  at::CheckedFrom c = "ROIAlignForwardV2CPULaucher";
+  at::checkAllSameType(c, {input_t, rois_t});
+
+  auto num_rois = rois.size(0);
+  auto channels = input.size(1);
+  auto height = input.size(2);
+  auto width = input.size(3);
+
+  at::Tensor output = at::zeros(
+      {num_rois, channels, pooled_height, pooled_width}, input.options());
+
+  auto output_size = num_rois * pooled_height * pooled_width * channels;
+
+  if (output.numel() == 0) return output;
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "ROIAlign_forward", [&] {
+    ROIAlignForward<scalar_t>(
+        output_size, input.contiguous().data<scalar_t>(), spatial_scale,
+        channels, height, width, pooled_height, pooled_width, sampling_ratio,
+        rois.contiguous().data<scalar_t>(), output.data<scalar_t>(), aligned);
+  });
+  return output;
+}
+
+at::Tensor ROIAlignBackwardV2CPULaucher(
+    const at::Tensor& grad, const at::Tensor& rois, const float spatial_scale,
+    const int pooled_height, const int pooled_width, const int batch_size,
+    const int channels, const int height, const int width,
+    const int sampling_ratio, bool aligned) {
+  AT_ASSERTM(grad.device().is_cpu(), "grad must be a CPU tensor");
+  AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor");
+
+  at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2};
+
+  at::CheckedFrom c = "ROIAlignBackwardV2CPULaucher";
+  at::checkAllSameType(c, {grad_t, rois_t});
+
+  at::Tensor grad_input =
+      at::zeros({batch_size, channels, height, width}, grad.options());
+
+  // handle possibly empty gradients
+  if (grad.numel() == 0) {
+    return grad_input;
+  }
+
+  // get stride values to ensure indexing into gradients is correct.
+  int n_stride = grad.stride(0);
+  int c_stride = grad.stride(1);
+  int h_stride = grad.stride(2);
+  int w_stride = grad.stride(3);
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "ROIAlign_backward", [&] {
+    ROIAlignBackward<scalar_t>(
+        grad.numel(), grad.contiguous().data<scalar_t>(), spatial_scale,
+        channels, height, width, pooled_height, pooled_width, sampling_ratio,
+        grad_input.data<scalar_t>(), rois.contiguous().data<scalar_t>(),
+        n_stride, c_stride, h_stride, w_stride, aligned);
+  });
+  return grad_input;
+}
diff --git a/mmdet/ops/roi_align/src/roi_align_ext.cpp b/mmdet/ops/roi_align/src/roi_align_ext.cpp
index 50454d25..f01351a8 100644
--- a/mmdet/ops/roi_align/src/roi_align_ext.cpp
+++ b/mmdet/ops/roi_align/src/roi_align_ext.cpp
@@ -1,6 +1,5 @@
-#include <torch/extension.h>
-
 #include <ATen/ATen.h>
+#include <torch/extension.h>
 
 #include <cmath>
 #include <vector>
@@ -34,6 +33,19 @@ at::Tensor ROIAlignBackwardV2Laucher(
     const int sampling_ratio, bool aligned);
 #endif
 
+at::Tensor ROIAlignForwardV2CPULaucher(const at::Tensor& input,
+                                       const at::Tensor& rois,
+                                       const float spatial_scale,
+                                       const int pooled_height,
+                                       const int pooled_width,
+                                       const int sampling_ratio, bool aligned);
+
+at::Tensor ROIAlignBackwardV2CPULaucher(
+    const at::Tensor& grad, const at::Tensor& rois, const float spatial_scale,
+    const int pooled_height, const int pooled_width, const int batch_size,
+    const int channels, const int height, const int width,
+    const int sampling_ratio, bool aligned);
+
 #define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
 #define CHECK_CONTIGUOUS(x) \
   AT_CHECK(x.is_contiguous(), #x, " must be contiguous ")
@@ -125,7 +137,8 @@ inline at::Tensor ROIAlign_forwardV2(const at::Tensor& input,
     AT_ERROR("ROIAlignV2 is not compiled with GPU support");
 #endif
   }
-  AT_ERROR("ROIAlignV2 is not implemented on CPU");
+  return ROIAlignForwardV2CPULaucher(input, rois, spatial_scale, pooled_height,
+                                     pooled_width, sampling_ratio, aligned);
 }
 
 inline at::Tensor ROIAlign_backwardV2(
@@ -142,7 +155,9 @@ inline at::Tensor ROIAlign_backwardV2(
     AT_ERROR("ROIAlignV2 is not compiled with GPU support");
 #endif
   }
-  AT_ERROR("ROIAlignV2 is not implemented on CPU");
+  return ROIAlignBackwardV2CPULaucher(grad, rois, spatial_scale, pooled_height,
+                                      pooled_width, batch_size, channels,
+                                      height, width, sampling_ratio, aligned);
 }
 
 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
diff --git a/setup.py b/setup.py
index 7c01e6be..14af9d1b 100755
--- a/setup.py
+++ b/setup.py
@@ -240,7 +240,10 @@ if __name__ == '__main__':
             make_cuda_ext(
                 name='roi_align_ext',
                 module='mmdet.ops.roi_align',
-                sources=['src/roi_align_ext.cpp'],
+                sources=[
+                    'src/roi_align_ext.cpp',
+                    'src/cpu/roi_align_v2.cpp',
+                ],
                 sources_cuda=[
                     'src/cuda/roi_align_kernel.cu',
                     'src/cuda/roi_align_kernel_v2.cu'
diff --git a/tests/test_assigner.py b/tests/test_assigner.py
index cb783e83..a3904b38 100644
--- a/tests/test_assigner.py
+++ b/tests/test_assigner.py
@@ -49,7 +49,7 @@ def test_max_iou_assigner_with_ignore():
         [0, 0, 10, 10],
         [10, 10, 20, 20],
         [5, 5, 15, 15],
-        [32, 32, 38, 42],
+        [30, 32, 40, 42],
     ])
     gt_bboxes = torch.FloatTensor([
         [0, 0, 10, 9],
diff --git a/tests/test_forward.py b/tests/test_forward.py
index 2ace7bbd..5a1e8948 100644
--- a/tests/test_forward.py
+++ b/tests/test_forward.py
@@ -179,9 +179,6 @@ def test_cascade_forward():
     model, train_cfg, test_cfg = _get_detector_cfg(
         'cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py')
     model['pretrained'] = None
-    # torchvision roi align supports CPU
-    model['roi_head']['bbox_roi_extractor']['roi_layer'][
-        'use_torchvision'] = True
 
     from mmdet.models import build_detector
     detector = build_detector(model, train_cfg=train_cfg, test_cfg=test_cfg)
@@ -233,9 +230,6 @@ def test_faster_rcnn_forward():
     model, train_cfg, test_cfg = _get_detector_cfg(
         'faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py')
     model['pretrained'] = None
-    # torchvision roi align supports CPU
-    model['roi_head']['bbox_roi_extractor']['roi_layer'][
-        'use_torchvision'] = True
 
     from mmdet.models import build_detector
     detector = build_detector(model, train_cfg=train_cfg, test_cfg=test_cfg)
@@ -287,9 +281,6 @@ def test_faster_rcnn_ohem_forward():
     model, train_cfg, test_cfg = _get_detector_cfg(
         'faster_rcnn/faster_rcnn_r50_fpn_ohem_1x_coco.py')
     model['pretrained'] = None
-    # torchvision roi align supports CPU
-    model['roi_head']['bbox_roi_extractor']['roi_layer'][
-        'use_torchvision'] = True
 
     from mmdet.models import build_detector
     detector = build_detector(model, train_cfg=train_cfg, test_cfg=test_cfg)
diff --git a/tests/test_masks.py b/tests/test_masks.py
index 7c0c8c12..11b2e8e6 100644
--- a/tests/test_masks.py
+++ b/tests/test_masks.py
@@ -36,7 +36,7 @@ def dummy_bboxes(num, max_height, max_width):
     x1y1 = np.random.randint(0, min(max_height // 2, max_width // 2), (num, 2))
     wh = np.random.randint(0, min(max_height // 2, max_width // 2), (num, 2))
     x2y2 = x1y1 + wh
-    return np.concatenate([x1y1, x2y2], axis=1).squeeze()
+    return np.concatenate([x1y1, x2y2], axis=1).squeeze().astype(np.float32)
 
 
 def test_bitmap_mask_init():
@@ -174,18 +174,18 @@ def test_bitmap_mask_crop():
     bitmap_masks = BitmapMasks(raw_masks, 28, 28)
     cropped_masks = bitmap_masks.crop(dummy_bbox)
     assert len(cropped_masks) == 0
-    assert cropped_masks.height == 18
-    assert cropped_masks.width == 11
+    assert cropped_masks.height == 17
+    assert cropped_masks.width == 10
 
     # crop with bitmap masks contain 3 instances
     raw_masks = dummy_raw_bitmap_masks((3, 28, 28))
     bitmap_masks = BitmapMasks(raw_masks, 28, 28)
     cropped_masks = bitmap_masks.crop(dummy_bbox)
     assert len(cropped_masks) == 3
-    assert cropped_masks.height == 18
-    assert cropped_masks.width == 11
+    assert cropped_masks.height == 17
+    assert cropped_masks.width == 10
     x1, y1, x2, y2 = dummy_bbox
-    assert (cropped_masks.masks == raw_masks[:, y1:y2 + 1, x1:x2 + 1]).all()
+    assert (cropped_masks.masks == raw_masks[:, y1:y2, x1:x2]).all()
 
     # crop with invalid bbox
     with pytest.raises(AssertionError):
@@ -453,9 +453,9 @@ def test_polygon_mask_crop():
     polygon_masks = PolygonMasks(raw_masks, 28, 28)
     cropped_masks = polygon_masks.crop(dummy_bbox)
     assert len(cropped_masks) == 0
-    assert cropped_masks.height == 18
-    assert cropped_masks.width == 11
-    assert cropped_masks.to_ndarray().shape == (0, 18, 11)
+    assert cropped_masks.height == 17
+    assert cropped_masks.width == 10
+    assert cropped_masks.to_ndarray().shape == (0, 17, 10)
 
     # crop with polygon masks contain 1 instances
     raw_masks = [[np.array([1., 3., 5., 1., 5., 6., 1, 6])]]
@@ -463,11 +463,10 @@ def test_polygon_mask_crop():
     bbox = np.array([0, 0, 3, 4])
     cropped_masks = polygon_masks.crop(bbox)
     assert len(cropped_masks) == 1
-    assert cropped_masks.height == 5
-    assert cropped_masks.width == 4
-    assert cropped_masks.to_ndarray().shape == (1, 5, 4)
-    truth = np.array([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 1], [0, 1, 1, 1],
-                      [0, 1, 1, 1]])
+    assert cropped_masks.height == 4
+    assert cropped_masks.width == 3
+    assert cropped_masks.to_ndarray().shape == (1, 4, 3)
+    truth = np.array([[0, 0, 0], [0, 0, 0], [0, 0, 1], [0, 1, 1]])
     assert (cropped_masks.to_ndarray() == truth).all()
 
     # crop with invalid bbox
diff --git a/tests/test_nms.py b/tests/test_nms.py
index 6861f1e5..e99af88e 100644
--- a/tests/test_nms.py
+++ b/tests/test_nms.py
@@ -13,7 +13,7 @@ def test_nms_device_and_dtypes_cpu():
     CommandLine:
         xdoctest -m tests/test_nms.py test_nms_device_and_dtypes_cpu
     """
-    iou_thr = 0.7
+    iou_thr = 0.6
     base_dets = np.array([[49.1, 32.4, 51.0, 35.9, 0.9],
                           [49.3, 32.9, 51.0, 35.3, 0.9],
                           [35.3, 11.5, 39.9, 14.5, 0.4],
@@ -23,22 +23,22 @@ def test_nms_device_and_dtypes_cpu():
     dets = base_dets.astype(np.float32)
     supressed, inds = nms(dets, iou_thr)
     assert dets.dtype == supressed.dtype
-    assert len(inds) == len(supressed) == 3
+    assert len(inds) == len(supressed) == 2
 
     dets = torch.FloatTensor(base_dets)
     surpressed, inds = nms(dets, iou_thr)
     assert dets.dtype == surpressed.dtype
-    assert len(inds) == len(surpressed) == 3
+    assert len(inds) == len(surpressed) == 2
 
     dets = base_dets.astype(np.float64)
     supressed, inds = nms(dets, iou_thr)
     assert dets.dtype == supressed.dtype
-    assert len(inds) == len(supressed) == 3
+    assert len(inds) == len(supressed) == 2
 
     dets = torch.DoubleTensor(base_dets)
     surpressed, inds = nms(dets, iou_thr)
     assert dets.dtype == surpressed.dtype
-    assert len(inds) == len(surpressed) == 3
+    assert len(inds) == len(surpressed) == 2
 
 
 def test_nms_device_and_dtypes_gpu():
@@ -50,7 +50,7 @@ def test_nms_device_and_dtypes_gpu():
         import pytest
         pytest.skip('test requires GPU and torch+cuda')
 
-    iou_thr = 0.7
+    iou_thr = 0.6
     base_dets = np.array([[49.1, 32.4, 51.0, 35.9, 0.9],
                           [49.3, 32.9, 51.0, 35.3, 0.9],
                           [35.3, 11.5, 39.9, 14.5, 0.4],
@@ -62,9 +62,9 @@ def test_nms_device_and_dtypes_gpu():
         dets = base_dets.astype(np.float32)
         supressed, inds = nms(dets, iou_thr, device_id)
         assert dets.dtype == supressed.dtype
-        assert len(inds) == len(supressed) == 3
+        assert len(inds) == len(supressed) == 2
 
         dets = torch.FloatTensor(base_dets).to(device_id)
         surpressed, inds = nms(dets, iou_thr)
         assert dets.dtype == surpressed.dtype
-        assert len(inds) == len(surpressed) == 3
+        assert len(inds) == len(surpressed) == 2
diff --git a/tests/test_sampler.py b/tests/test_sampler.py
index 53ef380f..5afa16a8 100644
--- a/tests/test_sampler.py
+++ b/tests/test_sampler.py
@@ -102,9 +102,7 @@ def _context_for_ohem():
     model, train_cfg, test_cfg = _get_detector_cfg(
         'faster_rcnn/faster_rcnn_r50_fpn_ohem_1x_coco.py')
     model['pretrained'] = None
-    # torchvision roi align supports CPU
-    model['roi_head']['bbox_roi_extractor']['roi_layer'][
-        'use_torchvision'] = True
+
     from mmdet.models import build_detector
     context = build_detector(
         model, train_cfg=train_cfg, test_cfg=test_cfg).roi_head
-- 
GitLab