Skip to content
Snippets Groups Projects
Unverified Commit 263ae90c authored by Haian Huang(深度眸)'s avatar Haian Huang(深度眸) Committed by GitHub
Browse files

[Feature]: Support ONNX tracable Batch Inference for Faster R-CNN (#4785)

* Support faster rcnn

* Fix Lint

* Add docstr

* Fix docstr

* Update docstr

* Update code

* Update docstr
parent e3857b5f
No related branches found
No related tags found
No related merge requests found
......@@ -275,35 +275,102 @@ class BBoxHead(nn.Module):
scale_factor,
rescale=False,
cfg=None):
"""Transform network output for a batch into bbox predictions.
If the input rois has batch dimension, the function would be in
`batch_mode` and return is a tuple[list[Tensor], list[Tensor]],
otherwise, the return is a tuple[Tensor, Tensor].
Args:
rois (Tensor): Boxes to be transformed. Has shape (num_boxes, 5)
or (B, num_boxes, 5)
cls_score (list[Tensor] or Tensor): Box scores for
each scale level, each is a 4D-tensor, the channel number is
num_points * num_classes.
bbox_pred (Tensor, optional): Box energies / deltas for each scale
level, each is a 4D-tensor, the channel number is
num_classes * 4.
img_shape (Sequence[int] or torch.Tensor or Sequence[
Sequence[int]], optional): Maximum bounds for boxes, specifies
(H, W, C) or (H, W). If rois shape is (B, num_boxes, 4), then
the max_shape should be a Sequence[Sequence[int]]
and the length of max_shape should also be B.
scale_factor (tuple[ndarray] or ndarray): Scale factor of the
image arange as (w_scale, h_scale, w_scale, h_scale). In
`batch_mode`, the scale_factor shape is tuple[ndarray].
rescale (bool): If True, return boxes in original image space.
Default: False.
cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head. Default: None
Returns:
tuple[list[Tensor], list[Tensor]] or tuple[Tensor, Tensor]:
If the input has a batch dimension, the return value is
a tuple of the list. The first list contains the boxes of
the corresponding image in a batch, each tensor has the
shape (num_boxes, 5) and last dimension 5 represent
(tl_x, tl_y, br_x, br_y, score). Each Tensor in the second
list is the labels with shape (num_boxes, ). The length of
both lists should be equal to batch_size. Otherwise return
value is a tuple of two tensors, the first tensor is the
boxes with scores, the second tensor is the labels, both
have the same shape as the first case.
"""
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
scores = F.softmax(cls_score, dim=1) if cls_score is not None else None
scores = F.softmax(
cls_score, dim=-1) if cls_score is not None else None
batch_mode = True
if rois.ndim == 2:
# e.g. AugTest, Cascade R-CNN, HTC, SCNet...
batch_mode = False
# add batch dimension
if scores is not None:
scores = scores.unsqueeze(0)
if bbox_pred is not None:
bbox_pred = bbox_pred.unsqueeze(0)
rois = rois.unsqueeze(0)
if bbox_pred is not None:
bboxes = self.bbox_coder.decode(
rois[:, 1:], bbox_pred, max_shape=img_shape)
rois[..., 1:], bbox_pred, max_shape=img_shape)
else:
bboxes = rois[:, 1:].clone()
bboxes = rois[..., 1:].clone()
if img_shape is not None:
bboxes[:, [0, 2]].clamp_(min=0, max=img_shape[1])
bboxes[:, [1, 3]].clamp_(min=0, max=img_shape[0])
if rescale and bboxes.size(0) > 0:
if isinstance(scale_factor, float):
bboxes /= scale_factor
max_shape = bboxes.new_tensor(img_shape)[..., :2]
min_xy = bboxes.new_tensor(0)
max_xy = torch.cat(
[max_shape] * 2, dim=-1).flip(-1).unsqueeze(-2)
bboxes = torch.where(bboxes < min_xy, min_xy, bboxes)
bboxes = torch.where(bboxes > max_xy, max_xy, bboxes)
if rescale and bboxes.size(-2) > 0:
if not isinstance(scale_factor, tuple):
scale_factor = tuple([scale_factor])
# B, 1, bboxes.size(-1)
scale_factor = bboxes.new_tensor(scale_factor).unsqueeze(1).repeat(
1, 1,
bboxes.size(-1) // 4)
bboxes /= scale_factor
det_bboxes = []
det_labels = []
for (bbox, score) in zip(bboxes, scores):
if cfg is not None:
det_bbox, det_label = multiclass_nms(bbox, score,
cfg.score_thr, cfg.nms,
cfg.max_per_img)
else:
scale_factor = bboxes.new_tensor(scale_factor)
bboxes = (bboxes.view(bboxes.size(0), -1, 4) /
scale_factor).view(bboxes.size()[0], -1)
if cfg is None:
return bboxes, scores
else:
det_bboxes, det_labels = multiclass_nms(bboxes, scores,
cfg.score_thr, cfg.nms,
cfg.max_per_img)
return det_bboxes, det_labels
det_bbox, det_label = bbox, score
det_bboxes.append(det_bbox)
det_labels.append(det_label)
if not batch_mode:
det_bboxes = det_bboxes[0]
det_labels = det_labels[0]
return det_bboxes, det_labels
@force_fp32(apply_to=('bbox_preds', ))
def refine_bboxes(self, rois, labels, bbox_preds, pos_is_gts, img_metas):
......
......@@ -55,51 +55,113 @@ class BBoxTestMixin(object):
proposals,
rcnn_test_cfg,
rescale=False):
"""Test only det bboxes without augmentation."""
rois = bbox2roi(proposals)
bbox_results = self._bbox_forward(x, rois)
"""Test only det bboxes without augmentation.
Args:
x (tuple[Tensor]): Feature maps of all scale level.
img_metas (list[dict]): Image meta info.
proposals (Tensor or List[Tensor]): Region proposals.
rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN.
rescale (bool): If True, return boxes in original image space.
Default: False.
Returns:
tuple[list[Tensor], list[Tensor]]: The first list contains
the boxes of the corresponding image in a batch, each
tensor has the shape (num_boxes, 5) and last dimension
5 represent (tl_x, tl_y, br_x, br_y, score). Each Tensor
in the second list is the labels with shape (num_boxes, ).
The length of both lists should be equal to batch_size.
"""
# get origin input shape to support onnx dynamic input shape
if torch.onnx.is_in_onnx_export():
img_shapes = tuple(meta['img_shape_for_onnx']
for meta in img_metas)
assert len(
img_metas
) == 1, 'Only support one input image while in exporting to ONNX'
img_shapes = img_metas[0]['img_shape_for_onnx']
else:
img_shapes = tuple(meta['img_shape'] for meta in img_metas)
scale_factors = tuple(meta['scale_factor'] for meta in img_metas)
# split batch bbox prediction back to each image
# The length of proposals of different batches may be different.
# In order to form a batch, a padding operation is required.
if isinstance(proposals, list):
# padding to form a batch
max_size = max([proposal.size(0) for proposal in proposals])
for i, proposal in enumerate(proposals):
supplement = proposal.new_full(
(max_size - proposal.size(0), 5), 0)
proposals[i] = torch.cat((supplement, proposal), dim=0)
rois = torch.stack(proposals, dim=0)
else:
rois = proposals
batch_index = torch.arange(
rois.size(0), device=rois.device).float().view(-1, 1, 1).expand(
rois.size(0), rois.size(1), 1)
rois = torch.cat([batch_index, rois[..., :4]], dim=-1)
batch_size = rois.shape[0]
num_proposals_per_img = rois.shape[1]
# Eliminate the batch dimension
rois = rois.view(-1, 5)
bbox_results = self._bbox_forward(x, rois)
cls_score = bbox_results['cls_score']
bbox_pred = bbox_results['bbox_pred']
# use shape[] to keep tracing
num_proposals_per_img = tuple(p.shape[0] for p in proposals)
rois = rois.split(num_proposals_per_img, 0)
cls_score = cls_score.split(num_proposals_per_img, 0)
# some detector with_reg is False, bbox_pred will be None
# Recover the batch dimension
rois = rois.reshape(batch_size, num_proposals_per_img, -1)
cls_score = cls_score.reshape(batch_size, num_proposals_per_img, -1)
if not torch.onnx.is_in_onnx_export():
# remove padding
supplement_mask = rois[..., -1] == 0
cls_score[supplement_mask, :] = 0
# bbox_pred would be None in some detector when with_reg is False,
# e.g. Grid R-CNN.
if bbox_pred is not None:
# the bbox prediction of some detectors like SABL is not Tensor
if isinstance(bbox_pred, torch.Tensor):
bbox_pred = bbox_pred.split(num_proposals_per_img, 0)
bbox_pred = bbox_pred.reshape(batch_size,
num_proposals_per_img, -1)
if not torch.onnx.is_in_onnx_export():
bbox_pred[supplement_mask, :] = 0
else:
bbox_pred = self.bbox_head.bbox_pred_split(
# TODO: Looking forward to a better way
# For SABL
bbox_preds = self.bbox_head.bbox_pred_split(
bbox_pred, num_proposals_per_img)
# apply bbox post-processing to each image individually
det_bboxes = []
det_labels = []
for i in range(len(proposals)):
# remove padding
supplement_mask = proposals[i][..., -1] == 0
for bbox in bbox_preds[i]:
bbox[supplement_mask] = 0
det_bbox, det_label = self.bbox_head.get_bboxes(
rois[i],
cls_score[i],
bbox_preds[i],
img_shapes[i],
scale_factors[i],
rescale=rescale,
cfg=rcnn_test_cfg)
det_bboxes.append(det_bbox)
det_labels.append(det_label)
return det_bboxes, det_labels
else:
bbox_pred = (None, ) * len(proposals)
bbox_pred = None
# apply bbox post-processing to each image individually
det_bboxes = []
det_labels = []
for i in range(len(proposals)):
det_bbox, det_label = self.bbox_head.get_bboxes(
rois[i],
cls_score[i],
bbox_pred[i],
img_shapes[i],
scale_factors[i],
rescale=rescale,
cfg=rcnn_test_cfg)
det_bboxes.append(det_bbox)
det_labels.append(det_label)
return det_bboxes, det_labels
return self.bbox_head.get_bboxes(
rois,
cls_score,
bbox_pred,
img_shapes,
scale_factors,
rescale=rescale,
cfg=rcnn_test_cfg)
def aug_test_bboxes(self, feats, img_metas, proposal_list, rcnn_test_cfg):
"""Test det bboxes with test time augmentation."""
......
import mmcv
import pytest
import torch
from mmdet.core import bbox2roi
......@@ -64,6 +65,31 @@ def test_bbox_head_loss():
assert losses.get('loss_bbox', 0) > 0, 'box-loss should be non-zero'
@pytest.mark.parametrize(['num_sample', 'num_batch'], [[2, 2], [0, 2], [0, 0]])
def test_bbox_head_get_bboxes(num_sample, num_batch):
self = BBoxHead(reg_class_agnostic=True)
num_class = 6
rois = torch.rand((num_sample, 5))
cls_score = torch.rand((num_sample, num_class))
bbox_pred = torch.rand((num_sample, 4))
scale_factor = 2.0
det_bboxes, det_labels = self.get_bboxes(
rois, cls_score, bbox_pred, None, scale_factor, rescale=True)
if num_sample == 0:
assert len(det_bboxes) == 0 and len(det_labels) == 0
else:
assert det_bboxes.shape == bbox_pred.shape
assert det_labels.shape == cls_score.shape
rois = torch.rand((num_batch, num_sample, 5))
cls_score = torch.rand((num_batch, num_sample, num_class))
bbox_pred = torch.rand((num_batch, num_sample, 4))
det_bboxes, det_labels = self.get_bboxes(
rois, cls_score, bbox_pred, None, scale_factor, rescale=True)
assert len(det_bboxes) == num_batch and len(det_labels) == num_batch
def test_refine_boxes():
"""Mirrors the doctest in
``mmdet.models.bbox_heads.bbox_head.BBoxHead.refine_boxes`` but checks for
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment