Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rescalability layer #1455

Draft
wants to merge 76 commits into
base: main
Choose a base branch
from
Draft

Rescalability layer #1455

wants to merge 76 commits into from

Conversation

daviswer
Copy link

Implements rescaling of checkpoints to different world sizes and numbers of workers. User specifies in advance the number of data partitions, and when saving/loading checkpoints with different total workers, stateful guarantees are maintained: seen data is not revisited until the next epoch.

Based off of the datasets in the corresponding IBM torchtitan PR, but with an adjusted rescaling and iteration mechanism to support greater flexibility and robustness (removes divisibility constraints from worker and shard counts, and guarantees only one open file per physical worker regardless of number of logical shards). Uses StatefulDataLoader and DCP to manage checkpointing from the master process. An epoch completion testing script is included for demo purposes. It is possible that the IBM datasets can be merged into the existing torchdata Nodes structure.

Changes

  • Add IBM rescalable datasets and checkpointing functions to torchdata/stateful_dataloader/ibm_rescalable.py
  • Add demo script and correctness check to examples/ibm_rescaling/rescaling_demo.py

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 25, 2025
@scotts
Copy link
Contributor

scotts commented Feb 28, 2025

Thanks for the work, @daviswer! Some first-level comments:

  1. Let's create some unit tests based on the example demo. I think it makes sense for them to live in test/stateful_dataloader.
  2. Let's name the Python file after the main abstraction users will use from so it, so ibm_rescalable.py should become scalable_reader.py.
  3. The name _WrapperDataset is really generic. From the code and your comments, I think a name closer to the capability it provides might be _NestedStatefulDataset.
  4. The class _ShardFileHandler should probably be an abstract base class. And since we anticipate that others may end up creating their own shard handlers for other formats, we should probably consider a public API, so we should drop the leading _. We might also want to break it out into its own file, such as shard_handler.py. Then all future shard handlers would go in there.

@daviswer
Copy link
Author

daviswer commented Mar 5, 2025

Thanks @scotts , I made changes 2-4 and working on unit tests now. I'll note that _StatefulDataset and _NestedStatefulDataset largely represent legacy code, gluing things together until we decide we either want to merge this into Nodes, or use these to represent stateful datasets (in which case we'll need to rework them anyways with contracts/APIs/etc. per #1456)

A preferred format as we can load document chunks without having to ever pull
the entire document or shard file, allowing for graceful handling of large documents.
Non-standard data format, though.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to confirm my understanding of the format of the pyarrow shard files. I am imagining a very large text file made up a thousands of tokens. That file is broken in multiple PyArrow shard files. Each of those PyArrow shard files is made up of multiple RecordBatches, each with a tokens field which is a list of tokens. That means each token is a 'row' (in that sense that RecordBatches are supposed to be a batch of records/rows). Is that right ?

Additionally, why do we not consider having list of tokens as a single row in the recordbatch? What is the value addition of using recordbatches here? Thanks.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The general assumption is that each pyarrow shard file represents a collection of documents, rather than a single large document getting split over multiple files. But yes, each file is a collection of RecordBatches, each of which contains a single 'row' of text (a single document) in the tokens field. We use RecordBatches because that's how pyarrow docs suggest reading/writing random-access memory-mapped files, and we put a single document per RecordBatch to minimize overhead of loading individual documents in random order.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool. Can every document (even large ones) fit in a single RecordBatch ?
The doc string mentions "as we can load document chunks without having to ever pull the entire document", here we are still referring to a single document being loaded in a RecordBatch?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes - because pyarrow files are memory-mapped, RecordBatches (and slices) are loaded lazily. Until you request the actual slice, the RecordBatch is just metadata, so it can hold a document of any size without extra overhead

@divyanshk
Copy link
Contributor

divyanshk commented Mar 14, 2025

Sharing some points that we discussed over the call.

  1. The core work here is in the data access layer, nor particularly in the data loader. I imagine we can figure out a way to give end users (think PyTorch users with established Dataset class) a RescalableDataset wrapper which converts their existing Dataset into ones which can be rescaled if one decides to start and re-start a job with a different world size. The ScalableReader is effectively that, although we should wonder if want to make the user give more inputs (like a custom file handler) or we can configure those inside the rescalable dataset wrapper.

    This can feed directly into a canonical StatefulDataLoader, with {save, load}_distributed_state_dict functionality incorporated into StatefulDataLoader's state_dict / load_state_dict methods as special cases for RescalableDataset. At this point I don't know how feasible that is (@daviswer brought up a good point whether we want to take a dependency on DCP, vs having a generic interface for any checkpointing API to work) but this seems like a simpler interface for users to onboard to.

  2. So far the implementation is solving for text-heavy AI workloads. We should also align on whether we want to extend the scope to include other styles, like for eg, a data row being an arbitrary Dict[str, Any], typical map-style datasets, typical HuggingFace vision datasets, etc.

  3. I need to look at some internal data-access layer APIs to ensure we don't diverge too much.

@scotts @daviswer

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants