Resumable checkpointable dataloading#1570
Open
pzelasko wants to merge 31 commits into
Open
Conversation
…r archives Introduce lhotse/indexing.py with: - Binary index format: arrays of little-endian uint64 byte offsets - create_jsonl_index(), create_tar_index(), create_shar_index() for index creation - IndexedJsonlReader and IndexedTarReader for O(1) random-access reads - LazyShuffledRange: Feistel-cipher-based O(1) memory permutation of range(N) - Utility functions: index_file_path(), read_index(), index_exists() This is the foundation for resumable checkpointable dataloading. Index files enable seeking directly to any sample without sequential scanning. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Refactor lhotse/lazy.py: - Add StatefulIterator mixin with state_dict()/load_state_dict() protocol - Unify child attribute names: self.iterator -> self.source (single child), self.iterators -> self.sources (list of children) - Add state_dict()/load_state_dict() to all 10 lazy iterator classes: LazyJsonlIterator, LazyManifestIterator, LazyIteratorChain, LazyIteratorMultiplexer, LazyInfiniteApproximateMultiplexer, LazyShuffler, LazyFilter, LazyMapper, LazyFlattener, LazyRepeater - LazyJsonlIterator uses IndexedJsonlReader for position tracking when index files are available The unified naming enables trivial graph traversal for recursive state collection/restoration in the checkpoint module. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Introduce lhotse/checkpoint.py with: - collect_state_dict(root): Recursively collects state from all StatefulIterator nodes in the iterator graph - restore_state_dict(root, state): Recursively restores state, validates type names match between graph and checkpoint - DataloaderCheckpoint dataclass: Serializable container for num_workers, world_size, rank, worker_states, and sampler_state with save()/load()/validate() methods Uses the unified source/sources naming convention from Phase 2 for trivial graph traversal. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Modify LazySharIterator (lhotse/shar/readers/lazy.py): - Inherit from StatefulIterator - Track current_shard_idx, position_in_shard, shard_order during iteration - Add state_dict()/load_state_dict() for mid-iteration checkpointing - Exclude .idx files from _init_from_dir discovery Modify SharWriter (lhotse/shar/writers/shar.py): - Add compress_jsonl parameter (default True for backward compat) - Add create_index parameter (default True) to auto-create .idx files - _create_indexes() skips non-local paths (pipe:, http://, s3://) - Support uncompressed JSONL output for indexing compatibility Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Enhance CutSampler.state_dict() (base.py): - Capture CutSet iterator graph state via collect_state_dict() when cuts_iter contains StatefulIterator instances - Store as 'cuts_state' key; consumed by load_state_dict() Modify DynamicCutSampler._fast_forward() and DynamicBucketingSampler._fast_forward(): - When cuts_state is available, use O(1) indexed state restoration via restore_state_dict() instead of O(N) batch re-iteration - Fall back to legacy _fast_forward for old state_dicts or when indexed restoration fails Add Stateful protocol to IterableDatasetWrapper: - state_dict() returns epoch + sampler state - load_state_dict() restores both, enabling torchdata StatefulDataLoader integration Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add 'lhotse index' command group with subcommands: - lhotse index jsonl PATH — create index for uncompressed JSONL - lhotse index tar PATH — create index for uncompressed tar archive - lhotse index shar DIR — create indexes for all files in a Shar directory Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add new 'Resumable Stateful Dataloading (Indexed)' section to docs/datasets.rst covering: - Motivation: O(1) vs O(N) restore - Setting up indexed data (SharWriter + CLI) - Using StatefulDataLoader for per-worker checkpointing - Requirements and limitations Add 'torchdata' as optional dependency under 'checkpoint' extra in setup.py for StatefulDataLoader support. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…rchdata
Fix IterableDatasetWrapper.load_state_dict() to create the sampler iter
immediately instead of setting it to None. torchdata's StatefulDataLoader
calls iter(dataset) before load_state_dict(), so nulling _sampler_iter
caused a TypeError on __next__. The sampler's __iter__ returns self when
_just_restored_state=True, preserving the restored position.
Add e2e tests exercising StatefulDataLoader with num_workers in {0, 2}:
basic mux, augmentation, full pipeline, and various checkpoint positions.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…pickling fixes - LazySharIterator: add has_constant_time_access, __getitem__ for O(1) random access across shards, and pickle-safe __getstate__/__setstate__ - IndexedJsonlReader/IndexedTarReader: add __getstate__/__setstate__ - IdentityDataset: new importable pass-through dataset utility for use with IterableDatasetWrapper (avoids __main__ pickling issues) - CutSet.from_file(shuffle=True, seed=N): thread shuffle/seed params through Serializable.from_file → load_manifest_lazy_or_eager → load_manifest_lazy for Feistel-cipher shuffled iteration - CutSet.from_jsonl_lazy(shuffle=True, seed=N): same shuffle support - Fix num_workers>0 pickling: defer _fast_forward() in samplers and iter(sampler) in IterableDatasetWrapper via lazy flags - CutSet.to_shar: expose compress_jsonl/create_index params - LazyMapper, LazyRepeater, LazySlicer, LazyIteratorChain: add has_constant_time_access property and __getitem__ for indexed access - Tutorial notebook: full rewrite using user-facing APIs only - Tests: Shar random access, indexed manifest iterator, CutSet state_dict, e2e checkpoint restore with file-backed lazy data Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Review fixes: - Replace assert with raise ValueError in LazyShuffledRange.load_state_dict - Raise ValueError in LazyIndexedManifestIterator.load_state_dict when shuffle state is missing from checkpoint - Add rank validation to DataloaderCheckpoint.validate() - Cache cumulative lengths in LazyIteratorChain.__getitem__ (O(log m) via bisect instead of O(m) linear scan) - Add thread safety warning to StatefulIterator docstring - Add test for CutSet.from_file(shuffle=True, seed=42) API path New test: - test/dataset/test_multinode_resume.py: E2E multi-node training resumption using torchdata StatefulDataLoader with seed="randomized", make_worker_init_fn, infinite .repeat() + .mux() pipelines, and world_size=2 / num_workers=2 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…extend tutorial - Split LazySharIterator into streaming (LazySharIterator) and indexed (LazyIndexedSharIterator) classes for clean separation of concerns - Add LazyIndexedSharIterator with O(1) random access, shuffle, and checkpoint/restore support in lhotse/shar/readers/indexed.py - Add indexed parameter to CutSet.from_files() and CutSet.from_shar() with auto-detection when .idx files exist - Extend LazyIteratorChain to auto-upgrade shuffle_iters to global Feistel-cipher shuffling when all sources are indexed - Add "lhotse_shar_fields" origin type for fields-based Shar reload - Extend tutorial notebook with torchdata disk checkpoint and Shar dataloading examples - Add comprehensive tests for all new functionality Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
… data Add an optional `index_path` parameter throughout the indexed reading stack so users can point to .idx files at a different (local) location. This enables indexed random access when data lives on object stores (S3, GCS) where np.memmap-based .idx files cannot be written next to the manifests. Changes bottom-up: - indexing.py: output_path for create_*_index, output_dir for create_shar_index, index_path for Indexed*Reader and index_exists - lazy.py: index_path for LazyJsonlIterator and LazyIndexedManifestIterator; extend _origin tuples to carry index_path - shar/readers/indexed.py: index_path for LazyIndexedSharIterator (directory or dict), forwarded to all per-shard readers and pickling - serialization.py: index_path threaded through from_file → load_manifest_lazy_or_eager → load_manifest_lazy - cut/set.py: index_path on CutSet.from_files (list) and CutSet.from_shar (dir or dict), with validation - checkpoint.py: variable-length origin tuple handling, index_path forwarded in all three origin loaders - bin/modes/index.py: --output-dir / -o CLI option for jsonl, tar, shar Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Resolve conflicts in LazyCutMixer._mix_one (keep feature branch structure, add master's tag parameter) and librispeech formatting. Update test_indexed_read.py to use numpy instead of lilcom, consistent with master's make-lilcom-optional changes. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…called Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com>
Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com>
Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com>
…idual GET requests Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com>
…ptr fallback lazy.py + dataset/dataloading.py: resolve (shard_id, num_shards) at iter time via get_worker_partition() / LHOTSE_USE_WORKER_PARTITION so force_map_dataset:false is correct under DP×worker fan-out; map-style path bit-identical. LazyIteratorMultiplexer rejects seed='randomized' under multi-shard partition. ais/batch_loader.py + dataset/input_strategies.py: rename prefer_individual -> force_individual. Add byte-range fallback for shar_ptr requests via per-object get_reader(byte_range=...) when MOSS BatchRequest.add lacks byte-range or force_individual is set; previously shar_ptr was silently dropped from the batch. indexing.py: atomic _write_index (stage-and-rename) so concurrent auto_create_index racers never observe a 0-byte .idx; size-aware index_exists treats 0-byte / non-uint64-aligned files as missing so a stale truncated sidecar triggers re-create instead of producing len(reader)=-1 -> "__len__() should return >= 0" deep in the sampler. create_tar_index falls back to tf.offset when the reader can't seek (AIS ObjectFileReader inherits BufferedIOBase.tell which delegates to seek -> raises). test/test_partition.py: 49 tests pinning partition edge cases (rank- without-signal, topology mismatch, randomized rejection, worker × DP composition, mixed indexed/non-indexed chains). test/test_indexing.py: +12 partition tests on LazyShuffledRange, +2 race tests reproducing the 0-byte / concurrent auto-create crash seen in production. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…tes to it; refresh stale memmap tests
* Add PartitionedIndexedIterator in lhotse/dataset/dataloading.py: the
shared partition-aware iteration driver for indexed leaf iterators.
Encapsulates (shard_id, num_shards) lookup, position tracking across
DataLoader worker subprocesses, and topology-validated resume. Supports
both stride and Feistel-shuffled modes (shuffle=True delegates to
LazyShuffledRange). Iterators plug in by storing self._iter_state and
forwarding state_dict()/load_state_dict().
* Refactor LazyIndexedManifestIterator to delegate to the helper. Drops
~50 lines of duplicated position/shuffle/topology plumbing; the body
is now a single for-loop yielding from self._iter_state.iterate(...).
* test/test_partition.py: 13 new helper-level tests pinning the contract
(single-rank full coverage, multi-rank disjoint x [2,3,4,7], resume
from middle, topology-mismatch raises, map-style fallthrough, empty
manifest, n<world_size, neutral state_dict). Update the chain-level
topology-mismatch test to match the helper's new error message
("topology mismatch on resume" instead of LazyShuffledRange's
"state mismatch", which the helper's earlier check now preempts).
* test/test_indexing.py: rename test_read_index_remote_path_is_cached_and_memmapped
-> ..._is_cached_locally and drop the isinstance(np.memmap) assertions.
read_index switched from np.memmap to np.fromfile (deliberately, to
avoid exhausting kernel vm.max_map_count at 80k+ shards); the tests
were checking for the old behavior and had been failing pre-emptively.
Cache-correctness assertions are kept; the dtype check is replaced by
np.array_equal against the source bytes.
62/62 partition tests + 117/117 indexing tests pass under nemo312-hf5.
Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com>
…ator The shar reader's own ``_get_worker_indices`` + ``_position`` machinery was the last indexed-iterator path still using legacy ``split_for_dataloading``, which (a) defaults to False so iterable NeMo dataloaders never partitioned shar leaves across DP ranks, and (b) doesn't record ``(shard_id, num_shards)`` in state_dict so resume under a different topology silently desyncs. Caught by the 0909 SALM validator: 4072 AMI cuts duplicated across 8 ranks. Delegate iteration + state to ``PartitionedIndexedIterator``, matching every other indexed leaf (``LazyNeMoTarredIterator``, ``LazyIndexedManifestIterator``, parquet/jsonl text adapters). ``split_for_dataloading`` is kept as accepted-but-ignored for signature compat with ``CutSet.from_shar(split_for_dataloading=...)``. Regression test exercises 1/2/4 ranks (disjoint slices + complete union), the env-var-off collapse for map-style mode, same-topology resume, topology-mismatch raise, and shuffle correctness. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…GET operations Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com>
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
8 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This is a pretty large PR that makes changes all over Lhotse to enable random-access checkpointable dataloading.
The main idea is to introduce binary indexes for JSONL manifests (similarly to
StatelessSampler) that enable O(1) random sampling, and extend all iterator classes to support lookup by index.There are some tricky parts:
Unlike previous approaches it allows quick resumption, 100% determinism, and improved sampling randomness for sharded / tarred / lhotse shar data. I noticed it helps to improve the results for models trained on many datasets blended using
mux- sometimes a lot, depending on how large is the data and how it's physically sharded.I've been validating it for a while and I managed to train some successful models using it, but I might do some more hardening before merging.