From 2f32a47a170396dc0090bf4cb228f358b69ed26e Mon Sep 17 00:00:00 2001 From: robin Han <drcut@users.noreply.github.com> Date: Wed, 5 Aug 2020 00:03:07 +0800 Subject: [PATCH] Pytorch2onnx (#3075) * Update pytorch2onnx.py which using new logic to convert pytorch to ONNX * use standard API to check whether in ONNX convert process * only compare the score value while verifying results between ONNX and pytorch * move import onnx before import torch, or something weird will happen * use real images for input * modifying the way of calling nms * modify docstring for bbox2result, and remove unnecessary part for onnx exporting * modify the 'Convert to ONNX' part in docs * replace or to | in docstring * update according to the latest mmcv * add normalize part * raise error while using low version mmcv * minor update * minor update Co-authored-by: Jiarui XU <xvjiarui0826@gmail.com> --- docs/getting_started.md | 6 +- mmdet/core/anchor/anchor_generator.py | 4 + .../core/bbox/coder/delta_xywh_bbox_coder.py | 6 +- mmdet/core/bbox/transforms.py | 9 +- mmdet/core/post_processing/bbox_nms.py | 19 +- mmdet/models/detectors/single_stage.py | 5 + setup.cfg | 2 +- tools/pytorch2onnx.py | 212 +++++++++++------- 8 files changed, 164 insertions(+), 99 deletions(-) diff --git a/docs/getting_started.md b/docs/getting_started.md index f9e77c5c..64e431c3 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -445,13 +445,13 @@ Please refer to [robustness_benchmarking.md](robustness_benchmarking.md). ### Convert to ONNX (experimental) -We provide a script to convert model to [ONNX](https://github.com/onnx/onnx) format. The converted model could be visualized by tools like [Netron](https://github.com/lutzroeder/netron). +We provide a script to convert model to [ONNX](https://github.com/onnx/onnx) format. The converted model could be visualized by tools like [Netron](https://github.com/lutzroeder/netron). Besides, we also support comparing the output results between Pytorch and ONNX model. ```shell -python tools/pytorch2onnx.py ${CONFIG_FILE} ${CHECKPOINT_FILE} --out ${ONNX_FILE} [--shape ${INPUT_SHAPE}] +python tools/pytorch2onnx.py ${CONFIG_FILE} ${CHECKPOINT_FILE} --output_file ${ONNX_FILE} [--shape ${INPUT_SHAPE} --verify] ``` -**Note**: This tool is still experimental. Customized operators are not supported for now. We set `use_torchvision=True` on-the-fly for `RoIPool` and `RoIAlign`. +**Note**: This tool is still experimental. Some customized operators are not supported for now. We only support exporting RetinaNet model at this moment. ### Visualize the output results diff --git a/mmdet/core/anchor/anchor_generator.py b/mmdet/core/anchor/anchor_generator.py index 611c0eb5..f42a48cd 100644 --- a/mmdet/core/anchor/anchor_generator.py +++ b/mmdet/core/anchor/anchor_generator.py @@ -251,8 +251,12 @@ class AnchorGenerator(object): torch.Tensor: Anchors in the overall feature maps. """ feat_h, feat_w = featmap_size + # convert Tensor to int, so that we can covert to ONNX correctlly + feat_h = int(feat_h) + feat_w = int(feat_w) shift_x = torch.arange(0, feat_w, device=device) * stride[0] shift_y = torch.arange(0, feat_h, device=device) * stride[1] + shift_xx, shift_yy = self._meshgrid(shift_x, shift_y) shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1) shifts = shifts.type_as(base_anchors) diff --git a/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py b/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py index 9c206c10..82bf5947 100644 --- a/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py +++ b/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py @@ -161,8 +161,8 @@ def delta2bbox(rois, [0.0000, 0.3161, 4.1945, 0.6839], [5.0000, 5.0000, 5.0000, 5.0000]]) """ - means = deltas.new_tensor(means).repeat(1, deltas.size(1) // 4) - stds = deltas.new_tensor(stds).repeat(1, deltas.size(1) // 4) + means = deltas.new_tensor(means).view(1, -1).repeat(1, deltas.size(1) // 4) + stds = deltas.new_tensor(stds).view(1, -1).repeat(1, deltas.size(1) // 4) denorm_deltas = deltas * stds + means dx = denorm_deltas[:, 0::4] dy = denorm_deltas[:, 1::4] @@ -193,5 +193,5 @@ def delta2bbox(rois, y1 = y1.clamp(min=0, max=max_shape[0]) x2 = x2.clamp(min=0, max=max_shape[1]) y2 = y2.clamp(min=0, max=max_shape[0]) - bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view_as(deltas) + bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size()) return bboxes diff --git a/mmdet/core/bbox/transforms.py b/mmdet/core/bbox/transforms.py index 40ed93e1..f9ee7aae 100644 --- a/mmdet/core/bbox/transforms.py +++ b/mmdet/core/bbox/transforms.py @@ -96,8 +96,8 @@ def bbox2result(bboxes, labels, num_classes): """Convert detection results to a list of numpy arrays. Args: - bboxes (Tensor): shape (n, 5) - labels (Tensor): shape (n, ) + bboxes (torch.Tensor | np.ndarray): shape (n, 5) + labels (torch.Tensor | np.ndarray): shape (n, ) num_classes (int): class number, including background class Returns: @@ -106,8 +106,9 @@ def bbox2result(bboxes, labels, num_classes): if bboxes.shape[0] == 0: return [np.zeros((0, 5), dtype=np.float32) for i in range(num_classes)] else: - bboxes = bboxes.cpu().numpy() - labels = labels.cpu().numpy() + if isinstance(bboxes, torch.Tensor): + bboxes = bboxes.cpu().numpy() + labels = labels.cpu().numpy() return [bboxes[labels == i, :] for i in range(num_classes)] diff --git a/mmdet/core/post_processing/bbox_nms.py b/mmdet/core/post_processing/bbox_nms.py index 1c89bb43..1d87624c 100644 --- a/mmdet/core/post_processing/bbox_nms.py +++ b/mmdet/core/post_processing/bbox_nms.py @@ -31,20 +31,33 @@ def multiclass_nms(multi_bboxes, if multi_bboxes.shape[1] > 4: bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4) else: - bboxes = multi_bboxes[:, None].expand(-1, num_classes, 4) + bboxes = multi_bboxes[:, None].expand( + multi_scores.size(0), num_classes, 4) scores = multi_scores[:, :-1] # filter out boxes with low scores valid_mask = scores > score_thr - bboxes = bboxes[valid_mask] + + # We use masked_select for ONNX exporting purpose, + # which is equivalent to bboxes = bboxes[valid_mask] + # (TODO): as ONNX does not support repeat now, + # we have to use this ugly code + bboxes = torch.masked_select( + bboxes, + torch.stack((valid_mask, valid_mask, valid_mask, valid_mask), + -1)).view(-1, 4) if score_factors is not None: scores = scores * score_factors[:, None] - scores = scores[valid_mask] + scores = torch.masked_select(scores, valid_mask) labels = valid_mask.nonzero()[:, 1] if bboxes.numel() == 0: bboxes = multi_bboxes.new_zeros((0, 5)) labels = multi_bboxes.new_zeros((0, ), dtype=torch.long) + + if torch.onnx.is_in_onnx_export(): + raise RuntimeError('[ONNX Error] Can not record NMS ' + 'as it has not been executed this time') return bboxes, labels dets, keep = batched_nms(bboxes, scores, labels, nms_cfg) diff --git a/mmdet/models/detectors/single_stage.py b/mmdet/models/detectors/single_stage.py index bf163ad7..96225376 100644 --- a/mmdet/models/detectors/single_stage.py +++ b/mmdet/models/detectors/single_stage.py @@ -1,3 +1,4 @@ +import torch import torch.nn as nn from mmdet.core import bbox2result @@ -109,6 +110,10 @@ class SingleStageDetector(BaseDetector): outs = self.bbox_head(x) bbox_list = self.bbox_head.get_bboxes( *outs, img_metas, rescale=rescale) + # skip post-processing when exporting to ONNX + if torch.onnx.is_in_onnx_export(): + return bbox_list + bbox_results = [ bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) for det_bboxes, det_labels in bbox_list diff --git a/setup.cfg b/setup.cfg index 39934a82..6e8d3c2f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,7 +3,7 @@ line_length = 79 multi_line_output = 0 known_standard_library = setuptools known_first_party = mmdet -known_third_party = PIL,asynctest,cityscapesscripts,cv2,matplotlib,mmcv,numpy,onnx,pycocotools,pytest,robustness_eval,seaborn,six,terminaltables,torch,torchvision +known_third_party = PIL,asynctest,cityscapesscripts,cv2,matplotlib,mmcv,numpy,onnx,onnxruntime,pycocotools,pytest,robustness_eval,seaborn,six,terminaltables,torch,torchvision no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py index 4a251be3..4386467d 100644 --- a/tools/pytorch2onnx.py +++ b/tools/pytorch2onnx.py @@ -1,124 +1,166 @@ import argparse -import io +import os.path as osp +from functools import partial import mmcv +import numpy as np import onnx +import onnxruntime as rt import torch -from mmcv.ops import RoIAlign, RoIPool from mmcv.runner import load_checkpoint -from onnx import optimizer -from torch.onnx import OperatorExportTypes from mmdet.models import build_detector +try: + from mmcv.onnx.symbolic import register_extra_symbolics +except ModuleNotFoundError: + raise NotImplementedError('please update mmcv to version>=v1.0.4') -def export_onnx_model(model, inputs, passes): - """Trace and export a model to onnx format. Modified from - https://github.com/facebookresearch/detectron2/ - - Args: - model (nn.Module): - inputs (tuple[args]): the model will be called by `model(*inputs)` - passes (None or list[str]): the optimization passed for ONNX model - - Returns: - an onnx model - """ - assert isinstance(model, torch.nn.Module) - - # make sure all modules are in eval mode, onnx may change the training - # state of the module if the states are not consistent - def _check_eval(module): - assert not module.training - - model.apply(_check_eval) - - # Export the model to ONNX - with torch.no_grad(): - with io.BytesIO() as f: - torch.onnx.export( - model, - inputs, - f, - operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK, - # verbose=True, # NOTE: uncomment this for debugging - # export_params=True, - ) - onnx_model = onnx.load_from_string(f.getvalue()) - - # Apply ONNX's Optimization - if passes is not None: - all_passes = optimizer.get_available_passes() - assert all(p in all_passes for p in passes), \ - f'Only {all_passes} are supported' - onnx_model = optimizer.optimize(onnx_model, passes) - return onnx_model + +def pytorch2onnx(model, + input_img, + input_shape, + opset_version=11, + show=False, + output_file='tmp.onnx', + verify=False, + normalize_cfg=None): + model.cpu().eval() + # read image + one_img = mmcv.imread(input_img) + if normalize_cfg: + one_img = mmcv.imnormalize(one_img, normalize_cfg['mean'], + normalize_cfg['std']) + one_img = mmcv.imresize(one_img, input_shape[2:]).transpose(2, 0, 1) + one_img = torch.from_numpy(one_img).unsqueeze(0).float() + (_, C, H, W) = input_shape + one_meta = { + 'img_shape': (H, W, C), + 'ori_shape': (H, W, C), + 'pad_shape': (H, W, C), + 'filename': '<demo>.png', + 'scale_factor': 1.0, + 'flip': False + } + # onnx.export does not support kwargs + origin_forward = model.forward + model.forward = partial( + model.forward, img_metas=[[one_meta]], return_loss=False) + # pytorch has some bug in pytorch1.3, we have to fix it + # by replacing these existing op + register_extra_symbolics(opset_version) + torch.onnx.export( + model, ([one_img]), + output_file, + export_params=True, + keep_initializers_as_inputs=True, + verbose=show, + opset_version=opset_version) + model.forward = origin_forward + print(f'Successfully exported ONNX model: {output_file}') + if verify: + # check by onnx + onnx_model = onnx.load(output_file) + onnx.checker.check_model(onnx_model) + + # check the numerical value + # get pytorch output + pytorch_result = model([one_img], [[one_meta]], return_loss=False) + + # get onnx output + input_all = [node.name for node in onnx_model.graph.input] + input_initializer = [ + node.name for node in onnx_model.graph.initializer + ] + net_feed_input = list(set(input_all) - set(input_initializer)) + assert (len(net_feed_input) == 1) + sess = rt.InferenceSession(output_file) + from mmdet.core import bbox2result + det_bboxes, det_labels = sess.run( + None, {net_feed_input[0]: one_img.detach().numpy()}) + # only compare a part of result + bbox_results = bbox2result(det_bboxes, det_labels, 1) + onnx_results = bbox_results[0] + assert np.allclose( + pytorch_result[0][:, 4], onnx_results[:, 4] + ), 'The outputs are different between Pytorch and ONNX' + print('The numerical values are same between Pytorch and ONNX') def parse_args(): parser = argparse.ArgumentParser( - description='MMDet pytorch model conversion to ONNX') + description='Convert MMDetection models to ONNX') parser.add_argument('config', help='test config file path') parser.add_argument('checkpoint', help='checkpoint file') + parser.add_argument('--input-img', type=str, help='Images for input') + parser.add_argument('--show', action='store_true', help='show onnx graph') + parser.add_argument('--output-file', type=str, default='tmp.onnx') + parser.add_argument('--opset-version', type=int, default=11) parser.add_argument( - '--out', type=str, required=True, help='output ONNX filename') + '--verify', + action='store_true', + help='verify the onnx model output against pytorch output') parser.add_argument( '--shape', type=int, nargs='+', - default=[1280, 800], + default=[800, 1216], help='input image size') parser.add_argument( - '--passes', type=str, nargs='+', help='ONNX optimization passes') + '--mean', + type=int, + nargs='+', + default=[0, 0, 0], + help='mean value used for preprocess input data') + parser.add_argument( + '--std', + type=int, + nargs='+', + default=[1, 1, 1], + help='variance value used for preprocess input data') args = parser.parse_args() return args -def main(): +if __name__ == '__main__': args = parse_args() - if not args.out.endswith('.onnx'): - raise ValueError('The output file must be a onnx file.') + assert args.opset_version == 11, 'MMDet only support opset 11 now' + + if not args.input_img: + args.input_img = osp.join( + osp.dirname(__file__), '../tests/data/color.jpg') if len(args.shape) == 1: - input_shape = (3, args.shape[0], args.shape[0]) + input_shape = (1, 3, args.shape[0], args.shape[0]) elif len(args.shape) == 2: - input_shape = (3, ) + tuple(args.shape) + input_shape = (1, 3) + tuple(args.shape) else: raise ValueError('invalid input shape') + assert len(args.mean) == 3 + assert len(args.std) == 3 + + normalize_cfg = { + 'mean': np.array(args.mean, dtype=np.float32), + 'std': np.array(args.std, dtype=np.float32) + } + cfg = mmcv.Config.fromfile(args.config) cfg.model.pretrained = None + cfg.data.test.test_mode = True - # build the model and load checkpoint + # build the model model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg) - load_checkpoint(model, args.checkpoint, map_location='cpu') - # Only support CPU mode for now - model.cpu().eval() - # Customized ops are not supported, use torchvision ops instead. - for m in model.modules(): - if isinstance(m, (RoIPool, RoIAlign)): - # set use_torchvision on-the-fly - m.use_torchvision = True - - # TODO: a better way to override forward function - if hasattr(model, 'forward_dummy'): - model.forward = model.forward_dummy - else: - raise NotImplementedError( - 'ONNX conversion is currently not currently supported with ' - f'{model.__class__.__name__}') - - input_data = torch.empty((1, *input_shape), - dtype=next(model.parameters()).dtype, - device=next(model.parameters()).device) - - onnx_model = export_onnx_model(model, (input_data, ), args.passes) - # Print a human readable representation of the graph - onnx.helper.printable_graph(onnx_model.graph) - print(f'saving model in {args.out}') - onnx.save(onnx_model, args.out) - - -if __name__ == '__main__': - main() + checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu') + + # conver model to onnx file + pytorch2onnx( + model, + args.input_img, + input_shape, + opset_version=args.opset_version, + show=args.show, + output_file=args.output_file, + verify=args.verify, + normalize_cfg=normalize_cfg) -- GitLab