Skip to content

Commit

Permalink
Cherry-picking changes in main to the release branch (#1446)
Browse files Browse the repository at this point in the history
* Fix end of epoch StatefulDataLoader restart (#1439)

* add test for end of epoch state dict check

* run precommit

update stateful_dataloader

run precommit

local changes

update test to test the order of batches

update test

update tests

revert changes in SDL

revert changes in SDL

update tests

run precommit

* update sampler

* run precommit

* remove unnecessary comment

* add test for statedict before and after endofepoch

* run precommit

* check if _sampler_iter is exhausted

* run precommit

* remove commented lines

* remove default values

* only exhaust sampler_iter if present in sd

* update _StatefulRandomSamplerIterator

update state dict if the iterator has finished

add comment about why were updating state dict

run precommit

* update randomsampleriter state_dict fully

* run precommit

* fork torch.utils.data RandomSampler

reverse changes to sdl.py

generator to iterator

run precommit

update generator usage

* update class name

* run precommit

* add a method to generate permutations

* update return type

* update next logic

* add comment

* update tests to include non stateful samplers

* add comments

* Using system generated seed in RandomSampler (#1441)

* add new sampler tests

* update seed generation in sampler

* run precommit

* update seed generation

* change variable name

* update comment

* add seed to tests

* run precommit
  • Loading branch information
ramanishsingh authored Feb 19, 2025
1 parent 89a1c71 commit c4177af
Show file tree
Hide file tree
Showing 4 changed files with 412 additions and 85 deletions.
64 changes: 56 additions & 8 deletions test/stateful_dataloader/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torch.utils.data import Dataset

from torchdata.stateful_dataloader import StatefulDataLoader
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
from torchdata.stateful_dataloader.sampler import RandomSampler, StatefulDistributedSampler


class MockDataset(Dataset):
Expand All @@ -34,7 +34,10 @@ def __getitem__(self, idx):
"Fails with TSAN with the following error: starting new threads after multi-threaded "
"fork is not supported. Dying (set die_after_fork=0 to override)",
)
@unittest.skipIf(TEST_WITH_ASAN, "DataLoader tests hang in ASAN, see: https://github.com/pytorch/pytorch/issues/66223")
@unittest.skipIf(
TEST_WITH_ASAN,
"DataLoader tests hang in ASAN, see: https://github.com/pytorch/pytorch/issues/66223",
)
class TestDataLoader(TestCase):
def setUp(self):
super().setUp()
Expand All @@ -44,7 +47,12 @@ def setUp(self):
def test_initialization_StatefulDistributedSampler(self):

sampler = StatefulDistributedSampler(
self.dataset, num_replicas=10, rank=0, shuffle=False, seed=42, drop_last=False
self.dataset,
num_replicas=10,
rank=0,
shuffle=False,
seed=42,
drop_last=False,
)
self.assertEqual(sampler.dataset, self.dataset)
self.assertEqual(sampler.num_replicas, 10)
Expand Down Expand Up @@ -139,7 +147,8 @@ def test_drop_last_effect(self):
)

self.assertTrue(
len(indices_with_drop) <= len(indices_without_drop), "Drop last should result in fewer or equal indices"
len(indices_with_drop) <= len(indices_without_drop),
"Drop last should result in fewer or equal indices",
)

def test_data_order_with_shuffle(self):
Expand All @@ -153,7 +162,11 @@ def test_data_order_with_shuffle(self):
for batch in dataloader:
data_loaded.extend(batch)
self.assertEqual(len(data_loaded), len(self.dataset), "All data should be loaded")
self.assertEqual(data_loaded, data_sampled, "Data loaded by DataLoader should match data sampled by sampler")
self.assertEqual(
data_loaded,
data_sampled,
"Data loaded by DataLoader should match data sampled by sampler",
)

def test_data_order_without_shuffle(self):
sampler = StatefulDistributedSampler(self.dataset, num_replicas=1, rank=0, shuffle=False)
Expand All @@ -167,8 +180,16 @@ def test_data_order_without_shuffle(self):
for batch in dataloader:
data_loaded.extend(batch)
self.assertEqual(len(data_loaded), len(self.dataset), "All data should be loaded")
self.assertEqual(data_loaded, data_sampled, "Data loaded by DataLoader should match data sampled by sampler")
self.assertEqual(data_loaded, list(range(100)), "Data loaded by DataLoader should be in original order")
self.assertEqual(
data_loaded,
data_sampled,
"Data loaded by DataLoader should match data sampled by sampler",
)
self.assertEqual(
data_loaded,
list(range(100)),
"Data loaded by DataLoader should be in original order",
)

def test_data_distribution_across_replicas(self):
num_replicas = 5
Expand All @@ -181,9 +202,36 @@ def test_data_distribution_across_replicas(self):
data_loaded.extend([int(x.item()) for x in batch])
all_data.extend(data_loaded)
self.assertEqual(
sorted(all_data), list(range(100)), "All data points should be covered exactly once across all replicas"
sorted(all_data),
list(range(100)),
"All data points should be covered exactly once across all replicas",
)

def test_seed_replicability(self):
# Test that the same seed will result in the same data order
# We first pick a random number as seed, then use it to initialize two dataloaders
min_seed, max_seed = 0, 1000 # [min_seed, max_seed)
seed = torch.randint(min_seed, max_seed, (1,), dtype=torch.int64).item()
torch.manual_seed(seed)

dataloader1 = StatefulDataLoader(self.dataset, batch_size=1, shuffle=True)
results1 = list(dataloader1)

# Repeat the same process with the same seed
torch.manual_seed(seed)
dataloader2 = StatefulDataLoader(self.dataset, batch_size=1, shuffle=True)
results2 = list(dataloader2)

# Repeat the same process with a different seed, making sure that the seed is different
min_seed, max_seed = 1000, 2000 # [min_seed, max_seed)
seed = torch.randint(min_seed, max_seed, (1,), dtype=torch.int64).item()
torch.manual_seed(seed)
dataloader3 = StatefulDataLoader(self.dataset, batch_size=1, shuffle=True)
results3 = list(dataloader3)

self.assertEqual(results1, results2, "Data should be replicable with same seed")
self.assertNotEqual(results1, results3, "Data should not be replicable with different seed")


if __name__ == "__main__":
run_tests()
213 changes: 211 additions & 2 deletions test/stateful_dataloader/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import torch
import torch.utils.data

from parameterized import parameterized
from torch.testing._internal.common_utils import IS_MACOS, TEST_CUDA, TestCase
from torchdata.stateful_dataloader import Stateful, StatefulDataLoader

Expand Down Expand Up @@ -1314,7 +1316,7 @@ def test(self):
dataset=dataset,
num_workers=num_workers,
collate_fn=identity,
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None),
)
it = iter(dl)
# Fetch at least one batch from each worker
Expand All @@ -1325,7 +1327,10 @@ def test(self):
if num_workers > 0:
for i in range(num_workers):
# Ensure worker state is stored only once if the dataset is also the iterator
self.assertEqual(state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["dataset_state"], None)
self.assertEqual(
state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["dataset_state"],
None,
)
self.assertTrue(
state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["fetcher_state"][
"dataset_iter_state"
Expand Down Expand Up @@ -1441,6 +1446,210 @@ def test_fast_state_dict_request_skip_steps(self) -> None:
self._run_test(17, 19)


class TestMultiEpochSDL_shard0(TestCase):
def get_map_dl(self, data_size, num_workers, batch_size, shuffle):
dataset = DummyMapDataset(data_size, shuffle=False)
return StatefulDataLoader(
dataset=dataset,
num_workers=num_workers,
batch_size=batch_size,
shuffle=shuffle,
multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None),
)

def _run(self, data_size, num_workers, batch_size, shuffle):
# For reproducibility of testing, fixing the seed
torch.manual_seed(0)
dataloader1 = self.get_map_dl(
data_size=data_size,
num_workers=num_workers,
batch_size=batch_size,
shuffle=shuffle,
)
# Run through the dataloader for 2 epochs and count the number of items yielded
num_items_yielded = 0
dataloader1_items = []
for _ in range(2):
for batch in dataloader1:
dataloader1_items.append(batch)
num_items_yielded += 1
# Save the state dict
state_dict = dataloader1.state_dict()
# Create a new StatefulDataLoader instance and load the state dict
new_dataloader1 = self.get_map_dl(
data_size=data_size,
num_workers=num_workers,
batch_size=batch_size,
shuffle=shuffle,
)
new_dataloader1.load_state_dict(state_dict)
# Run through the new dataloader for another 2 epochs and count the number of items yielded
additional_num_items_yielded = 0
for i in range(2):
epoch_num_items_yielded = 0
for batch in new_dataloader1:
dataloader1_items.append(batch)
epoch_num_items_yielded += 1
additional_num_items_yielded += epoch_num_items_yielded
# Check that the total number of items yielded is correct
self.assertEqual(num_items_yielded + additional_num_items_yielded, data_size * 4)

# now run a second dataloder for 4 epochs and check if the order is same.
# we need to fix the seed again since we want to bring the initial conditions to the same state as at the time of instantiating the first dataloader
torch.manual_seed(0)
dataloader2 = self.get_map_dl(
data_size=data_size,
num_workers=num_workers,
batch_size=batch_size,
shuffle=shuffle,
)
dataloader2_items = []
for _ in range(4):
for batch in dataloader2:
dataloader2_items.append(batch)

self.assertEqual(dataloader1_items, dataloader2_items)

@parameterized.expand(itertools.product([100], [0, 2], [1], [False, True]))
def test_multi_epoch_sdl(self, datasize, num_workers, batch_size, shuffle):
self._run(datasize, num_workers, batch_size, shuffle)


class TestEndOfEpochBehavior_shard0(TestCase):
def get_map_dl(self, data_size, num_workers, batch_size, shuffle):
dataset = DummyMapDataset(data_size, shuffle=False)
return StatefulDataLoader(
dataset=dataset,
num_workers=num_workers,
batch_size=batch_size,
shuffle=shuffle,
multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None),
)

def _count_items_yielded(self, data_loader: StatefulDataLoader) -> int:
num_items_yielded = 0
for batch in data_loader:
num_items_yielded += 1
return num_items_yielded

def _run(self, data_size, num_workers, batch_size, shuffle):
dataloader = self.get_map_dl(
data_size=data_size,
num_workers=num_workers,
batch_size=batch_size,
shuffle=shuffle,
)
# Run through the dataloader for 1 epoch and count the number of items yielded
num_items_yielded = 0

for batch in dataloader:
num_items_yielded += 1
sd_in = dataloader.state_dict()
sd_out = dataloader.state_dict()

self.assertEqual(num_items_yielded, data_size)

# Create a new StatefulDataLoader instance and load the state dict saved before the end of epoch
dataloader_sd_in = self.get_map_dl(
data_size=data_size,
num_workers=num_workers,
batch_size=batch_size,
shuffle=shuffle,
)
dataloader_sd_in.load_state_dict(sd_in)

# Run through the new dataloader for 1 epoch and count the number of items yielded
# num_items_yielded should be 0 since the state dict was saved before the end of epoch
num_items_yielded = self._count_items_yielded(dataloader_sd_in)
self.assertEqual(num_items_yielded, 0)

# Create a new StatefulDataLoader instance and load the state dict saved after the end of epoch
dataloader_sd_out = self.get_map_dl(
data_size=data_size,
num_workers=num_workers,
batch_size=batch_size,
shuffle=shuffle,
)
dataloader_sd_out.load_state_dict(sd_out)

# Run through the new dataloader for 1 epoch and count the number of items yielded
# num_items_yielded should be data_size since the state dict was saved after the end of epoch
num_items_yielded = self._count_items_yielded(dataloader_sd_out)
self.assertEqual(num_items_yielded, data_size)

@parameterized.expand(itertools.product([100], [0, 2], [1], [False, True]))
def test_end_of_epoch_behavior(self, datasize, num_workers, batch_size, shuffle):
self._run(datasize, num_workers, batch_size, shuffle)


class TestNotStatefulSamplerSDL_shard0(TestCase):
def get_map_dl(self, data_size, num_workers, batch_size, sampler_cls):
dataset = DummyMapDataset(data_size, shuffle=False)
sampler = sampler_cls(dataset)
return StatefulDataLoader(
dataset=dataset,
num_workers=num_workers,
batch_size=batch_size,
sampler=sampler,
multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None),
)

def _run(self, data_size, num_workers, batch_size, interrupt, sampler_cls):
torch.manual_seed(0) # Fixing seed for deterministic results
dataloader1 = self.get_map_dl(
data_size=data_size,
num_workers=num_workers,
batch_size=batch_size,
sampler_cls=sampler_cls,
)
# interrupt the dataloader after interrupt batches and save the state dict
results_dataloader1 = []
for i, batch in enumerate(dataloader1):
results_dataloader1.append(batch)
if i == interrupt:
break
state_dict = dataloader1.state_dict()

torch.manual_seed(
0
) # We need to fix seed again so that before fast forwarding we are at the same state of gen as before
resumed_dataloader1 = self.get_map_dl(
data_size=data_size,
num_workers=num_workers,
batch_size=batch_size,
sampler_cls=sampler_cls,
)
resumed_dataloader1.load_state_dict(state_dict)

for batch in resumed_dataloader1:
results_dataloader1.append(batch)

# now start a completely new dataloader and get all the batches
torch.manual_seed(0)
dataloader2 = self.get_map_dl(
data_size=data_size,
num_workers=num_workers,
batch_size=batch_size,
sampler_cls=sampler_cls,
)
results_dataloader2 = []
for batch in dataloader2:
results_dataloader2.append(batch)
self.assertEqual(results_dataloader1, results_dataloader2)

@parameterized.expand(
itertools.product(
[100],
[0, 2],
[1],
[10, 50, 80],
[torch.utils.data.RandomSampler, torch.utils.data.SequentialSampler],
)
)
def test_notstatefulSDL(self, data_size, num_workers, batch_size, interrupt, sampler_cls):
self._run(100, 0, 1, interrupt, sampler_cls)


class TestMultiEpochState_shard0(TestCase):
def get_iterable_dl(self, pw, num_workers):
data_size = [25, 50, 100, 75]
Expand Down
Loading

0 comments on commit c4177af

Please sign in to comment.