Skip to content
Snippets Groups Projects
Unverified Commit 2f32a47a authored by robin Han's avatar robin Han Committed by GitHub
Browse files

Pytorch2onnx (#3075)


* Update pytorch2onnx.py which using new logic to convert pytorch to ONNX

* use standard API to check whether in ONNX convert process

* only compare the score value while verifying results between ONNX and pytorch

* move import onnx before import torch, or something weird will happen

* use real images for input

* modifying the way of calling nms

* modify docstring for bbox2result, and remove unnecessary part for onnx exporting

* modify the 'Convert to ONNX' part in docs

* replace or to | in docstring

* update according to the latest mmcv

* add normalize part

* raise error while using low version mmcv

* minor update

* minor update

Co-authored-by: default avatarJiarui XU <xvjiarui0826@gmail.com>
parent 9d3f9b03
No related branches found
No related tags found
No related merge requests found
...@@ -445,13 +445,13 @@ Please refer to [robustness_benchmarking.md](robustness_benchmarking.md). ...@@ -445,13 +445,13 @@ Please refer to [robustness_benchmarking.md](robustness_benchmarking.md).
### Convert to ONNX (experimental) ### Convert to ONNX (experimental)
We provide a script to convert model to [ONNX](https://github.com/onnx/onnx) format. The converted model could be visualized by tools like [Netron](https://github.com/lutzroeder/netron). We provide a script to convert model to [ONNX](https://github.com/onnx/onnx) format. The converted model could be visualized by tools like [Netron](https://github.com/lutzroeder/netron). Besides, we also support comparing the output results between Pytorch and ONNX model.
```shell ```shell
python tools/pytorch2onnx.py ${CONFIG_FILE} ${CHECKPOINT_FILE} --out ${ONNX_FILE} [--shape ${INPUT_SHAPE}] python tools/pytorch2onnx.py ${CONFIG_FILE} ${CHECKPOINT_FILE} --output_file ${ONNX_FILE} [--shape ${INPUT_SHAPE} --verify]
``` ```
**Note**: This tool is still experimental. Customized operators are not supported for now. We set `use_torchvision=True` on-the-fly for `RoIPool` and `RoIAlign`. **Note**: This tool is still experimental. Some customized operators are not supported for now. We only support exporting RetinaNet model at this moment.
### Visualize the output results ### Visualize the output results
......
...@@ -251,8 +251,12 @@ class AnchorGenerator(object): ...@@ -251,8 +251,12 @@ class AnchorGenerator(object):
torch.Tensor: Anchors in the overall feature maps. torch.Tensor: Anchors in the overall feature maps.
""" """
feat_h, feat_w = featmap_size feat_h, feat_w = featmap_size
# convert Tensor to int, so that we can covert to ONNX correctlly
feat_h = int(feat_h)
feat_w = int(feat_w)
shift_x = torch.arange(0, feat_w, device=device) * stride[0] shift_x = torch.arange(0, feat_w, device=device) * stride[0]
shift_y = torch.arange(0, feat_h, device=device) * stride[1] shift_y = torch.arange(0, feat_h, device=device) * stride[1]
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y) shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1) shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1)
shifts = shifts.type_as(base_anchors) shifts = shifts.type_as(base_anchors)
......
...@@ -161,8 +161,8 @@ def delta2bbox(rois, ...@@ -161,8 +161,8 @@ def delta2bbox(rois,
[0.0000, 0.3161, 4.1945, 0.6839], [0.0000, 0.3161, 4.1945, 0.6839],
[5.0000, 5.0000, 5.0000, 5.0000]]) [5.0000, 5.0000, 5.0000, 5.0000]])
""" """
means = deltas.new_tensor(means).repeat(1, deltas.size(1) // 4) means = deltas.new_tensor(means).view(1, -1).repeat(1, deltas.size(1) // 4)
stds = deltas.new_tensor(stds).repeat(1, deltas.size(1) // 4) stds = deltas.new_tensor(stds).view(1, -1).repeat(1, deltas.size(1) // 4)
denorm_deltas = deltas * stds + means denorm_deltas = deltas * stds + means
dx = denorm_deltas[:, 0::4] dx = denorm_deltas[:, 0::4]
dy = denorm_deltas[:, 1::4] dy = denorm_deltas[:, 1::4]
...@@ -193,5 +193,5 @@ def delta2bbox(rois, ...@@ -193,5 +193,5 @@ def delta2bbox(rois,
y1 = y1.clamp(min=0, max=max_shape[0]) y1 = y1.clamp(min=0, max=max_shape[0])
x2 = x2.clamp(min=0, max=max_shape[1]) x2 = x2.clamp(min=0, max=max_shape[1])
y2 = y2.clamp(min=0, max=max_shape[0]) y2 = y2.clamp(min=0, max=max_shape[0])
bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view_as(deltas) bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size())
return bboxes return bboxes
...@@ -96,8 +96,8 @@ def bbox2result(bboxes, labels, num_classes): ...@@ -96,8 +96,8 @@ def bbox2result(bboxes, labels, num_classes):
"""Convert detection results to a list of numpy arrays. """Convert detection results to a list of numpy arrays.
Args: Args:
bboxes (Tensor): shape (n, 5) bboxes (torch.Tensor | np.ndarray): shape (n, 5)
labels (Tensor): shape (n, ) labels (torch.Tensor | np.ndarray): shape (n, )
num_classes (int): class number, including background class num_classes (int): class number, including background class
Returns: Returns:
...@@ -106,8 +106,9 @@ def bbox2result(bboxes, labels, num_classes): ...@@ -106,8 +106,9 @@ def bbox2result(bboxes, labels, num_classes):
if bboxes.shape[0] == 0: if bboxes.shape[0] == 0:
return [np.zeros((0, 5), dtype=np.float32) for i in range(num_classes)] return [np.zeros((0, 5), dtype=np.float32) for i in range(num_classes)]
else: else:
bboxes = bboxes.cpu().numpy() if isinstance(bboxes, torch.Tensor):
labels = labels.cpu().numpy() bboxes = bboxes.cpu().numpy()
labels = labels.cpu().numpy()
return [bboxes[labels == i, :] for i in range(num_classes)] return [bboxes[labels == i, :] for i in range(num_classes)]
......
...@@ -31,20 +31,33 @@ def multiclass_nms(multi_bboxes, ...@@ -31,20 +31,33 @@ def multiclass_nms(multi_bboxes,
if multi_bboxes.shape[1] > 4: if multi_bboxes.shape[1] > 4:
bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4) bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4)
else: else:
bboxes = multi_bboxes[:, None].expand(-1, num_classes, 4) bboxes = multi_bboxes[:, None].expand(
multi_scores.size(0), num_classes, 4)
scores = multi_scores[:, :-1] scores = multi_scores[:, :-1]
# filter out boxes with low scores # filter out boxes with low scores
valid_mask = scores > score_thr valid_mask = scores > score_thr
bboxes = bboxes[valid_mask]
# We use masked_select for ONNX exporting purpose,
# which is equivalent to bboxes = bboxes[valid_mask]
# (TODO): as ONNX does not support repeat now,
# we have to use this ugly code
bboxes = torch.masked_select(
bboxes,
torch.stack((valid_mask, valid_mask, valid_mask, valid_mask),
-1)).view(-1, 4)
if score_factors is not None: if score_factors is not None:
scores = scores * score_factors[:, None] scores = scores * score_factors[:, None]
scores = scores[valid_mask] scores = torch.masked_select(scores, valid_mask)
labels = valid_mask.nonzero()[:, 1] labels = valid_mask.nonzero()[:, 1]
if bboxes.numel() == 0: if bboxes.numel() == 0:
bboxes = multi_bboxes.new_zeros((0, 5)) bboxes = multi_bboxes.new_zeros((0, 5))
labels = multi_bboxes.new_zeros((0, ), dtype=torch.long) labels = multi_bboxes.new_zeros((0, ), dtype=torch.long)
if torch.onnx.is_in_onnx_export():
raise RuntimeError('[ONNX Error] Can not record NMS '
'as it has not been executed this time')
return bboxes, labels return bboxes, labels
dets, keep = batched_nms(bboxes, scores, labels, nms_cfg) dets, keep = batched_nms(bboxes, scores, labels, nms_cfg)
......
import torch
import torch.nn as nn import torch.nn as nn
from mmdet.core import bbox2result from mmdet.core import bbox2result
...@@ -109,6 +110,10 @@ class SingleStageDetector(BaseDetector): ...@@ -109,6 +110,10 @@ class SingleStageDetector(BaseDetector):
outs = self.bbox_head(x) outs = self.bbox_head(x)
bbox_list = self.bbox_head.get_bboxes( bbox_list = self.bbox_head.get_bboxes(
*outs, img_metas, rescale=rescale) *outs, img_metas, rescale=rescale)
# skip post-processing when exporting to ONNX
if torch.onnx.is_in_onnx_export():
return bbox_list
bbox_results = [ bbox_results = [
bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)
for det_bboxes, det_labels in bbox_list for det_bboxes, det_labels in bbox_list
......
...@@ -3,7 +3,7 @@ line_length = 79 ...@@ -3,7 +3,7 @@ line_length = 79
multi_line_output = 0 multi_line_output = 0
known_standard_library = setuptools known_standard_library = setuptools
known_first_party = mmdet known_first_party = mmdet
known_third_party = PIL,asynctest,cityscapesscripts,cv2,matplotlib,mmcv,numpy,onnx,pycocotools,pytest,robustness_eval,seaborn,six,terminaltables,torch,torchvision known_third_party = PIL,asynctest,cityscapesscripts,cv2,matplotlib,mmcv,numpy,onnx,onnxruntime,pycocotools,pytest,robustness_eval,seaborn,six,terminaltables,torch,torchvision
no_lines_before = STDLIB,LOCALFOLDER no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY default_section = THIRDPARTY
......
import argparse import argparse
import io import os.path as osp
from functools import partial
import mmcv import mmcv
import numpy as np
import onnx import onnx
import onnxruntime as rt
import torch import torch
from mmcv.ops import RoIAlign, RoIPool
from mmcv.runner import load_checkpoint from mmcv.runner import load_checkpoint
from onnx import optimizer
from torch.onnx import OperatorExportTypes
from mmdet.models import build_detector from mmdet.models import build_detector
try:
from mmcv.onnx.symbolic import register_extra_symbolics
except ModuleNotFoundError:
raise NotImplementedError('please update mmcv to version>=v1.0.4')
def export_onnx_model(model, inputs, passes):
"""Trace and export a model to onnx format. Modified from def pytorch2onnx(model,
https://github.com/facebookresearch/detectron2/ input_img,
input_shape,
Args: opset_version=11,
model (nn.Module): show=False,
inputs (tuple[args]): the model will be called by `model(*inputs)` output_file='tmp.onnx',
passes (None or list[str]): the optimization passed for ONNX model verify=False,
normalize_cfg=None):
Returns: model.cpu().eval()
an onnx model # read image
""" one_img = mmcv.imread(input_img)
assert isinstance(model, torch.nn.Module) if normalize_cfg:
one_img = mmcv.imnormalize(one_img, normalize_cfg['mean'],
# make sure all modules are in eval mode, onnx may change the training normalize_cfg['std'])
# state of the module if the states are not consistent one_img = mmcv.imresize(one_img, input_shape[2:]).transpose(2, 0, 1)
def _check_eval(module): one_img = torch.from_numpy(one_img).unsqueeze(0).float()
assert not module.training (_, C, H, W) = input_shape
one_meta = {
model.apply(_check_eval) 'img_shape': (H, W, C),
'ori_shape': (H, W, C),
# Export the model to ONNX 'pad_shape': (H, W, C),
with torch.no_grad(): 'filename': '<demo>.png',
with io.BytesIO() as f: 'scale_factor': 1.0,
torch.onnx.export( 'flip': False
model, }
inputs, # onnx.export does not support kwargs
f, origin_forward = model.forward
operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK, model.forward = partial(
# verbose=True, # NOTE: uncomment this for debugging model.forward, img_metas=[[one_meta]], return_loss=False)
# export_params=True, # pytorch has some bug in pytorch1.3, we have to fix it
) # by replacing these existing op
onnx_model = onnx.load_from_string(f.getvalue()) register_extra_symbolics(opset_version)
torch.onnx.export(
# Apply ONNX's Optimization model, ([one_img]),
if passes is not None: output_file,
all_passes = optimizer.get_available_passes() export_params=True,
assert all(p in all_passes for p in passes), \ keep_initializers_as_inputs=True,
f'Only {all_passes} are supported' verbose=show,
onnx_model = optimizer.optimize(onnx_model, passes) opset_version=opset_version)
return onnx_model model.forward = origin_forward
print(f'Successfully exported ONNX model: {output_file}')
if verify:
# check by onnx
onnx_model = onnx.load(output_file)
onnx.checker.check_model(onnx_model)
# check the numerical value
# get pytorch output
pytorch_result = model([one_img], [[one_meta]], return_loss=False)
# get onnx output
input_all = [node.name for node in onnx_model.graph.input]
input_initializer = [
node.name for node in onnx_model.graph.initializer
]
net_feed_input = list(set(input_all) - set(input_initializer))
assert (len(net_feed_input) == 1)
sess = rt.InferenceSession(output_file)
from mmdet.core import bbox2result
det_bboxes, det_labels = sess.run(
None, {net_feed_input[0]: one_img.detach().numpy()})
# only compare a part of result
bbox_results = bbox2result(det_bboxes, det_labels, 1)
onnx_results = bbox_results[0]
assert np.allclose(
pytorch_result[0][:, 4], onnx_results[:, 4]
), 'The outputs are different between Pytorch and ONNX'
print('The numerical values are same between Pytorch and ONNX')
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='MMDet pytorch model conversion to ONNX') description='Convert MMDetection models to ONNX')
parser.add_argument('config', help='test config file path') parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file') parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('--input-img', type=str, help='Images for input')
parser.add_argument('--show', action='store_true', help='show onnx graph')
parser.add_argument('--output-file', type=str, default='tmp.onnx')
parser.add_argument('--opset-version', type=int, default=11)
parser.add_argument( parser.add_argument(
'--out', type=str, required=True, help='output ONNX filename') '--verify',
action='store_true',
help='verify the onnx model output against pytorch output')
parser.add_argument( parser.add_argument(
'--shape', '--shape',
type=int, type=int,
nargs='+', nargs='+',
default=[1280, 800], default=[800, 1216],
help='input image size') help='input image size')
parser.add_argument( parser.add_argument(
'--passes', type=str, nargs='+', help='ONNX optimization passes') '--mean',
type=int,
nargs='+',
default=[0, 0, 0],
help='mean value used for preprocess input data')
parser.add_argument(
'--std',
type=int,
nargs='+',
default=[1, 1, 1],
help='variance value used for preprocess input data')
args = parser.parse_args() args = parser.parse_args()
return args return args
def main(): if __name__ == '__main__':
args = parse_args() args = parse_args()
if not args.out.endswith('.onnx'): assert args.opset_version == 11, 'MMDet only support opset 11 now'
raise ValueError('The output file must be a onnx file.')
if not args.input_img:
args.input_img = osp.join(
osp.dirname(__file__), '../tests/data/color.jpg')
if len(args.shape) == 1: if len(args.shape) == 1:
input_shape = (3, args.shape[0], args.shape[0]) input_shape = (1, 3, args.shape[0], args.shape[0])
elif len(args.shape) == 2: elif len(args.shape) == 2:
input_shape = (3, ) + tuple(args.shape) input_shape = (1, 3) + tuple(args.shape)
else: else:
raise ValueError('invalid input shape') raise ValueError('invalid input shape')
assert len(args.mean) == 3
assert len(args.std) == 3
normalize_cfg = {
'mean': np.array(args.mean, dtype=np.float32),
'std': np.array(args.std, dtype=np.float32)
}
cfg = mmcv.Config.fromfile(args.config) cfg = mmcv.Config.fromfile(args.config)
cfg.model.pretrained = None cfg.model.pretrained = None
cfg.data.test.test_mode = True
# build the model and load checkpoint # build the model
model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg) model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
load_checkpoint(model, args.checkpoint, map_location='cpu') checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
# Only support CPU mode for now
model.cpu().eval() # conver model to onnx file
# Customized ops are not supported, use torchvision ops instead. pytorch2onnx(
for m in model.modules(): model,
if isinstance(m, (RoIPool, RoIAlign)): args.input_img,
# set use_torchvision on-the-fly input_shape,
m.use_torchvision = True opset_version=args.opset_version,
show=args.show,
# TODO: a better way to override forward function output_file=args.output_file,
if hasattr(model, 'forward_dummy'): verify=args.verify,
model.forward = model.forward_dummy normalize_cfg=normalize_cfg)
else:
raise NotImplementedError(
'ONNX conversion is currently not currently supported with '
f'{model.__class__.__name__}')
input_data = torch.empty((1, *input_shape),
dtype=next(model.parameters()).dtype,
device=next(model.parameters()).device)
onnx_model = export_onnx_model(model, (input_data, ), args.passes)
# Print a human readable representation of the graph
onnx.helper.printable_graph(onnx_model.graph)
print(f'saving model in {args.out}')
onnx.save(onnx_model, args.out)
if __name__ == '__main__':
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment