From 9946e121e7807e5927b3797bee360392761dbea0 Mon Sep 17 00:00:00 2001
From: RunningLeon <maningsheng@sensetime.com>
Date: Wed, 3 Mar 2021 17:23:02 +0800
Subject: [PATCH] [Feature]: Support ONNX inference with dynamic input shape in
 AnchorHead (#4684)

* make anchor_generator exportable to ONNX

* make k of topk dynamic for onnx

* rename nms_pre_t -> nms_pre_tensor
---
 mmdet/core/anchor/anchor_generator.py   |  9 ++++-----
 mmdet/models/dense_heads/anchor_head.py | 15 +++++++++++++--
 2 files changed, 17 insertions(+), 7 deletions(-)

diff --git a/mmdet/core/anchor/anchor_generator.py b/mmdet/core/anchor/anchor_generator.py
index 29b5ed04..388d2608 100644
--- a/mmdet/core/anchor/anchor_generator.py
+++ b/mmdet/core/anchor/anchor_generator.py
@@ -196,8 +196,9 @@ class AnchorGenerator(object):
         Returns:
             tuple[torch.Tensor]: The mesh grids of x and y.
         """
-        xx = x.repeat(len(y))
-        yy = y.view(-1, 1).repeat(1, len(x)).view(-1)
+        # use shape instead of len to keep tracing while exporting to onnx
+        xx = x.repeat(y.shape[0])
+        yy = y.view(-1, 1).repeat(1, x.shape[0]).view(-1)
         if row_major:
             return xx, yy
         else:
@@ -250,10 +251,8 @@ class AnchorGenerator(object):
         Returns:
             torch.Tensor: Anchors in the overall feature maps.
         """
+        # keep as Tensor, so that we can covert to ONNX correctly
         feat_h, feat_w = featmap_size
-        # convert Tensor to int, so that we can covert to ONNX correctlly
-        feat_h = int(feat_h)
-        feat_w = int(feat_w)
         shift_x = torch.arange(0, feat_w, device=device) * stride[0]
         shift_y = torch.arange(0, feat_h, device=device) * stride[1]
 
diff --git a/mmdet/models/dense_heads/anchor_head.py b/mmdet/models/dense_heads/anchor_head.py
index 970b63dc..8ae50d38 100644
--- a/mmdet/models/dense_heads/anchor_head.py
+++ b/mmdet/models/dense_heads/anchor_head.py
@@ -620,6 +620,11 @@ class AnchorHead(BaseDenseHead, BBoxTestMixin):
         """
         cfg = self.test_cfg if cfg is None else cfg
         assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors)
+        # convert to tensor to keep tracing
+        nms_pre_tensor = torch.tensor(
+            cfg.get('nms_pre', -1),
+            device=cls_score_list[0].device,
+            dtype=torch.long)
         mlvl_bboxes = []
         mlvl_scores = []
         for cls_score, bbox_pred, anchors in zip(cls_score_list,
@@ -632,8 +637,14 @@ class AnchorHead(BaseDenseHead, BBoxTestMixin):
             else:
                 scores = cls_score.softmax(-1)
             bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
-            nms_pre = cfg.get('nms_pre', -1)
-            if nms_pre > 0 and scores.shape[0] > nms_pre:
+            # Always keep topk op for dynamic input in onnx
+            if nms_pre_tensor > 0 and (torch.onnx.is_in_onnx_export()
+                                       or scores.shape[-2] > nms_pre_tensor):
+                from torch import _shape_as_tensor
+                # keep shape as tensor and get k
+                num_anchor = _shape_as_tensor(scores)[-2]
+                nms_pre = torch.where(nms_pre_tensor < num_anchor,
+                                      nms_pre_tensor, num_anchor)
                 # Get maximum scores for foreground classes.
                 if self.use_sigmoid_cls:
                     max_scores, _ = scores.max(dim=1)
-- 
GitLab