Skip to content
Snippets Groups Projects
Unverified Commit 9946e121 authored by RunningLeon's avatar RunningLeon Committed by GitHub
Browse files

[Feature]: Support ONNX inference with dynamic input shape in AnchorHead (#4684)

* make anchor_generator exportable to ONNX

* make k of topk dynamic for onnx

* rename nms_pre_t -> nms_pre_tensor
parent 66ccfeb2
No related branches found
No related tags found
No related merge requests found
...@@ -196,8 +196,9 @@ class AnchorGenerator(object): ...@@ -196,8 +196,9 @@ class AnchorGenerator(object):
Returns: Returns:
tuple[torch.Tensor]: The mesh grids of x and y. tuple[torch.Tensor]: The mesh grids of x and y.
""" """
xx = x.repeat(len(y)) # use shape instead of len to keep tracing while exporting to onnx
yy = y.view(-1, 1).repeat(1, len(x)).view(-1) xx = x.repeat(y.shape[0])
yy = y.view(-1, 1).repeat(1, x.shape[0]).view(-1)
if row_major: if row_major:
return xx, yy return xx, yy
else: else:
...@@ -250,10 +251,8 @@ class AnchorGenerator(object): ...@@ -250,10 +251,8 @@ class AnchorGenerator(object):
Returns: Returns:
torch.Tensor: Anchors in the overall feature maps. torch.Tensor: Anchors in the overall feature maps.
""" """
# keep as Tensor, so that we can covert to ONNX correctly
feat_h, feat_w = featmap_size feat_h, feat_w = featmap_size
# convert Tensor to int, so that we can covert to ONNX correctlly
feat_h = int(feat_h)
feat_w = int(feat_w)
shift_x = torch.arange(0, feat_w, device=device) * stride[0] shift_x = torch.arange(0, feat_w, device=device) * stride[0]
shift_y = torch.arange(0, feat_h, device=device) * stride[1] shift_y = torch.arange(0, feat_h, device=device) * stride[1]
......
...@@ -620,6 +620,11 @@ class AnchorHead(BaseDenseHead, BBoxTestMixin): ...@@ -620,6 +620,11 @@ class AnchorHead(BaseDenseHead, BBoxTestMixin):
""" """
cfg = self.test_cfg if cfg is None else cfg cfg = self.test_cfg if cfg is None else cfg
assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors) assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors)
# convert to tensor to keep tracing
nms_pre_tensor = torch.tensor(
cfg.get('nms_pre', -1),
device=cls_score_list[0].device,
dtype=torch.long)
mlvl_bboxes = [] mlvl_bboxes = []
mlvl_scores = [] mlvl_scores = []
for cls_score, bbox_pred, anchors in zip(cls_score_list, for cls_score, bbox_pred, anchors in zip(cls_score_list,
...@@ -632,8 +637,14 @@ class AnchorHead(BaseDenseHead, BBoxTestMixin): ...@@ -632,8 +637,14 @@ class AnchorHead(BaseDenseHead, BBoxTestMixin):
else: else:
scores = cls_score.softmax(-1) scores = cls_score.softmax(-1)
bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
nms_pre = cfg.get('nms_pre', -1) # Always keep topk op for dynamic input in onnx
if nms_pre > 0 and scores.shape[0] > nms_pre: if nms_pre_tensor > 0 and (torch.onnx.is_in_onnx_export()
or scores.shape[-2] > nms_pre_tensor):
from torch import _shape_as_tensor
# keep shape as tensor and get k
num_anchor = _shape_as_tensor(scores)[-2]
nms_pre = torch.where(nms_pre_tensor < num_anchor,
nms_pre_tensor, num_anchor)
# Get maximum scores for foreground classes. # Get maximum scores for foreground classes.
if self.use_sigmoid_cls: if self.use_sigmoid_cls:
max_scores, _ = scores.max(dim=1) max_scores, _ = scores.max(dim=1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment