diff --git a/.isort.cfg b/.isort.cfg
index 38124940effd311d4438be03accb5aeffb747cbd..06ae39a2a1c5059808dbc41dbe3b1f6f6986375f 100644
--- a/.isort.cfg
+++ b/.isort.cfg
@@ -3,6 +3,6 @@ line_length = 79
 multi_line_output = 0
 known_standard_library = setuptools
 known_first_party = mmdet
-known_third_party = asynctest,cityscapesscripts,cv2,matplotlib,mmcv,numpy,pycocotools,robustness_eval,roi_align,roi_pool,seaborn,six,terminaltables,torch,torchvision
+known_third_party = asynctest,cityscapesscripts,cv2,matplotlib,mmcv,numpy,onnx,pycocotools,robustness_eval,roi_align,roi_pool,seaborn,six,terminaltables,torch,torchvision
 no_lines_before = STDLIB,LOCALFOLDER
 default_section = THIRDPARTY
diff --git a/docs/GETTING_STARTED.md b/docs/GETTING_STARTED.md
index 077def3e39acf7989fc203744f940d892a7e06e9..3640220b32f09f60e17bce0e05993a95bc1b2e5a 100644
--- a/docs/GETTING_STARTED.md
+++ b/docs/GETTING_STARTED.md
@@ -358,6 +358,15 @@ The final output filename will be `faster_rcnn_r50_fpn_1x_20190801-{hash id}.pth
 
 Please refer to [ROBUSTNESS_BENCHMARKING.md](ROBUSTNESS_BENCHMARKING.md).
 
+### Convert to ONNX (experimental)
+
+We provide a script to convert model to [ONNX](https://github.com/onnx/onnx) format. The converted model could be visualized by tools like [Netron](https://github.com/lutzroeder/netron).
+
+```shell
+python tools/pytorch2onnx.py ${CONFIG_FILE} ${CHECKPOINT_FILE} --out ${ONNX_FILE} [--shape ${INPUT_SHAPE}]
+```
+
+**Note**: This tool is still experimental. Customized operators are not supported for now. We set `use_torchvision=True` on-the-fly for `RoIPool` and `RoIAlign`.
 
 ## How-to
 
diff --git a/mmdet/models/detectors/cascade_rcnn.py b/mmdet/models/detectors/cascade_rcnn.py
index 4ee7c0a31b9b09763c5a912628cef7b21439a3d9..064c733399c0d4ab0de1b8b8a32963f1539f7fd6 100644
--- a/mmdet/models/detectors/cascade_rcnn.py
+++ b/mmdet/models/detectors/cascade_rcnn.py
@@ -126,7 +126,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
         if self.with_rpn:
             rpn_outs = self.rpn_head(x)
             outs = outs + (rpn_outs, )
-        proposals = torch.randn(1000, 4).cuda()
+        proposals = torch.randn(1000, 4).to(device=img.device)
         # bbox heads
         rois = bbox2roi([proposals])
         if self.with_bbox:
diff --git a/mmdet/models/detectors/double_head_rcnn.py b/mmdet/models/detectors/double_head_rcnn.py
index 7a783353f1eba4ee551a4e9c4368a3584dd09aa0..7b6a80f200b7c50e651765887f06111d371299ab 100644
--- a/mmdet/models/detectors/double_head_rcnn.py
+++ b/mmdet/models/detectors/double_head_rcnn.py
@@ -20,7 +20,7 @@ class DoubleHeadRCNN(TwoStageDetector):
         if self.with_rpn:
             rpn_outs = self.rpn_head(x)
             outs = outs + (rpn_outs, )
-        proposals = torch.randn(1000, 4).cuda()
+        proposals = torch.randn(1000, 4).to(device=img.device)
         # bbox head
         rois = bbox2roi([proposals])
         bbox_cls_feats = self.bbox_roi_extractor(
diff --git a/mmdet/models/detectors/grid_rcnn.py b/mmdet/models/detectors/grid_rcnn.py
index 99c12d4536388b47c71f0eb7e95350be0ed75f82..6f6cf30d4947e08e8c35b69f1185189cf94f8cb1 100644
--- a/mmdet/models/detectors/grid_rcnn.py
+++ b/mmdet/models/detectors/grid_rcnn.py
@@ -88,7 +88,7 @@ class GridRCNN(TwoStageDetector):
         if self.with_rpn:
             rpn_outs = self.rpn_head(x)
             outs = outs + (rpn_outs, )
-        proposals = torch.randn(1000, 4).cuda()
+        proposals = torch.randn(1000, 4).to(device=img.device)
         # bbox head
         rois = bbox2roi([proposals])
         bbox_feats = self.bbox_roi_extractor(
diff --git a/mmdet/models/detectors/htc.py b/mmdet/models/detectors/htc.py
index d48bf0c731ed1e0b1575a8173ed5cb1e79b91a1b..63c4339d1f47570efd4ed6822772e960440f825c 100644
--- a/mmdet/models/detectors/htc.py
+++ b/mmdet/models/detectors/htc.py
@@ -162,7 +162,7 @@ class HybridTaskCascade(CascadeRCNN):
         if self.with_rpn:
             rpn_outs = self.rpn_head(x)
             outs = outs + (rpn_outs, )
-        proposals = torch.randn(1000, 4).cuda()
+        proposals = torch.randn(1000, 4).to(device=img.device)
         # semantic head
         if self.with_semantic:
             _, semantic_feat = self.semantic_head(x)
diff --git a/mmdet/models/detectors/two_stage.py b/mmdet/models/detectors/two_stage.py
index d9eee28087858a4e585bb70176dc67c112126b9e..9bb343f859c2e545f5381ed7ca8789af81672865 100644
--- a/mmdet/models/detectors/two_stage.py
+++ b/mmdet/models/detectors/two_stage.py
@@ -106,7 +106,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
         if self.with_rpn:
             rpn_outs = self.rpn_head(x)
             outs = outs + (rpn_outs, )
-        proposals = torch.randn(1000, 4).cuda()
+        proposals = torch.randn(1000, 4).to(device=img.device)
         # bbox head
         rois = bbox2roi([proposals])
         if self.with_bbox:
diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f3cf2d8fa07961762f3f7cd93f4f5879984101a
--- /dev/null
+++ b/tools/pytorch2onnx.py
@@ -0,0 +1,125 @@
+import argparse
+import io
+
+import mmcv
+import onnx
+import torch
+from mmcv.runner import load_checkpoint
+from onnx import optimizer
+from torch.onnx import OperatorExportTypes
+
+from mmdet.models import build_detector
+from mmdet.ops import RoIAlign, RoIPool
+
+
+def export_onnx_model(model, inputs, passes):
+    """
+    Trace and export a model to onnx format.
+    Modified from https://github.com/facebookresearch/detectron2/
+
+    Args:
+        model (nn.Module):
+        inputs (tuple[args]): the model will be called by `model(*inputs)`
+        passes (None or list[str]): the optimization passed for ONNX model
+
+    Returns:
+        an onnx model
+    """
+    assert isinstance(model, torch.nn.Module)
+
+    # make sure all modules are in eval mode, onnx may change the training
+    # state of the module if the states are not consistent
+    def _check_eval(module):
+        assert not module.training
+
+    model.apply(_check_eval)
+
+    # Export the model to ONNX
+    with torch.no_grad():
+        with io.BytesIO() as f:
+            torch.onnx.export(
+                model,
+                inputs,
+                f,
+                operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
+                # verbose=True,  # NOTE: uncomment this for debugging
+                # export_params=True,
+            )
+            onnx_model = onnx.load_from_string(f.getvalue())
+
+    # Apply ONNX's Optimization
+    if passes is not None:
+        all_passes = optimizer.get_available_passes()
+        assert all(p in all_passes for p in passes), \
+            'Only {} are supported'.format(all_passes)
+    onnx_model = optimizer.optimize(onnx_model, passes)
+    return onnx_model
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(
+        description='MMDet pytorch model conversion to ONNX')
+    parser.add_argument('config', help='test config file path')
+    parser.add_argument('checkpoint', help='checkpoint file')
+    parser.add_argument(
+        '--out', type=str, required=True, help='output ONNX filename')
+    parser.add_argument(
+        '--shape',
+        type=int,
+        nargs='+',
+        default=[1280, 800],
+        help='input image size')
+    parser.add_argument(
+        '--passes', type=str, nargs='+', help='ONNX optimization passes')
+    args = parser.parse_args()
+    return args
+
+
+def main():
+    args = parse_args()
+
+    if not args.out.endswith('.onnx'):
+        raise ValueError('The output file must be a onnx file.')
+
+    if len(args.shape) == 1:
+        input_shape = (3, args.shape[0], args.shape[0])
+    elif len(args.shape) == 2:
+        input_shape = (3, ) + tuple(args.shape)
+    else:
+        raise ValueError('invalid input shape')
+
+    cfg = mmcv.Config.fromfile(args.config)
+    cfg.model.pretrained = None
+
+    # build the model and load checkpoint
+    model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
+    load_checkpoint(model, args.checkpoint, map_location='cpu')
+    # Only support CPU mode for now
+    model.cpu().eval()
+    # Customized ops are not supported, use torchvision ops instead.
+    for m in model.modules():
+        if isinstance(m, (RoIPool, RoIAlign)):
+            # set use_torchvision on-the-fly
+            m.use_torchvision = True
+
+    # TODO: a better way to override forward function
+    if hasattr(model, 'forward_dummy'):
+        model.forward = model.forward_dummy
+    else:
+        raise NotImplementedError(
+            'ONNX conversion is currently not currently supported with '
+            '{}'.format(model.__class__.__name__))
+
+    input_data = torch.empty((1, *input_shape),
+                             dtype=next(model.parameters()).dtype,
+                             device=next(model.parameters()).device)
+
+    onnx_model = export_onnx_model(model, (input_data, ), args.passes)
+    # Print a human readable representation of the graph
+    onnx.helper.printable_graph(onnx_model.graph)
+    print('saving model in {}'.format(args.out))
+    onnx.save(onnx_model, args.out)
+
+
+if __name__ == '__main__':
+    main()