Skip to content
Snippets Groups Projects
Unverified Commit 16a6f7da authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

add some docstring (#1869)

parent 629b9ff2
No related branches found
No related tags found
No related merge requests found
......@@ -17,13 +17,6 @@ from mmdet.datasets import DATASETS, build_dataloader
from mmdet.models import RPN
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def get_root_logger(log_file=None, log_level=logging.INFO):
logger = logging.getLogger('mmdet')
# if the logger has been initialized, just return it
......@@ -45,6 +38,25 @@ def get_root_logger(log_file=None, log_level=logging.INFO):
return logger
def set_random_seed(seed, deterministic=False):
"""Set random seed.
Args:
seed (int): Seed to be used.
deterministic (bool): Whether to set the deterministic option for
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
to True and `torch.backends.cudnn.benchmark` to False.
Default: False.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def parse_losses(losses):
log_vars = OrderedDict()
for loss_name, loss_value in losses.items():
......@@ -70,6 +82,21 @@ def parse_losses(losses):
def batch_processor(model, data, train_mode):
"""Process a data batch.
This method is required as an argument of Runner, which defines how to
process a data batch and obtain proper outputs. The first 3 arguments of
batch_processor are fixed.
Args:
model (nn.Module): A PyTorch model.
data (dict): The data batch in a dict.
train_mode (bool): Training mode or not. It may be useless for some
models.
Returns:
dict: A dict containing losses and log vars.
"""
losses = model(**data)
loss, log_vars = parse_losses(losses)
......
......@@ -21,8 +21,30 @@ def build_dataloader(dataset,
dist=True,
shuffle=True,
**kwargs):
"""Build PyTorch DataLoader.
In distributed training, each GPU/process has a dataloader.
In non-distributed training, there is only one dataloader for all GPUs.
Args:
dataset (Dataset): A PyTorch dataset.
imgs_per_gpu (int): Number of images on each GPU, i.e., batch size of
each GPU.
workers_per_gpu (int): How many subprocesses to use for data loading
for each GPU.
num_gpus (int): Number of GPUs. Only used in non-distributed training.
dist (bool): Distributed training/test or not. Default: True.
shuffle (bool): Whether to shuffle the data at every epoch.
Default: True.
kwargs: any keyword argument to be used to initialize DataLoader
Returns:
DataLoader: A PyTorch dataloader.
"""
if dist:
rank, world_size = get_dist_info()
# DistributedGroupSampler will definitely shuffle the data to satisfy
# that images on each GPU are in the same group
if shuffle:
sampler = DistributedGroupSampler(dataset, imgs_per_gpu,
world_size, rank)
......
......@@ -32,6 +32,10 @@ def parse_args():
help='number of gpus to use '
'(only applicable to non-distributed training)')
parser.add_argument('--seed', type=int, default=None, help='random seed')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
......@@ -88,8 +92,9 @@ def main():
# set random seeds
if args.seed is not None:
logger.info('Set random seed to {}'.format(args.seed))
set_random_seed(args.seed)
logger.info('Set random seed to {}, deterministic: {}'.format(
args.seed, args.deterministic))
set_random_seed(args.seed, deterministic=args.deterministic)
model = build_detector(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment