Skip to content
Snippets Groups Projects
Commit 904d875a authored by Kai Chen's avatar Kai Chen
Browse files

modify distributed training api and use coalesced all_reduce

parent 15e9d026
No related branches found
No related tags found
No related merge requests found
import os
from collections import OrderedDict
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.nn.utils import clip_grad
from mmcv.torchpack import Hook, OptimizerHook
__all__ = [
'init_dist', 'average_gradients', 'broadcast_params', 'DistOptimizerHook',
'DistSamplerSeedHook'
'init_dist', 'reduce_grads', 'DistOptimizerHook', 'DistSamplerSeedHook'
]
def init_dist(world_size,
rank,
backend='gloo',
master_ip='127.0.0.1',
port=29500):
def init_dist(launcher, backend='nccl', **kwargs):
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
if launcher == 'pytorch':
_init_dist_pytorch(backend, **kwargs)
elif launcher == 'mpi':
_init_dist_pytorch(backend, **kwargs)
elif launcher == 'slurm':
_init_dist_pytorch(backend, **kwargs)
else:
raise ValueError('Invalid launcher type: {}'.format(launcher))
def _init_dist_pytorch(backend, **kwargs):
# TODO: use local_rank instead of rank % num_gpus
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
os.environ['MASTER_ADDR'] = master_ip
os.environ['MASTER_PORT'] = str(port)
if backend == 'nccl':
dist.init_process_group(backend='nccl')
else:
dist.init_process_group(
backend='gloo', rank=rank, world_size=world_size)
dist.init_process_group(backend=backend, **kwargs)
def _init_dist_mpi(backend, **kwargs):
raise NotImplementedError
def average_gradients(model):
for param in model.parameters():
if param.requires_grad and not (param.grad is None):
dist.all_reduce(param.grad.data)
def _init_dist_slurm(backend, **kwargs):
raise NotImplementedError
def broadcast_params(model):
for p in model.state_dict().values():
dist.broadcast(p, 0)
# modified from https://github.com/NVIDIA/apex/blob/master/apex/parallel/distributed.py#L9
def coalesce_all_reduce(tensors):
buckets = OrderedDict()
for tensor in tensors:
tp = tensor.type()
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(tensor)
for tp in buckets:
bucket = buckets[tp]
coalesced = _flatten_dense_tensors(bucket)
dist.all_reduce(coalesced)
coalesced /= dist.get_world_size()
for buf, synced in zip(bucket,
_unflatten_dense_tensors(coalesced, bucket)):
buf.copy_(synced)
def reduce_grads(model, coalesce=True):
grads = [
param.grad.data for param in model.parameters()
if param.requires_grad and param.grad is not None
]
if coalesce:
coalesce_all_reduce(grads)
else:
for tensor in grads:
dist.all_reduce(tensor)
class DistOptimizerHook(OptimizerHook):
def __init__(self, grad_clip=None, coalesce=True):
self.grad_clip = grad_clip
self.coalesce = coalesce
def after_train_iter(self, runner):
runner.optimizer.zero_grad()
runner.outputs['loss'].backward()
average_gradients(runner.model)
reduce_grads(runner.model, self.coalesce)
if self.grad_clip is not None:
clip_grad.clip_grad_norm_(
filter(lambda p: p.requires_grad, runner.model.parameters()),
......
from functools import partial
from mmcv.torchpack import get_dist_info
from torch.utils.data import DataLoader
from .collate import collate
......@@ -11,10 +12,9 @@ def build_dataloader(dataset,
workers_per_gpu,
num_gpus,
dist=True,
world_size=1,
rank=0,
**kwargs):
if dist:
rank, world_size = get_dist_info()
sampler = DistributedGroupSampler(dataset, imgs_per_gpu, world_size,
rank)
batch_size = imgs_per_gpu
......
......@@ -121,8 +121,7 @@ log_config = dict(
# yapf:enable
# runtime settings
total_epochs = 12
device_ids = range(8)
dist_params = dict(backend='nccl', port='29500', master_ip='127.0.0.1')
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/fpn_faster_rcnn_r50_1x'
load_from = None
......
......@@ -134,8 +134,7 @@ log_config = dict(
# yapf:enable
# runtime settings
total_epochs = 12
device_ids = range(8)
dist_params = dict(backend='nccl', port='29500', master_ip='127.0.0.1')
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/fpn_mask_rcnn_r50_1x'
load_from = None
......
......@@ -100,8 +100,7 @@ log_config = dict(
# yapf:enable
# runtime settings
total_epochs = 12
device_ids = range(8)
dist_params = dict(backend='gloo', port='29500', master_ip='127.0.0.1')
dist_params = dict(backend='gloo')
log_level = 'INFO'
work_dir = './work_dirs/fpn_rpn_r50_1x'
load_from = None
......
......@@ -39,9 +39,7 @@ def batch_processor(model, data, train_mode):
loss, log_vars = parse_losses(losses)
outputs = dict(
loss=loss / args.world_size,
log_vars=log_vars,
num_samples=len(data['img'].data))
loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))
return outputs
......@@ -54,61 +52,65 @@ def parse_args():
action='store_true',
help='whether to add a validate phase')
parser.add_argument(
'--dist', action='store_true', help='use distributed training or not')
parser.add_argument('--world-size', default=1, type=int)
parser.add_argument('--rank', default=0, type=int)
'--gpus', type=int, default=1, help='number of gpus to use')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
return args
args = parse_args()
def main():
# get config from file
args = parse_args()
cfg = Config.fromfile(args.config)
cfg.update(world_size=args.world_size, rank=args.rank)
cfg.update(gpus=args.gpus)
# init distributed environment if necessary
if args.dist:
print('Enable distributed training.')
init_dist(args.world_size, args.rank, **cfg.dist_params)
else:
if args.launcher == 'none':
dist = False
print('Disabled distributed training.')
else:
dist = True
print('Enabled distributed training.')
init_dist(args.launcher, **cfg.dist_args)
# prepare data loaders
train_dataset = obj_from_dict(cfg.data.train, datasets)
data_loaders = [
build_dataloader(
train_dataset, cfg.data.imgs_per_gpu, cfg.data.workers_per_gpu,
len(cfg.device_ids), args.dist, cfg.world_size, cfg.rank)
build_dataloader(train_dataset, cfg.data.imgs_per_gpu,
cfg.data.workers_per_gpu, cfg.gpus, dist)
]
if args.validate:
val_dataset = obj_from_dict(cfg.data.val, datasets)
data_loaders.append(
build_dataloader(
val_dataset, cfg.data.imgs_per_gpu, cfg.data.workers_per_gpu,
len(cfg.device_ids), args.dist, cfg.world_size, cfg.rank))
build_dataloader(val_dataset, cfg.data.imgs_per_gpu,
cfg.data.workers_per_gpu, cfg.gpus, dist))
# build model
model = build_detector(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
if args.dist:
if dist:
model = MMDistributedDataParallel(
model, device_ids=[cfg.rank], broadcast_buffers=False).cuda()
model,
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False).cuda()
else:
model = MMDataParallel(model, device_ids=cfg.device_ids).cuda()
model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()
# build runner
runner = Runner(model, batch_processor, cfg.optimizer, cfg.work_dir,
cfg.log_level)
# register hooks
optimizer_config = DistOptimizerHook(
**cfg.optimizer_config) if args.dist else cfg.optimizer_config
**cfg.optimizer_config) if dist else cfg.optimizer_config
runner.register_training_hooks(cfg.lr_config, optimizer_config,
cfg.checkpoint_config, cfg.log_config)
if args.dist:
if dist:
runner.register_hook(DistSamplerSeedHook())
if cfg.resume_from:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment