Skip to content
Snippets Groups Projects
Commit 300f7157 authored by Kai Chen's avatar Kai Chen
Browse files

allow manually setting random seeds

parent 143a8372
No related branches found
No related tags found
No related merge requests found
......@@ -4,6 +4,7 @@ import argparse
import logging
from collections import OrderedDict
import numpy as np
import torch
from mmcv import Config
from mmcv.torchpack import Runner, obj_from_dict
......@@ -53,6 +54,12 @@ def get_logger(log_level):
return logger
def set_random_seed(seed):
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def parse_args():
parser = argparse.ArgumentParser(description='Train a detector')
parser.add_argument('config', help='train config file path')
......@@ -63,6 +70,7 @@ def parse_args():
help='whether to add a validate phase')
parser.add_argument(
'--gpus', type=int, default=1, help='number of gpus to use')
parser.add_argument('--seed', type=int, help='random seed')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
......@@ -84,6 +92,11 @@ def main():
logger = get_logger(cfg.log_level)
# set random seed if specified
if args.seed is not None:
logger.info('Set random seed to {}'.format(args.seed))
set_random_seed(args.seed)
# init distributed environment if necessary
if args.launcher == 'none':
dist = False
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment