Skip to content
Snippets Groups Projects
Unverified Commit 2f32a47a authored by robin Han's avatar robin Han Committed by GitHub
Browse files

Pytorch2onnx (#3075)


* Update pytorch2onnx.py which using new logic to convert pytorch to ONNX

* use standard API to check whether in ONNX convert process

* only compare the score value while verifying results between ONNX and pytorch

* move import onnx before import torch, or something weird will happen

* use real images for input

* modifying the way of calling nms

* modify docstring for bbox2result, and remove unnecessary part for onnx exporting

* modify the 'Convert to ONNX' part in docs

* replace or to | in docstring

* update according to the latest mmcv

* add normalize part

* raise error while using low version mmcv

* minor update

* minor update

Co-authored-by: default avatarJiarui XU <xvjiarui0826@gmail.com>
parent 9d3f9b03
No related branches found
No related tags found
No related merge requests found
......@@ -445,13 +445,13 @@ Please refer to [robustness_benchmarking.md](robustness_benchmarking.md).
### Convert to ONNX (experimental)
We provide a script to convert model to [ONNX](https://github.com/onnx/onnx) format. The converted model could be visualized by tools like [Netron](https://github.com/lutzroeder/netron).
We provide a script to convert model to [ONNX](https://github.com/onnx/onnx) format. The converted model could be visualized by tools like [Netron](https://github.com/lutzroeder/netron). Besides, we also support comparing the output results between Pytorch and ONNX model.
```shell
python tools/pytorch2onnx.py ${CONFIG_FILE} ${CHECKPOINT_FILE} --out ${ONNX_FILE} [--shape ${INPUT_SHAPE}]
python tools/pytorch2onnx.py ${CONFIG_FILE} ${CHECKPOINT_FILE} --output_file ${ONNX_FILE} [--shape ${INPUT_SHAPE} --verify]
```
**Note**: This tool is still experimental. Customized operators are not supported for now. We set `use_torchvision=True` on-the-fly for `RoIPool` and `RoIAlign`.
**Note**: This tool is still experimental. Some customized operators are not supported for now. We only support exporting RetinaNet model at this moment.
### Visualize the output results
......
......@@ -251,8 +251,12 @@ class AnchorGenerator(object):
torch.Tensor: Anchors in the overall feature maps.
"""
feat_h, feat_w = featmap_size
# convert Tensor to int, so that we can covert to ONNX correctlly
feat_h = int(feat_h)
feat_w = int(feat_w)
shift_x = torch.arange(0, feat_w, device=device) * stride[0]
shift_y = torch.arange(0, feat_h, device=device) * stride[1]
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1)
shifts = shifts.type_as(base_anchors)
......
......@@ -161,8 +161,8 @@ def delta2bbox(rois,
[0.0000, 0.3161, 4.1945, 0.6839],
[5.0000, 5.0000, 5.0000, 5.0000]])
"""
means = deltas.new_tensor(means).repeat(1, deltas.size(1) // 4)
stds = deltas.new_tensor(stds).repeat(1, deltas.size(1) // 4)
means = deltas.new_tensor(means).view(1, -1).repeat(1, deltas.size(1) // 4)
stds = deltas.new_tensor(stds).view(1, -1).repeat(1, deltas.size(1) // 4)
denorm_deltas = deltas * stds + means
dx = denorm_deltas[:, 0::4]
dy = denorm_deltas[:, 1::4]
......@@ -193,5 +193,5 @@ def delta2bbox(rois,
y1 = y1.clamp(min=0, max=max_shape[0])
x2 = x2.clamp(min=0, max=max_shape[1])
y2 = y2.clamp(min=0, max=max_shape[0])
bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view_as(deltas)
bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size())
return bboxes
......@@ -96,8 +96,8 @@ def bbox2result(bboxes, labels, num_classes):
"""Convert detection results to a list of numpy arrays.
Args:
bboxes (Tensor): shape (n, 5)
labels (Tensor): shape (n, )
bboxes (torch.Tensor | np.ndarray): shape (n, 5)
labels (torch.Tensor | np.ndarray): shape (n, )
num_classes (int): class number, including background class
Returns:
......@@ -106,8 +106,9 @@ def bbox2result(bboxes, labels, num_classes):
if bboxes.shape[0] == 0:
return [np.zeros((0, 5), dtype=np.float32) for i in range(num_classes)]
else:
bboxes = bboxes.cpu().numpy()
labels = labels.cpu().numpy()
if isinstance(bboxes, torch.Tensor):
bboxes = bboxes.cpu().numpy()
labels = labels.cpu().numpy()
return [bboxes[labels == i, :] for i in range(num_classes)]
......
......@@ -31,20 +31,33 @@ def multiclass_nms(multi_bboxes,
if multi_bboxes.shape[1] > 4:
bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4)
else:
bboxes = multi_bboxes[:, None].expand(-1, num_classes, 4)
bboxes = multi_bboxes[:, None].expand(
multi_scores.size(0), num_classes, 4)
scores = multi_scores[:, :-1]
# filter out boxes with low scores
valid_mask = scores > score_thr
bboxes = bboxes[valid_mask]
# We use masked_select for ONNX exporting purpose,
# which is equivalent to bboxes = bboxes[valid_mask]
# (TODO): as ONNX does not support repeat now,
# we have to use this ugly code
bboxes = torch.masked_select(
bboxes,
torch.stack((valid_mask, valid_mask, valid_mask, valid_mask),
-1)).view(-1, 4)
if score_factors is not None:
scores = scores * score_factors[:, None]
scores = scores[valid_mask]
scores = torch.masked_select(scores, valid_mask)
labels = valid_mask.nonzero()[:, 1]
if bboxes.numel() == 0:
bboxes = multi_bboxes.new_zeros((0, 5))
labels = multi_bboxes.new_zeros((0, ), dtype=torch.long)
if torch.onnx.is_in_onnx_export():
raise RuntimeError('[ONNX Error] Can not record NMS '
'as it has not been executed this time')
return bboxes, labels
dets, keep = batched_nms(bboxes, scores, labels, nms_cfg)
......
import torch
import torch.nn as nn
from mmdet.core import bbox2result
......@@ -109,6 +110,10 @@ class SingleStageDetector(BaseDetector):
outs = self.bbox_head(x)
bbox_list = self.bbox_head.get_bboxes(
*outs, img_metas, rescale=rescale)
# skip post-processing when exporting to ONNX
if torch.onnx.is_in_onnx_export():
return bbox_list
bbox_results = [
bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)
for det_bboxes, det_labels in bbox_list
......
......@@ -3,7 +3,7 @@ line_length = 79
multi_line_output = 0
known_standard_library = setuptools
known_first_party = mmdet
known_third_party = PIL,asynctest,cityscapesscripts,cv2,matplotlib,mmcv,numpy,onnx,pycocotools,pytest,robustness_eval,seaborn,six,terminaltables,torch,torchvision
known_third_party = PIL,asynctest,cityscapesscripts,cv2,matplotlib,mmcv,numpy,onnx,onnxruntime,pycocotools,pytest,robustness_eval,seaborn,six,terminaltables,torch,torchvision
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY
......
import argparse
import io
import os.path as osp
from functools import partial
import mmcv
import numpy as np
import onnx
import onnxruntime as rt
import torch
from mmcv.ops import RoIAlign, RoIPool
from mmcv.runner import load_checkpoint
from onnx import optimizer
from torch.onnx import OperatorExportTypes
from mmdet.models import build_detector
try:
from mmcv.onnx.symbolic import register_extra_symbolics
except ModuleNotFoundError:
raise NotImplementedError('please update mmcv to version>=v1.0.4')
def export_onnx_model(model, inputs, passes):
"""Trace and export a model to onnx format. Modified from
https://github.com/facebookresearch/detectron2/
Args:
model (nn.Module):
inputs (tuple[args]): the model will be called by `model(*inputs)`
passes (None or list[str]): the optimization passed for ONNX model
Returns:
an onnx model
"""
assert isinstance(model, torch.nn.Module)
# make sure all modules are in eval mode, onnx may change the training
# state of the module if the states are not consistent
def _check_eval(module):
assert not module.training
model.apply(_check_eval)
# Export the model to ONNX
with torch.no_grad():
with io.BytesIO() as f:
torch.onnx.export(
model,
inputs,
f,
operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
# verbose=True, # NOTE: uncomment this for debugging
# export_params=True,
)
onnx_model = onnx.load_from_string(f.getvalue())
# Apply ONNX's Optimization
if passes is not None:
all_passes = optimizer.get_available_passes()
assert all(p in all_passes for p in passes), \
f'Only {all_passes} are supported'
onnx_model = optimizer.optimize(onnx_model, passes)
return onnx_model
def pytorch2onnx(model,
input_img,
input_shape,
opset_version=11,
show=False,
output_file='tmp.onnx',
verify=False,
normalize_cfg=None):
model.cpu().eval()
# read image
one_img = mmcv.imread(input_img)
if normalize_cfg:
one_img = mmcv.imnormalize(one_img, normalize_cfg['mean'],
normalize_cfg['std'])
one_img = mmcv.imresize(one_img, input_shape[2:]).transpose(2, 0, 1)
one_img = torch.from_numpy(one_img).unsqueeze(0).float()
(_, C, H, W) = input_shape
one_meta = {
'img_shape': (H, W, C),
'ori_shape': (H, W, C),
'pad_shape': (H, W, C),
'filename': '<demo>.png',
'scale_factor': 1.0,
'flip': False
}
# onnx.export does not support kwargs
origin_forward = model.forward
model.forward = partial(
model.forward, img_metas=[[one_meta]], return_loss=False)
# pytorch has some bug in pytorch1.3, we have to fix it
# by replacing these existing op
register_extra_symbolics(opset_version)
torch.onnx.export(
model, ([one_img]),
output_file,
export_params=True,
keep_initializers_as_inputs=True,
verbose=show,
opset_version=opset_version)
model.forward = origin_forward
print(f'Successfully exported ONNX model: {output_file}')
if verify:
# check by onnx
onnx_model = onnx.load(output_file)
onnx.checker.check_model(onnx_model)
# check the numerical value
# get pytorch output
pytorch_result = model([one_img], [[one_meta]], return_loss=False)
# get onnx output
input_all = [node.name for node in onnx_model.graph.input]
input_initializer = [
node.name for node in onnx_model.graph.initializer
]
net_feed_input = list(set(input_all) - set(input_initializer))
assert (len(net_feed_input) == 1)
sess = rt.InferenceSession(output_file)
from mmdet.core import bbox2result
det_bboxes, det_labels = sess.run(
None, {net_feed_input[0]: one_img.detach().numpy()})
# only compare a part of result
bbox_results = bbox2result(det_bboxes, det_labels, 1)
onnx_results = bbox_results[0]
assert np.allclose(
pytorch_result[0][:, 4], onnx_results[:, 4]
), 'The outputs are different between Pytorch and ONNX'
print('The numerical values are same between Pytorch and ONNX')
def parse_args():
parser = argparse.ArgumentParser(
description='MMDet pytorch model conversion to ONNX')
description='Convert MMDetection models to ONNX')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('--input-img', type=str, help='Images for input')
parser.add_argument('--show', action='store_true', help='show onnx graph')
parser.add_argument('--output-file', type=str, default='tmp.onnx')
parser.add_argument('--opset-version', type=int, default=11)
parser.add_argument(
'--out', type=str, required=True, help='output ONNX filename')
'--verify',
action='store_true',
help='verify the onnx model output against pytorch output')
parser.add_argument(
'--shape',
type=int,
nargs='+',
default=[1280, 800],
default=[800, 1216],
help='input image size')
parser.add_argument(
'--passes', type=str, nargs='+', help='ONNX optimization passes')
'--mean',
type=int,
nargs='+',
default=[0, 0, 0],
help='mean value used for preprocess input data')
parser.add_argument(
'--std',
type=int,
nargs='+',
default=[1, 1, 1],
help='variance value used for preprocess input data')
args = parser.parse_args()
return args
def main():
if __name__ == '__main__':
args = parse_args()
if not args.out.endswith('.onnx'):
raise ValueError('The output file must be a onnx file.')
assert args.opset_version == 11, 'MMDet only support opset 11 now'
if not args.input_img:
args.input_img = osp.join(
osp.dirname(__file__), '../tests/data/color.jpg')
if len(args.shape) == 1:
input_shape = (3, args.shape[0], args.shape[0])
input_shape = (1, 3, args.shape[0], args.shape[0])
elif len(args.shape) == 2:
input_shape = (3, ) + tuple(args.shape)
input_shape = (1, 3) + tuple(args.shape)
else:
raise ValueError('invalid input shape')
assert len(args.mean) == 3
assert len(args.std) == 3
normalize_cfg = {
'mean': np.array(args.mean, dtype=np.float32),
'std': np.array(args.std, dtype=np.float32)
}
cfg = mmcv.Config.fromfile(args.config)
cfg.model.pretrained = None
cfg.data.test.test_mode = True
# build the model and load checkpoint
# build the model
model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
load_checkpoint(model, args.checkpoint, map_location='cpu')
# Only support CPU mode for now
model.cpu().eval()
# Customized ops are not supported, use torchvision ops instead.
for m in model.modules():
if isinstance(m, (RoIPool, RoIAlign)):
# set use_torchvision on-the-fly
m.use_torchvision = True
# TODO: a better way to override forward function
if hasattr(model, 'forward_dummy'):
model.forward = model.forward_dummy
else:
raise NotImplementedError(
'ONNX conversion is currently not currently supported with '
f'{model.__class__.__name__}')
input_data = torch.empty((1, *input_shape),
dtype=next(model.parameters()).dtype,
device=next(model.parameters()).device)
onnx_model = export_onnx_model(model, (input_data, ), args.passes)
# Print a human readable representation of the graph
onnx.helper.printable_graph(onnx_model.graph)
print(f'saving model in {args.out}')
onnx.save(onnx_model, args.out)
if __name__ == '__main__':
main()
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
# conver model to onnx file
pytorch2onnx(
model,
args.input_img,
input_shape,
opset_version=args.opset_version,
show=args.show,
output_file=args.output_file,
verify=args.verify,
normalize_cfg=normalize_cfg)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment