diff --git a/.isort.cfg b/.isort.cfg index 38124940effd311d4438be03accb5aeffb747cbd..06ae39a2a1c5059808dbc41dbe3b1f6f6986375f 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -3,6 +3,6 @@ line_length = 79 multi_line_output = 0 known_standard_library = setuptools known_first_party = mmdet -known_third_party = asynctest,cityscapesscripts,cv2,matplotlib,mmcv,numpy,pycocotools,robustness_eval,roi_align,roi_pool,seaborn,six,terminaltables,torch,torchvision +known_third_party = asynctest,cityscapesscripts,cv2,matplotlib,mmcv,numpy,onnx,pycocotools,robustness_eval,roi_align,roi_pool,seaborn,six,terminaltables,torch,torchvision no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY diff --git a/docs/GETTING_STARTED.md b/docs/GETTING_STARTED.md index 077def3e39acf7989fc203744f940d892a7e06e9..3640220b32f09f60e17bce0e05993a95bc1b2e5a 100644 --- a/docs/GETTING_STARTED.md +++ b/docs/GETTING_STARTED.md @@ -358,6 +358,15 @@ The final output filename will be `faster_rcnn_r50_fpn_1x_20190801-{hash id}.pth Please refer to [ROBUSTNESS_BENCHMARKING.md](ROBUSTNESS_BENCHMARKING.md). +### Convert to ONNX (experimental) + +We provide a script to convert model to [ONNX](https://github.com/onnx/onnx) format. The converted model could be visualized by tools like [Netron](https://github.com/lutzroeder/netron). + +```shell +python tools/pytorch2onnx.py ${CONFIG_FILE} ${CHECKPOINT_FILE} --out ${ONNX_FILE} [--shape ${INPUT_SHAPE}] +``` + +**Note**: This tool is still experimental. Customized operators are not supported for now. We set `use_torchvision=True` on-the-fly for `RoIPool` and `RoIAlign`. ## How-to diff --git a/mmdet/models/detectors/cascade_rcnn.py b/mmdet/models/detectors/cascade_rcnn.py index 4ee7c0a31b9b09763c5a912628cef7b21439a3d9..064c733399c0d4ab0de1b8b8a32963f1539f7fd6 100644 --- a/mmdet/models/detectors/cascade_rcnn.py +++ b/mmdet/models/detectors/cascade_rcnn.py @@ -126,7 +126,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): if self.with_rpn: rpn_outs = self.rpn_head(x) outs = outs + (rpn_outs, ) - proposals = torch.randn(1000, 4).cuda() + proposals = torch.randn(1000, 4).to(device=img.device) # bbox heads rois = bbox2roi([proposals]) if self.with_bbox: diff --git a/mmdet/models/detectors/double_head_rcnn.py b/mmdet/models/detectors/double_head_rcnn.py index 7a783353f1eba4ee551a4e9c4368a3584dd09aa0..7b6a80f200b7c50e651765887f06111d371299ab 100644 --- a/mmdet/models/detectors/double_head_rcnn.py +++ b/mmdet/models/detectors/double_head_rcnn.py @@ -20,7 +20,7 @@ class DoubleHeadRCNN(TwoStageDetector): if self.with_rpn: rpn_outs = self.rpn_head(x) outs = outs + (rpn_outs, ) - proposals = torch.randn(1000, 4).cuda() + proposals = torch.randn(1000, 4).to(device=img.device) # bbox head rois = bbox2roi([proposals]) bbox_cls_feats = self.bbox_roi_extractor( diff --git a/mmdet/models/detectors/grid_rcnn.py b/mmdet/models/detectors/grid_rcnn.py index 99c12d4536388b47c71f0eb7e95350be0ed75f82..6f6cf30d4947e08e8c35b69f1185189cf94f8cb1 100644 --- a/mmdet/models/detectors/grid_rcnn.py +++ b/mmdet/models/detectors/grid_rcnn.py @@ -88,7 +88,7 @@ class GridRCNN(TwoStageDetector): if self.with_rpn: rpn_outs = self.rpn_head(x) outs = outs + (rpn_outs, ) - proposals = torch.randn(1000, 4).cuda() + proposals = torch.randn(1000, 4).to(device=img.device) # bbox head rois = bbox2roi([proposals]) bbox_feats = self.bbox_roi_extractor( diff --git a/mmdet/models/detectors/htc.py b/mmdet/models/detectors/htc.py index d48bf0c731ed1e0b1575a8173ed5cb1e79b91a1b..63c4339d1f47570efd4ed6822772e960440f825c 100644 --- a/mmdet/models/detectors/htc.py +++ b/mmdet/models/detectors/htc.py @@ -162,7 +162,7 @@ class HybridTaskCascade(CascadeRCNN): if self.with_rpn: rpn_outs = self.rpn_head(x) outs = outs + (rpn_outs, ) - proposals = torch.randn(1000, 4).cuda() + proposals = torch.randn(1000, 4).to(device=img.device) # semantic head if self.with_semantic: _, semantic_feat = self.semantic_head(x) diff --git a/mmdet/models/detectors/two_stage.py b/mmdet/models/detectors/two_stage.py index d9eee28087858a4e585bb70176dc67c112126b9e..9bb343f859c2e545f5381ed7ca8789af81672865 100644 --- a/mmdet/models/detectors/two_stage.py +++ b/mmdet/models/detectors/two_stage.py @@ -106,7 +106,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, if self.with_rpn: rpn_outs = self.rpn_head(x) outs = outs + (rpn_outs, ) - proposals = torch.randn(1000, 4).cuda() + proposals = torch.randn(1000, 4).to(device=img.device) # bbox head rois = bbox2roi([proposals]) if self.with_bbox: diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..4f3cf2d8fa07961762f3f7cd93f4f5879984101a --- /dev/null +++ b/tools/pytorch2onnx.py @@ -0,0 +1,125 @@ +import argparse +import io + +import mmcv +import onnx +import torch +from mmcv.runner import load_checkpoint +from onnx import optimizer +from torch.onnx import OperatorExportTypes + +from mmdet.models import build_detector +from mmdet.ops import RoIAlign, RoIPool + + +def export_onnx_model(model, inputs, passes): + """ + Trace and export a model to onnx format. + Modified from https://github.com/facebookresearch/detectron2/ + + Args: + model (nn.Module): + inputs (tuple[args]): the model will be called by `model(*inputs)` + passes (None or list[str]): the optimization passed for ONNX model + + Returns: + an onnx model + """ + assert isinstance(model, torch.nn.Module) + + # make sure all modules are in eval mode, onnx may change the training + # state of the module if the states are not consistent + def _check_eval(module): + assert not module.training + + model.apply(_check_eval) + + # Export the model to ONNX + with torch.no_grad(): + with io.BytesIO() as f: + torch.onnx.export( + model, + inputs, + f, + operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK, + # verbose=True, # NOTE: uncomment this for debugging + # export_params=True, + ) + onnx_model = onnx.load_from_string(f.getvalue()) + + # Apply ONNX's Optimization + if passes is not None: + all_passes = optimizer.get_available_passes() + assert all(p in all_passes for p in passes), \ + 'Only {} are supported'.format(all_passes) + onnx_model = optimizer.optimize(onnx_model, passes) + return onnx_model + + +def parse_args(): + parser = argparse.ArgumentParser( + description='MMDet pytorch model conversion to ONNX') + parser.add_argument('config', help='test config file path') + parser.add_argument('checkpoint', help='checkpoint file') + parser.add_argument( + '--out', type=str, required=True, help='output ONNX filename') + parser.add_argument( + '--shape', + type=int, + nargs='+', + default=[1280, 800], + help='input image size') + parser.add_argument( + '--passes', type=str, nargs='+', help='ONNX optimization passes') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + if not args.out.endswith('.onnx'): + raise ValueError('The output file must be a onnx file.') + + if len(args.shape) == 1: + input_shape = (3, args.shape[0], args.shape[0]) + elif len(args.shape) == 2: + input_shape = (3, ) + tuple(args.shape) + else: + raise ValueError('invalid input shape') + + cfg = mmcv.Config.fromfile(args.config) + cfg.model.pretrained = None + + # build the model and load checkpoint + model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg) + load_checkpoint(model, args.checkpoint, map_location='cpu') + # Only support CPU mode for now + model.cpu().eval() + # Customized ops are not supported, use torchvision ops instead. + for m in model.modules(): + if isinstance(m, (RoIPool, RoIAlign)): + # set use_torchvision on-the-fly + m.use_torchvision = True + + # TODO: a better way to override forward function + if hasattr(model, 'forward_dummy'): + model.forward = model.forward_dummy + else: + raise NotImplementedError( + 'ONNX conversion is currently not currently supported with ' + '{}'.format(model.__class__.__name__)) + + input_data = torch.empty((1, *input_shape), + dtype=next(model.parameters()).dtype, + device=next(model.parameters()).device) + + onnx_model = export_onnx_model(model, (input_data, ), args.passes) + # Print a human readable representation of the graph + onnx.helper.printable_graph(onnx_model.graph) + print('saving model in {}'.format(args.out)) + onnx.save(onnx_model, args.out) + + +if __name__ == '__main__': + main()