From 2f32a47a170396dc0090bf4cb228f358b69ed26e Mon Sep 17 00:00:00 2001
From: robin Han <drcut@users.noreply.github.com>
Date: Wed, 5 Aug 2020 00:03:07 +0800
Subject: [PATCH] Pytorch2onnx (#3075)

* Update pytorch2onnx.py which using new logic to convert pytorch to ONNX

* use standard API to check whether in ONNX convert process

* only compare the score value while verifying results between ONNX and pytorch

* move import onnx before import torch, or something weird will happen

* use real images for input

* modifying the way of calling nms

* modify docstring for bbox2result, and remove unnecessary part for onnx exporting

* modify the 'Convert to ONNX' part in docs

* replace or to | in docstring

* update according to the latest mmcv

* add normalize part

* raise error while using low version mmcv

* minor update

* minor update

Co-authored-by: Jiarui XU <xvjiarui0826@gmail.com>
---
 docs/getting_started.md                       |   6 +-
 mmdet/core/anchor/anchor_generator.py         |   4 +
 .../core/bbox/coder/delta_xywh_bbox_coder.py  |   6 +-
 mmdet/core/bbox/transforms.py                 |   9 +-
 mmdet/core/post_processing/bbox_nms.py        |  19 +-
 mmdet/models/detectors/single_stage.py        |   5 +
 setup.cfg                                     |   2 +-
 tools/pytorch2onnx.py                         | 212 +++++++++++-------
 8 files changed, 164 insertions(+), 99 deletions(-)

diff --git a/docs/getting_started.md b/docs/getting_started.md
index f9e77c5c..64e431c3 100644
--- a/docs/getting_started.md
+++ b/docs/getting_started.md
@@ -445,13 +445,13 @@ Please refer to [robustness_benchmarking.md](robustness_benchmarking.md).
 
 ### Convert to ONNX (experimental)
 
-We provide a script to convert model to [ONNX](https://github.com/onnx/onnx) format. The converted model could be visualized by tools like [Netron](https://github.com/lutzroeder/netron).
+We provide a script to convert model to [ONNX](https://github.com/onnx/onnx) format. The converted model could be visualized by tools like [Netron](https://github.com/lutzroeder/netron). Besides, we also support comparing the output results between Pytorch and ONNX model.
 
 ```shell
-python tools/pytorch2onnx.py ${CONFIG_FILE} ${CHECKPOINT_FILE} --out ${ONNX_FILE} [--shape ${INPUT_SHAPE}]
+python tools/pytorch2onnx.py ${CONFIG_FILE} ${CHECKPOINT_FILE} --output_file ${ONNX_FILE} [--shape ${INPUT_SHAPE} --verify]
 ```
 
-**Note**: This tool is still experimental. Customized operators are not supported for now. We set `use_torchvision=True` on-the-fly for `RoIPool` and `RoIAlign`.
+**Note**: This tool is still experimental. Some customized operators are not supported for now. We only support exporting RetinaNet model at this moment.
 
 ### Visualize the output results
 
diff --git a/mmdet/core/anchor/anchor_generator.py b/mmdet/core/anchor/anchor_generator.py
index 611c0eb5..f42a48cd 100644
--- a/mmdet/core/anchor/anchor_generator.py
+++ b/mmdet/core/anchor/anchor_generator.py
@@ -251,8 +251,12 @@ class AnchorGenerator(object):
             torch.Tensor: Anchors in the overall feature maps.
         """
         feat_h, feat_w = featmap_size
+        # convert Tensor to int, so that we can covert to ONNX correctlly
+        feat_h = int(feat_h)
+        feat_w = int(feat_w)
         shift_x = torch.arange(0, feat_w, device=device) * stride[0]
         shift_y = torch.arange(0, feat_h, device=device) * stride[1]
+
         shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
         shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1)
         shifts = shifts.type_as(base_anchors)
diff --git a/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py b/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py
index 9c206c10..82bf5947 100644
--- a/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py
+++ b/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py
@@ -161,8 +161,8 @@ def delta2bbox(rois,
                 [0.0000, 0.3161, 4.1945, 0.6839],
                 [5.0000, 5.0000, 5.0000, 5.0000]])
     """
-    means = deltas.new_tensor(means).repeat(1, deltas.size(1) // 4)
-    stds = deltas.new_tensor(stds).repeat(1, deltas.size(1) // 4)
+    means = deltas.new_tensor(means).view(1, -1).repeat(1, deltas.size(1) // 4)
+    stds = deltas.new_tensor(stds).view(1, -1).repeat(1, deltas.size(1) // 4)
     denorm_deltas = deltas * stds + means
     dx = denorm_deltas[:, 0::4]
     dy = denorm_deltas[:, 1::4]
@@ -193,5 +193,5 @@ def delta2bbox(rois,
         y1 = y1.clamp(min=0, max=max_shape[0])
         x2 = x2.clamp(min=0, max=max_shape[1])
         y2 = y2.clamp(min=0, max=max_shape[0])
-    bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view_as(deltas)
+    bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size())
     return bboxes
diff --git a/mmdet/core/bbox/transforms.py b/mmdet/core/bbox/transforms.py
index 40ed93e1..f9ee7aae 100644
--- a/mmdet/core/bbox/transforms.py
+++ b/mmdet/core/bbox/transforms.py
@@ -96,8 +96,8 @@ def bbox2result(bboxes, labels, num_classes):
     """Convert detection results to a list of numpy arrays.
 
     Args:
-        bboxes (Tensor): shape (n, 5)
-        labels (Tensor): shape (n, )
+        bboxes (torch.Tensor | np.ndarray): shape (n, 5)
+        labels (torch.Tensor | np.ndarray): shape (n, )
         num_classes (int): class number, including background class
 
     Returns:
@@ -106,8 +106,9 @@ def bbox2result(bboxes, labels, num_classes):
     if bboxes.shape[0] == 0:
         return [np.zeros((0, 5), dtype=np.float32) for i in range(num_classes)]
     else:
-        bboxes = bboxes.cpu().numpy()
-        labels = labels.cpu().numpy()
+        if isinstance(bboxes, torch.Tensor):
+            bboxes = bboxes.cpu().numpy()
+            labels = labels.cpu().numpy()
         return [bboxes[labels == i, :] for i in range(num_classes)]
 
 
diff --git a/mmdet/core/post_processing/bbox_nms.py b/mmdet/core/post_processing/bbox_nms.py
index 1c89bb43..1d87624c 100644
--- a/mmdet/core/post_processing/bbox_nms.py
+++ b/mmdet/core/post_processing/bbox_nms.py
@@ -31,20 +31,33 @@ def multiclass_nms(multi_bboxes,
     if multi_bboxes.shape[1] > 4:
         bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4)
     else:
-        bboxes = multi_bboxes[:, None].expand(-1, num_classes, 4)
+        bboxes = multi_bboxes[:, None].expand(
+            multi_scores.size(0), num_classes, 4)
     scores = multi_scores[:, :-1]
 
     # filter out boxes with low scores
     valid_mask = scores > score_thr
-    bboxes = bboxes[valid_mask]
+
+    # We use masked_select for ONNX exporting purpose,
+    # which is equivalent to bboxes = bboxes[valid_mask]
+    # (TODO): as ONNX does not support repeat now,
+    # we have to use this ugly code
+    bboxes = torch.masked_select(
+        bboxes,
+        torch.stack((valid_mask, valid_mask, valid_mask, valid_mask),
+                    -1)).view(-1, 4)
     if score_factors is not None:
         scores = scores * score_factors[:, None]
-    scores = scores[valid_mask]
+    scores = torch.masked_select(scores, valid_mask)
     labels = valid_mask.nonzero()[:, 1]
 
     if bboxes.numel() == 0:
         bboxes = multi_bboxes.new_zeros((0, 5))
         labels = multi_bboxes.new_zeros((0, ), dtype=torch.long)
+
+        if torch.onnx.is_in_onnx_export():
+            raise RuntimeError('[ONNX Error] Can not record NMS '
+                               'as it has not been executed this time')
         return bboxes, labels
 
     dets, keep = batched_nms(bboxes, scores, labels, nms_cfg)
diff --git a/mmdet/models/detectors/single_stage.py b/mmdet/models/detectors/single_stage.py
index bf163ad7..96225376 100644
--- a/mmdet/models/detectors/single_stage.py
+++ b/mmdet/models/detectors/single_stage.py
@@ -1,3 +1,4 @@
+import torch
 import torch.nn as nn
 
 from mmdet.core import bbox2result
@@ -109,6 +110,10 @@ class SingleStageDetector(BaseDetector):
         outs = self.bbox_head(x)
         bbox_list = self.bbox_head.get_bboxes(
             *outs, img_metas, rescale=rescale)
+        # skip post-processing when exporting to ONNX
+        if torch.onnx.is_in_onnx_export():
+            return bbox_list
+
         bbox_results = [
             bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)
             for det_bboxes, det_labels in bbox_list
diff --git a/setup.cfg b/setup.cfg
index 39934a82..6e8d3c2f 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -3,7 +3,7 @@ line_length = 79
 multi_line_output = 0
 known_standard_library = setuptools
 known_first_party = mmdet
-known_third_party = PIL,asynctest,cityscapesscripts,cv2,matplotlib,mmcv,numpy,onnx,pycocotools,pytest,robustness_eval,seaborn,six,terminaltables,torch,torchvision
+known_third_party = PIL,asynctest,cityscapesscripts,cv2,matplotlib,mmcv,numpy,onnx,onnxruntime,pycocotools,pytest,robustness_eval,seaborn,six,terminaltables,torch,torchvision
 no_lines_before = STDLIB,LOCALFOLDER
 default_section = THIRDPARTY
 
diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py
index 4a251be3..4386467d 100644
--- a/tools/pytorch2onnx.py
+++ b/tools/pytorch2onnx.py
@@ -1,124 +1,166 @@
 import argparse
-import io
+import os.path as osp
+from functools import partial
 
 import mmcv
+import numpy as np
 import onnx
+import onnxruntime as rt
 import torch
-from mmcv.ops import RoIAlign, RoIPool
 from mmcv.runner import load_checkpoint
-from onnx import optimizer
-from torch.onnx import OperatorExportTypes
 
 from mmdet.models import build_detector
 
+try:
+    from mmcv.onnx.symbolic import register_extra_symbolics
+except ModuleNotFoundError:
+    raise NotImplementedError('please update mmcv to version>=v1.0.4')
 
-def export_onnx_model(model, inputs, passes):
-    """Trace and export a model to onnx format. Modified from
-    https://github.com/facebookresearch/detectron2/
-
-    Args:
-        model (nn.Module):
-        inputs (tuple[args]): the model will be called by `model(*inputs)`
-        passes (None or list[str]): the optimization passed for ONNX model
-
-    Returns:
-        an onnx model
-    """
-    assert isinstance(model, torch.nn.Module)
-
-    # make sure all modules are in eval mode, onnx may change the training
-    # state of the module if the states are not consistent
-    def _check_eval(module):
-        assert not module.training
-
-    model.apply(_check_eval)
-
-    # Export the model to ONNX
-    with torch.no_grad():
-        with io.BytesIO() as f:
-            torch.onnx.export(
-                model,
-                inputs,
-                f,
-                operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
-                # verbose=True,  # NOTE: uncomment this for debugging
-                # export_params=True,
-            )
-            onnx_model = onnx.load_from_string(f.getvalue())
-
-    # Apply ONNX's Optimization
-    if passes is not None:
-        all_passes = optimizer.get_available_passes()
-        assert all(p in all_passes for p in passes), \
-            f'Only {all_passes} are supported'
-    onnx_model = optimizer.optimize(onnx_model, passes)
-    return onnx_model
+
+def pytorch2onnx(model,
+                 input_img,
+                 input_shape,
+                 opset_version=11,
+                 show=False,
+                 output_file='tmp.onnx',
+                 verify=False,
+                 normalize_cfg=None):
+    model.cpu().eval()
+    # read image
+    one_img = mmcv.imread(input_img)
+    if normalize_cfg:
+        one_img = mmcv.imnormalize(one_img, normalize_cfg['mean'],
+                                   normalize_cfg['std'])
+    one_img = mmcv.imresize(one_img, input_shape[2:]).transpose(2, 0, 1)
+    one_img = torch.from_numpy(one_img).unsqueeze(0).float()
+    (_, C, H, W) = input_shape
+    one_meta = {
+        'img_shape': (H, W, C),
+        'ori_shape': (H, W, C),
+        'pad_shape': (H, W, C),
+        'filename': '<demo>.png',
+        'scale_factor': 1.0,
+        'flip': False
+    }
+    # onnx.export does not support kwargs
+    origin_forward = model.forward
+    model.forward = partial(
+        model.forward, img_metas=[[one_meta]], return_loss=False)
+    # pytorch has some bug in pytorch1.3, we have to fix it
+    # by replacing these existing op
+    register_extra_symbolics(opset_version)
+    torch.onnx.export(
+        model, ([one_img]),
+        output_file,
+        export_params=True,
+        keep_initializers_as_inputs=True,
+        verbose=show,
+        opset_version=opset_version)
+    model.forward = origin_forward
+    print(f'Successfully exported ONNX model: {output_file}')
+    if verify:
+        # check by onnx
+        onnx_model = onnx.load(output_file)
+        onnx.checker.check_model(onnx_model)
+
+        # check the numerical value
+        # get pytorch output
+        pytorch_result = model([one_img], [[one_meta]], return_loss=False)
+
+        # get onnx output
+        input_all = [node.name for node in onnx_model.graph.input]
+        input_initializer = [
+            node.name for node in onnx_model.graph.initializer
+        ]
+        net_feed_input = list(set(input_all) - set(input_initializer))
+        assert (len(net_feed_input) == 1)
+        sess = rt.InferenceSession(output_file)
+        from mmdet.core import bbox2result
+        det_bboxes, det_labels = sess.run(
+            None, {net_feed_input[0]: one_img.detach().numpy()})
+        # only compare a part of result
+        bbox_results = bbox2result(det_bboxes, det_labels, 1)
+        onnx_results = bbox_results[0]
+        assert np.allclose(
+            pytorch_result[0][:, 4], onnx_results[:, 4]
+        ), 'The outputs are different between Pytorch and ONNX'
+        print('The numerical values are same between Pytorch and ONNX')
 
 
 def parse_args():
     parser = argparse.ArgumentParser(
-        description='MMDet pytorch model conversion to ONNX')
+        description='Convert MMDetection models to ONNX')
     parser.add_argument('config', help='test config file path')
     parser.add_argument('checkpoint', help='checkpoint file')
+    parser.add_argument('--input-img', type=str, help='Images for input')
+    parser.add_argument('--show', action='store_true', help='show onnx graph')
+    parser.add_argument('--output-file', type=str, default='tmp.onnx')
+    parser.add_argument('--opset-version', type=int, default=11)
     parser.add_argument(
-        '--out', type=str, required=True, help='output ONNX filename')
+        '--verify',
+        action='store_true',
+        help='verify the onnx model output against pytorch output')
     parser.add_argument(
         '--shape',
         type=int,
         nargs='+',
-        default=[1280, 800],
+        default=[800, 1216],
         help='input image size')
     parser.add_argument(
-        '--passes', type=str, nargs='+', help='ONNX optimization passes')
+        '--mean',
+        type=int,
+        nargs='+',
+        default=[0, 0, 0],
+        help='mean value used for preprocess input data')
+    parser.add_argument(
+        '--std',
+        type=int,
+        nargs='+',
+        default=[1, 1, 1],
+        help='variance value used for preprocess input data')
     args = parser.parse_args()
     return args
 
 
-def main():
+if __name__ == '__main__':
     args = parse_args()
 
-    if not args.out.endswith('.onnx'):
-        raise ValueError('The output file must be a onnx file.')
+    assert args.opset_version == 11, 'MMDet only support opset 11 now'
+
+    if not args.input_img:
+        args.input_img = osp.join(
+            osp.dirname(__file__), '../tests/data/color.jpg')
 
     if len(args.shape) == 1:
-        input_shape = (3, args.shape[0], args.shape[0])
+        input_shape = (1, 3, args.shape[0], args.shape[0])
     elif len(args.shape) == 2:
-        input_shape = (3, ) + tuple(args.shape)
+        input_shape = (1, 3) + tuple(args.shape)
     else:
         raise ValueError('invalid input shape')
 
+    assert len(args.mean) == 3
+    assert len(args.std) == 3
+
+    normalize_cfg = {
+        'mean': np.array(args.mean, dtype=np.float32),
+        'std': np.array(args.std, dtype=np.float32)
+    }
+
     cfg = mmcv.Config.fromfile(args.config)
     cfg.model.pretrained = None
+    cfg.data.test.test_mode = True
 
-    # build the model and load checkpoint
+    # build the model
     model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
-    load_checkpoint(model, args.checkpoint, map_location='cpu')
-    # Only support CPU mode for now
-    model.cpu().eval()
-    # Customized ops are not supported, use torchvision ops instead.
-    for m in model.modules():
-        if isinstance(m, (RoIPool, RoIAlign)):
-            # set use_torchvision on-the-fly
-            m.use_torchvision = True
-
-    # TODO: a better way to override forward function
-    if hasattr(model, 'forward_dummy'):
-        model.forward = model.forward_dummy
-    else:
-        raise NotImplementedError(
-            'ONNX conversion is currently not currently supported with '
-            f'{model.__class__.__name__}')
-
-    input_data = torch.empty((1, *input_shape),
-                             dtype=next(model.parameters()).dtype,
-                             device=next(model.parameters()).device)
-
-    onnx_model = export_onnx_model(model, (input_data, ), args.passes)
-    # Print a human readable representation of the graph
-    onnx.helper.printable_graph(onnx_model.graph)
-    print(f'saving model in {args.out}')
-    onnx.save(onnx_model, args.out)
-
-
-if __name__ == '__main__':
-    main()
+    checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
+
+    # conver model to onnx file
+    pytorch2onnx(
+        model,
+        args.input_img,
+        input_shape,
+        opset_version=args.opset_version,
+        show=args.show,
+        output_file=args.output_file,
+        verify=args.verify,
+        normalize_cfg=normalize_cfg)
-- 
GitLab