From e8117f0f7add6b8145339a7c2ae227f8859537de Mon Sep 17 00:00:00 2001
From: David de la Iglesia Castro <daviddelaiglesiacastro@gmail.com>
Date: Wed, 29 Apr 2020 14:32:51 +0200
Subject: [PATCH] Add option to save the result images of running tools/test
 (#2414)

* Use apis/inference in detectors show_result

* Add images_out_dir arg to test apis

* Include --images_out_dir in assertion

* Fix single class parsing

* Remove single class hack

* Move import

* Add example of saving results

* Fix list number

* Refactor show_result

* Update docs

* Update __init__

* Fix CLASSES reference

* Remove unnecessar assert for custom classes

* Use - instead of _

* Use - instead of _

* Rename images_out_dir -> out_dir

* Remove unnecessary model.module

* Update show_result with all params and docstring

* Fix missing comma

* Set rescale always to True

* Remove outdated restriction from docs

* Drop pathlib

* Rename out_dir -> show_dir

* Update docstring

* More explicit code. Fix out_dir

* Update docstrings

* Flake8

Co-authored-by: mmeendez8 <miguelmndez@gmail.com>
Co-authored-by: sbugallo <sbugallo@gradiant.org>
---
 demo/inference_demo.ipynb      |   2 +-
 docs/GETTING_STARTED.md        |  34 +++++----
 mmdet/apis/__init__.py         |   6 +-
 mmdet/apis/inference.py        |  82 ++--------------------
 mmdet/apis/test.py             |  28 ++++++--
 mmdet/models/detectors/base.py | 122 +++++++++++++++++++++------------
 tools/test.py                  |  14 ++--
 7 files changed, 142 insertions(+), 146 deletions(-)

diff --git a/demo/inference_demo.ipynb b/demo/inference_demo.ipynb
index 4df4e7c0..e94a8c59 100644
--- a/demo/inference_demo.ipynb
+++ b/demo/inference_demo.ipynb
@@ -64,7 +64,7 @@
    ],
    "source": [
     "# show the results\n",
-    "show_result_pyplot(img, result, model.CLASSES)"
+    "show_result_pyplot(model, img, result)"
    ]
   }
  ],
diff --git a/docs/GETTING_STARTED.md b/docs/GETTING_STARTED.md
index e9bb07c9..61ee238d 100644
--- a/docs/GETTING_STARTED.md
+++ b/docs/GETTING_STARTED.md
@@ -28,8 +28,8 @@ Optional arguments:
 - `RESULT_FILE`: Filename of the output results in pickle format. If not specified, the results will not be saved to a file.
 - `EVAL_METRICS`: Items to be evaluated on the results. Allowed values depend on the dataset, e.g., `proposal_fast`, `proposal`, `bbox`, `segm` are available for COCO, `mAP`, `recall` for PASCAL VOC. Cityscapes could be evaluated by `cityscapes` as well as all COCO metrics.
 - `--show`: If specified, detection results will be plotted on the images and shown in a new window. It is only applicable to single GPU testing and used for debugging and visualization. Please make sure that GUI is available in your environment, otherwise you may encounter the error like `cannot connect to X server`.
+- `--show-dir`: If specified, detection results will be plotted on the images and saved to the specified directory. It is only applicable to single GPU testing and used for debugging and visualization. You **don't** need a GUI available in your environment for using this option.
 
-If you would like to evaluate the dataset, do not specify `--show` at the same time.
 
 Examples:
 
@@ -43,7 +43,15 @@ python tools/test.py configs/faster_rcnn_r50_fpn_1x_coco.py \
     --show
 ```
 
-2. Test Faster R-CNN on PASCAL VOC (without saving the test results) and evaluate the mAP.
+2. Test Faster R-CNN and save the painted images for latter visualization.
+
+```shell
+python tools/test.py configs/faster_rcnn_r50_fpn_1x.py \
+    checkpoints/faster_rcnn_r50_fpn_1x_20181010-3d1b3351.pth \
+    --show-dir faster_rcnn_r50_fpn_1x_results
+```
+
+3. Test Faster R-CNN on PASCAL VOC (without saving the test results) and evaluate the mAP.
 
 ```shell
 python tools/test.py configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc.py \
@@ -51,7 +59,7 @@ python tools/test.py configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc.py \
     --eval mAP
 ```
 
-3. Test Mask R-CNN with 8 GPUs, and evaluate the bbox and mask AP.
+4. Test Mask R-CNN with 8 GPUs, and evaluate the bbox and mask AP.
 
 ```shell
 ./tools/dist_test.sh configs/mask_rcnn_r50_fpn_1x_coco.py \
@@ -59,7 +67,7 @@ python tools/test.py configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc.py \
     8 --out results.pkl --eval bbox segm
 ```
 
-4. Test Mask R-CNN with 8 GPUs, and evaluate the **classwise** bbox and mask AP.
+5. Test Mask R-CNN with 8 GPUs, and evaluate the **classwise** bbox and mask AP.
 
 ```shell
 ./tools/dist_test.sh configs/mask_rcnn_r50_fpn_1x_coco.py \
@@ -67,7 +75,7 @@ python tools/test.py configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc.py \
     8 --out results.pkl --eval bbox segm --options "classwise=True"
 ```
 
-5. Test Mask R-CNN on COCO test-dev with 8 GPUs, and generate the json file to be submit to the official evaluation server.
+6. Test Mask R-CNN on COCO test-dev with 8 GPUs, and generate the json file to be submit to the official evaluation server.
 
 ```shell
 ./tools/dist_test.sh configs/mask_rcnn_r50_fpn_1x_coco.py \
@@ -77,7 +85,7 @@ python tools/test.py configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc.py \
 
 You will get two json files `mask_rcnn_test-dev_results.bbox.json` and `mask_rcnn_test-dev_results.segm.json`.
 
-6. Test Mask R-CNN on Cityscapes test with 8 GPUs, and generate the txt and png files to be submit to the official evaluation server.
+7. Test Mask R-CNN on Cityscapes test with 8 GPUs, and generate the txt and png files to be submit to the official evaluation server.
 
 ```shell
 ./tools/dist_test.sh configs/cityscapes/mask_rcnn_r50_fpn_1x_cityscapes.py \
@@ -108,7 +116,7 @@ python demo/webcam_demo.py configs/faster_rcnn_r50_fpn_1x_coco.py \
 Here is an example of building the model and test given images.
 
 ```python
-from mmdet.apis import init_detector, inference_detector, show_result
+from mmdet.apis import init_detector, inference_detector
 import mmcv
 
 config_file = 'configs/faster_rcnn_r50_fpn_1x_coco.py'
@@ -121,15 +129,15 @@ model = init_detector(config_file, checkpoint_file, device='cuda:0')
 img = 'test.jpg'  # or img = mmcv.imread(img), which will only load it once
 result = inference_detector(model, img)
 # visualize the results in a new window
-show_result(img, result, model.CLASSES)
+model.show_result(img, result)
 # or save the visualization results to image files
-show_result(img, result, model.CLASSES, out_file='result.jpg')
+model.show_result(img, result, out_file='result.jpg')
 
 # test a video and show the results
 video = mmcv.VideoReader('video.mp4')
 for frame in video:
     result = inference_detector(model, frame)
-    show_result(frame, result, model.CLASSES, wait_time=1)
+    model.show_result(frame, result, wait_time=1)
 ```
 
 A notebook demo can be found in [demo/inference_demo.ipynb](https://github.com/open-mmlab/mmdetection/blob/master/demo/inference_demo.ipynb).
@@ -143,7 +151,7 @@ See `tests/async_benchmark.py` to compare the speed of synchronous and asynchron
 ```python
 import asyncio
 import torch
-from mmdet.apis import init_detector, async_inference_detector, show_result
+from mmdet.apis import init_detector, async_inference_detector
 from mmdet.utils.contextmanagers import concurrent
 
 async def main():
@@ -167,9 +175,9 @@ async def main():
         result = await async_inference_detector(model, img)
 
     # visualize the results in a new window
-    show_result(img, result, model.CLASSES)
+    model.show_result(img, result)
     # or save the visualization results to image files
-    show_result(img, result, model.CLASSES, out_file='result.jpg')
+    model.show_result(img, result, out_file='result.jpg')
 
 
 asyncio.run(main())
diff --git a/mmdet/apis/__init__.py b/mmdet/apis/__init__.py
index 0dfb4cdd..1d8035b7 100644
--- a/mmdet/apis/__init__.py
+++ b/mmdet/apis/__init__.py
@@ -1,10 +1,10 @@
 from .inference import (async_inference_detector, inference_detector,
-                        init_detector, show_result, show_result_pyplot)
+                        init_detector, show_result_pyplot)
 from .test import multi_gpu_test, single_gpu_test
 from .train import get_root_logger, set_random_seed, train_detector
 
 __all__ = [
     'get_root_logger', 'set_random_seed', 'train_detector', 'init_detector',
-    'async_inference_detector', 'inference_detector', 'show_result',
-    'show_result_pyplot', 'multi_gpu_test', 'single_gpu_test'
+    'async_inference_detector', 'inference_detector', 'show_result_pyplot',
+    'multi_gpu_test', 'single_gpu_test'
 ]
diff --git a/mmdet/apis/inference.py b/mmdet/apis/inference.py
index c26ae0f0..1d5c3094 100644
--- a/mmdet/apis/inference.py
+++ b/mmdet/apis/inference.py
@@ -2,8 +2,6 @@ import warnings
 
 import matplotlib.pyplot as plt
 import mmcv
-import numpy as np
-import pycocotools.mask as maskUtils
 import torch
 from mmcv.parallel import collate, scatter
 from mmcv.runner import load_checkpoint
@@ -130,76 +128,8 @@ async def async_inference_detector(model, img):
     return result
 
 
-# TODO: merge this method with the one in BaseDetector
-def show_result(img,
-                result,
-                class_names,
-                score_thr=0.3,
-                wait_time=0,
-                show=True,
-                out_file=None):
-    """Visualize the detection results on the image.
-
-    Args:
-        img (str or np.ndarray): Image filename or loaded image.
-        result (tuple[list] or list): The detection result, can be either
-            (bbox, segm) or just bbox.
-        class_names (list[str] or tuple[str]): A list of class names.
-        score_thr (float): The threshold to visualize the bboxes and masks.
-        wait_time (int): Value of waitKey param.
-        show (bool, optional): Whether to show the image with opencv or not.
-        out_file (str, optional): If specified, the visualization result will
-            be written to the out file instead of shown in a window.
-
-    Returns:
-        np.ndarray or None: If neither `show` nor `out_file` is specified, the
-            visualized image is returned, otherwise None is returned.
-    """
-    assert isinstance(class_names, (tuple, list))
-    img = mmcv.imread(img)
-    img = img.copy()
-    if isinstance(result, tuple):
-        bbox_result, segm_result = result
-    else:
-        bbox_result, segm_result = result, None
-    bboxes = np.vstack(bbox_result)
-    labels = [
-        np.full(bbox.shape[0], i, dtype=np.int32)
-        for i, bbox in enumerate(bbox_result)
-    ]
-    labels = np.concatenate(labels)
-    # draw segmentation masks
-    if segm_result is not None:
-        segms = mmcv.concat_list(segm_result)
-        inds = np.where(bboxes[:, -1] > score_thr)[0]
-        np.random.seed(42)
-        color_masks = [
-            np.random.randint(0, 256, (1, 3), dtype=np.uint8)
-            for _ in range(max(labels))
-        ]
-        for i in inds:
-            i = int(i)
-            color_mask = color_masks[labels[i]]
-            mask = maskUtils.decode(segms[i]).astype(np.bool)
-            img[mask] = img[mask] * 0.5 + color_mask * 0.5
-    # if out_file specified, do not show image in window
-    if out_file is not None:
-        show = False
-    # draw bounding boxes
-    mmcv.imshow_det_bboxes(
-        img,
-        bboxes,
-        labels,
-        class_names=class_names,
-        score_thr=score_thr,
-        show=show,
-        wait_time=wait_time,
-        out_file=out_file)
-    if not (show or out_file):
-        return img
-
-
-def show_result_pyplot(img,
+def show_result_pyplot(model,
+                       img,
                        result,
                        class_names,
                        score_thr=0.3,
@@ -207,16 +137,16 @@ def show_result_pyplot(img,
     """Visualize the detection results on the image.
 
     Args:
+        model (nn.Module): The loaded detector.
         img (str or np.ndarray): Image filename or loaded image.
         result (tuple[list] or list): The detection result, can be either
             (bbox, segm) or just bbox.
         class_names (list[str] or tuple[str]): A list of class names.
         score_thr (float): The threshold to visualize the bboxes and masks.
         fig_size (tuple): Figure size of the pyplot figure.
-        out_file (str, optional): If specified, the visualization result will
-            be written to the out file instead of shown in a window.
     """
-    img = show_result(
-        img, result, class_names, score_thr=score_thr, show=False)
+    if hasattr(model, 'module'):
+        model = model.module
+    img = model.show_result(img, result, score_thr=score_thr, show=False)
     plt.figure(figsize=fig_size)
     plt.imshow(mmcv.bgr2rgb(img))
diff --git a/mmdet/apis/test.py b/mmdet/apis/test.py
index 5b0dea2d..cb1772a3 100644
--- a/mmdet/apis/test.py
+++ b/mmdet/apis/test.py
@@ -8,19 +8,39 @@ import torch
 import torch.distributed as dist
 from mmcv.runner import get_dist_info
 
+from mmdet.core import tensor2imgs
 
-def single_gpu_test(model, data_loader, show=False):
+
+def single_gpu_test(model, data_loader, show=False, out_dir=None):
     model.eval()
     results = []
     dataset = data_loader.dataset
     prog_bar = mmcv.ProgressBar(len(dataset))
     for i, data in enumerate(data_loader):
         with torch.no_grad():
-            result = model(return_loss=False, rescale=not show, **data)
+            result = model(return_loss=False, rescale=True, **data)
         results.append(result)
 
-        if show:
-            model.module.show_result(data, result)
+        if show or out_dir:
+            img_tensor = data['img'][0]
+            img_metas = data['img_metas'][0].data[0]
+            imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
+            assert len(imgs) == len(img_metas)
+
+            for img, img_meta in zip(imgs, img_metas):
+                h, w, _ = img_meta['img_shape']
+                img_show = img[:h, :w, :]
+
+                ori_h, ori_w = img_meta['ori_shape'][:-1]
+                img_show = mmcv.imresize(img_show, (ori_w, ori_h))
+
+                if out_dir:
+                    out_file = osp.join(out_dir, img_meta['filename'])
+                else:
+                    out_file = None
+
+                model.module.show_result(
+                    img_show, result, show=show, out_file=out_file)
 
         batch_size = data['img'][0].size(0)
         for _ in range(batch_size):
diff --git a/mmdet/models/detectors/base.py b/mmdet/models/detectors/base.py
index 936eeba0..55c762ee 100644
--- a/mmdet/models/detectors/base.py
+++ b/mmdet/models/detectors/base.py
@@ -5,7 +5,7 @@ import numpy as np
 import pycocotools.mask as maskUtils
 import torch.nn as nn
 
-from mmdet.core import auto_fp16, get_classes, tensor2imgs
+from mmdet.core import auto_fp16
 from mmdet.utils import print_log
 
 
@@ -151,51 +151,85 @@ class BaseDetector(nn.Module, metaclass=ABCMeta):
         else:
             return self.forward_test(img, img_metas, **kwargs)
 
-    def show_result(self, data, result, dataset=None, score_thr=0.3):
+    def show_result(self,
+                    img,
+                    result,
+                    score_thr=0.3,
+                    bbox_color='green',
+                    text_color='green',
+                    thickness=1,
+                    font_scale=0.5,
+                    win_name='',
+                    show=False,
+                    wait_time=0,
+                    out_file=None):
+        """Draw `result` over `img`.
+
+        Args:
+            img (str or Tensor): The image to be displayed.
+            result (Tensor or tuple): The results to draw over `img`
+                bbox_result or (bbox_result, segm_result).
+            score_thr (float, optional): Minimum score of bboxes to be shown.
+                Default: 0.3.
+            bbox_color (str or tuple or :obj:`Color`): Color of bbox lines.
+            text_color (str or tuple or :obj:`Color`): Color of texts.
+            thickness (int): Thickness of lines.
+            font_scale (float): Font scales of texts.
+            win_name (str): The window name.
+            wait_time (int): Value of waitKey param.
+                Default: 0.
+            show (bool): Whether to show the image.
+                Default: False.
+            out_file (str or None): The filename to write the image.
+                Default: None.
+
+        Returns:
+            img (Tensor): Only if not `show` or `out_file`
+        """
+        img = mmcv.imread(img)
+        img = img.copy()
         if isinstance(result, tuple):
             bbox_result, segm_result = result
         else:
             bbox_result, segm_result = result, None
-
-        img_tensor = data['img'][0]
-        img_metas = data['img_metas'][0].data[0]
-        imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
-        assert len(imgs) == len(img_metas)
-
-        if dataset is None:
-            class_names = self.CLASSES
-        elif isinstance(dataset, str):
-            class_names = get_classes(dataset)
-        elif isinstance(dataset, (list, tuple)):
-            class_names = dataset
-        else:
-            raise TypeError(
-                'dataset must be a valid dataset name or a sequence'
-                f' of class names, not {type(dataset)}')
-
-        for img, img_meta in zip(imgs, img_metas):
-            h, w, _ = img_meta['img_shape']
-            img_show = img[:h, :w, :]
-
-            bboxes = np.vstack(bbox_result)
-            # draw segmentation masks
-            if segm_result is not None:
-                segms = mmcv.concat_list(segm_result)
-                inds = np.where(bboxes[:, -1] > score_thr)[0]
-                for i in inds:
-                    color_mask = np.random.randint(
-                        0, 256, (1, 3), dtype=np.uint8)
-                    mask = maskUtils.decode(segms[i]).astype(np.bool)
-                    img_show[mask] = img_show[mask] * 0.5 + color_mask * 0.5
-            # draw bounding boxes
-            labels = [
-                np.full(bbox.shape[0], i, dtype=np.int32)
-                for i, bbox in enumerate(bbox_result)
+        bboxes = np.vstack(bbox_result)
+        labels = [
+            np.full(bbox.shape[0], i, dtype=np.int32)
+            for i, bbox in enumerate(bbox_result)
+        ]
+        labels = np.concatenate(labels)
+        # draw segmentation masks
+        if segm_result is not None:
+            segms = mmcv.concat_list(segm_result)
+            inds = np.where(bboxes[:, -1] > score_thr)[0]
+            np.random.seed(42)
+            color_masks = [
+                np.random.randint(0, 256, (1, 3), dtype=np.uint8)
+                for _ in range(max(labels) + 1)
             ]
-            labels = np.concatenate(labels)
-            mmcv.imshow_det_bboxes(
-                img_show,
-                bboxes,
-                labels,
-                class_names=class_names,
-                score_thr=score_thr)
+            for i in inds:
+                i = int(i)
+                color_mask = color_masks[labels[i]]
+                mask = maskUtils.decode(segms[i]).astype(np.bool)
+                img[mask] = img[mask] * 0.5 + color_mask * 0.5
+        # if out_file specified, do not show image in window
+        if out_file is not None:
+            show = False
+        # draw bounding boxes
+        mmcv.imshow_det_bboxes(
+            img,
+            bboxes,
+            labels,
+            class_names=self.CLASSES,
+            score_thr=score_thr,
+            bbox_color=bbox_color,
+            text_color=text_color,
+            thickness=thickness,
+            font_scale=font_scale,
+            win_name=win_name,
+            show=show,
+            wait_time=wait_time,
+            out_file=out_file)
+
+        if not (show or out_file):
+            return img
diff --git a/tools/test.py b/tools/test.py
index 7d39587c..79bc8262 100644
--- a/tools/test.py
+++ b/tools/test.py
@@ -38,6 +38,9 @@ def parse_args():
         help='evaluation metrics, which depends on the dataset, e.g., "bbox",'
         ' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC')
     parser.add_argument('--show', action='store_true', help='show results')
+    parser.add_argument(
+        '--show-dir',
+        help='(Optional) directory where painted images will be saved')
     parser.add_argument(
         '--gpu-collect',
         action='store_true',
@@ -45,7 +48,7 @@ def parse_args():
     parser.add_argument(
         '--tmpdir',
         help='tmp directory used for collecting results from multiple '
-        'workers, available when gpu_collect is not specified')
+        'workers, available when gpu-collect is not specified')
     parser.add_argument(
         '--options', nargs='+', action=DictAction, help='arguments in dict')
     parser.add_argument(
@@ -63,10 +66,11 @@ def parse_args():
 def main():
     args = parse_args()
 
-    assert args.out or args.eval or args.format_only or args.show, \
+    assert args.out or args.eval or args.format_only or args.show \
+        or args.show_dir, \
         ('Please specify at least one operation (save/eval/format/show the '
-         'results) with the argument "--out", "--eval", "--format_only" '
-         'or "--show"')
+         'results / save the results) with the argument "--out", "--eval"'
+         ', "--format-only", "--show" or "--show-dir"')
 
     if args.eval and args.format_only:
         raise ValueError('--eval and --format_only cannot be both specified')
@@ -115,7 +119,7 @@ def main():
 
     if not distributed:
         model = MMDataParallel(model, device_ids=[0])
-        outputs = single_gpu_test(model, data_loader, args.show)
+        outputs = single_gpu_test(model, data_loader, args.show, args.show_dir)
     else:
         model = MMDistributedDataParallel(
             model.cuda(),
-- 
GitLab