Skip to content

Commit 4ba17dd

Browse files
NivekTfacebook-github-bot
authored andcommitted
Saving and restoring initial seed generator (#998)
Summary: Pull Request resolved: #998 Changes to `DataLoader2`: - Modifying `state_dict` to store `randomness_state`, which includes: - `_seed: int` - `_reset_seed: bool` - flag indicating whether `_seed` needs to be set - `_seed_generator` - the latest version at the time when `state_dict` is called - `_initial_seed_generator` - the versopm that is saved at the beginning of very epoch - Modifying `from_state` and `load_state_dict` to restore `randomness_state` - Adding a method `_restore_checkpoint_beginning_of_epoch` - This sets `self._seed_generator = self._initial_seed_generator`, allowing users to re-create an epoch from the beginning. --- ### Considerations Storing the randomness states provide more flexibility for users to restore as they see fit. The decision to do that should not be controversial. I decided to make add a new method for checkpointing at the beginning of the epoch, ensure that users are not confused about what randomness is restored by default. The basic idea is that we want to allow users to restore `dl2._seed_generator` to the previously saved version. From that point on, they can create a new `__iter__` and continue from the beginning of the epoch. - Note that since `_seed` and `_reset_seed` are also saved, if the users were planning to use a different seed or if there was a need to re-seed, those remain valid after restoring the checkpoint. - Finally, if users change their mind at any point (after restoring) and want to manual set `seed`. That `seed` will override any other behavior and the `seed` will be used. Test Plan: Imported from OSS Reviewed By: wenleix Differential Revision: D44390519 Pulled By: NivekT fbshipit-source-id: 0faa3d7aef23463e86765571018f83384bcb31e1
1 parent aeda987 commit 4ba17dd

File tree

4 files changed

+110
-5
lines changed

4 files changed

+110
-5
lines changed

test/dataloader2/test_mprs.py

+62-4
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
87
import multiprocessing as mp
98
import unittest
109
from unittest import TestCase
@@ -14,7 +13,7 @@
1413
from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES
1514

1615
from torchdata.dataloader2 import DataLoader2, DataLoader2Iterator, MultiProcessingReadingService
17-
from torchdata.datapipes.iter import IterableWrapper
16+
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe
1817

1918

2019
def _add_one(x: int) -> int:
@@ -46,6 +45,17 @@ def _dispatching_dp(n_elements=1000):
4645
return dp
4746

4847

48+
class NonShardableDataPipe(IterDataPipe):
49+
def __init__(self, dp: IterDataPipe):
50+
self.dp = dp
51+
52+
def is_replicable(self):
53+
return False
54+
55+
def __iter__(self):
56+
yield from self.dp
57+
58+
4959
class TestMultiProcessingReadingService(TestCase):
5060
r"""
5161
This tests specific functionalities of MultiProcessingReadingService, notably
@@ -64,7 +74,7 @@ def test_early_exit(self, ctx, dp_fn, main_prefetch, worker_prefetch) -> None:
6474
worker_prefetch_cnt=worker_prefetch,
6575
multiprocessing_context=ctx,
6676
)
67-
dl = DataLoader2(dp, reading_service=rs)
77+
dl: DataLoader2 = DataLoader2(dp, reading_service=rs)
6878
it = iter(dl)
6979
for _ in range(10):
7080
_ = next(it)
@@ -82,7 +92,7 @@ def test_exit(self, ctx, dp_fn, main_prefetch, worker_prefetch) -> None:
8292
worker_prefetch_cnt=worker_prefetch,
8393
multiprocessing_context=ctx,
8494
)
85-
dl = DataLoader2(dp, reading_service=rs)
95+
dl: DataLoader2 = DataLoader2(dp, reading_service=rs)
8696
_ = list(dl)
8797
dl.shutdown()
8898

@@ -248,6 +258,54 @@ def test_reading_service_limit(self, dp, n_workers, worker_prefetch_cnt, main_pr
248258
res.append(x)
249259
self.assertEqual(9, len(res))
250260

261+
def test_initial_epoch_checkpointing(self):
262+
dp = IterableWrapper(range(20)).shuffle().sharding_filter()
263+
# Note that the second `shuffle` occurs in the main process, which uses a different RNG from
264+
# the `shuffle` done in the worker processes
265+
dp = NonShardableDataPipe(dp).shuffle() # type: ignore[assignment, arg-type]
266+
rs = MultiProcessingReadingService(num_workers=2)
267+
268+
# Functional Test: Saving state before iterator is created
269+
dl: DataLoader2 = DataLoader2(datapipe=dp, reading_service=rs)
270+
dl.seed(1)
271+
initial_state = dl.state_dict()
272+
it1 = iter(dl)
273+
274+
restored_dl: DataLoader2 = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type]
275+
restored_dl._restore_checkpoint_beginning_of_epoch()
276+
self.assertEqual(list(it1), list(restored_dl))
277+
278+
dl.shutdown()
279+
restored_dl.shutdown()
280+
281+
# Functional Test: Saving state after iterator is created
282+
dl = DataLoader2(datapipe=dp, reading_service=rs)
283+
dl.seed(1)
284+
it1 = iter(dl)
285+
initial_state = dl.state_dict()
286+
287+
restored_dl = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type]
288+
restored_dl._restore_checkpoint_beginning_of_epoch()
289+
self.assertEqual(list(it1), list(restored_dl))
290+
291+
dl.shutdown()
292+
restored_dl.shutdown()
293+
294+
# Functional Test: Saving state after iterator is created and began iterating
295+
dl = DataLoader2(datapipe=dp, reading_service=rs)
296+
dl.seed(1)
297+
it1 = iter(dl)
298+
temp = next(it1) # Starts iterating
299+
initial_state = dl.state_dict()
300+
301+
restored_dl = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type]
302+
restored_dl._restore_checkpoint_beginning_of_epoch()
303+
304+
self.assertEqual([temp] + list(it1), list(restored_dl)) # Note skipping over 1st element from actual result
305+
306+
dl.shutdown()
307+
restored_dl.shutdown()
308+
251309
# TODO: Test cases when there is official support of `pause` and `resume` with round-robin sharding
252310
# Currently, using sharding_round_robin raises a warning
253311
# def test_round_robin_dispatching_pause_limit(self):

torchdata/dataloader2/dataloader2.py

+36-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
7+
import pickle
88
import warnings
99

1010
from typing import Any, Dict, Generic, Iterable, Iterator, Optional, TypeVar, Union
@@ -19,6 +19,7 @@
1919
T_co = TypeVar("T_co", covariant=True)
2020
SERIALIZED_DATAPIPE_KEY_NAME = "serialized_datapipe"
2121
READING_SERVICE_STATE_KEY_NAME = "reading_service_state"
22+
RANDOMNESS_STATE_KEY_NAME = "randomness_state"
2223

2324

2425
class DataLoader2Iterator(Iterator[T_co]):
@@ -176,6 +177,8 @@ def __init__(
176177
self._seed_generator: SeedGenerator = SeedGenerator()
177178
self._seed: Optional[int] = None
178179
self._reset_seed: bool = True
180+
# Seed generator as of beginning of each epoch
181+
self._initial_seed_generator: SeedGenerator = clone(self._seed_generator)
179182

180183
def __iter__(self) -> DataLoader2Iterator[T_co]:
181184
r"""
@@ -198,6 +201,9 @@ def __iter__(self) -> DataLoader2Iterator[T_co]:
198201
else:
199202
self._seed_generator.seed()
200203

204+
# Saving initial seed generator state
205+
self._initial_seed_generator = clone(self._seed_generator)
206+
201207
if not self._adapted and self.reading_service is not None:
202208
if self.reading_service_state is None:
203209
self.datapipe = self.reading_service.initialize(self.datapipe)
@@ -269,10 +275,17 @@ def state_dict(self) -> Dict[str, Any]:
269275

270276
# Serialize datapipe after applying adapters and before reading service adaption
271277
serialized_datapipe = serialize_datapipe(self._datapipe_before_reading_service_adapt)
278+
serialized_randomness_state = (
279+
self._seed,
280+
self._reset_seed,
281+
pickle.dumps(self._seed_generator),
282+
pickle.dumps(self._initial_seed_generator),
283+
)
272284

273285
return {
274286
SERIALIZED_DATAPIPE_KEY_NAME: serialized_datapipe,
275287
READING_SERVICE_STATE_KEY_NAME: reading_service_state,
288+
RANDOMNESS_STATE_KEY_NAME: serialized_randomness_state,
276289
}
277290

278291
@classmethod
@@ -294,6 +307,12 @@ def from_state(
294307
reading_service=reading_service,
295308
)
296309
data_loader.reading_service_state = reading_service_state
310+
311+
randomness_state = state[RANDOMNESS_STATE_KEY_NAME]
312+
data_loader._seed, data_loader._reset_seed = randomness_state[0], randomness_state[1]
313+
data_loader._seed_generator = pickle.loads(randomness_state[2])
314+
data_loader._initial_seed_generator = pickle.loads(randomness_state[3])
315+
297316
return data_loader
298317

299318
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
@@ -320,12 +339,28 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
320339
self.datapipe = deserialized_datapipe
321340
self.reading_service_state = reading_service_state
322341

342+
randomness_state = state_dict[RANDOMNESS_STATE_KEY_NAME]
343+
self._seed, self._reset_seed = randomness_state[0], randomness_state[1]
344+
self._seed_generator = pickle.loads(randomness_state[2])
345+
self._initial_seed_generator = pickle.loads(randomness_state[3])
346+
323347
# re-initialize datapipe_adapter_fn and _datapipe_before_reading_service_adapt
324348
if self.datapipe_adapter_fns is not None:
325349
for adapter_fn in self.datapipe_adapter_fns:
326350
self.datapipe = adapter_fn(self.datapipe)
327351
self._datapipe_before_reading_service_adapt = clone(self.datapipe)
328352

353+
def _restore_checkpoint_beginning_of_epoch(self) -> None:
354+
r"""
355+
At the beginning of each iteration (epoch), the initial state of randomness is automatically saved.
356+
That state is also saved as part of ``state_dict``. This method restores the current DataLoader2 RNG state
357+
to that initial state.
358+
359+
The common use case is to invoke this method after ``DataLoader2``'s state is restored (through
360+
``.from_state(...)`` or ``load_state_dict(...)``) in order to resume from the beginning of the last-ran epoch.
361+
"""
362+
self._seed_generator = self._initial_seed_generator
363+
329364
def _pause(self):
330365
if hasattr(self.reading_service, "_pause"):
331366
self._is_paused = True

torchdata/dataloader2/random/seed_generator.py

+10
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,13 @@ def spawn(self, worker_id: int, inplace: bool = False) -> "SeedGenerator":
8383
self._worker_rng = self._worker_rng.spawn(worker_id)
8484
return self
8585
return SeedGenerator(seed=None, _rngs=(self._shared_rng.clone(), self._worker_rng.spawn(worker_id)))
86+
87+
def __getstate__(self):
88+
state = (
89+
self._shared_rng,
90+
self._worker_rng,
91+
)
92+
return state
93+
94+
def __setstate__(self, state):
95+
self._shared_rng, self._worker_rng = state

torchdata/dataloader2/reading_service.py

+2
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,8 @@ def initialize_iteration(
312312
) -> Optional[Callable[[DataPipe], DataPipe]]:
313313
assert self._end_datapipe is not None
314314

315+
# Set random seeds for DataPipe that are in the main process (NOT those in worker processes)
316+
# Worker seeds are set in `process_reset_fn`
315317
set_graph_random_seed(self._end_datapipe, seed_generator)
316318

317319
if self._mp:

0 commit comments

Comments
 (0)