From 5ebba9a9221c3f6f8754d4efe293cdcd12be8dd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Haian=20Huang=28=E6=B7=B1=E5=BA=A6=E7=9C=B8=29?= <1286304229@qq.com> Date: Tue, 6 Apr 2021 17:35:38 +0800 Subject: [PATCH] Support ONNX exportable batch inference of Mask R-CNN (#4871) * Support faster rcnn * Fix Lint * Add docstr * Fix docstr * Update docstr * Update code * Update docstr * MaskRCNN support batch infer * Update * Delete todo * Fix error * Fix comments --- mmdet/models/roi_heads/test_mixins.py | 111 ++++++++++++++------------ 1 file changed, 60 insertions(+), 51 deletions(-) diff --git a/mmdet/models/roi_heads/test_mixins.py b/mmdet/models/roi_heads/test_mixins.py index 23ca4935..78a092a4 100644 --- a/mmdet/models/roi_heads/test_mixins.py +++ b/mmdet/models/roi_heads/test_mixins.py @@ -253,58 +253,67 @@ class MaskTestMixin(object): # image shapes of images in the batch ori_shapes = tuple(meta['ori_shape'] for meta in img_metas) scale_factors = tuple(meta['scale_factor'] for meta in img_metas) - num_imgs = len(det_bboxes) - if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes): - segm_results = [[[] for _ in range(self.mask_head.num_classes)] - for _ in range(num_imgs)] - else: - # if det_bboxes is rescaled to the original image size, we need to - # rescale it back to the testing scale to obtain RoIs. - if rescale and not isinstance(scale_factors[0], float): - scale_factors = [ - torch.from_numpy(scale_factor).to(det_bboxes[0].device) - for scale_factor in scale_factors - ] - if torch.onnx.is_in_onnx_export(): - # avoid mask_pred.split with static number of prediction - mask_preds = [] - _bboxes = [] - for i, boxes in enumerate(det_bboxes): - boxes = boxes[:, :4] - if rescale: - boxes *= scale_factors[i] - _bboxes.append(boxes) - img_inds = boxes[:, :1].clone() * 0 + i - mask_rois = torch.cat([img_inds, boxes], dim=-1) - mask_result = self._mask_forward(x, mask_rois) - mask_preds.append(mask_result['mask_pred']) + + # The length of proposals of different batches may be different. + # In order to form a batch, a padding operation is required. + if isinstance(det_bboxes, list): + # padding to form a batch + max_size = max([bboxes.size(0) for bboxes in det_bboxes]) + for i, (bbox, label) in enumerate(zip(det_bboxes, det_labels)): + supplement_bbox = bbox.new_full( + (max_size - bbox.size(0), bbox.size(1)), 0) + supplement_label = label.new_full((max_size - label.size(0), ), + 0) + det_bboxes[i] = torch.cat((supplement_bbox, bbox), dim=0) + det_labels[i] = torch.cat((supplement_label, label), dim=0) + det_bboxes = torch.stack(det_bboxes, dim=0) + det_labels = torch.stack(det_labels, dim=0) + + batch_size = det_bboxes.size(0) + num_proposals_per_img = det_bboxes.shape[1] + + # if det_bboxes is rescaled to the original image size, we need to + # rescale it back to the testing scale to obtain RoIs. + det_bboxes = det_bboxes[..., :4] + if rescale: + if not isinstance(scale_factors[0], float): + scale_factors = det_bboxes.new_tensor(scale_factors) + det_bboxes = det_bboxes * scale_factors.unsqueeze(1) + + batch_index = torch.arange( + det_bboxes.size(0), device=det_bboxes.device).float().view( + -1, 1, 1).expand(det_bboxes.size(0), det_bboxes.size(1), 1) + mask_rois = torch.cat([batch_index, det_bboxes], dim=-1) + mask_rois = mask_rois.view(-1, 5) + mask_results = self._mask_forward(x, mask_rois) + mask_pred = mask_results['mask_pred'] + + # Recover the batch dimension + mask_preds = mask_pred.reshape(batch_size, num_proposals_per_img, + *mask_pred.shape[1:]) + + # apply mask post-processing to each image individually + segm_results = [] + for i in range(batch_size): + mask_pred = mask_preds[i] + det_bbox = det_bboxes[i] + det_label = det_labels[i] + + # remove padding + supplement_mask = det_bbox[..., -1] != 0 + mask_pred = mask_pred[supplement_mask] + det_bbox = det_bbox[supplement_mask] + det_label = det_label[supplement_mask] + + if det_label.shape[0] == 0: + segm_results.append([[] + for _ in range(self.mask_head.num_classes) + ]) else: - _bboxes = [ - det_bboxes[i][:, :4] * - scale_factors[i] if rescale else det_bboxes[i][:, :4] - for i in range(len(det_bboxes)) - ] - mask_rois = bbox2roi(_bboxes) - mask_results = self._mask_forward(x, mask_rois) - mask_pred = mask_results['mask_pred'] - # split batch mask prediction back to each image - num_mask_roi_per_img = [ - det_bbox.shape[0] for det_bbox in det_bboxes - ] - mask_preds = mask_pred.split(num_mask_roi_per_img, 0) - - # apply mask post-processing to each image individually - segm_results = [] - for i in range(num_imgs): - if det_bboxes[i].shape[0] == 0: - segm_results.append( - [[] for _ in range(self.mask_head.num_classes)]) - else: - segm_result = self.mask_head.get_seg_masks( - mask_preds[i], _bboxes[i], det_labels[i], - self.test_cfg, ori_shapes[i], scale_factors[i], - rescale) - segm_results.append(segm_result) + segm_result = self.mask_head.get_seg_masks( + mask_pred, det_bbox, det_label, self.test_cfg, + ori_shapes[i], scale_factors[i], rescale) + segm_results.append(segm_result) return segm_results def aug_test_mask(self, feats, img_metas, det_bboxes, det_labels): -- GitLab