From caa4a4ed7d9ac99d816ab631bc573a9c3f752845 Mon Sep 17 00:00:00 2001
From: Jiangmiao Pang <pangjiangmiao@gmail.com>
Date: Thu, 19 Nov 2020 16:33:54 +0800
Subject: [PATCH] Support unclip border bbox regression (#4076)

* update

* clip border

* clip border

* clip

* update

* update

* update

* update
---
 .gitignore                                    |  1 +
 mmdet/core/bbox/coder/bucketing_bbox_coder.py | 15 +++--
 .../core/bbox/coder/delta_xywh_bbox_coder.py  | 15 +++--
 mmdet/core/bbox/coder/tblr_bbox_coder.py      | 15 +++--
 mmdet/datasets/pipelines/transforms.py        | 59 ++++++++++++++-----
 mmdet/models/dense_heads/anchor_head.py       |  1 +
 mmdet/models/dense_heads/ssd_head.py          |  1 +
 .../models/roi_heads/bbox_heads/bbox_head.py  |  1 +
 8 files changed, 81 insertions(+), 27 deletions(-)

diff --git a/.gitignore b/.gitignore
index 8eb9117a..77ca0d7c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -104,6 +104,7 @@ venv.bak/
 .mypy_cache/
 
 data/
+data
 .vscode
 .idea
 .DS_Store
diff --git a/mmdet/core/bbox/coder/bucketing_bbox_coder.py b/mmdet/core/bbox/coder/bucketing_bbox_coder.py
index 4f670879..e8c450c5 100644
--- a/mmdet/core/bbox/coder/bucketing_bbox_coder.py
+++ b/mmdet/core/bbox/coder/bucketing_bbox_coder.py
@@ -26,6 +26,8 @@ class BucketingBBoxCoder(BaseBBoxCoder):
              To avoid too large offset displacements. Defaults to 1.0.
         cls_ignore_neighbor (bool): Ignore second nearest bucket or Not.
              Defaults to True.
+        clip_border (bool, optional): Whether clip the objects outside the
+            border of the image. Defaults to True.
     """
 
     def __init__(self,
@@ -33,13 +35,15 @@ class BucketingBBoxCoder(BaseBBoxCoder):
                  scale_factor,
                  offset_topk=2,
                  offset_upperbound=1.0,
-                 cls_ignore_neighbor=True):
+                 cls_ignore_neighbor=True,
+                 clip_border=True):
         super(BucketingBBoxCoder, self).__init__()
         self.num_buckets = num_buckets
         self.scale_factor = scale_factor
         self.offset_topk = offset_topk
         self.offset_upperbound = offset_upperbound
         self.cls_ignore_neighbor = cls_ignore_neighbor
+        self.clip_border = clip_border
 
     def encode(self, bboxes, gt_bboxes):
         """Get bucketing estimation and fine regression targets during
@@ -81,7 +85,7 @@ class BucketingBBoxCoder(BaseBBoxCoder):
             0) == bboxes.size(0)
         decoded_bboxes = bucket2bbox(bboxes, cls_preds, offset_preds,
                                      self.num_buckets, self.scale_factor,
-                                     max_shape)
+                                     max_shape, self.clip_border)
 
         return decoded_bboxes
 
@@ -262,7 +266,8 @@ def bucket2bbox(proposals,
                 offset_preds,
                 num_buckets,
                 scale_factor=1.0,
-                max_shape=None):
+                max_shape=None,
+                clip_border=True):
     """Apply bucketing estimation (cls preds) and fine regression (offset
     preds) to generate det bboxes.
 
@@ -273,6 +278,8 @@ def bucket2bbox(proposals,
         num_buckets (int): Number of buckets.
         scale_factor (float): Scale factor to rescale proposals.
         max_shape (tuple[int, int]): Maximum bounds for boxes. specifies (H, W)
+        clip_border (bool, optional): Whether clip the objects outside the
+            border of the image. Defaults to True.
 
     Returns:
         tuple[Tensor]: (bboxes, loc_confidence).
@@ -322,7 +329,7 @@ def bucket2bbox(proposals,
     y1 = t_buckets - t_offsets * bucket_h
     y2 = d_buckets - d_offsets * bucket_h
 
-    if max_shape is not None:
+    if clip_border and max_shape is not None:
         x1 = x1.clamp(min=0, max=max_shape[1] - 1)
         y1 = y1.clamp(min=0, max=max_shape[0] - 1)
         x2 = x2.clamp(min=0, max=max_shape[1] - 1)
diff --git a/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py b/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py
index 82bf5947..e9eb3579 100644
--- a/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py
+++ b/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py
@@ -18,14 +18,18 @@ class DeltaXYWHBBoxCoder(BaseBBoxCoder):
             delta coordinates
         target_stds (Sequence[float]): Denormalizing standard deviation of
             target for delta coordinates
+        clip_border (bool, optional): Whether clip the objects outside the
+            border of the image. Defaults to True.
     """
 
     def __init__(self,
                  target_means=(0., 0., 0., 0.),
-                 target_stds=(1., 1., 1., 1.)):
+                 target_stds=(1., 1., 1., 1.),
+                 clip_border=True):
         super(BaseBBoxCoder, self).__init__()
         self.means = target_means
         self.stds = target_stds
