You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
So I just got bitten by this on a separate project and just realized that this likely affects CenterNet as well. Essentially, when we use Pytorch DataLoaders with num_workers > 1, Pytorch uses multiprocessing in the background to spawn the different processes. The problem however is related to numpy and random seeds. What ends up happening is that all the child processes end up having identical numpy seeds. So if you have num_workers = 4, the 1st image each of the workers processes will all have exactly identical augmentations. What's worse is that once we finish with an epoch, the next time around, all the workers will again start with the same initial numpy seed. This is more easily explained by example based on discussions here: pytorch/pytorch#5059
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
class FalseDataset(Dataset):
def __init__(self, length):
self.len = length
def __len__(self):
return self.len
def __getitem__(self, ID):
r = np.random.randint(1, 10000)
return [r, torch.initial_seed()]
false_instance = FalseDataset(8)
train_loader = DataLoader(false_instance, shuffle=False, num_workers=4, worker_init_fn=lambda id: np.random.seed(torch.initial_seed() // 2**32 + id))
def train_epoch(loader, epoch):
for batch_idx, output in enumerate(loader):
print(batch_idx, output)
return "epoch" + str(epoch) + " ended"
for i in range(2): # simulating epochs
print(train_epoch(train_loader, i))
If this is run without the worker_init_fn parameter (which is the default in Pytorch DataLoaders), we get the following output:
However, if we include the worker_init_fn parameter, we make use of 2 things. We utilize the unique pid each child process has to differentiate the seed across the workers during a given epoch. We also utilize the torch.initial_seed() which Pytorch handles correctly and randomizes after each epoch (i.e., iteration through the entire dataloader) to ensure that random seeds across epochs also stay randomized. This results in the following (desired) output:
Unfortunately I couldn't confirm this for sure in this repo as the imports in the files don't work correctly except via some specific entrypoints into the codebase, so I couldn't make a quick test script like above to iterate through the Dataset and see if identical augmentations were being applied across workers and across epochs. But I am pretty sure this is very likely impacting this codebase and I wonder if you might get better performance if you modify the Datasets to utilize the worker_init_fn method to truly randomize the seed properly during training.
The text was updated successfully, but these errors were encountered:
So I just got bitten by this on a separate project and just realized that this likely affects CenterNet as well. Essentially, when we use Pytorch DataLoaders with
num_workers > 1
, Pytorch uses multiprocessing in the background to spawn the different processes. The problem however is related to numpy and random seeds. What ends up happening is that all the child processes end up having identical numpy seeds. So if you havenum_workers = 4
, the 1st image each of the workers processes will all have exactly identical augmentations. What's worse is that once we finish with an epoch, the next time around, all the workers will again start with the same initial numpy seed. This is more easily explained by example based on discussions here: pytorch/pytorch#5059If this is run without the
worker_init_fn
parameter (which is the default in Pytorch DataLoaders), we get the following output:However, if we include the
worker_init_fn
parameter, we make use of 2 things. We utilize the unique pid each child process has to differentiate the seed across the workers during a given epoch. We also utilize thetorch.initial_seed()
which Pytorch handles correctly and randomizes after each epoch (i.e., iteration through the entire dataloader) to ensure that random seeds across epochs also stay randomized. This results in the following (desired) output:Unfortunately I couldn't confirm this for sure in this repo as the imports in the files don't work correctly except via some specific entrypoints into the codebase, so I couldn't make a quick test script like above to iterate through the Dataset and see if identical augmentations were being applied across workers and across epochs. But I am pretty sure this is very likely impacting this codebase and I wonder if you might get better performance if you modify the Datasets to utilize the
worker_init_fn
method to truly randomize the seed properly during training.The text was updated successfully, but these errors were encountered: