Skip to content
Snippets Groups Projects
Unverified Commit a1c3aa49 authored by David de la Iglesia Castro's avatar David de la Iglesia Castro Committed by GitHub
Browse files

Merge non dist and dist train (#2440)

* Merge _dist_train and _non_dist_train

* Add missing distributed arg to Fp16OptimizerHook
parent 25ede58a
No related branches found
No related tags found
No related merge requests found
......@@ -90,104 +90,6 @@ def train_detector(model,
meta=None):
logger = get_root_logger(cfg.log_level)
# start training
if distributed:
_dist_train(
model,
dataset,
cfg,
validate=validate,
logger=logger,
timestamp=timestamp,
meta=meta)
else:
_non_dist_train(
model,
dataset,
cfg,
validate=validate,
logger=logger,
timestamp=timestamp,
meta=meta)
def _dist_train(model,
dataset,
cfg,
validate=False,
logger=None,
timestamp=None,
meta=None):
# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
data_loaders = [
build_dataloader(
ds,
cfg.data.imgs_per_gpu,
cfg.data.workers_per_gpu,
dist=True,
seed=cfg.seed) for ds in dataset
]
# put model on gpus
find_unused_parameters = cfg.get('find_unused_parameters', False)
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
# build runner
optimizer = build_optimizer(model, cfg.optimizer)
runner = Runner(
model,
batch_processor,
optimizer,
cfg.work_dir,
logger=logger,
meta=meta)
# an ugly walkaround to make the .log and .log.json filenames the same
runner.timestamp = timestamp
# fp16 setting
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
optimizer_config = Fp16OptimizerHook(**cfg.optimizer_config,
**fp16_cfg)
else:
optimizer_config = DistOptimizerHook(**cfg.optimizer_config)
# register hooks
runner.register_training_hooks(cfg.lr_config, optimizer_config,
cfg.checkpoint_config, cfg.log_config)
runner.register_hook(DistSamplerSeedHook())
# register eval hooks
if validate:
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
val_dataloader = build_dataloader(
val_dataset,
imgs_per_gpu=1,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=True,
shuffle=False)
eval_cfg = cfg.get('evaluation', {})
runner.register_hook(DistEvalHook(val_dataloader, **eval_cfg))
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
def _non_dist_train(model,
dataset,
cfg,
validate=False,
logger=None,
timestamp=None,
meta=None):
# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
data_loaders = [
......@@ -195,12 +97,25 @@ def _non_dist_train(model,
ds,
cfg.data.imgs_per_gpu,
cfg.data.workers_per_gpu,
# cfg.gpus will be ignored if distributed
len(cfg.gpu_ids),
dist=False,
dist=distributed,
seed=cfg.seed) for ds in dataset
]
# put model on gpus
model = MMDataParallel(model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
if distributed:
find_unused_parameters = cfg.get('find_unused_parameters', False)
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
model = MMDataParallel(
model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
# build runner
optimizer = build_optimizer(model, cfg.optimizer)
......@@ -213,15 +128,22 @@ def _non_dist_train(model,
meta=meta)
# an ugly walkaround to make the .log and .log.json filenames the same
runner.timestamp = timestamp
# fp16 setting
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
optimizer_config = Fp16OptimizerHook(
**cfg.optimizer_config, **fp16_cfg, distributed=False)
**cfg.optimizer_config, **fp16_cfg, distributed=distributed)
elif distributed:
optimizer_config = DistOptimizerHook(**cfg.optimizer_config)
else:
optimizer_config = cfg.optimizer_config
# register hooks
runner.register_training_hooks(cfg.lr_config, optimizer_config,
cfg.checkpoint_config, cfg.log_config)
if distributed:
runner.register_hook(DistSamplerSeedHook())
# register eval hooks
if validate:
......@@ -230,10 +152,11 @@ def _non_dist_train(model,
val_dataset,
imgs_per_gpu=1,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=False,
dist=distributed,
shuffle=False)
eval_cfg = cfg.get('evaluation', {})
runner.register_hook(EvalHook(val_dataloader, **eval_cfg))
eval_hook = DistEvalHook if distributed else EvalHook
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
if cfg.resume_from:
runner.resume(cfg.resume_from)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment