From 34c8ad31bbde7082d51cbfd74e582897ac614f77 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Haian=20Huang=28=E6=B7=B1=E5=BA=A6=E7=9C=B8=29?=
 <1286304229@qq.com>
Date: Sun, 7 Feb 2021 14:27:09 +0800
Subject: [PATCH] [Feature] Support gather model metrics in .dev_scripts
 (#4558)

* Support gather model metrics

* Add max_keep_ckpts param

* Reuse get_final_results

* Add config name to display

* Optimize printing info and fix comment

* Add metric dump and fix comment
---
 .dev_scripts/benchmark_filter.py         | 11 ++--
 .dev_scripts/convert_benchmark_script.py | 28 +++++---
 .dev_scripts/gather_benchmark_metric.py  | 81 ++++++++++++++++++++++++
 setup.cfg                                |  2 +-
 4 files changed, 109 insertions(+), 13 deletions(-)
 create mode 100644 .dev_scripts/gather_benchmark_metric.py

diff --git a/.dev_scripts/benchmark_filter.py b/.dev_scripts/benchmark_filter.py
index 0ba2d04b..39f54629 100644
--- a/.dev_scripts/benchmark_filter.py
+++ b/.dev_scripts/benchmark_filter.py
@@ -36,7 +36,8 @@ basic_arch_root = [
     'foveabox', 'fp16', 'free_anchor', 'fsaf', 'gfl', 'ghm', 'grid_rcnn',
     'guided_anchoring', 'htc', 'libra_rcnn', 'mask_rcnn', 'ms_rcnn',
     'nas_fcos', 'paa', 'pisa', 'point_rend', 'reppoints', 'retinanet', 'rpn',
-    'sabl', 'ssd', 'tridentnet', 'vfnet', 'yolact', 'yolo'
+    'sabl', 'ssd', 'tridentnet', 'vfnet', 'yolact', 'yolo', 'sparse_rcnn',
+    'scnet'
 ]
 
 datasets_root = [
@@ -102,7 +103,7 @@ benchmark_pool = [
     'configs/pafpn/faster_rcnn_r50_pafpn_1x_coco.py',
     'configs/pisa/pisa_mask_rcnn_r50_fpn_1x_coco.py',
     'configs/point_rend/point_rend_r50_caffe_fpn_mstrain_1x_coco.py',
-    'configs/regnet/mask_rcnn_regnetx-3GF_fpn_1x_coco.py',
+    'configs/regnet/mask_rcnn_regnetx-3.2GF_fpn_1x_coco.py',
     'configs/reppoints/reppoints_moment_r50_fpn_gn-neck+head_1x_coco.py',
     'configs/res2net/faster_rcnn_r2_101_fpn_2x_coco.py',
     'configs/resnest/'
@@ -114,7 +115,9 @@ benchmark_pool = [
     'configs/tridentnet/tridentnet_r50_caffe_1x_coco.py',
     'configs/vfnet/vfnet_r50_fpn_1x_coco.py',
     'configs/yolact/yolact_r50_1x8_coco.py',
-    'configs/yolo/yolov3_d53_320_273e_coco.py'
+    'configs/yolo/yolov3_d53_320_273e_coco.py',
+    'configs/sparse_rcnn/sparse_rcnn_r50_fpn_1x_coco.py',
+    'configs/scnet/scnet_r50_fpn_1x_coco.py'
 ]
 
 
@@ -131,7 +134,7 @@ def main():
     if args.nn_module:
         benchmark_type += nn_module_root
 
-    special_model = args.options
+    special_model = args.model_options
     if special_model is not None:
         benchmark_type += special_model
 
diff --git a/.dev_scripts/convert_benchmark_script.py b/.dev_scripts/convert_benchmark_script.py
index e307ef0b..fc06d5e8 100644
--- a/.dev_scripts/convert_benchmark_script.py
+++ b/.dev_scripts/convert_benchmark_script.py
@@ -11,6 +11,11 @@ def parse_args():
     parser.add_argument(
         'json_path', type=str, help='json path output by benchmark_filter')
     parser.add_argument('partition', type=str, help='slurm partition name')
+    parser.add_argument(
+        '--max-keep-ckpts',
+        type=int,
+        default=1,
+        help='The maximum checkpoints to keep')
     parser.add_argument(
         '--run', action='store_true', help='run script directly')
     parser.add_argument(
@@ -40,28 +45,35 @@ def main():
     # stdout is no output
     stdout_cfg = '>/dev/null'
 
+    max_keep_ckpts = args.max_keep_ckpts
+
     commands = []
     for i, cfg in enumerate(model_cfgs):
         # print cfg name
-        echo_info = 'echo \'' + cfg + '\' &'
+        echo_info = f'echo \'{cfg}\' &'
         commands.append(echo_info)
         commands.append('\n')
 
         fname, _ = osp.splitext(osp.basename(cfg))
         out_fname = osp.join(root_name, fname)
         # default setting
-        command_info = 'GPUS=8  GPUS_PER_NODE=8  CPUS_PER_TASK=2 ' \
-                       + train_script_name + ' '
-        command_info += partition + ' '
-        command_info += fname + ' '
-        command_info += cfg + ' '
-        command_info += out_fname + ' '
-        command_info += stdout_cfg + ' &'
+        command_info = f'GPUS=8  GPUS_PER_NODE=8  ' \
+                       f'CPUS_PER_TASK=2 {train_script_name} '
+        command_info += f'{partition} '
+        command_info += f'{fname} '
+        command_info += f'{cfg} '
+        command_info += f'{out_fname} '
+        if max_keep_ckpts:
+            command_info += f'--cfg-options ' \
+                            f'checkpoint_config.max_keep_ckpts=' \
+                            f'{max_keep_ckpts}' + ' '
+        command_info += f'{stdout_cfg} &'
 
         commands.append(command_info)
 
         if i < len(model_cfgs):
             commands.append('\n')
+
     command_str = ''.join(commands)
     if args.out:
         with open(args.out, 'w') as f:
diff --git a/.dev_scripts/gather_benchmark_metric.py b/.dev_scripts/gather_benchmark_metric.py
new file mode 100644
index 00000000..3683b941
--- /dev/null
+++ b/.dev_scripts/gather_benchmark_metric.py
@@ -0,0 +1,81 @@
+import argparse
+import glob
+import os.path as osp
+
+import mmcv
+from gather_models import get_final_results
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(
+        description='Gather benchmarked models metric')
+    parser.add_argument(
+        'root',
+        type=str,
+        help='root path of benchmarked models to be gathered')
+    parser.add_argument(
+        'benchmark_json', type=str, help='json path of benchmark models')
+    parser.add_argument(
+        '--out', type=str, help='output path of gathered metrics to be stored')
+    parser.add_argument(
+        '--not-show', action='store_true', help='not show metrics')
+
+    args = parser.parse_args()
+    return args
+
+
+if __name__ == '__main__':
+    args = parse_args()
+
+    root_path = args.root
+    metrics_out = args.out
+    benchmark_json_path = args.benchmark_json
+    model_configs = mmcv.load(benchmark_json_path)['models']
+
+    result_dict = {}
+    for config in model_configs:
+        config_name = osp.split(config)[-1]
+        config_name = osp.splitext(config_name)[0]
+        result_path = osp.join(root_path, config_name)
+        if osp.exists(result_path):
+            # 1 read config
+            cfg = mmcv.Config.fromfile(config)
+            total_epochs = cfg.total_epochs
+            final_results = cfg.evaluation.metric
+            if not isinstance(final_results, list):
+                final_results = [final_results]
+            final_results_out = []
+            for key in final_results:
+                if 'proposal_fast' in key:
+                    final_results_out.append('AR@1000')  # RPN
+                elif 'mAP' not in key:
+                    final_results_out.append(key + '_mAP')
+
+            # 2 determine whether total_epochs ckpt exists
+            ckpt_path = f'epoch_{total_epochs}.pth'
+            if osp.exists(osp.join(result_path, ckpt_path)):
+                log_json_path = list(
+                    sorted(glob.glob(osp.join(result_path, '*.log.json'))))[-1]
+
+                # 3 read metric
+                model_performance = get_final_results(log_json_path,
+                                                      total_epochs,
+                                                      final_results_out)
+                if model_performance is None:
+                    print(f'log file error: {log_json_path}')
+                    continue
+                result_dict[config] = model_performance
+            else:
+                print(f'{config} not exist: {ckpt_path}')
+        else:
+            print(f'not exist: {config}')
+
+    # 4 save or print results
+    if metrics_out:
+        mmcv.mkdir_or_exist(metrics_out)
+        mmcv.dump(result_dict, osp.join(metrics_out, 'model_metric_info.json'))
+    if not args.not_show:
+        print('===================================')
+        for config_name, metrics in result_dict.items():
+            print(config_name, metrics)
+        print('===================================')
diff --git a/setup.cfg b/setup.cfg
index c38be6d3..8e67c1f6 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -3,7 +3,7 @@ line_length = 79
 multi_line_output = 0
 known_standard_library = setuptools
 known_first_party = mmdet
-known_third_party = PIL,asynctest,cityscapesscripts,cv2,matplotlib,mmcv,numpy,onnx,onnxruntime,pycocotools,pytest,seaborn,six,terminaltables,torch
+known_third_party = PIL,asynctest,cityscapesscripts,cv2,gather_models,matplotlib,mmcv,numpy,onnx,onnxruntime,pycocotools,pytest,seaborn,six,terminaltables,torch
 no_lines_before = STDLIB,LOCALFOLDER
 default_section = THIRDPARTY
 
-- 
GitLab