From c1ef12df9c9d35f1402734435b23a4ae711f3084 Mon Sep 17 00:00:00 2001 From: Kai Chen <chenkaidev@gmail.com> Date: Mon, 20 Apr 2020 22:22:09 +0800 Subject: [PATCH] Remove datasets.loader and add datasets.samplers (#2482) * remove datasets.loader and add datasets.samplers * fix typo in the filename --- .gitignore | 1 + mmdet/datasets/__init__.py | 8 +- mmdet/datasets/builder.py | 87 ++++++++++++++++++ mmdet/datasets/loader/__init__.py | 4 - mmdet/datasets/loader/build_loader.py | 88 ------------------- mmdet/datasets/samplers/__init__.py | 4 + .../datasets/samplers/distributed_sampler.py | 28 ++++++ .../sampler.py => samplers/group_sampler.py} | 27 ------ 8 files changed, 124 insertions(+), 123 deletions(-) delete mode 100644 mmdet/datasets/loader/__init__.py delete mode 100644 mmdet/datasets/loader/build_loader.py create mode 100644 mmdet/datasets/samplers/__init__.py create mode 100644 mmdet/datasets/samplers/distributed_sampler.py rename mmdet/datasets/{loader/sampler.py => samplers/group_sampler.py} (83%) diff --git a/.gitignore b/.gitignore index e1a14b06..d9498d94 100644 --- a/.gitignore +++ b/.gitignore @@ -107,6 +107,7 @@ mmdet/version.py data .vscode .idea +.DS_Store # custom *.pkl diff --git a/mmdet/datasets/__init__.py b/mmdet/datasets/__init__.py index 7ad926d4..0272a1c6 100644 --- a/mmdet/datasets/__init__.py +++ b/mmdet/datasets/__init__.py @@ -1,10 +1,10 @@ -from .builder import build_dataset +from .builder import build_dataloader, build_dataset from .cityscapes import CityscapesDataset from .coco import CocoDataset from .custom import CustomDataset from .dataset_wrappers import ConcatDataset, RepeatDataset -from .loader import DistributedGroupSampler, GroupSampler, build_dataloader from .registry import DATASETS +from .samplers import DistributedGroupSampler, DistributedSampler, GroupSampler from .voc import VOCDataset from .wider_face import WIDERFaceDataset from .xml_style import XMLDataset @@ -12,6 +12,6 @@ from .xml_style import XMLDataset __all__ = [ 'CustomDataset', 'XMLDataset', 'CocoDataset', 'VOCDataset', 'CityscapesDataset', 'GroupSampler', 'DistributedGroupSampler', - 'build_dataloader', 'ConcatDataset', 'RepeatDataset', 'WIDERFaceDataset', - 'DATASETS', 'build_dataset' + 'DistributedSampler', 'build_dataloader', 'ConcatDataset', 'RepeatDataset', + 'WIDERFaceDataset', 'DATASETS', 'build_dataset' ] diff --git a/mmdet/datasets/builder.py b/mmdet/datasets/builder.py index 83e0486b..f20f48a9 100644 --- a/mmdet/datasets/builder.py +++ b/mmdet/datasets/builder.py @@ -1,8 +1,25 @@ import copy +import platform +import random +from functools import partial + +import numpy as np +from mmcv.parallel import collate +from mmcv.runner import get_dist_info +from torch.utils.data import DataLoader from mmdet.utils import build_from_cfg from .dataset_wrappers import ConcatDataset, RepeatDataset from .registry import DATASETS +from .samplers import DistributedGroupSampler, DistributedSampler, GroupSampler + +if platform.system() != 'Windows': + # https://github.com/pytorch/pytorch/issues/973 + import resource + rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) + hard_limit = rlimit[1] + soft_limit = min(4096, hard_limit) + resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit)) def _concat_dataset(cfg, default_args=None): @@ -39,3 +56,73 @@ def build_dataset(cfg, default_args=None): dataset = build_from_cfg(cfg, DATASETS, default_args) return dataset + + +def build_dataloader(dataset, + imgs_per_gpu, + workers_per_gpu, + num_gpus=1, + dist=True, + shuffle=True, + seed=None, + **kwargs): + """Build PyTorch DataLoader. + + In distributed training, each GPU/process has a dataloader. + In non-distributed training, there is only one dataloader for all GPUs. + + Args: + dataset (Dataset): A PyTorch dataset. + imgs_per_gpu (int): Number of images on each GPU, i.e., batch size of + each GPU. + workers_per_gpu (int): How many subprocesses to use for data loading + for each GPU. + num_gpus (int): Number of GPUs. Only used in non-distributed training. + dist (bool): Distributed training/test or not. Default: True. + shuffle (bool): Whether to shuffle the data at every epoch. + Default: True. + kwargs: any keyword argument to be used to initialize DataLoader + + Returns: + DataLoader: A PyTorch dataloader. + """ + rank, world_size = get_dist_info() + if dist: + # DistributedGroupSampler will definitely shuffle the data to satisfy + # that images on each GPU are in the same group + if shuffle: + sampler = DistributedGroupSampler(dataset, imgs_per_gpu, + world_size, rank) + else: + sampler = DistributedSampler( + dataset, world_size, rank, shuffle=False) + batch_size = imgs_per_gpu + num_workers = workers_per_gpu + else: + sampler = GroupSampler(dataset, imgs_per_gpu) if shuffle else None + batch_size = num_gpus * imgs_per_gpu + num_workers = num_gpus * workers_per_gpu + + init_fn = partial( + worker_init_fn, num_workers=num_workers, rank=rank, + seed=seed) if seed is not None else None + + data_loader = DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + num_workers=num_workers, + collate_fn=partial(collate, samples_per_gpu=imgs_per_gpu), + pin_memory=False, + worker_init_fn=init_fn, + **kwargs) + + return data_loader + + +def worker_init_fn(worker_id, num_workers, rank, seed): + # The seed of each worker equals to + # num_worker * rank + worker_id + user_seed + worker_seed = num_workers * rank + worker_id + seed + np.random.seed(worker_seed) + random.seed(worker_seed) diff --git a/mmdet/datasets/loader/__init__.py b/mmdet/datasets/loader/__init__.py deleted file mode 100644 index 4404615b..00000000 --- a/mmdet/datasets/loader/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .build_loader import build_dataloader -from .sampler import DistributedGroupSampler, GroupSampler - -__all__ = ['GroupSampler', 'DistributedGroupSampler', 'build_dataloader'] diff --git a/mmdet/datasets/loader/build_loader.py b/mmdet/datasets/loader/build_loader.py deleted file mode 100644 index c9a0919d..00000000 --- a/mmdet/datasets/loader/build_loader.py +++ /dev/null @@ -1,88 +0,0 @@ -import platform -import random -from functools import partial - -import numpy as np -from mmcv.parallel import collate -from mmcv.runner import get_dist_info -from torch.utils.data import DataLoader - -from .sampler import DistributedGroupSampler, DistributedSampler, GroupSampler - -if platform.system() != 'Windows': - # https://github.com/pytorch/pytorch/issues/973 - import resource - rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) - hard_limit = rlimit[1] - soft_limit = min(4096, hard_limit) - resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit)) - - -def build_dataloader(dataset, - imgs_per_gpu, - workers_per_gpu, - num_gpus=1, - dist=True, - shuffle=True, - seed=None, - **kwargs): - """Build PyTorch DataLoader. - - In distributed training, each GPU/process has a dataloader. - In non-distributed training, there is only one dataloader for all GPUs. - - Args: - dataset (Dataset): A PyTorch dataset. - imgs_per_gpu (int): Number of images on each GPU, i.e., batch size of - each GPU. - workers_per_gpu (int): How many subprocesses to use for data loading - for each GPU. - num_gpus (int): Number of GPUs. Only used in non-distributed training. - dist (bool): Distributed training/test or not. Default: True. - shuffle (bool): Whether to shuffle the data at every epoch. - Default: True. - kwargs: any keyword argument to be used to initialize DataLoader - - Returns: - DataLoader: A PyTorch dataloader. - """ - rank, world_size = get_dist_info() - if dist: - # DistributedGroupSampler will definitely shuffle the data to satisfy - # that images on each GPU are in the same group - if shuffle: - sampler = DistributedGroupSampler(dataset, imgs_per_gpu, - world_size, rank) - else: - sampler = DistributedSampler( - dataset, world_size, rank, shuffle=False) - batch_size = imgs_per_gpu - num_workers = workers_per_gpu - else: - sampler = GroupSampler(dataset, imgs_per_gpu) if shuffle else None - batch_size = num_gpus * imgs_per_gpu - num_workers = num_gpus * workers_per_gpu - - init_fn = partial( - worker_init_fn, num_workers=num_workers, rank=rank, - seed=seed) if seed is not None else None - - data_loader = DataLoader( - dataset, - batch_size=batch_size, - sampler=sampler, - num_workers=num_workers, - collate_fn=partial(collate, samples_per_gpu=imgs_per_gpu), - pin_memory=False, - worker_init_fn=init_fn, - **kwargs) - - return data_loader - - -def worker_init_fn(worker_id, num_workers, rank, seed): - # The seed of each worker equals to - # num_worker * rank + worker_id + user_seed - worker_seed = num_workers * rank + worker_id + seed - np.random.seed(worker_seed) - random.seed(worker_seed) diff --git a/mmdet/datasets/samplers/__init__.py b/mmdet/datasets/samplers/__init__.py new file mode 100644 index 00000000..2596aeb2 --- /dev/null +++ b/mmdet/datasets/samplers/__init__.py @@ -0,0 +1,4 @@ +from .distributed_sampler import DistributedSampler +from .group_sampler import DistributedGroupSampler, GroupSampler + +__all__ = ['DistributedSampler', 'DistributedGroupSampler', 'GroupSampler'] diff --git a/mmdet/datasets/samplers/distributed_sampler.py b/mmdet/datasets/samplers/distributed_sampler.py new file mode 100644 index 00000000..2a85619c --- /dev/null +++ b/mmdet/datasets/samplers/distributed_sampler.py @@ -0,0 +1,28 @@ +import torch +from torch.utils.data import DistributedSampler as _DistributedSampler + + +class DistributedSampler(_DistributedSampler): + + def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): + super().__init__(dataset, num_replicas=num_replicas, rank=rank) + self.shuffle = shuffle + + def __iter__(self): + # deterministically shuffle based on epoch + if self.shuffle: + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = torch.arange(len(self.dataset)).tolist() + + # add extra samples to make it evenly divisible + indices += indices[:(self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) diff --git a/mmdet/datasets/loader/sampler.py b/mmdet/datasets/samplers/group_sampler.py similarity index 83% rename from mmdet/datasets/loader/sampler.py rename to mmdet/datasets/samplers/group_sampler.py index f3dd9962..35e21b35 100644 --- a/mmdet/datasets/loader/sampler.py +++ b/mmdet/datasets/samplers/group_sampler.py @@ -4,36 +4,9 @@ import math import numpy as np import torch from mmcv.runner import get_dist_info -from torch.utils.data import DistributedSampler as _DistributedSampler from torch.utils.data import Sampler -class DistributedSampler(_DistributedSampler): - - def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): - super().__init__(dataset, num_replicas=num_replicas, rank=rank) - self.shuffle = shuffle - - def __iter__(self): - # deterministically shuffle based on epoch - if self.shuffle: - g = torch.Generator() - g.manual_seed(self.epoch) - indices = torch.randperm(len(self.dataset), generator=g).tolist() - else: - indices = torch.arange(len(self.dataset)).tolist() - - # add extra samples to make it evenly divisible - indices += indices[:(self.total_size - len(indices))] - assert len(indices) == self.total_size - - # subsample - indices = indices[self.rank:self.total_size:self.num_replicas] - assert len(indices) == self.num_samples - - return iter(indices) - - class GroupSampler(Sampler): def __init__(self, dataset, samples_per_gpu=1): -- GitLab