From eabf4ee002ee5852a53c5b84666f00a18f7045a3 Mon Sep 17 00:00:00 2001
From: sicer <mansicer@qq.com>
Date: Sat, 13 Nov 2021 15:45:33 +0800
Subject: [PATCH] update scripts

---
 src/config/default.yaml |  2 +-
 src/main.py             | 14 +++++++++--
 src/run.py              | 52 +++++++++++++++++++++++++++++++++++++----
 3 files changed, 61 insertions(+), 7 deletions(-)

diff --git a/src/config/default.yaml b/src/config/default.yaml
index 1218a0b..8f5823e 100755
--- a/src/config/default.yaml
+++ b/src/config/default.yaml
@@ -19,7 +19,7 @@ buffer_cpu_only: True # If true we won't keep all of the replay buffer in vram
 
 # --- Logging options ---
 use_tensorboard: True # Log results to tensorboard
-save_model: True # Save the models to disk
+save_model: False # Save the models to disk
 save_model_interval: 2000000 # Save models after this many timesteps
 checkpoint_path: "" # Load a checkpoint from this path
 evaluate: False # Evaluate model for test_nepisode episodes and quit (no training)
diff --git a/src/main.py b/src/main.py
index 7b191f0..a6a6e21 100755
--- a/src/main.py
+++ b/src/main.py
@@ -35,6 +35,16 @@ def my_main(_run, _config, _log):
     run(_run, config, _log)
 
 
+def _get_comment(params, arg_name):
+    comment = ''
+    for _i, _v in enumerate(params):
+        if _v.split("=")[0] == arg_name:
+            comment = _v.split("=")[1]
+            del params[_i]
+            break
+    return comment
+
+
 def _get_config(params, arg_name, subfolder):
     config_name = None
     for _i, _v in enumerate(params):
@@ -71,8 +81,6 @@ def config_copy(config):
 
 
 if __name__ == '__main__':
-    th.cuda.set_device(0)
-    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
     params = deepcopy(sys.argv)
 
     # Get the defaults from default.yaml
@@ -83,11 +91,13 @@ if __name__ == '__main__':
             assert False, "default.yaml error: {}".format(exc)
 
     # Load algorithm and env base configs
+    comment = _get_comment(params, '--comment')
     env_config = _get_config(params, "--env-config", "envs")
     alg_config = _get_config(params, "--config", "algs")
     # config_dict = {**config_dict, **env_config, **alg_config}
     config_dict = recursive_dict_update(config_dict, env_config)
     config_dict = recursive_dict_update(config_dict, alg_config)
+    config_dict['comment'] = comment
 
     # now add all the config to sacred
     ex.add_config(config_dict)
diff --git a/src/run.py b/src/run.py
index 19acc16..5653d0f 100755
--- a/src/run.py
+++ b/src/run.py
@@ -2,6 +2,7 @@ import datetime
 import os
 import pprint
 import time
+import json
 import threading
 import torch as th
 from types import SimpleNamespace as SN
@@ -33,12 +34,34 @@ def run(_run, _config, _log):
                                        width=1)
     _log.info("\n\n" + experiment_params + "\n")
 
-    # configure tensorboard logger
-    unique_token = "{}__{}".format(args.name, datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
+    # log config
+    if str(args.env).startswith('sc2'):
+        unique_token = "{}_{}_{}_{}".format(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), alg_name, args.env, args.env_args['map_name'])
+    else:
+        unique_token = "{}_{}_{}".format(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), alg_name, args.env)
     args.unique_token = unique_token
+
+    if str(args.env).startswith('sc2'):
+        json_output_direc = os.path.join(dirname(dirname(abspath(__file__))), "results", args.env, args.env_args['map_name'], alg_name)
+    else:
+        json_output_direc = os.path.join(dirname(dirname(abspath(__file__))), "results", args.env, alg_name)
+    json_exp_direc = os.path.join(json_output_direc, unique_token + '.json')
+    json_logging = {'config': vars(args)} 
+    with open(json_exp_direc, 'w') as f:
+        json.dump(json_logging, f, ensure_ascii=False)
+
+    # configure tensorboard logger
+    if len(args.comment) > 0:
+        alg_name = '{}_{}'.format(args.name, args.comment)
+    else:
+        alg_name = args.name
+    
+    if str(args.env).startswith('sc2'):
+        tb_logs_direc = os.path.join(dirname(dirname(abspath(__file__))), "results", "tb_logs", args.env, args.env_args['map_name'], alg_name)
+    else:
+        tb_logs_direc = os.path.join(dirname(dirname(abspath(__file__))), "results", "tb_logs", args.env, alg_name)
+    tb_exp_direc = os.path.join(tb_logs_direc, unique_token)
     if args.use_tensorboard:
-        tb_logs_direc = os.path.join(dirname(dirname(abspath(__file__))), "results", "tb_logs")
-        tb_exp_direc = os.path.join(tb_logs_direc, "{}").format(unique_token)
         logger.setup_tb(tb_exp_direc)
 
     # sacred is on by default
@@ -47,6 +70,13 @@ def run(_run, _config, _log):
     # Run and train
     run_sequential(args=args, logger=logger)
 
+    if args.use_tensorboard:
+        print(f'Export tensorboard scalars at {tb_exp_direc} to json file {json_exp_direc}')
+        data_logging = export_scalar_to_json(tb_exp_direc, json_output_direc, args)
+        json_logging.update(data_logging)
+        with open(json_exp_direc, 'w') as f:
+            json.dump(json_logging, f, ensure_ascii=False)
+
     # Clean up after finishing
     print("Exiting Main")
 
@@ -238,3 +268,17 @@ def args_sanity_check(config, _log):
         config["test_nepisode"] = (config["test_nepisode"]//config["batch_size_run"]) * config["batch_size_run"]
 
     return config
+
+def export_scalar_to_json(tensorboard_path, output_path, args):
+    from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
+    os.makedirs(output_path, exist_ok=True)
+    filename = os.path.basename(tensorboard_path)
+    output_path = os.path.join(output_path, filename + '.json')
+    summary = EventAccumulator(tensorboard_path).Reload()
+    scalar_list = summary.Tags()['scalars']
+    stone_dict = {}
+    stone_dict['seed'] = args.seed
+    for scalar_name in scalar_list:
+        stone_dict['_'.join([scalar_name, 'T'])] = [ scalar.step for scalar in summary.Scalars(scalar_name) ]
+        stone_dict[scalar_name] = [ scalar.value for scalar in summary.Scalars(scalar_name) ]
+    return stone_dict
-- 
GitLab