From 8824417be9218e9c300c13dcd2ade46958002f3c Mon Sep 17 00:00:00 2001 From: tangyanf <18844194126@163.com> Date: Thu, 4 Mar 2021 23:32:06 -0800 Subject: [PATCH] [Feature]: export onnx model to support dynamic input size (#4685) * modify get_bboes() to export onnx model which support dynamic input shape * update code * update code * update code * update code * update code * update code * update code * update code * update code * update code --- .../core/bbox/coder/delta_xywh_bbox_coder.py | 26 ++++++++++++++++--- mmdet/models/dense_heads/anchor_head.py | 6 ++++- mmdet/models/detectors/single_stage.py | 3 +++ 3 files changed, 30 insertions(+), 5 deletions(-) diff --git a/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py b/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py index 4cf8ac87..1c6bc36e 100644 --- a/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py +++ b/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py @@ -199,9 +199,27 @@ def delta2bbox(rois, x2 = gx + gw * 0.5 y2 = gy + gh * 0.5 if clip_border and max_shape is not None: - x1 = x1.clamp(min=0, max=max_shape[1]) - y1 = y1.clamp(min=0, max=max_shape[0]) - x2 = x2.clamp(min=0, max=max_shape[1]) - y2 = y2.clamp(min=0, max=max_shape[0]) + # use where() to replace clip(), + # because clip()'s attr min/max do not support dynamic in onnx + if torch.onnx.is_in_onnx_export(): + zero = torch.tensor(0, dtype=torch.float32) + zero = zero.expand(x1.size()) + width = torch.tensor(max_shape[1], dtype=torch.float32) + width = width.expand(x1.size()) + height = torch.tensor(max_shape[0], dtype=torch.float32) + height = height.expand(x1.size()) + x1 = torch.where(x1 < zero, zero, x1) + x1 = torch.where(x1 > width, width, x1) + y1 = torch.where(y1 < zero, zero, y1) + y1 = torch.where(y1 > height, height, y1) + x2 = torch.where(x2 < zero, zero, x2) + x2 = torch.where(x2 > width, width, x2) + y2 = torch.where(y2 < zero, zero, y2) + y2 = torch.where(y2 > height, height, y2) + else: + x1 = x1.clamp(min=0, max=max_shape[1]) + y1 = y1.clamp(min=0, max=max_shape[0]) + x2 = x2.clamp(min=0, max=max_shape[1]) + y2 = y2.clamp(min=0, max=max_shape[0]) bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size()) return bboxes diff --git a/mmdet/models/dense_heads/anchor_head.py b/mmdet/models/dense_heads/anchor_head.py index 8ae50d38..87e60136 100644 --- a/mmdet/models/dense_heads/anchor_head.py +++ b/mmdet/models/dense_heads/anchor_head.py @@ -567,7 +567,11 @@ class AnchorHead(BaseDenseHead, BBoxTestMixin): bbox_pred_list = [ bbox_preds[i][img_id].detach() for i in range(num_levels) ] - img_shape = img_metas[img_id]['img_shape'] + # get origin input shape to support onnx dynamic shape + if torch.onnx.is_in_onnx_export(): + img_shape = img_metas[img_id]['img_shape_for_onnx'] + else: + img_shape = img_metas[img_id]['img_shape'] scale_factor = img_metas[img_id]['scale_factor'] if with_nms: # some heads don't support with_nms argument diff --git a/mmdet/models/detectors/single_stage.py b/mmdet/models/detectors/single_stage.py index ce6b26ea..c2bace62 100644 --- a/mmdet/models/detectors/single_stage.py +++ b/mmdet/models/detectors/single_stage.py @@ -111,6 +111,9 @@ class SingleStageDetector(BaseDetector): """ x = self.extract_feat(img) outs = self.bbox_head(x) + # get origin input shape to support onnx dynamic shape + if torch.onnx.is_in_onnx_export(): + img_metas[0]['img_shape_for_onnx'] = img.shape[2:] bbox_list = self.bbox_head.get_bboxes( *outs, img_metas, rescale=rescale) # skip post-processing when exporting to ONNX -- GitLab