+        self.clip_border = clip_border
 
     def encode(self, bboxes, gt_bboxes):
         """Get box regression transformation deltas that can be used to
@@ -66,7 +70,7 @@ class DeltaXYWHBBoxCoder(BaseBBoxCoder):
 
         assert pred_bboxes.size(0) == bboxes.size(0)
         decoded_bboxes = delta2bbox(bboxes, pred_bboxes, self.means, self.stds,
-                                    max_shape, wh_ratio_clip)
+                                    max_shape, wh_ratio_clip, self.clip_border)
 
         return decoded_bboxes
 
@@ -121,7 +125,8 @@ def delta2bbox(rois,
                means=(0., 0., 0., 0.),
                stds=(1., 1., 1., 1.),
                max_shape=None,
-               wh_ratio_clip=16 / 1000):
+               wh_ratio_clip=16 / 1000,
+               clip_border=True):
     """Apply deltas to shift/scale base boxes.
 
     Typically the rois are anchor or proposed bounding boxes and the deltas are
@@ -138,6 +143,8 @@ def delta2bbox(rois,
             coordinates
         max_shape (tuple[int, int]): Maximum bounds for boxes. specifies (H, W)
         wh_ratio_clip (float): Maximum aspect ratio for boxes.
+        clip_border (bool, optional): Whether clip the objects outside the
+            border of the image. Defaults to True.
 
     Returns:
         Tensor: Boxes with shape (N, 4), where columns represent
@@ -188,7 +195,7 @@ def delta2bbox(rois,
     y1 = gy - gh * 0.5
     x2 = gx + gw * 0.5
     y2 = gy + gh * 0.5
-    if max_shape is not None:
+    if clip_border and max_shape is not None:
         x1 = x1.clamp(min=0, max=max_shape[1])
         y1 = y1.clamp(min=0, max=max_shape[0])
         x2 = x2.clamp(min=0, max=max_shape[1])
diff --git a/mmdet/core/bbox/coder/tblr_bbox_coder.py b/mmdet/core/bbox/coder/tblr_bbox_coder.py
index f586a419..276fe1f5 100644
--- a/mmdet/core/bbox/coder/tblr_bbox_coder.py
+++ b/mmdet/core/bbox/coder/tblr_bbox_coder.py
@@ -17,11 +17,14 @@ class TBLRBBoxCoder(BaseBBoxCoder):
           divided with when coding the coordinates. If it is a list, it should
           have length of 4 indicating normalization factor in tblr dims.
           Otherwise it is a unified float factor for all dims. Default: 4.0
+        clip_border (bool, optional): Whether clip the objects outside the
+            border of the image. Defaults to True.
     """
 
-    def __init__(self, normalizer=4.0):
+    def __init__(self, normalizer=4.0, clip_border=True):
         super(BaseBBoxCoder, self).__init__()
         self.normalizer = normalizer
+        self.clip_border = clip_border
 
     def encode(self, bboxes, gt_bboxes):
         """Get box regression transformation deltas that can be used to
@@ -59,7 +62,8 @@ class TBLRBBoxCoder(BaseBBoxCoder):
             bboxes,
             pred_bboxes,
             normalizer=self.normalizer,
-            max_shape=max_shape)
+            max_shape=max_shape,
+            clip_border=self.clip_border)
 
         return decoded_bboxes
 
@@ -114,7 +118,8 @@ def tblr2bboxes(priors,
                 tblr,
                 normalizer=4.0,
                 normalize_by_wh=True,
-                max_shape=None):
+                max_shape=None,
+                clip_border=True):
     """Decode tblr outputs to prediction boxes.
 
     The process includes 3 steps: 1) De-normalize tblr coordinates by
@@ -136,6 +141,8 @@ def tblr2bboxes(priors,
           normalized by the side length (wh) of prior bboxes.
         max_shape (tuple, optional): Shape of the image. Decoded bboxes
           exceeding which will be clamped.
+        clip_border (bool, optional): Whether clip the objects outside the
+            border of the image. Defaults to True.
 
     Return:
         encoded boxes (Tensor), Shape: (n, 4)
@@ -157,7 +164,7 @@ def tblr2bboxes(priors,
     ymin = prior_centers[:, 1].unsqueeze(1) - top
     ymax = prior_centers[:, 1].unsqueeze(1) + bottom
     boxes = torch.cat((xmin, ymin, xmax, ymax), dim=1)
-    if max_shape is not None:
+    if clip_border and max_shape is not None:
         boxes[:, 0].clamp_(min=0, max=max_shape[1])
         boxes[:, 1].clamp_(min=0, max=max_shape[0])
         boxes[:, 2].clamp_(min=0, max=max_shape[1])
diff --git a/mmdet/datasets/pipelines/transforms.py b/mmdet/datasets/pipelines/transforms.py
index 90185bb8..87d7a413 100644
--- a/mmdet/datasets/pipelines/transforms.py
+++ b/mmdet/datasets/pipelines/transforms.py
@@ -49,6 +49,8 @@ class Resize(object):
         ratio_range (tuple[float]): (min_ratio, max_ratio)
         keep_ratio (bool): Whether to keep the aspect ratio when resizing the
             image.
+        bbox_clip_border (bool, optional): Whether clip the objects outside
+            the border of the image. Defaults to True.
         backend (str): Image resize backend, choices are 'cv2' and 'pillow'.
             These two backends generates slightly different results. Defaults
             to 'cv2'.
@@ -59,6 +61,7 @@ class Resize(object):
                  multiscale_mode='range',
                  ratio_range=None,
                  keep_ratio=True,
+                 bbox_clip_border=True,
                  backend='cv2'):
         if img_scale is None:
             self.img_scale = None
@@ -80,6 +83,7 @@ class Resize(object):
         self.multiscale_mode = multiscale_mode
         self.ratio_range = ratio_range
         self.keep_ratio = keep_ratio
+        self.bbox_clip_border = bbox_clip_border
 
     @staticmethod
     def random_select(img_scales):
@@ -219,11 +223,12 @@ class Resize(object):
 
     def _resize_bboxes(self, results):
         """Resize bounding boxes with ``results['scale_factor']``."""
-        img_shape = results['img_shape']
         for key in results.get('bbox_fields', []):
             bboxes = results[key] * results['scale_factor']
-            bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1])
-            bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0])
+            if self.bbox_clip_border:
+                img_shape = results['img_shape']
+                bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1])
+                bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0])
             results[key] = bboxes
 
     def _resize_masks(self, results):
@@ -290,6 +295,7 @@ class Resize(object):
         repr_str += f'multiscale_mode={self.multiscale_mode}, '
         repr_str += f'ratio_range={self.ratio_range}, '
         repr_str += f'keep_ratio={self.keep_ratio})'
+        repr_str += f'bbox_clip_border={self.bbox_clip_border})'
         return repr_str
 
 
@@ -570,6 +576,8 @@ class RandomCrop(object):
         crop_size (tuple): Expected size after cropping, (h, w).
         allow_negative_crop (bool): Whether to allow a crop that does not
             contain any bbox area. Default to False.
+        bbox_clip_border (bool, optional): Whether clip the objects outside
+            the border of the image. Defaults to True.
 
     Note:
         - If the image is smaller than the crop size, return the original image
@@ -581,10 +589,14 @@ class RandomCrop(object):
           `allow_negative_crop` is set to False, skip this image.
     """
 
-    def __init__(self, crop_size, allow_negative_crop=False):
+    def __init__(self,
+                 crop_size,
+                 allow_negative_crop=False,
+                 bbox_clip_border=True):
         assert crop_size[0] > 0 and crop_size[1] > 0
         self.crop_size = crop_size
         self.allow_negative_crop = allow_negative_crop
+        self.bbox_clip_border = bbox_clip_border
         # The key correspondence from bboxes to labels and masks.
         self.bbox2label = {
             'gt_bboxes': 'gt_labels',
@@ -628,8 +640,9 @@ class RandomCrop(object):
             bbox_offset = np.array([offset_w, offset_h, offset_w, offset_h],
                                    dtype=np.float32)
             bboxes = results[key] - bbox_offset
-            bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1])
-            bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0])
+            if self.bbox_clip_border:
+                bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1])
+                bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0])
             valid_inds = (bboxes[:, 2] > bboxes[:, 0]) & (
                 bboxes[:, 3] > bboxes[:, 1])
             # If the crop does not contain any gt-bbox area and
@@ -657,7 +670,9 @@ class RandomCrop(object):
         return results
 
     def __repr__(self):
-        return self.__class__.__name__ + f'(crop_size={self.crop_size})'
+        repr_str = self.__class__.__name__ + f'(crop_size={self.crop_size}), '
+        repr_str += f'bbox_clip_border={self.bbox_clip_border})'
+        return repr_str
 
 
 @PIPELINES.register_module()
@@ -907,6 +922,8 @@ class MinIoURandomCrop(object):
         bounding boxes
         min_crop_size (float): minimum crop's size (i.e. h,w := a*h, a*w,
         where a >= min_crop_size).
+        bbox_clip_border (bool, optional): Whether clip the objects outside
+            the border of the image. Defaults to True.
 
     Note:
         The keys for bboxes, labels and masks should be paired. That is, \
@@ -914,11 +931,15 @@ class MinIoURandomCrop(object):
         `gt_bboxes_ignore` to `gt_labels_ignore` and `gt_masks_ignore`.
     """
 
-    def __init__(self, min_ious=(0.1, 0.3, 0.5, 0.7, 0.9), min_crop_size=0.3):
+    def __init__(self,
+                 min_ious=(0.1, 0.3, 0.5, 0.7, 0.9),
+                 min_crop_size=0.3,
+                 bbox_clip_border=True):
         # 1: return ori img
         self.min_ious = min_ious
         self.sample_mode = (1, *min_ious, 0)
         self.min_crop_size = min_crop_size
+        self.bbox_clip_border = bbox_clip_border
         self.bbox2label = {
             'gt_bboxes': 'gt_labels',
             'gt_bboxes_ignore': 'gt_labels_ignore'
@@ -995,8 +1016,9 @@ class MinIoURandomCrop(object):
                         boxes = results[key].copy()
                         mask = is_center_of_bboxes_in_patch(boxes, patch)
                         boxes = boxes[mask]
-                        boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:])
-                        boxes[:, :2] = boxes[:, :2].clip(min=patch[:2])
+                        if self.bbox_clip_border:
+                            boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:])
+                            boxes[:, :2] = boxes[:, :2].clip(min=patch[:2])
                         boxes -= np.tile(patch[:2], 2)
 
                         results[key] = boxes
@@ -1024,7 +1046,8 @@ class MinIoURandomCrop(object):
     def __repr__(self):
         repr_str = self.__class__.__name__
         repr_str += f'(min_ious={self.min_ious}, '
-        repr_str += f'min_crop_size={self.min_crop_size})'
+        repr_str += f'min_crop_size={self.min_crop_size}), '
+        repr_str += f'bbox_clip_border={self.bbox_clip_border})'
         return repr_str
 
 
@@ -1351,6 +1374,8 @@ class RandomCenterCropPad(object):
             - 'logical_or': final_shape = input_shape | padding_shape_value
             - 'size_divisor': final_shape = int(
               ceil(input_shape / padding_shape_value) * padding_shape_value)
+        bbox_clip_border (bool, optional): Whether clip the objects outside
+            the border of the image. Defaults to True.
     """
 
     def __init__(self,
@@ -1361,7 +1386,8 @@ class RandomCenterCropPad(object):
                  std=None,
                  to_rgb=None,
                  test_mode=False,
-                 test_pad_mode=('logical_or', 127)):
+                 test_pad_mode=('logical_or', 127),
+                 bbox_clip_border=True):
         if test_mode:
             assert crop_size is None, 'crop_size must be None in test mode'
             assert ratios is None, 'ratios must be None in test mode'
@@ -1394,6 +1420,7 @@ class RandomCenterCropPad(object):
             self.std = std
         self.test_mode = test_mode
         self.test_pad_mode = test_pad_mode
+        self.bbox_clip_border = bbox_clip_border
 
     def _get_border(self, border, size):
         """Get final border for the target size.
@@ -1527,8 +1554,9 @@ class RandomCenterCropPad(object):
                     bboxes = results[key][mask]
                     bboxes[:, 0:4:2] += cropped_center_x - left_w - x0
                     bboxes[:, 1:4:2] += cropped_center_y - top_h - y0
-                    bboxes[:, 0:4:2] = np.clip(bboxes[:, 0:4:2], 0, new_w)
-                    bboxes[:, 1:4:2] = np.clip(bboxes[:, 1:4:2], 0, new_h)
+                    if self.bbox_clip_border:
+                        bboxes[:, 0:4:2] = np.clip(bboxes[:, 0:4:2], 0, new_w)
+                        bboxes[:, 1:4:2] = np.clip(bboxes[:, 1:4:2], 0, new_h)
                     keep = (bboxes[:, 2] > bboxes[:, 0]) & (
                         bboxes[:, 3] > bboxes[:, 1])
                     bboxes = bboxes[keep]
@@ -1602,7 +1630,8 @@ class RandomCenterCropPad(object):
         repr_str += f'std={self.input_std}, '
         repr_str += f'to_rgb={self.to_rgb}, '
         repr_str += f'test_mode={self.test_mode}, '
-        repr_str += f'test_pad_mode={self.test_pad_mode})'
+        repr_str += f'test_pad_mode={self.test_pad_mode}), '
+        repr_str += f'bbox_clip_border={self.bbox_clip_border})'
         return repr_str
 
 
diff --git a/mmdet/models/dense_heads/anchor_head.py b/mmdet/models/dense_heads/anchor_head.py
index 79c019fa..a5bb4137 100644
--- a/mmdet/models/dense_heads/anchor_head.py
+++ b/mmdet/models/dense_heads/anchor_head.py
@@ -41,6 +41,7 @@ class AnchorHead(BaseDenseHead, BBoxTestMixin):
                      strides=[4, 8, 16, 32, 64]),
                  bbox_coder=dict(
                      type='DeltaXYWHBBoxCoder',
+                     clip_border=True,
                      target_means=(.0, .0, .0, .0),
                      target_stds=(1.0, 1.0, 1.0, 1.0)),
                  reg_decoded_bbox=False,
diff --git a/mmdet/models/dense_heads/ssd_head.py b/mmdet/models/dense_heads/ssd_head.py
index 9d2f755a..42554c12 100644
--- a/mmdet/models/dense_heads/ssd_head.py
+++ b/mmdet/models/dense_heads/ssd_head.py
@@ -40,6 +40,7 @@ class SSDHead(AnchorHead):
                      basesize_ratio_range=(0.1, 0.9)),
                  bbox_coder=dict(
                      type='DeltaXYWHBBoxCoder',
+                     clip_border=True,
                      target_means=[.0, .0, .0, .0],
                      target_stds=[1.0, 1.0, 1.0, 1.0],
                  ),
diff --git a/mmdet/models/roi_heads/bbox_heads/bbox_head.py b/mmdet/models/roi_heads/bbox_heads/bbox_head.py
index a872e651..e0931e17 100644
--- a/mmdet/models/roi_heads/bbox_heads/bbox_head.py
+++ b/mmdet/models/roi_heads/bbox_heads/bbox_head.py
@@ -23,6 +23,7 @@ class BBoxHead(nn.Module):
                  num_classes=80,
                  bbox_coder=dict(
                      type='DeltaXYWHBBoxCoder',
+                     clip_border=True,
                      target_means=[0., 0., 0., 0.],
                      target_stds=[0.1, 0.1, 0.2, 0.2]),
                  reg_class_agnostic=False,
-- 
GitLab