Skip to content
Snippets Groups Projects
Commit eabf4ee0 authored by sicer's avatar sicer
Browse files

update scripts

parent 7ccb7f4a
No related branches found
No related tags found
No related merge requests found
......@@ -19,7 +19,7 @@ buffer_cpu_only: True # If true we won't keep all of the replay buffer in vram
# --- Logging options ---
use_tensorboard: True # Log results to tensorboard
save_model: True # Save the models to disk
save_model: False # Save the models to disk
save_model_interval: 2000000 # Save models after this many timesteps
checkpoint_path: "" # Load a checkpoint from this path
evaluate: False # Evaluate model for test_nepisode episodes and quit (no training)
......
......@@ -35,6 +35,16 @@ def my_main(_run, _config, _log):
run(_run, config, _log)
def _get_comment(params, arg_name):
comment = ''
for _i, _v in enumerate(params):
if _v.split("=")[0] == arg_name:
comment = _v.split("=")[1]
del params[_i]
break
return comment
def _get_config(params, arg_name, subfolder):
config_name = None
for _i, _v in enumerate(params):
......@@ -71,8 +81,6 @@ def config_copy(config):
if __name__ == '__main__':
th.cuda.set_device(0)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
params = deepcopy(sys.argv)
# Get the defaults from default.yaml
......@@ -83,11 +91,13 @@ if __name__ == '__main__':
assert False, "default.yaml error: {}".format(exc)
# Load algorithm and env base configs
comment = _get_comment(params, '--comment')
env_config = _get_config(params, "--env-config", "envs")
alg_config = _get_config(params, "--config", "algs")
# config_dict = {**config_dict, **env_config, **alg_config}
config_dict = recursive_dict_update(config_dict, env_config)
config_dict = recursive_dict_update(config_dict, alg_config)
config_dict['comment'] = comment
# now add all the config to sacred
ex.add_config(config_dict)
......
......@@ -2,6 +2,7 @@ import datetime
import os
import pprint
import time
import json
import threading
import torch as th
from types import SimpleNamespace as SN
......@@ -33,12 +34,34 @@ def run(_run, _config, _log):
width=1)
_log.info("\n\n" + experiment_params + "\n")
# configure tensorboard logger
unique_token = "{}__{}".format(args.name, datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
# log config
if str(args.env).startswith('sc2'):
unique_token = "{}_{}_{}_{}".format(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), alg_name, args.env, args.env_args['map_name'])
else:
unique_token = "{}_{}_{}".format(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), alg_name, args.env)
args.unique_token = unique_token
if str(args.env).startswith('sc2'):
json_output_direc = os.path.join(dirname(dirname(abspath(__file__))), "results", args.env, args.env_args['map_name'], alg_name)
else:
json_output_direc = os.path.join(dirname(dirname(abspath(__file__))), "results", args.env, alg_name)
json_exp_direc = os.path.join(json_output_direc, unique_token + '.json')
json_logging = {'config': vars(args)}
with open(json_exp_direc, 'w') as f:
json.dump(json_logging, f, ensure_ascii=False)
# configure tensorboard logger
if len(args.comment) > 0:
alg_name = '{}_{}'.format(args.name, args.comment)
else:
alg_name = args.name
if str(args.env).startswith('sc2'):
tb_logs_direc = os.path.join(dirname(dirname(abspath(__file__))), "results", "tb_logs", args.env, args.env_args['map_name'], alg_name)
else:
tb_logs_direc = os.path.join(dirname(dirname(abspath(__file__))), "results", "tb_logs", args.env, alg_name)
tb_exp_direc = os.path.join(tb_logs_direc, unique_token)
if args.use_tensorboard:
tb_logs_direc = os.path.join(dirname(dirname(abspath(__file__))), "results", "tb_logs")
tb_exp_direc = os.path.join(tb_logs_direc, "{}").format(unique_token)
logger.setup_tb(tb_exp_direc)
# sacred is on by default
......@@ -47,6 +70,13 @@ def run(_run, _config, _log):
# Run and train
run_sequential(args=args, logger=logger)
if args.use_tensorboard:
print(f'Export tensorboard scalars at {tb_exp_direc} to json file {json_exp_direc}')
data_logging = export_scalar_to_json(tb_exp_direc, json_output_direc, args)
json_logging.update(data_logging)
with open(json_exp_direc, 'w') as f:
json.dump(json_logging, f, ensure_ascii=False)
# Clean up after finishing
print("Exiting Main")
......@@ -238,3 +268,17 @@ def args_sanity_check(config, _log):
config["test_nepisode"] = (config["test_nepisode"]//config["batch_size_run"]) * config["batch_size_run"]
return config
def export_scalar_to_json(tensorboard_path, output_path, args):
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
os.makedirs(output_path, exist_ok=True)
filename = os.path.basename(tensorboard_path)
output_path = os.path.join(output_path, filename + '.json')
summary = EventAccumulator(tensorboard_path).Reload()
scalar_list = summary.Tags()['scalars']
stone_dict = {}
stone_dict['seed'] = args.seed
for scalar_name in scalar_list:
stone_dict['_'.join([scalar_name, 'T'])] = [ scalar.step for scalar in summary.Scalars(scalar_name) ]
stone_dict[scalar_name] = [ scalar.value for scalar in summary.Scalars(scalar_name) ]
return stone_dict
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment