diff --git a/mmdet/models/roi_heads/bbox_heads/bbox_head.py b/mmdet/models/roi_heads/bbox_heads/bbox_head.py
index 5d3e7fd4d2f786d19d224e28ba26028115432956..408abef3a244115b4e73748049a228e37ad0665c 100644
--- a/mmdet/models/roi_heads/bbox_heads/bbox_head.py
+++ b/mmdet/models/roi_heads/bbox_heads/bbox_head.py
@@ -275,35 +275,102 @@ class BBoxHead(nn.Module):
                    scale_factor,
                    rescale=False,
                    cfg=None):
+        """Transform network output for a batch into bbox predictions.
+
+        If the input rois has batch dimension, the function would be in
+        `batch_mode` and return is a tuple[list[Tensor], list[Tensor]],
+        otherwise, the return is a tuple[Tensor, Tensor].
+
+        Args:
+            rois (Tensor): Boxes to be transformed. Has shape (num_boxes, 5)
+               or (B, num_boxes, 5)
+            cls_score (list[Tensor] or Tensor): Box scores for
+               each scale level, each is a 4D-tensor, the channel number is
+               num_points * num_classes.
+            bbox_pred (Tensor, optional): Box energies / deltas for each scale
+                level, each is a 4D-tensor, the channel number is
+                num_classes * 4.
+            img_shape (Sequence[int] or torch.Tensor or Sequence[
+                Sequence[int]], optional): Maximum bounds for boxes, specifies
+                (H, W, C) or (H, W). If rois shape is (B, num_boxes, 4), then
+                the max_shape should be a Sequence[Sequence[int]]
+                and the length of max_shape should also be B.
+            scale_factor (tuple[ndarray] or ndarray): Scale factor of the
+               image arange as (w_scale, h_scale, w_scale, h_scale). In
+               `batch_mode`, the scale_factor shape is tuple[ndarray].
+            rescale (bool): If True, return boxes in original image space.
+                Default: False.
+            cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head. Default: None
+
+        Returns:
+            tuple[list[Tensor], list[Tensor]] or tuple[Tensor, Tensor]:
+                If the input has a batch dimension, the return value is
+                a tuple of the list. The first list contains the boxes of
+                the corresponding image in a batch, each tensor has the
+                shape (num_boxes, 5) and last dimension 5 represent
+                (tl_x, tl_y, br_x, br_y, score). Each Tensor in the second
+                list is the labels with shape (num_boxes, ). The length of
+                both lists should be equal to batch_size. Otherwise return
+                value is a tuple of two tensors, the first tensor is the
+                boxes with scores, the second tensor is the labels, both
+                have the same shape as the first case.
+        """
         if isinstance(cls_score, list):
             cls_score = sum(cls_score) / float(len(cls_score))
-        scores = F.softmax(cls_score, dim=1) if cls_score is not None else None
+
+        scores = F.softmax(
+            cls_score, dim=-1) if cls_score is not None else None
+
+        batch_mode = True
+        if rois.ndim == 2:
+            # e.g. AugTest, Cascade R-CNN, HTC, SCNet...
+            batch_mode = False
+
+            # add batch dimension
+            if scores is not None:
+                scores = scores.unsqueeze(0)
+            if bbox_pred is not None:
+                bbox_pred = bbox_pred.unsqueeze(0)
+            rois = rois.unsqueeze(0)
 
         if bbox_pred is not None:
             bboxes = self.bbox_coder.decode(
-                rois[:, 1:], bbox_pred, max_shape=img_shape)
+                rois[..., 1:], bbox_pred, max_shape=img_shape)
         else:
-            bboxes = rois[:, 1:].clone()
+            bboxes = rois[..., 1:].clone()
             if img_shape is not None:
-                bboxes[:, [0, 2]].clamp_(min=0, max=img_shape[1])
-                bboxes[:, [1, 3]].clamp_(min=0, max=img_shape[0])
-
-        if rescale and bboxes.size(0) > 0:
-            if isinstance(scale_factor, float):
-                bboxes /= scale_factor
+                max_shape = bboxes.new_tensor(img_shape)[..., :2]
+                min_xy = bboxes.new_tensor(0)
+                max_xy = torch.cat(
+                    [max_shape] * 2, dim=-1).flip(-1).unsqueeze(-2)
+                bboxes = torch.where(bboxes < min_xy, min_xy, bboxes)
+                bboxes = torch.where(bboxes > max_xy, max_xy, bboxes)
+
+        if rescale and bboxes.size(-2) > 0:
+            if not isinstance(scale_factor, tuple):
+                scale_factor = tuple([scale_factor])
+            # B, 1, bboxes.size(-1)
+            scale_factor = bboxes.new_tensor(scale_factor).unsqueeze(1).repeat(
+                1, 1,
+                bboxes.size(-1) // 4)
+            bboxes /= scale_factor
+
+        det_bboxes = []
+        det_labels = []
+        for (bbox, score) in zip(bboxes, scores):
+            if cfg is not None:
+                det_bbox, det_label = multiclass_nms(bbox, score,
+                                                     cfg.score_thr, cfg.nms,
+                                                     cfg.max_per_img)
             else:
-                scale_factor = bboxes.new_tensor(scale_factor)
-                bboxes = (bboxes.view(bboxes.size(0), -1, 4) /
-                          scale_factor).view(bboxes.size()[0], -1)
-
-        if cfg is None:
-            return bboxes, scores
-        else:
-            det_bboxes, det_labels = multiclass_nms(bboxes, scores,
-                                                    cfg.score_thr, cfg.nms,
-                                                    cfg.max_per_img)
-
-            return det_bboxes, det_labels
+                det_bbox, det_label = bbox, score
+            det_bboxes.append(det_bbox)
+            det_labels.append(det_label)
+
+        if not batch_mode:
+            det_bboxes = det_bboxes[0]
+            det_labels = det_labels[0]
+        return det_bboxes, det_labels
 
     @force_fp32(apply_to=('bbox_preds', ))
     def refine_bboxes(self, rois, labels, bbox_preds, pos_is_gts, img_metas):
diff --git a/mmdet/models/roi_heads/test_mixins.py b/mmdet/models/roi_heads/test_mixins.py
index d87ff91e942c3d70dc5fcece72b4c94df8d00b65..f20f6e5570c0d8eb211093bb9fcab0ecb6fadd01 100644
--- a/mmdet/models/roi_heads/test_mixins.py
+++ b/mmdet/models/roi_heads/test_mixins.py
@@ -55,51 +55,113 @@ class BBoxTestMixin(object):
                            proposals,
                            rcnn_test_cfg,
                            rescale=False):
-        """Test only det bboxes without augmentation."""
-        rois = bbox2roi(proposals)
-        bbox_results = self._bbox_forward(x, rois)
+        """Test only det bboxes without augmentation.
+
+        Args:
+            x (tuple[Tensor]): Feature maps of all scale level.
+            img_metas (list[dict]): Image meta info.
+            proposals (Tensor or List[Tensor]): Region proposals.
+            rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN.
+            rescale (bool): If True, return boxes in original image space.
+                Default: False.
+
+        Returns:
+            tuple[list[Tensor], list[Tensor]]: The first list contains
+                the boxes of the corresponding image in a batch, each
+                tensor has the shape (num_boxes, 5) and last dimension
+                5 represent (tl_x, tl_y, br_x, br_y, score). Each Tensor
+                in the second list is the labels with shape (num_boxes, ).
+                The length of both lists should be equal to batch_size.
+        """
         # get origin input shape to support onnx dynamic input shape
         if torch.onnx.is_in_onnx_export():
-            img_shapes = tuple(meta['img_shape_for_onnx']
-                               for meta in img_metas)
+            assert len(
+                img_metas
+            ) == 1, 'Only support one input image while in exporting to ONNX'
+            img_shapes = img_metas[0]['img_shape_for_onnx']
         else:
             img_shapes = tuple(meta['img_shape'] for meta in img_metas)
         scale_factors = tuple(meta['scale_factor'] for meta in img_metas)
 
-        # split batch bbox prediction back to each image
+        # The length of proposals of different batches may be different.
+        # In order to form a batch, a padding operation is required.
+        if isinstance(proposals, list):
+            # padding to form a batch
+            max_size = max([proposal.size(0) for proposal in proposals])
+            for i, proposal in enumerate(proposals):
+                supplement = proposal.new_full(
+                    (max_size - proposal.size(0), 5), 0)
+                proposals[i] = torch.cat((supplement, proposal), dim=0)
+            rois = torch.stack(proposals, dim=0)
+        else:
+            rois = proposals
+
+        batch_index = torch.arange(
+            rois.size(0), device=rois.device).float().view(-1, 1, 1).expand(
+                rois.size(0), rois.size(1), 1)
+        rois = torch.cat([batch_index, rois[..., :4]], dim=-1)
+        batch_size = rois.shape[0]
+        num_proposals_per_img = rois.shape[1]
+
+        # Eliminate the batch dimension
+        rois = rois.view(-1, 5)
+        bbox_results = self._bbox_forward(x, rois)
         cls_score = bbox_results['cls_score']
         bbox_pred = bbox_results['bbox_pred']
-        # use shape[] to keep tracing
-        num_proposals_per_img = tuple(p.shape[0] for p in proposals)
-        rois = rois.split(num_proposals_per_img, 0)
-        cls_score = cls_score.split(num_proposals_per_img, 0)
 
-        # some detector with_reg is False, bbox_pred will be None
+        # Recover the batch dimension
+        rois = rois.reshape(batch_size, num_proposals_per_img, -1)
+        cls_score = cls_score.reshape(batch_size, num_proposals_per_img, -1)
+
+        if not torch.onnx.is_in_onnx_export():
+            # remove padding
+            supplement_mask = rois[..., -1] == 0
+            cls_score[supplement_mask, :] = 0
+
+        # bbox_pred would be None in some detector when with_reg is False,
+        # e.g. Grid R-CNN.
         if bbox_pred is not None:
             # the bbox prediction of some detectors like SABL is not Tensor
             if isinstance(bbox_pred, torch.Tensor):
-                bbox_pred = bbox_pred.split(num_proposals_per_img, 0)
+                bbox_pred = bbox_pred.reshape(batch_size,
+                                              num_proposals_per_img, -1)
+                if not torch.onnx.is_in_onnx_export():
+                    bbox_pred[supplement_mask, :] = 0
             else:
-                bbox_pred = self.bbox_head.bbox_pred_split(
+                # TODO: Looking forward to a better way
+                # For SABL
+                bbox_preds = self.bbox_head.bbox_pred_split(
                     bbox_pred, num_proposals_per_img)
+                # apply bbox post-processing to each image individually
+                det_bboxes = []
+                det_labels = []
+                for i in range(len(proposals)):
+                    # remove padding
+                    supplement_mask = proposals[i][..., -1] == 0
+                    for bbox in bbox_preds[i]:
+                        bbox[supplement_mask] = 0
+                    det_bbox, det_label = self.bbox_head.get_bboxes(
+                        rois[i],
+                        cls_score[i],
+                        bbox_preds[i],
+                        img_shapes[i],
+                        scale_factors[i],
+                        rescale=rescale,
+                        cfg=rcnn_test_cfg)
+                    det_bboxes.append(det_bbox)
+                    det_labels.append(det_label)
+                return det_bboxes, det_labels
         else:
-            bbox_pred = (None, ) * len(proposals)
+            bbox_pred = None
 
-        # apply bbox post-processing to each image individually
-        det_bboxes = []
-        det_labels = []
-        for i in range(len(proposals)):
-            det_bbox, det_label = self.bbox_head.get_bboxes(
-                rois[i],
-                cls_score[i],
-                bbox_pred[i],
-                img_shapes[i],
-                scale_factors[i],
-                rescale=rescale,
-                cfg=rcnn_test_cfg)
-            det_bboxes.append(det_bbox)
-            det_labels.append(det_label)
-        return det_bboxes, det_labels
+        return self.bbox_head.get_bboxes(
+            rois,
+            cls_score,
+            bbox_pred,
+            img_shapes,
+            scale_factors,
+            rescale=rescale,
+            cfg=rcnn_test_cfg)
 
     def aug_test_bboxes(self, feats, img_metas, proposal_list, rcnn_test_cfg):
         """Test det bboxes with test time augmentation."""
diff --git a/tests/test_models/test_roi_heads/test_bbox_head.py b/tests/test_models/test_roi_heads/test_bbox_head.py
index 0a1d9f1da8faa15e1c399478cbfa5066ac190370..a7506b9b2d1b0fb86906c5d1e16283732a606131 100644
--- a/tests/test_models/test_roi_heads/test_bbox_head.py
+++ b/tests/test_models/test_roi_heads/test_bbox_head.py
@@ -1,4 +1,5 @@
 import mmcv
+import pytest
 import torch
 
 from mmdet.core import bbox2roi
@@ -64,6 +65,31 @@ def test_bbox_head_loss():
     assert losses.get('loss_bbox', 0) > 0, 'box-loss should be non-zero'
 
 
+@pytest.mark.parametrize(['num_sample', 'num_batch'], [[2, 2], [0, 2], [0, 0]])
+def test_bbox_head_get_bboxes(num_sample, num_batch):
+    self = BBoxHead(reg_class_agnostic=True)
+
+    num_class = 6
+    rois = torch.rand((num_sample, 5))
+    cls_score = torch.rand((num_sample, num_class))
+    bbox_pred = torch.rand((num_sample, 4))
+    scale_factor = 2.0
+    det_bboxes, det_labels = self.get_bboxes(
+        rois, cls_score, bbox_pred, None, scale_factor, rescale=True)
+    if num_sample == 0:
+        assert len(det_bboxes) == 0 and len(det_labels) == 0
+    else:
+        assert det_bboxes.shape == bbox_pred.shape
+        assert det_labels.shape == cls_score.shape
+
+    rois = torch.rand((num_batch, num_sample, 5))
+    cls_score = torch.rand((num_batch, num_sample, num_class))
+    bbox_pred = torch.rand((num_batch, num_sample, 4))
+    det_bboxes, det_labels = self.get_bboxes(
+        rois, cls_score, bbox_pred, None, scale_factor, rescale=True)
+    assert len(det_bboxes) == num_batch and len(det_labels) == num_batch
+
+
 def test_refine_boxes():
     """Mirrors the doctest in
     ``mmdet.models.bbox_heads.bbox_head.BBoxHead.refine_boxes`` but checks for