From 5ebba9a9221c3f6f8754d4efe293cdcd12be8dd4 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Haian=20Huang=28=E6=B7=B1=E5=BA=A6=E7=9C=B8=29?=
 <1286304229@qq.com>
Date: Tue, 6 Apr 2021 17:35:38 +0800
Subject: [PATCH] Support ONNX exportable batch inference of Mask R-CNN (#4871)

* Support faster rcnn

* Fix Lint

* Add docstr

* Fix docstr

* Update docstr

* Update code

* Update docstr

* MaskRCNN support batch infer

* Update

* Delete todo

* Fix error

* Fix comments
---
 mmdet/models/roi_heads/test_mixins.py | 111 ++++++++++++++------------
 1 file changed, 60 insertions(+), 51 deletions(-)

diff --git a/mmdet/models/roi_heads/test_mixins.py b/mmdet/models/roi_heads/test_mixins.py
index 23ca4935..78a092a4 100644
--- a/mmdet/models/roi_heads/test_mixins.py
+++ b/mmdet/models/roi_heads/test_mixins.py
@@ -253,58 +253,67 @@ class MaskTestMixin(object):
         # image shapes of images in the batch
         ori_shapes = tuple(meta['ori_shape'] for meta in img_metas)
         scale_factors = tuple(meta['scale_factor'] for meta in img_metas)
-        num_imgs = len(det_bboxes)
-        if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes):
-            segm_results = [[[] for _ in range(self.mask_head.num_classes)]
-                            for _ in range(num_imgs)]
-        else:
-            # if det_bboxes is rescaled to the original image size, we need to
-            # rescale it back to the testing scale to obtain RoIs.
-            if rescale and not isinstance(scale_factors[0], float):
-                scale_factors = [
-                    torch.from_numpy(scale_factor).to(det_bboxes[0].device)
-                    for scale_factor in scale_factors
-                ]
-            if torch.onnx.is_in_onnx_export():
-                # avoid mask_pred.split with static number of prediction
-                mask_preds = []
-                _bboxes = []
-                for i, boxes in enumerate(det_bboxes):
-                    boxes = boxes[:, :4]
-                    if rescale:
-                        boxes *= scale_factors[i]
-                    _bboxes.append(boxes)
-                    img_inds = boxes[:, :1].clone() * 0 + i
-                    mask_rois = torch.cat([img_inds, boxes], dim=-1)
-                    mask_result = self._mask_forward(x, mask_rois)
-                    mask_preds.append(mask_result['mask_pred'])
+
+        # The length of proposals of different batches may be different.
+        # In order to form a batch, a padding operation is required.
+        if isinstance(det_bboxes, list):
+            # padding to form a batch
+            max_size = max([bboxes.size(0) for bboxes in det_bboxes])
+            for i, (bbox, label) in enumerate(zip(det_bboxes, det_labels)):
+                supplement_bbox = bbox.new_full(
+                    (max_size - bbox.size(0), bbox.size(1)), 0)
+                supplement_label = label.new_full((max_size - label.size(0), ),
+                                                  0)
+                det_bboxes[i] = torch.cat((supplement_bbox, bbox), dim=0)
+                det_labels[i] = torch.cat((supplement_label, label), dim=0)
+            det_bboxes = torch.stack(det_bboxes, dim=0)
+            det_labels = torch.stack(det_labels, dim=0)
+
+        batch_size = det_bboxes.size(0)
+        num_proposals_per_img = det_bboxes.shape[1]
+
+        # if det_bboxes is rescaled to the original image size, we need to
+        # rescale it back to the testing scale to obtain RoIs.
+        det_bboxes = det_bboxes[..., :4]
+        if rescale:
+            if not isinstance(scale_factors[0], float):
+                scale_factors = det_bboxes.new_tensor(scale_factors)
+            det_bboxes = det_bboxes * scale_factors.unsqueeze(1)
+
+        batch_index = torch.arange(
+            det_bboxes.size(0), device=det_bboxes.device).float().view(
+                -1, 1, 1).expand(det_bboxes.size(0), det_bboxes.size(1), 1)
+        mask_rois = torch.cat([batch_index, det_bboxes], dim=-1)
+        mask_rois = mask_rois.view(-1, 5)
+        mask_results = self._mask_forward(x, mask_rois)
+        mask_pred = mask_results['mask_pred']
+
+        # Recover the batch dimension
+        mask_preds = mask_pred.reshape(batch_size, num_proposals_per_img,
+                                       *mask_pred.shape[1:])
+
+        # apply mask post-processing to each image individually
+        segm_results = []
+        for i in range(batch_size):
+            mask_pred = mask_preds[i]
+            det_bbox = det_bboxes[i]
+            det_label = det_labels[i]
+
+            # remove padding
+            supplement_mask = det_bbox[..., -1] != 0
+            mask_pred = mask_pred[supplement_mask]
+            det_bbox = det_bbox[supplement_mask]
+            det_label = det_label[supplement_mask]
+
+            if det_label.shape[0] == 0:
+                segm_results.append([[]
+                                     for _ in range(self.mask_head.num_classes)
+                                     ])
             else:
-                _bboxes = [
-                    det_bboxes[i][:, :4] *
-                    scale_factors[i] if rescale else det_bboxes[i][:, :4]
-                    for i in range(len(det_bboxes))
-                ]
-                mask_rois = bbox2roi(_bboxes)
-                mask_results = self._mask_forward(x, mask_rois)
-                mask_pred = mask_results['mask_pred']
-                # split batch mask prediction back to each image
-                num_mask_roi_per_img = [
-                    det_bbox.shape[0] for det_bbox in det_bboxes
-                ]
-                mask_preds = mask_pred.split(num_mask_roi_per_img, 0)
-
-            # apply mask post-processing to each image individually
-            segm_results = []
-            for i in range(num_imgs):
-                if det_bboxes[i].shape[0] == 0:
-                    segm_results.append(
-                        [[] for _ in range(self.mask_head.num_classes)])
-                else:
-                    segm_result = self.mask_head.get_seg_masks(
-                        mask_preds[i], _bboxes[i], det_labels[i],
-                        self.test_cfg, ori_shapes[i], scale_factors[i],
-                        rescale)
-                    segm_results.append(segm_result)
+                segm_result = self.mask_head.get_seg_masks(
+                    mask_pred, det_bbox, det_label, self.test_cfg,
+                    ori_shapes[i], scale_factors[i], rescale)
+                segm_results.append(segm_result)
         return segm_results
 
     def aug_test_mask(self, feats, img_metas, det_bboxes, det_labels):
-- 
GitLab