Skip to content

Commit

Permalink
Using system generated seed in RandomSampler (#1441)
Browse files Browse the repository at this point in the history
* add new sampler tests

* update seed generation in sampler

* run precommit

* update seed generation

* change variable name

* update comment

* add seed to tests

* run precommit
  • Loading branch information
ramanishsingh authored Feb 19, 2025
1 parent f15fd3a commit 1277308
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 10 deletions.
64 changes: 56 additions & 8 deletions test/stateful_dataloader/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torch.utils.data import Dataset

from torchdata.stateful_dataloader import StatefulDataLoader
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
from torchdata.stateful_dataloader.sampler import RandomSampler, StatefulDistributedSampler


class MockDataset(Dataset):
Expand All @@ -34,7 +34,10 @@ def __getitem__(self, idx):
"Fails with TSAN with the following error: starting new threads after multi-threaded "
"fork is not supported. Dying (set die_after_fork=0 to override)",
)
@unittest.skipIf(TEST_WITH_ASAN, "DataLoader tests hang in ASAN, see: https://github.com/pytorch/pytorch/issues/66223")
@unittest.skipIf(
TEST_WITH_ASAN,
"DataLoader tests hang in ASAN, see: https://github.com/pytorch/pytorch/issues/66223",
)
class TestDataLoader(TestCase):
def setUp(self):
super().setUp()
Expand All @@ -44,7 +47,12 @@ def setUp(self):
def test_initialization_StatefulDistributedSampler(self):

sampler = StatefulDistributedSampler(
self.dataset, num_replicas=10, rank=0, shuffle=False, seed=42, drop_last=False
self.dataset,
num_replicas=10,
rank=0,
shuffle=False,
seed=42,
drop_last=False,
)
self.assertEqual(sampler.dataset, self.dataset)
self.assertEqual(sampler.num_replicas, 10)
Expand Down Expand Up @@ -139,7 +147,8 @@ def test_drop_last_effect(self):
)

self.assertTrue(
len(indices_with_drop) <= len(indices_without_drop), "Drop last should result in fewer or equal indices"
len(indices_with_drop) <= len(indices_without_drop),
"Drop last should result in fewer or equal indices",
)

def test_data_order_with_shuffle(self):
Expand All @@ -153,7 +162,11 @@ def test_data_order_with_shuffle(self):
for batch in dataloader:
data_loaded.extend(batch)
self.assertEqual(len(data_loaded), len(self.dataset), "All data should be loaded")
self.assertEqual(data_loaded, data_sampled, "Data loaded by DataLoader should match data sampled by sampler")
self.assertEqual(
data_loaded,
data_sampled,
"Data loaded by DataLoader should match data sampled by sampler",
)

def test_data_order_without_shuffle(self):
sampler = StatefulDistributedSampler(self.dataset, num_replicas=1, rank=0, shuffle=False)
Expand All @@ -167,8 +180,16 @@ def test_data_order_without_shuffle(self):
for batch in dataloader:
data_loaded.extend(batch)
self.assertEqual(len(data_loaded), len(self.dataset), "All data should be loaded")
self.assertEqual(data_loaded, data_sampled, "Data loaded by DataLoader should match data sampled by sampler")
self.assertEqual(data_loaded, list(range(100)), "Data loaded by DataLoader should be in original order")
self.assertEqual(
data_loaded,
data_sampled,
"Data loaded by DataLoader should match data sampled by sampler",
)
self.assertEqual(
data_loaded,
list(range(100)),
"Data loaded by DataLoader should be in original order",
)

def test_data_distribution_across_replicas(self):
num_replicas = 5
Expand All @@ -181,9 +202,36 @@ def test_data_distribution_across_replicas(self):
data_loaded.extend([int(x.item()) for x in batch])
all_data.extend(data_loaded)
self.assertEqual(
sorted(all_data), list(range(100)), "All data points should be covered exactly once across all replicas"
sorted(all_data),
list(range(100)),
"All data points should be covered exactly once across all replicas",
)

def test_seed_replicability(self):
# Test that the same seed will result in the same data order
# We first pick a random number as seed, then use it to initialize two dataloaders
min_seed, max_seed = 0, 1000 # [min_seed, max_seed)
seed = torch.randint(min_seed, max_seed, (1,), dtype=torch.int64).item()
torch.manual_seed(seed)

dataloader1 = StatefulDataLoader(self.dataset, batch_size=1, shuffle=True)
results1 = list(dataloader1)

# Repeat the same process with the same seed
torch.manual_seed(seed)
dataloader2 = StatefulDataLoader(self.dataset, batch_size=1, shuffle=True)
results2 = list(dataloader2)

# Repeat the same process with a different seed, making sure that the seed is different
min_seed, max_seed = 1000, 2000 # [min_seed, max_seed)
seed = torch.randint(min_seed, max_seed, (1,), dtype=torch.int64).item()
torch.manual_seed(seed)
dataloader3 = StatefulDataLoader(self.dataset, batch_size=1, shuffle=True)
results3 = list(dataloader3)

self.assertEqual(results1, results2, "Data should be replicable with same seed")
self.assertNotEqual(results1, results3, "Data should not be replicable with different seed")


if __name__ == "__main__":
run_tests()
4 changes: 4 additions & 0 deletions test/stateful_dataloader/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1458,6 +1458,8 @@ def get_map_dl(self, data_size, num_workers, batch_size, shuffle):
)

def _run(self, data_size, num_workers, batch_size, shuffle):
# For reproducibility of testing, fixing the seed
torch.manual_seed(0)
dataloader1 = self.get_map_dl(
data_size=data_size,
num_workers=num_workers,
Expand Down Expand Up @@ -1493,6 +1495,8 @@ def _run(self, data_size, num_workers, batch_size, shuffle):
self.assertEqual(num_items_yielded + additional_num_items_yielded, data_size * 4)

# now run a second dataloder for 4 epochs and check if the order is same.
# we need to fix the seed again since we want to bring the initial conditions to the same state as at the time of instantiating the first dataloader
torch.manual_seed(0)
dataloader2 = self.get_map_dl(
data_size=data_size,
num_workers=num_workers,
Expand Down
5 changes: 3 additions & 2 deletions torchdata/stateful_dataloader/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,10 @@ def __init__(
self.replacement = replacement
self._num_samples = num_samples
if generator is None:
# Ensure that underlying sampler has something repeatable
# Prevoiusly the random seed was fixed as 1. We then changed it to system generated seed to ensure deterministic randomness.
seed = int(torch.empty((), dtype=torch.int64).random_().item())
generator = torch.Generator()
generator.manual_seed(1)
generator.manual_seed(seed)
self.generator = generator
if not isinstance(self.replacement, bool):
raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}")
Expand Down

0 comments on commit 1277308

Please sign in to comment.