From 8824417be9218e9c300c13dcd2ade46958002f3c Mon Sep 17 00:00:00 2001
From: tangyanf <18844194126@163.com>
Date: Thu, 4 Mar 2021 23:32:06 -0800
Subject: [PATCH] [Feature]: export onnx model to support dynamic input size
 (#4685)

* modify get_bboes() to export onnx model which support dynamic input shape

* update code

* update code

* update code

* update code

* update code

* update code

* update code

* update code

* update code

* update code
---
 .../core/bbox/coder/delta_xywh_bbox_coder.py  | 26 ++++++++++++++++---
 mmdet/models/dense_heads/anchor_head.py       |  6 ++++-
 mmdet/models/detectors/single_stage.py        |  3 +++
 3 files changed, 30 insertions(+), 5 deletions(-)

diff --git a/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py b/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py
index 4cf8ac87..1c6bc36e 100644
--- a/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py
+++ b/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py
@@ -199,9 +199,27 @@ def delta2bbox(rois,
     x2 = gx + gw * 0.5
     y2 = gy + gh * 0.5
     if clip_border and max_shape is not None:
-        x1 = x1.clamp(min=0, max=max_shape[1])
-        y1 = y1.clamp(min=0, max=max_shape[0])
-        x2 = x2.clamp(min=0, max=max_shape[1])
-        y2 = y2.clamp(min=0, max=max_shape[0])
+        # use where() to replace clip(),
+        # because clip()'s attr min/max do not support dynamic in onnx
+        if torch.onnx.is_in_onnx_export():
+            zero = torch.tensor(0, dtype=torch.float32)
+            zero = zero.expand(x1.size())
+            width = torch.tensor(max_shape[1], dtype=torch.float32)
+            width = width.expand(x1.size())
+            height = torch.tensor(max_shape[0], dtype=torch.float32)
+            height = height.expand(x1.size())
+            x1 = torch.where(x1 < zero, zero, x1)
+            x1 = torch.where(x1 > width, width, x1)
+            y1 = torch.where(y1 < zero, zero, y1)
+            y1 = torch.where(y1 > height, height, y1)
+            x2 = torch.where(x2 < zero, zero, x2)
+            x2 = torch.where(x2 > width, width, x2)
+            y2 = torch.where(y2 < zero, zero, y2)
+            y2 = torch.where(y2 > height, height, y2)
+        else:
+            x1 = x1.clamp(min=0, max=max_shape[1])
+            y1 = y1.clamp(min=0, max=max_shape[0])
+            x2 = x2.clamp(min=0, max=max_shape[1])
+            y2 = y2.clamp(min=0, max=max_shape[0])
     bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size())
     return bboxes
diff --git a/mmdet/models/dense_heads/anchor_head.py b/mmdet/models/dense_heads/anchor_head.py
index 8ae50d38..87e60136 100644
--- a/mmdet/models/dense_heads/anchor_head.py
+++ b/mmdet/models/dense_heads/anchor_head.py
@@ -567,7 +567,11 @@ class AnchorHead(BaseDenseHead, BBoxTestMixin):
             bbox_pred_list = [
                 bbox_preds[i][img_id].detach() for i in range(num_levels)
             ]
-            img_shape = img_metas[img_id]['img_shape']
+            # get origin input shape to support onnx dynamic shape
+            if torch.onnx.is_in_onnx_export():
+                img_shape = img_metas[img_id]['img_shape_for_onnx']
+            else:
+                img_shape = img_metas[img_id]['img_shape']
             scale_factor = img_metas[img_id]['scale_factor']
             if with_nms:
                 # some heads don't support with_nms argument
diff --git a/mmdet/models/detectors/single_stage.py b/mmdet/models/detectors/single_stage.py
index ce6b26ea..c2bace62 100644
--- a/mmdet/models/detectors/single_stage.py
+++ b/mmdet/models/detectors/single_stage.py
@@ -111,6 +111,9 @@ class SingleStageDetector(BaseDetector):
         """
         x = self.extract_feat(img)
         outs = self.bbox_head(x)
+        # get origin input shape to support onnx dynamic shape
+        if torch.onnx.is_in_onnx_export():
+            img_metas[0]['img_shape_for_onnx'] = img.shape[2:]
         bbox_list = self.bbox_head.get_bboxes(
             *outs, img_metas, rescale=rescale)
         # skip post-processing when exporting to ONNX
-- 
GitLab