Skip to content
Snippets Groups Projects
Unverified Commit 5ebba9a9 authored by Haian Huang(深度眸)'s avatar Haian Huang(深度眸) Committed by GitHub
Browse files

Support ONNX exportable batch inference of Mask R-CNN (#4871)

* Support faster rcnn

* Fix Lint

* Add docstr

* Fix docstr

* Update docstr

* Update code

* Update docstr

* MaskRCNN support batch infer

* Update

* Delete todo

* Fix error

* Fix comments
parent a56638fb
No related branches found
No related tags found
No related merge requests found
......@@ -253,58 +253,67 @@ class MaskTestMixin(object):
# image shapes of images in the batch
ori_shapes = tuple(meta['ori_shape'] for meta in img_metas)
scale_factors = tuple(meta['scale_factor'] for meta in img_metas)
num_imgs = len(det_bboxes)
if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes):
segm_results = [[[] for _ in range(self.mask_head.num_classes)]
for _ in range(num_imgs)]
else:
# if det_bboxes is rescaled to the original image size, we need to
# rescale it back to the testing scale to obtain RoIs.
if rescale and not isinstance(scale_factors[0], float):
scale_factors = [
torch.from_numpy(scale_factor).to(det_bboxes[0].device)
for scale_factor in scale_factors
]
if torch.onnx.is_in_onnx_export():
# avoid mask_pred.split with static number of prediction
mask_preds = []
_bboxes = []
for i, boxes in enumerate(det_bboxes):
boxes = boxes[:, :4]
if rescale:
boxes *= scale_factors[i]
_bboxes.append(boxes)
img_inds = boxes[:, :1].clone() * 0 + i
mask_rois = torch.cat([img_inds, boxes], dim=-1)
mask_result = self._mask_forward(x, mask_rois)
mask_preds.append(mask_result['mask_pred'])
# The length of proposals of different batches may be different.
# In order to form a batch, a padding operation is required.
if isinstance(det_bboxes, list):
# padding to form a batch
max_size = max([bboxes.size(0) for bboxes in det_bboxes])
for i, (bbox, label) in enumerate(zip(det_bboxes, det_labels)):
supplement_bbox = bbox.new_full(
(max_size - bbox.size(0), bbox.size(1)), 0)
supplement_label = label.new_full((max_size - label.size(0), ),
0)
det_bboxes[i] = torch.cat((supplement_bbox, bbox), dim=0)
det_labels[i] = torch.cat((supplement_label, label), dim=0)
det_bboxes = torch.stack(det_bboxes, dim=0)
det_labels = torch.stack(det_labels, dim=0)
batch_size = det_bboxes.size(0)
num_proposals_per_img = det_bboxes.shape[1]
# if det_bboxes is rescaled to the original image size, we need to
# rescale it back to the testing scale to obtain RoIs.
det_bboxes = det_bboxes[..., :4]
if rescale:
if not isinstance(scale_factors[0], float):
scale_factors = det_bboxes.new_tensor(scale_factors)
det_bboxes = det_bboxes * scale_factors.unsqueeze(1)
batch_index = torch.arange(
det_bboxes.size(0), device=det_bboxes.device).float().view(
-1, 1, 1).expand(det_bboxes.size(0), det_bboxes.size(1), 1)
mask_rois = torch.cat([batch_index, det_bboxes], dim=-1)
mask_rois = mask_rois.view(-1, 5)
mask_results = self._mask_forward(x, mask_rois)
mask_pred = mask_results['mask_pred']
# Recover the batch dimension
mask_preds = mask_pred.reshape(batch_size, num_proposals_per_img,
*mask_pred.shape[1:])
# apply mask post-processing to each image individually
segm_results = []
for i in range(batch_size):
mask_pred = mask_preds[i]
det_bbox = det_bboxes[i]
det_label = det_labels[i]
# remove padding
supplement_mask = det_bbox[..., -1] != 0
mask_pred = mask_pred[supplement_mask]
det_bbox = det_bbox[supplement_mask]
det_label = det_label[supplement_mask]
if det_label.shape[0] == 0:
segm_results.append([[]
for _ in range(self.mask_head.num_classes)
])
else:
_bboxes = [
det_bboxes[i][:, :4] *
scale_factors[i] if rescale else det_bboxes[i][:, :4]
for i in range(len(det_bboxes))
]
mask_rois = bbox2roi(_bboxes)
mask_results = self._mask_forward(x, mask_rois)
mask_pred = mask_results['mask_pred']
# split batch mask prediction back to each image
num_mask_roi_per_img = [
det_bbox.shape[0] for det_bbox in det_bboxes
]
mask_preds = mask_pred.split(num_mask_roi_per_img, 0)
# apply mask post-processing to each image individually
segm_results = []
for i in range(num_imgs):
if det_bboxes[i].shape[0] == 0:
segm_results.append(
[[] for _ in range(self.mask_head.num_classes)])
else:
segm_result = self.mask_head.get_seg_masks(
mask_preds[i], _bboxes[i], det_labels[i],
self.test_cfg, ori_shapes[i], scale_factors[i],
rescale)
segm_results.append(segm_result)
segm_result = self.mask_head.get_seg_masks(
mask_pred, det_bbox, det_label, self.test_cfg,
ori_shapes[i], scale_factors[i], rescale)
segm_results.append(segm_result)
return segm_results
def aug_test_mask(self, feats, img_metas, det_bboxes, det_labels):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment