Skip to content
Snippets Groups Projects
Commit b7968de7 authored by Kai Chen's avatar Kai Chen
Browse files

modify MMDistributedDataParallel, no longer inherited from DistributedDataParallel

parent e74c260f
No related branches found
No related tags found
No related merge requests found
from torch.nn.parallel import DistributedDataParallel
import torch
import torch.distributed as dist
import torch.nn as nn
from torch._utils import (_flatten_dense_tensors, _unflatten_dense_tensors,
_take_tensors)
from .scatter_gather import scatter_kwargs
class MMDistributedDataParallel(DistributedDataParallel):
class MMDistributedDataParallel(nn.Module):
def __init__(self, module, dim=0, broadcast_buffers=True):
super(MMDistributedDataParallel, self).__init__()
self.module = module
self.dim = dim
self.broadcast_buffers = broadcast_buffers
self.first_synced = False
self.broadcast_bucket_size = 32 * 1024 * 1024
def _dist_broadcast_coalesced(self, tensors, buffer_size):
for tensors in _take_tensors(tensors, buffer_size):
flat_tensors = _flatten_dense_tensors(tensors)
dist.broadcast(flat_tensors, 0)
for tensor, synced in zip(
tensors, _unflatten_dense_tensors(flat_tensors, tensors)):
tensor.copy_(synced)
def sync_params(self):
module_states = list(self.module.state_dict().values())
if len(module_states) > 0:
self._dist_broadcast_coalesced(module_states,
self.broadcast_bucket_size)
if self.broadcast_buffers:
buffers = [b.data for b in self.module._all_buffers()]
if len(buffers) > 0:
self._dist_broadcast_coalesced(buffers,
self.broadcast_bucket_size)
def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def forward(self, *inputs, **kwargs):
if not self.first_synced:
self.sync_params()
self.first_synced = True
inputs, kwargs = self.scatter(inputs, kwargs,
[torch.cuda.current_device()])
return self.module(*inputs[0], **kwargs[0])
......@@ -2,8 +2,4 @@
PYTHON=${PYTHON:-"python"}
$PYTHON train.py $1 --dist --world-size $2 --rank 0 &
let MAX_RANK=$2-1
for i in `seq 1 $MAX_RANK`; do
$PYTHON train.py $1 --dist --world-size $2 --rank $i > /dev/null 2>&1 &
done
$PYTHON -m torch.distributed.launch --nproc_per_node=$2 train.py $1 --launcher pytorch
\ No newline at end of file
......@@ -95,10 +95,7 @@ def main():
model = build_detector(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
if dist:
model = MMDistributedDataParallel(
model,
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False).cuda()
model = MMDistributedDataParallel(model).cuda()
else:
model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment