diff --git a/demo/inference_demo.ipynb b/demo/inference_demo.ipynb
index 4df4e7c0792e507cdfdbdf7c0b06007878eaff81..e94a8c5927a0d32b6cbd890b34cd6d2d5c60927e 100644
--- a/demo/inference_demo.ipynb
+++ b/demo/inference_demo.ipynb
@@ -64,7 +64,7 @@
    ],
    "source": [
     "# show the results\n",
-    "show_result_pyplot(img, result, model.CLASSES)"
+    "show_result_pyplot(model, img, result)"
    ]
   }
  ],
diff --git a/docs/GETTING_STARTED.md b/docs/GETTING_STARTED.md
index e9bb07c935f7468ae36e952211b61aa627018608..61ee238de214bea030ea60e6db473f195ff446fd 100644
--- a/docs/GETTING_STARTED.md
+++ b/docs/GETTING_STARTED.md
@@ -28,8 +28,8 @@ Optional arguments:
 - `RESULT_FILE`: Filename of the output results in pickle format. If not specified, the results will not be saved to a file.
 - `EVAL_METRICS`: Items to be evaluated on the results. Allowed values depend on the dataset, e.g., `proposal_fast`, `proposal`, `bbox`, `segm` are available for COCO, `mAP`, `recall` for PASCAL VOC. Cityscapes could be evaluated by `cityscapes` as well as all COCO metrics.
 - `--show`: If specified, detection results will be plotted on the images and shown in a new window. It is only applicable to single GPU testing and used for debugging and visualization. Please make sure that GUI is available in your environment, otherwise you may encounter the error like `cannot connect to X server`.
+- `--show-dir`: If specified, detection results will be plotted on the images and saved to the specified directory. It is only applicable to single GPU testing and used for debugging and visualization. You **don't** need a GUI available in your environment for using this option.
 
-If you would like to evaluate the dataset, do not specify `--show` at the same time.
 
 Examples:
 
@@ -43,7 +43,15 @@ python tools/test.py configs/faster_rcnn_r50_fpn_1x_coco.py \
     --show
 ```
 
-2. Test Faster R-CNN on PASCAL VOC (without saving the test results) and evaluate the mAP.
+2. Test Faster R-CNN and save the painted images for latter visualization.
+
+```shell
+python tools/test.py configs/faster_rcnn_r50_fpn_1x.py \
+    checkpoints/faster_rcnn_r50_fpn_1x_20181010-3d1b3351.pth \
+    --show-dir faster_rcnn_r50_fpn_1x_results
+```
+
+3. Test Faster R-CNN on PASCAL VOC (without saving the test results) and evaluate the mAP.
 
 ```shell
 python tools/test.py configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc.py \
@@ -51,7 +59,7 @@ python tools/test.py configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc.py \
     --eval mAP
 ```
 
-3. Test Mask R-CNN with 8 GPUs, and evaluate the bbox and mask AP.
+4. Test Mask R-CNN with 8 GPUs, and evaluate the bbox and mask AP.
 
 ```shell
 ./tools/dist_test.sh configs/mask_rcnn_r50_fpn_1x_coco.py \
@@ -59,7 +67,7 @@ python tools/test.py configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc.py \
     8 --out results.pkl --eval bbox segm
 ```
 
-4. Test Mask R-CNN with 8 GPUs, and evaluate the **classwise** bbox and mask AP.
+5. Test Mask R-CNN with 8 GPUs, and evaluate the **classwise** bbox and mask AP.
 
 ```shell
 ./tools/dist_test.sh configs/mask_rcnn_r50_fpn_1x_coco.py \
@@ -67,7 +75,7 @@ python tools/test.py configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc.py \
     8 --out results.pkl --eval bbox segm --options "classwise=True"
 ```
 
-5. Test Mask R-CNN on COCO test-dev with 8 GPUs, and generate the json file to be submit to the official evaluation server.
+6. Test Mask R-CNN on COCO test-dev with 8 GPUs, and generate the json file to be submit to the official evaluation server.
 
 ```shell
 ./tools/dist_test.sh configs/mask_rcnn_r50_fpn_1x_coco.py \
@@ -77,7 +85,7 @@ python tools/test.py configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc.py \
 
 You will get two json files `mask_rcnn_test-dev_results.bbox.json` and `mask_rcnn_test-dev_results.segm.json`.
 
-6. Test Mask R-CNN on Cityscapes test with 8 GPUs, and generate the txt and png files to be submit to the official evaluation server.
+7. Test Mask R-CNN on Cityscapes test with 8 GPUs, and generate the txt and png files to be submit to the official evaluation server.
 
 ```shell
 ./tools/dist_test.sh configs/cityscapes/mask_rcnn_r50_fpn_1x_cityscapes.py \
@@ -108,7 +116,7 @@ python demo/webcam_demo.py configs/faster_rcnn_r50_fpn_1x_coco.py \
 Here is an example of building the model and test given images.
 
 ```python
-from mmdet.apis import init_detector, inference_detector, show_result
+from mmdet.apis import init_detector, inference_detector
 import mmcv
 
 config_file = 'configs/faster_rcnn_r50_fpn_1x_coco.py'
@@ -121,15 +129,15 @@ model = init_detector(config_file, checkpoint_file, device='cuda:0')
 img = 'test.jpg'  # or img = mmcv.imread(img), which will only load it once
 result = inference_detector(model, img)
 # visualize the results in a new window
-show_result(img, result, model.CLASSES)
+model.show_result(img, result)
 # or save the visualization results to image files
-show_result(img, result, model.CLASSES, out_file='result.jpg')
+model.show_result(img, result, out_file='result.jpg')
 
 # test a video and show the results
 video = mmcv.VideoReader('video.mp4')
 for frame in video:
     result = inference_detector(model, frame)
-    show_result(frame, result, model.CLASSES, wait_time=1)
+    model.show_result(frame, result, wait_time=1)
 ```
 
 A notebook demo can be found in [demo/inference_demo.ipynb](https://github.com/open-mmlab/mmdetection/blob/master/demo/inference_demo.ipynb).
@@ -143,7 +151,7 @@ See `tests/async_benchmark.py` to compare the speed of synchronous and asynchron
 ```python
 import asyncio
 import torch
-from mmdet.apis import init_detector, async_inference_detector, show_result
+from mmdet.apis import init_detector, async_inference_detector
 from mmdet.utils.contextmanagers import concurrent
 
 async def main():
@@ -167,9 +175,9 @@ async def main():
         result = await async_inference_detector(model, img)
 
     # visualize the results in a new window
-    show_result(img, result, model.CLASSES)
+    model.show_result(img, result)
     # or save the visualization results to image files
-    show_result(img, result, model.CLASSES, out_file='result.jpg')
+    model.show_result(img, result, out_file='result.jpg')
 
 
 asyncio.run(main())
diff --git a/mmdet/apis/__init__.py b/mmdet/apis/__init__.py
index 0dfb4cdd319a605759f59b95a6f30007f352c3c1..1d8035b74877fdeccaa41cbc10a9f1f9924eac85 100644
--- a/mmdet/apis/__init__.py
+++ b/mmdet/apis/__init__.py
@@ -1,10 +1,10 @@
 from .inference import (async_inference_detector, inference_detector,
-                        init_detector, show_result, show_result_pyplot)
+                        init_detector, show_result_pyplot)
 from .test import multi_gpu_test, single_gpu_test
 from .train import get_root_logger, set_random_seed, train_detector
 
 __all__ = [
     'get_root_logger', 'set_random_seed', 'train_detector', 'init_detector',
-    'async_inference_detector', 'inference_detector', 'show_result',
-    'show_result_pyplot', 'multi_gpu_test', 'single_gpu_test'
+    'async_inference_detector', 'inference_detector', 'show_result_pyplot',
+    'multi_gpu_test', 'single_gpu_test'
 ]
diff --git a/mmdet/apis/inference.py b/mmdet/apis/inference.py
index c26ae0f0f7d6c57b15816d32b1d969594ce1ce11..1d5c30940ab8d3c1bceb600c80593030c4fbb2cf 100644
--- a/mmdet/apis/inference.py
+++ b/mmdet/apis/inference.py
@@ -2,8 +2,6 @@ import warnings
 
 import matplotlib.pyplot as plt
 import mmcv
-import numpy as np
-import pycocotools.mask as maskUtils
 import torch
 from mmcv.parallel import collate, scatter
 from mmcv.runner import load_checkpoint
@@ -130,76 +128,8 @@ async def async_inference_detector(model, img):
     return result
 
 
-# TODO: merge this method with the one in BaseDetector
-def show_result(img,
-                result,
-                class_names,
-                score_thr=0.3,
-                wait_time=0,
-                show=True,
-                out_file=None):
-    """Visualize the detection results on the image.
-
-    Args:
-        img (str or np.ndarray): Image filename or loaded image.
-        result (tuple[list] or list): The detection result, can be either
-            (bbox, segm) or just bbox.
-        class_names (list[str] or tuple[str]): A list of class names.
-        score_thr (float): The threshold to visualize the bboxes and masks.
-        wait_time (int): Value of waitKey param.
-        show (bool, optional): Whether to show the image with opencv or not.
-        out_file (str, optional): If specified, the visualization result will
-            be written to the out file instead of shown in a window.
-
-    Returns:
-        np.ndarray or None: If neither `show` nor `out_file` is specified, the
-            visualized image is returned, otherwise None is returned.
-    """
-    assert isinstance(class_names, (tuple, list))
-    img = mmcv.imread(img)
-    img = img.copy()
-    if isinstance(result, tuple):
-        bbox_result, segm_result = result
-    else:
-        bbox_result, segm_result = result, None
-    bboxes = np.vstack(bbox_result)
-    labels = [
-        np.full(bbox.shape[0], i, dtype=np.int32)
-        for i, bbox in enumerate(bbox_result)
-    ]
-    labels = np.concatenate(labels)
-    # draw segmentation masks
-    if segm_result is not None:
-        segms = mmcv.concat_list(segm_result)
-        inds = np.where(bboxes[:, -1] > score_thr)[0]
-        np.random.seed(42)
-        color_masks = [
-            np.random.randint(0, 256, (1, 3), dtype=np.uint8)
-            for _ in range(max(labels))
-        ]
-        for i in inds:
-            i = int(i)
-            color_mask = color_masks[labels[i]]
-            mask = maskUtils.decode(segms[i]).astype(np.bool)
-            img[mask] = img[mask] * 0.5 + color_mask * 0.5
-    # if out_file specified, do not show image in window
-    if out_file is not None:
-        show = False
-    # draw bounding boxes
-    mmcv.imshow_det_bboxes(
-        img,
-        bboxes,
-        labels,
-        class_names=class_names,
-        score_thr=score_thr,
-        show=show,
-        wait_time=wait_time,
-        out_file=out_file)
-    if not (show or out_file):
-        return img
-
-
-def show_result_pyplot(img,
+def show_result_pyplot(model,
+                       img,
                        result,
                        class_names,
                        score_thr=0.3,
@@ -207,16 +137,16 @@ def show_result_pyplot(img,
     """Visualize the detection results on the image.
 
     Args:
+        model (nn.Module): The loaded detector.
         img (str or np.ndarray): Image filename or loaded image.
         result (tuple[list] or list): The detection result, can be either
             (bbox, segm) or just bbox.
         class_names (list[str] or tuple[str]): A list of class names.
         score_thr (float): The threshold to visualize the bboxes and masks.
         fig_size (tuple): Figure size of the pyplot figure.
-        out_file (str, optional): If specified, the visualization result will
-            be written to the out file instead of shown in a window.
     """
-    img = show_result(
-        img, result, class_names, score_thr=score_thr, show=False)
+    if hasattr(model, 'module'):
+        model = model.module
+    img = model.show_result(img, result, score_thr=score_thr, show=False)
     plt.figure(figsize=fig_size)
     plt.imshow(mmcv.bgr2rgb(img))
diff --git a/mmdet/apis/test.py b/mmdet/apis/test.py
index 5b0dea2d85933641a677fd4d2c689779bb7b9c96..cb1772a3d8acd24f6932999b393a6e8acd0f7a91 100644
--- a/mmdet/apis/test.py
+++ b/mmdet/apis/test.py
@@ -8,19 +8,39 @@ import torch
 import torch.distributed as dist
 from mmcv.runner import get_dist_info
 
+from mmdet.core import tensor2imgs
 
-def single_gpu_test(model, data_loader, show=False):
+
+def single_gpu_test(model, data_loader, show=False, out_dir=None):
     model.eval()
     results = []
     dataset = data_loader.dataset
     prog_bar = mmcv.ProgressBar(len(dataset))
     for i, data in enumerate(data_loader):
         with torch.no_grad():
-            result = model(return_loss=False, rescale=not show, **data)
+            result = model(return_loss=False, rescale=True, **data)
         results.append(result)
 
-        if show:
-            model.module.show_result(data, result)
+        if show or out_dir:
+            img_tensor = data['img'][0]
+            img_metas = data['img_metas'][0].data[0]
+            imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
+            assert len(imgs) == len(img_metas)
+
+            for img, img_meta in zip(imgs, img_metas):
+                h, w, _ = img_meta['img_shape']
+                img_show = img[:h, :w, :]
+
+                ori_h, ori_w = img_meta['ori_shape'][:-1]
+                img_show = mmcv.imresize(img_show, (ori_w, ori_h))
+
+                if out_dir:
+                    out_file = osp.join(out_dir, img_meta['filename'])
+                else:
+                    out_file = None
+
+                model.module.show_result(
+                    img_show, result, show=show, out_file=out_file)
 
         batch_size = data['img'][0].size(0)
         for _ in range(batch_size):
diff --git a/mmdet/models/detectors/base.py b/mmdet/models/detectors/base.py
index 936eeba05095194f1cc723a763efc1317832ea1e..55c762ee871ea3c7dba25dd8d4709e816bdeb31b 100644
--- a/mmdet/models/detectors/base.py
+++ b/mmdet/models/detectors/base.py
@@ -5,7 +5,7 @@ import numpy as np
 import pycocotools.mask as maskUtils
 import torch.nn as nn
 
-from mmdet.core import auto_fp16, get_classes, tensor2imgs
+from mmdet.core import auto_fp16
 from mmdet.utils import print_log
 
 
@@ -151,51 +151,85 @@ class BaseDetector(nn.Module, metaclass=ABCMeta):
         else:
             return self.forward_test(img, img_metas, **kwargs)
 
-    def show_result(self, data, result, dataset=None, score_thr=0.3):
+    def show_result(self,
+                    img,
+                    result,
+                    score_thr=0.3,
+                    bbox_color='green',
+                    text_color='green',
+                    thickness=1,
+                    font_scale=0.5,
+                    win_name='',
+                    show=False,
+                    wait_time=0,
+                    out_file=None):
+        """Draw `result` over `img`.
+
+        Args:
+            img (str or Tensor): The image to be displayed.
+            result (Tensor or tuple): The results to draw over `img`
+                bbox_result or (bbox_result, segm_result).
+            score_thr (float, optional): Minimum score of bboxes to be shown.
+                Default: 0.3.
+            bbox_color (str or tuple or :obj:`Color`): Color of bbox lines.
+            text_color (str or tuple or :obj:`Color`): Color of texts.
+            thickness (int): Thickness of lines.
+            font_scale (float): Font scales of texts.
+            win_name (str): The window name.
+            wait_time (int): Value of waitKey param.
+                Default: 0.
+            show (bool): Whether to show the image.
+                Default: False.
+            out_file (str or None): The filename to write the image.
+                Default: None.
+
+        Returns:
+            img (Tensor): Only if not `show` or `out_file`
+        """
+        img = mmcv.imread(img)
+        img = img.copy()
         if isinstance(result, tuple):
             bbox_result, segm_result = result
         else:
             bbox_result, segm_result = result, None
-
-        img_tensor = data['img'][0]
-        img_metas = data['img_metas'][0].data[0]
-        imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
-        assert len(imgs) == len(img_metas)
-
-        if dataset is None:
-            class_names = self.CLASSES
-        elif isinstance(dataset, str):
-            class_names = get_classes(dataset)
-        elif isinstance(dataset, (list, tuple)):
-            class_names = dataset
-        else:
-            raise TypeError(
-                'dataset must be a valid dataset name or a sequence'
-                f' of class names, not {type(dataset)}')
-
-        for img, img_meta in zip(imgs, img_metas):
-            h, w, _ = img_meta['img_shape']
-            img_show = img[:h, :w, :]
-
-            bboxes = np.vstack(bbox_result)
-            # draw segmentation masks
-            if segm_result is not None:
-                segms = mmcv.concat_list(segm_result)
-                inds = np.where(bboxes[:, -1] > score_thr)[0]
-                for i in inds:
-                    color_mask = np.random.randint(
-                        0, 256, (1, 3), dtype=np.uint8)
-                    mask = maskUtils.decode(segms[i]).astype(np.bool)
-                    img_show[mask] = img_show[mask] * 0.5 + color_mask * 0.5
-            # draw bounding boxes
-            labels = [
-                np.full(bbox.shape[0], i, dtype=np.int32)
-                for i, bbox in enumerate(bbox_result)
+        bboxes = np.vstack(bbox_result)
+        labels = [
+            np.full(bbox.shape[0], i, dtype=np.int32)
+            for i, bbox in enumerate(bbox_result)
+        ]
+        labels = np.concatenate(labels)
+        # draw segmentation masks
+        if segm_result is not None:
+            segms = mmcv.concat_list(segm_result)
+            inds = np.where(bboxes[:, -1] > score_thr)[0]
+            np.random.seed(42)
+            color_masks = [
+                np.random.randint(0, 256, (1, 3), dtype=np.uint8)
+                for _ in range(max(labels) + 1)
             ]
-            labels = np.concatenate(labels)
-            mmcv.imshow_det_bboxes(
-                img_show,
-                bboxes,
-                labels,
-                class_names=class_names,
-                score_thr=score_thr)
+            for i in inds:
+                i = int(i)
+                color_mask = color_masks[labels[i]]
+                mask = maskUtils.decode(segms[i]).astype(np.bool)
+                img[mask] = img[mask] * 0.5 + color_mask * 0.5
+        # if out_file specified, do not show image in window
+        if out_file is not None:
+            show = False
+        # draw bounding boxes
+        mmcv.imshow_det_bboxes(
+            img,
+            bboxes,
+            labels,
+            class_names=self.CLASSES,
+            score_thr=score_thr,
+            bbox_color=bbox_color,
+            text_color=text_color,
+            thickness=thickness,
+            font_scale=font_scale,
+            win_name=win_name,
+            show=show,
+            wait_time=wait_time,
+            out_file=out_file)
+
+        if not (show or out_file):
+            return img
diff --git a/tools/test.py b/tools/test.py
index 7d39587c50aa1112c6a773079bbad9bdea381b40..79bc8262d83640aeb574f8463720dfa7d6a73842 100644
--- a/tools/test.py
+++ b/tools/test.py
@@ -38,6 +38,9 @@ def parse_args():
         help='evaluation metrics, which depends on the dataset, e.g., "bbox",'
         ' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC')
     parser.add_argument('--show', action='store_true', help='show results')
+    parser.add_argument(
+        '--show-dir',
+        help='(Optional) directory where painted images will be saved')
     parser.add_argument(
         '--gpu-collect',
         action='store_true',
@@ -45,7 +48,7 @@ def parse_args():
     parser.add_argument(
         '--tmpdir',
         help='tmp directory used for collecting results from multiple '
-        'workers, available when gpu_collect is not specified')
+        'workers, available when gpu-collect is not specified')
     parser.add_argument(
         '--options', nargs='+', action=DictAction, help='arguments in dict')
     parser.add_argument(
@@ -63,10 +66,11 @@ def parse_args():
 def main():
     args = parse_args()
 
-    assert args.out or args.eval or args.format_only or args.show, \
+    assert args.out or args.eval or args.format_only or args.show \
+        or args.show_dir, \
         ('Please specify at least one operation (save/eval/format/show the '
-         'results) with the argument "--out", "--eval", "--format_only" '
-         'or "--show"')
+         'results / save the results) with the argument "--out", "--eval"'
+         ', "--format-only", "--show" or "--show-dir"')
 
     if args.eval and args.format_only:
         raise ValueError('--eval and --format_only cannot be both specified')
@@ -115,7 +119,7 @@ def main():
 
     if not distributed:
         model = MMDataParallel(model, device_ids=[0])
-        outputs = single_gpu_test(model, data_loader, args.show)
+        outputs = single_gpu_test(model, data_loader, args.show, args.show_dir)
     else:
         model = MMDistributedDataParallel(
             model.cuda(),