Skip to content
Snippets Groups Projects
Unverified Commit 8824417b authored by tangyanf's avatar tangyanf Committed by GitHub
Browse files

[Feature]: export onnx model to support dynamic input size (#4685)

* modify get_bboes() to export onnx model which support dynamic input shape

* update code

* update code

* update code

* update code

* update code

* update code

* update code

* update code

* update code

* update code
parent 9946e121
No related branches found
No related tags found
No related merge requests found
......@@ -199,9 +199,27 @@ def delta2bbox(rois,
x2 = gx + gw * 0.5
y2 = gy + gh * 0.5
if clip_border and max_shape is not None:
x1 = x1.clamp(min=0, max=max_shape[1])
y1 = y1.clamp(min=0, max=max_shape[0])
x2 = x2.clamp(min=0, max=max_shape[1])
y2 = y2.clamp(min=0, max=max_shape[0])
# use where() to replace clip(),
# because clip()'s attr min/max do not support dynamic in onnx
if torch.onnx.is_in_onnx_export():
zero = torch.tensor(0, dtype=torch.float32)
zero = zero.expand(x1.size())
width = torch.tensor(max_shape[1], dtype=torch.float32)
width = width.expand(x1.size())
height = torch.tensor(max_shape[0], dtype=torch.float32)
height = height.expand(x1.size())
x1 = torch.where(x1 < zero, zero, x1)
x1 = torch.where(x1 > width, width, x1)
y1 = torch.where(y1 < zero, zero, y1)
y1 = torch.where(y1 > height, height, y1)
x2 = torch.where(x2 < zero, zero, x2)
x2 = torch.where(x2 > width, width, x2)
y2 = torch.where(y2 < zero, zero, y2)
y2 = torch.where(y2 > height, height, y2)
else:
x1 = x1.clamp(min=0, max=max_shape[1])
y1 = y1.clamp(min=0, max=max_shape[0])
x2 = x2.clamp(min=0, max=max_shape[1])
y2 = y2.clamp(min=0, max=max_shape[0])
bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size())
return bboxes
......@@ -567,7 +567,11 @@ class AnchorHead(BaseDenseHead, BBoxTestMixin):
bbox_pred_list = [
bbox_preds[i][img_id].detach() for i in range(num_levels)
]
img_shape = img_metas[img_id]['img_shape']
# get origin input shape to support onnx dynamic shape
if torch.onnx.is_in_onnx_export():
img_shape = img_metas[img_id]['img_shape_for_onnx']
else:
img_shape = img_metas[img_id]['img_shape']
scale_factor = img_metas[img_id]['scale_factor']
if with_nms:
# some heads don't support with_nms argument
......
......@@ -111,6 +111,9 @@ class SingleStageDetector(BaseDetector):
"""
x = self.extract_feat(img)
outs = self.bbox_head(x)
# get origin input shape to support onnx dynamic shape
if torch.onnx.is_in_onnx_export():
img_metas[0]['img_shape_for_onnx'] = img.shape[2:]
bbox_list = self.bbox_head.get_bboxes(
*outs, img_metas, rescale=rescale)
# skip post-processing when exporting to ONNX
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment