Skip to content

Commit

Permalink
Add tests for out of order with checkpointing (#1428)
Browse files Browse the repository at this point in the history
* Add tests for out of order with checkpointing

* add warning logs back

* update test cases
  • Loading branch information
michael-diggin authored Jan 30, 2025
1 parent cad6dbe commit fcdc8b9
Showing 1 changed file with 131 additions and 0 deletions.
131 changes: 131 additions & 0 deletions test/stateful_dataloader/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import itertools
import json
import math
import time
import unittest
from copy import deepcopy

Expand Down Expand Up @@ -1632,5 +1634,134 @@ def test_mp(self):
self._run_test(2, CountIterCallsIter(100))


class _TestSlowIndexDataset(torch.utils.data.Dataset):
def __init__(self, end: int, slow_index: int):
self.end = end
self.slow_index = slow_index
self._worker_id = None

def __getitem__(self, idx):
if idx == self.slow_index:
time.sleep(1.0)
return idx

def __len__(self):
return self.end


class _TestSlowIterableDataset(torch.utils.data.IterableDataset):
def __init__(self, start: int, end: int):
self.start = start
self.end = end
self.mid = math.ceil((self.end - self.start) / 2)

def give_data(self, iter_start, iter_end):
for i in range(iter_start, iter_end):
if i == self.mid:
time.sleep(1.0)
yield i

def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
worker_id = worker_info.id
iter_start = self.start + worker_id * per_worker
iter_end = min(iter_start + per_worker, self.end)
return self.give_data(iter_start, iter_end)


class TestOutOfOrderWithCheckpointing(TestCase):
def test_out_of_order_index_ds(self):
dataset = _TestSlowIndexDataset(end=10, slow_index=0)
dataloader = StatefulDataLoader(
dataset,
num_workers=2,
in_order=False,
)

# worker_id = 0 gets 'stuck' on 0 and also has 2 in it's queue
# due to prefetch_factor being 2
output = []
for i, data in enumerate(dataloader):
output.append(data)
if i == 3:
state_dict = dataloader.state_dict()
break

# 0 is the slow index, assert it isn't in the output before the pause
self.assertNotIn(0, output)

new_dataloader = StatefulDataLoader(dataset, num_workers=2, in_order=False)
new_dataloader.load_state_dict(state_dict)
for i, data in enumerate(new_dataloader):
output.append(data)

self.assertEqual(len(output), 10)
self.assertNotEqual(output, list(range(10)))
self.assertEqual(sorted(output), list(range(10)))

def test_out_of_order_iterable_ds_one_completed_worker(self):
dataset = _TestSlowIterableDataset(start=0, end=10)
dataloader = StatefulDataLoader(
dataset,
num_workers=2,
prefetch_factor=2,
in_order=False,
)

# break later on, as one of the workers will be finished
output = []
for i, data in enumerate(dataloader):
output.append(data)
if i == 7:
state_dict = dataloader.state_dict()
break

worker_0_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"]["fetcher_state"]["fetcher_ended"]
worker_1_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_1"]["fetcher_state"]["fetcher_ended"]
self.assertTrue(worker_0_ended)
self.assertFalse(worker_1_ended)

new_dataloader = StatefulDataLoader(dataset, batch_size=1, num_workers=2, in_order=False)
new_dataloader.load_state_dict(state_dict)
for i, data in enumerate(new_dataloader):
output.append(data)

self.assertEqual(len(output), 10)
self.assertEqual(output, list(range(10)))
self.assertNotEqual(output, [0, 5, 1, 6, 2, 7, 3, 8, 4, 9])

def test_out_of_order_iterable_ds_no_completed_workers(self):
dataset = _TestSlowIterableDataset(start=0, end=10)
dataloader = StatefulDataLoader(
dataset,
num_workers=2,
prefetch_factor=2,
in_order=False,
)

# break early - both workers will resume
output = []
for i, data in enumerate(dataloader):
output.append(data)
if i == 3:
state_dict = dataloader.state_dict()
break

worker_0_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"]["fetcher_state"]["fetcher_ended"]
worker_1_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_1"]["fetcher_state"]["fetcher_ended"]
self.assertFalse(worker_0_ended)
self.assertFalse(worker_1_ended)

new_dataloader = StatefulDataLoader(dataset, batch_size=1, num_workers=2, in_order=False)
new_dataloader.load_state_dict(state_dict)
for i, data in enumerate(new_dataloader):
output.append(data)

self.assertEqual(len(output), 10)
self.assertEqual(output, list(range(10)))
self.assertNotEqual(output, [0, 5, 1, 6, 2, 7, 3, 8, 4, 9])


if __name__ == "__main__":
unittest.main()

0 comments on commit fcdc8b9

Please sign in to comment.