Skip to content

Resumable checkpointable dataloading#1570

Open
pzelasko wants to merge 31 commits into
masterfrom
feat/resumable-checkpointable-dataloading
Open

Resumable checkpointable dataloading#1570
pzelasko wants to merge 31 commits into
masterfrom
feat/resumable-checkpointable-dataloading

Conversation

@pzelasko

@pzelasko pzelasko commented Jun 9, 2026

Copy link
Copy Markdown
Collaborator

This is a pretty large PR that makes changes all over Lhotse to enable random-access checkpointable dataloading.
The main idea is to introduce binary indexes for JSONL manifests (similarly to StatelessSampler) that enable O(1) random sampling, and extend all iterator classes to support lookup by index.

There are some tricky parts:

  • iterators can be composed into a graph (mux, augmentation, filter, map, etc), and the indexing scheme is designed to account for that.
  • checkpointing dynamic bucketing sampler's buffers - instead of storing actual data, it only stores the information necessary to lookup a sample (including keys necessary to traverse the entire iterator graph)

Unlike previous approaches it allows quick resumption, 100% determinism, and improved sampling randomness for sharded / tarred / lhotse shar data. I noticed it helps to improve the results for models trained on many datasets blended using mux - sometimes a lot, depending on how large is the data and how it's physically sharded.

I've been validating it for a while and I managed to train some successful models using it, but I might do some more hardening before merging.

pzelasko and others added 28 commits March 2, 2026 13:21
…r archives

Introduce lhotse/indexing.py with:
- Binary index format: arrays of little-endian uint64 byte offsets
- create_jsonl_index(), create_tar_index(), create_shar_index() for index creation
- IndexedJsonlReader and IndexedTarReader for O(1) random-access reads
- LazyShuffledRange: Feistel-cipher-based O(1) memory permutation of range(N)
- Utility functions: index_file_path(), read_index(), index_exists()

This is the foundation for resumable checkpointable dataloading.
Index files enable seeking directly to any sample without sequential scanning.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Refactor lhotse/lazy.py:
- Add StatefulIterator mixin with state_dict()/load_state_dict() protocol
- Unify child attribute names: self.iterator -> self.source (single child),
  self.iterators -> self.sources (list of children)
- Add state_dict()/load_state_dict() to all 10 lazy iterator classes:
  LazyJsonlIterator, LazyManifestIterator, LazyIteratorChain,
  LazyIteratorMultiplexer, LazyInfiniteApproximateMultiplexer,
  LazyShuffler, LazyFilter, LazyMapper, LazyFlattener, LazyRepeater
- LazyJsonlIterator uses IndexedJsonlReader for position tracking when
  index files are available

The unified naming enables trivial graph traversal for recursive
state collection/restoration in the checkpoint module.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Introduce lhotse/checkpoint.py with:
- collect_state_dict(root): Recursively collects state from all
  StatefulIterator nodes in the iterator graph
- restore_state_dict(root, state): Recursively restores state,
  validates type names match between graph and checkpoint
- DataloaderCheckpoint dataclass: Serializable container for
  num_workers, world_size, rank, worker_states, and sampler_state
  with save()/load()/validate() methods

Uses the unified source/sources naming convention from Phase 2
for trivial graph traversal.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Modify LazySharIterator (lhotse/shar/readers/lazy.py):
- Inherit from StatefulIterator
- Track current_shard_idx, position_in_shard, shard_order during iteration
- Add state_dict()/load_state_dict() for mid-iteration checkpointing
- Exclude .idx files from _init_from_dir discovery

Modify SharWriter (lhotse/shar/writers/shar.py):
- Add compress_jsonl parameter (default True for backward compat)
- Add create_index parameter (default True) to auto-create .idx files
- _create_indexes() skips non-local paths (pipe:, http://, s3://)
- Support uncompressed JSONL output for indexing compatibility

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Enhance CutSampler.state_dict() (base.py):
- Capture CutSet iterator graph state via collect_state_dict() when
  cuts_iter contains StatefulIterator instances
- Store as 'cuts_state' key; consumed by load_state_dict()

Modify DynamicCutSampler._fast_forward() and
DynamicBucketingSampler._fast_forward():
- When cuts_state is available, use O(1) indexed state restoration
  via restore_state_dict() instead of O(N) batch re-iteration
- Fall back to legacy _fast_forward for old state_dicts or when
  indexed restoration fails

Add Stateful protocol to IterableDatasetWrapper:
- state_dict() returns epoch + sampler state
- load_state_dict() restores both, enabling torchdata
  StatefulDataLoader integration

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add 'lhotse index' command group with subcommands:
- lhotse index jsonl PATH  — create index for uncompressed JSONL
- lhotse index tar PATH    — create index for uncompressed tar archive
- lhotse index shar DIR    — create indexes for all files in a Shar directory

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add new 'Resumable Stateful Dataloading (Indexed)' section to
docs/datasets.rst covering:
- Motivation: O(1) vs O(N) restore
- Setting up indexed data (SharWriter + CLI)
- Using StatefulDataLoader for per-worker checkpointing
- Requirements and limitations

Add 'torchdata' as optional dependency under 'checkpoint' extra
in setup.py for StatefulDataLoader support.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…rchdata

Fix IterableDatasetWrapper.load_state_dict() to create the sampler iter
immediately instead of setting it to None. torchdata's StatefulDataLoader
calls iter(dataset) before load_state_dict(), so nulling _sampler_iter
caused a TypeError on __next__. The sampler's __iter__ returns self when
_just_restored_state=True, preserving the restored position.

Add e2e tests exercising StatefulDataLoader with num_workers in {0, 2}:
basic mux, augmentation, full pipeline, and various checkpoint positions.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…pickling fixes

- LazySharIterator: add has_constant_time_access, __getitem__ for O(1)
  random access across shards, and pickle-safe __getstate__/__setstate__
- IndexedJsonlReader/IndexedTarReader: add __getstate__/__setstate__
- IdentityDataset: new importable pass-through dataset utility for use
  with IterableDatasetWrapper (avoids __main__ pickling issues)
- CutSet.from_file(shuffle=True, seed=N): thread shuffle/seed params
  through Serializable.from_file → load_manifest_lazy_or_eager →
  load_manifest_lazy for Feistel-cipher shuffled iteration
- CutSet.from_jsonl_lazy(shuffle=True, seed=N): same shuffle support
- Fix num_workers>0 pickling: defer _fast_forward() in samplers and
  iter(sampler) in IterableDatasetWrapper via lazy flags
- CutSet.to_shar: expose compress_jsonl/create_index params
- LazyMapper, LazyRepeater, LazySlicer, LazyIteratorChain: add
  has_constant_time_access property and __getitem__ for indexed access
- Tutorial notebook: full rewrite using user-facing APIs only
- Tests: Shar random access, indexed manifest iterator, CutSet
  state_dict, e2e checkpoint restore with file-backed lazy data

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Review fixes:
- Replace assert with raise ValueError in LazyShuffledRange.load_state_dict
- Raise ValueError in LazyIndexedManifestIterator.load_state_dict when
  shuffle state is missing from checkpoint
- Add rank validation to DataloaderCheckpoint.validate()
- Cache cumulative lengths in LazyIteratorChain.__getitem__ (O(log m)
  via bisect instead of O(m) linear scan)
- Add thread safety warning to StatefulIterator docstring
- Add test for CutSet.from_file(shuffle=True, seed=42) API path

New test:
- test/dataset/test_multinode_resume.py: E2E multi-node training
  resumption using torchdata StatefulDataLoader with seed="randomized",
  make_worker_init_fn, infinite .repeat() + .mux() pipelines, and
  world_size=2 / num_workers=2

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…extend tutorial

- Split LazySharIterator into streaming (LazySharIterator) and indexed
  (LazyIndexedSharIterator) classes for clean separation of concerns
- Add LazyIndexedSharIterator with O(1) random access, shuffle, and
  checkpoint/restore support in lhotse/shar/readers/indexed.py
- Add indexed parameter to CutSet.from_files() and CutSet.from_shar()
  with auto-detection when .idx files exist
- Extend LazyIteratorChain to auto-upgrade shuffle_iters to global
  Feistel-cipher shuffling when all sources are indexed
- Add "lhotse_shar_fields" origin type for fields-based Shar reload
- Extend tutorial notebook with torchdata disk checkpoint and Shar
  dataloading examples
- Add comprehensive tests for all new functionality

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
… data

Add an optional `index_path` parameter throughout the indexed reading
stack so users can point to .idx files at a different (local) location.
This enables indexed random access when data lives on object stores
(S3, GCS) where np.memmap-based .idx files cannot be written next to
the manifests.

Changes bottom-up:
- indexing.py: output_path for create_*_index, output_dir for
  create_shar_index, index_path for Indexed*Reader and index_exists
- lazy.py: index_path for LazyJsonlIterator and
  LazyIndexedManifestIterator; extend _origin tuples to carry index_path
- shar/readers/indexed.py: index_path for LazyIndexedSharIterator
  (directory or dict), forwarded to all per-shard readers and pickling
- serialization.py: index_path threaded through from_file →
  load_manifest_lazy_or_eager → load_manifest_lazy
- cut/set.py: index_path on CutSet.from_files (list) and
  CutSet.from_shar (dir or dict), with validation
- checkpoint.py: variable-length origin tuple handling, index_path
  forwarded in all three origin loaders
- bin/modes/index.py: --output-dir / -o CLI option for jsonl, tar, shar

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Resolve conflicts in LazyCutMixer._mix_one (keep feature branch
structure, add master's tag parameter) and librispeech formatting.
Update test_indexed_read.py to use numpy instead of lilcom,
consistent with master's make-lilcom-optional changes.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com>
…called

Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com>
Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com>
Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com>
…idual GET requests

Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com>
…ptr fallback

lazy.py + dataset/dataloading.py: resolve (shard_id, num_shards) at iter
time via get_worker_partition() / LHOTSE_USE_WORKER_PARTITION so
force_map_dataset:false is correct under DP×worker fan-out; map-style
path bit-identical. LazyIteratorMultiplexer rejects seed='randomized'
under multi-shard partition.

ais/batch_loader.py + dataset/input_strategies.py: rename
prefer_individual -> force_individual. Add byte-range fallback for
shar_ptr requests via per-object get_reader(byte_range=...) when MOSS
BatchRequest.add lacks byte-range or force_individual is set;
previously shar_ptr was silently dropped from the batch.

indexing.py: atomic _write_index (stage-and-rename) so concurrent
auto_create_index racers never observe a 0-byte .idx; size-aware
index_exists treats 0-byte / non-uint64-aligned files as missing so
a stale truncated sidecar triggers re-create instead of producing
len(reader)=-1 -> "__len__() should return >= 0" deep in the
sampler. create_tar_index falls back to tf.offset when the reader
can't seek (AIS ObjectFileReader inherits BufferedIOBase.tell which
delegates to seek -> raises).

test/test_partition.py: 49 tests pinning partition edge cases (rank-
without-signal, topology mismatch, randomized rejection, worker × DP
composition, mixed indexed/non-indexed chains).

test/test_indexing.py: +12 partition tests on LazyShuffledRange, +2
race tests reproducing the 0-byte / concurrent auto-create crash
seen in production.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…tes to it; refresh stale memmap tests

* Add PartitionedIndexedIterator in lhotse/dataset/dataloading.py: the
  shared partition-aware iteration driver for indexed leaf iterators.
  Encapsulates (shard_id, num_shards) lookup, position tracking across
  DataLoader worker subprocesses, and topology-validated resume. Supports
  both stride and Feistel-shuffled modes (shuffle=True delegates to
  LazyShuffledRange). Iterators plug in by storing self._iter_state and
  forwarding state_dict()/load_state_dict().
* Refactor LazyIndexedManifestIterator to delegate to the helper. Drops
  ~50 lines of duplicated position/shuffle/topology plumbing; the body
  is now a single for-loop yielding from self._iter_state.iterate(...).
* test/test_partition.py: 13 new helper-level tests pinning the contract
  (single-rank full coverage, multi-rank disjoint x [2,3,4,7], resume
  from middle, topology-mismatch raises, map-style fallthrough, empty
  manifest, n<world_size, neutral state_dict). Update the chain-level
  topology-mismatch test to match the helper's new error message
  ("topology mismatch on resume" instead of LazyShuffledRange's
  "state mismatch", which the helper's earlier check now preempts).
* test/test_indexing.py: rename test_read_index_remote_path_is_cached_and_memmapped
  -> ..._is_cached_locally and drop the isinstance(np.memmap) assertions.
  read_index switched from np.memmap to np.fromfile (deliberately, to
  avoid exhausting kernel vm.max_map_count at 80k+ shards); the tests
  were checking for the old behavior and had been failing pre-emptively.
  Cache-correctness assertions are kept; the dtype check is replaced by
  np.array_equal against the source bytes.

62/62 partition tests + 117/117 indexing tests pass under nemo312-hf5.
Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com>
…ator

The shar reader's own ``_get_worker_indices`` + ``_position``
machinery was the last indexed-iterator path still using legacy
``split_for_dataloading``, which (a) defaults to False so iterable
NeMo dataloaders never partitioned shar leaves across DP ranks,
and (b) doesn't record ``(shard_id, num_shards)`` in state_dict
so resume under a different topology silently desyncs. Caught by
the 0909 SALM validator: 4072 AMI cuts duplicated across 8 ranks.

Delegate iteration + state to ``PartitionedIndexedIterator``,
matching every other indexed leaf (``LazyNeMoTarredIterator``,
``LazyIndexedManifestIterator``, parquet/jsonl text adapters).
``split_for_dataloading`` is kept as accepted-but-ignored for
signature compat with ``CutSet.from_shar(split_for_dataloading=...)``.

Regression test exercises 1/2/4 ranks (disjoint slices + complete
union), the env-var-off collapse for map-style mode, same-topology
resume, topology-mismatch raise, and shuffle correctness.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…GET operations

Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com>
@review-notebook-app

Copy link
Copy Markdown

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant