Skip to content
Snippets Groups Projects
Unverified Commit 5d75636c authored by Cao Yuhang's avatar Cao Yuhang Committed by GitHub
Browse files

reset worker_seed (#2111)

* reset worker_seed

* fix isort

* minor fix

* fix comment
parent c47e36a5
No related branches found
No related tags found
No related merge requests found
......@@ -44,8 +44,8 @@ def build_dataloader(dataset,
Returns:
DataLoader: A PyTorch dataloader.
"""
rank, world_size = get_dist_info()
if dist:
rank, world_size = get_dist_info()
# DistributedGroupSampler will definitely shuffle the data to satisfy
# that images on each GPU are in the same group
if shuffle:
......@@ -61,6 +61,13 @@ def build_dataloader(dataset,
batch_size = num_gpus * imgs_per_gpu
num_workers = num_gpus * workers_per_gpu
def worker_init_fn(worker_id):
# The seed of each worker equals to
# num_worker * rank + worker_id + user_seed
worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)
data_loader = DataLoader(
dataset,
batch_size=batch_size,
......@@ -72,8 +79,3 @@ def build_dataloader(dataset,
**kwargs)
return data_loader
def worker_init_fn(seed):
np.random.seed(seed)
random.seed(seed)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment