Skip to content
Snippets Groups Projects
Unverified Commit 639f934a authored by Cao Yuhang's avatar Cao Yuhang Committed by GitHub
Browse files

fix worker_init_fn (#2170)

parent 9daea73c
No related branches found
No related tags found
No related merge requests found
......@@ -61,12 +61,9 @@ def build_dataloader(dataset,
batch_size = num_gpus * imgs_per_gpu
num_workers = num_gpus * workers_per_gpu
def worker_init_fn(worker_id):
# The seed of each worker equals to
# num_worker * rank + worker_id + user_seed
worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)
init_fn = partial(
worker_init_fn, num_workers=num_workers, rank=rank,
seed=seed) if seed is not None else None
data_loader = DataLoader(
dataset,
......@@ -75,7 +72,15 @@ def build_dataloader(dataset,
num_workers=num_workers,
collate_fn=partial(collate, samples_per_gpu=imgs_per_gpu),
pin_memory=False,
worker_init_fn=worker_init_fn if seed is not None else None,
worker_init_fn=init_fn,
**kwargs)
return data_loader
def worker_init_fn(worker_id, num_workers, rank, seed):
# The seed of each worker equals to
# num_worker * rank + worker_id + user_seed
worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment