diff --git a/docs/source/getting_started_with_torchdata_nodes.rst b/docs/source/getting_started_with_torchdata_nodes.rst new file mode 100644 index 000000000..3f732f9fa --- /dev/null +++ b/docs/source/getting_started_with_torchdata_nodes.rst @@ -0,0 +1,57 @@ +Getting Started With ``torchdata.nodes`` (beta) +=============================================== + +Install torchdata with pip. + +.. code:: bash + + pip install torchdata>=0.10.0 + +Generator Example +~~~~~~~~~~~~~~~~~ + +Wrap a generator (or any iterable) to convert it to a BaseNode and get started + +.. code:: python + + from torchdata.nodes import IterableWrapper, ParallelMapper, Loader + + node = IterableWrapper(range(10)) + node = ParallelMapper(node, map_fn=lambda x: x**2, num_workers=3, method="thread") + loader = Loader(node) + result = list(loader) + print(result) + # [0, 1, 4, 9, 16, 25, 36, 49, 64, 81] + +Sampler Example +~~~~~~~~~~~~~~~ + +Samplers are still supported, and you can use your existing +``torch.utils.data.Dataset``\'s. See :ref:`migrate-to-nodes-from-utils` for an in-depth +example. + +.. code:: python + + import torch.utils.data + from torchdata.nodes import SamplerWrapper, ParallelMapper, Loader + + + class SquaredDataset(torch.utils.data.Dataset): + def __getitem__(self, i: int) -> int: + return i**2 + def __len__(self): + return 10 + + dataset = SquaredDataset() + sampler = RandomSampler(dataset) + + # For fine-grained control of iteration order, define your own sampler + node = SamplerWrapper(sampler) + # Simply apply dataset's __getitem__ as a map function to the indices generated from sampler + node = ParallelMapper(node, map_fn=dataset.__getitem__, num_workers=3, method="thread") + # Loader is used to convert a node (iterator) into an Iterable that may be reused for multi epochs + loader = Loader(node) + print(list(loader)) + # [25, 36, 9, 49, 0, 81, 4, 16, 64, 1] + print(list(loader)) + # [0, 4, 1, 64, 49, 25, 9, 16, 81, 36] diff --git a/docs/source/index.rst b/docs/source/index.rst index 46b92cabc..b13b99803 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -31,19 +31,26 @@ Features described in this documentation are classified by release status: binary distributions like PyPI or Conda, except sometimes behind run-time flags, and are at an early stage for feedback and testing. +.. toctree:: + :maxdepth: 2 + :caption: Developer Notes: + + what_is_torchdata_nodes.rst .. toctree:: :maxdepth: 2 :caption: API Reference: - torchdata.stateful_dataloader.rst torchdata.nodes.rst + torchdata.stateful_dataloader.rst .. toctree:: :maxdepth: 2 :caption: Tutorial and Examples: + getting_started_with_torchdata_nodes.rst + migrate_to_nodes_from_utils.rst stateful_dataloader_tutorial.rst diff --git a/docs/source/migrate_to_nodes_from_utils.rst b/docs/source/migrate_to_nodes_from_utils.rst new file mode 100644 index 000000000..7d424791e --- /dev/null +++ b/docs/source/migrate_to_nodes_from_utils.rst @@ -0,0 +1,184 @@ +.. _migrate-to-nodes-from-utils: + +Migrating to ``torchdata.nodes`` from ``torch.utils.data`` +========================================================== + +This guide is intended to help people familiar with ``torch.utils.data``, or +:class:`~torchdata.stateful_dataloader.StatefulDataLoader`, +to get started with ``torchdata.nodes``, and provide a starting ground for defining +your own dataloading pipelines. + +We'll demonstrate how to achieve the most common DataLoader features, re-use existing samplers and datasets, +and load/save dataloader state. It performs at least as well as ``DataLoader`` and ``StatefulDataLoader``, +see :ref:`how-does-nodes-perform`. + +Map-Style Datasets +~~~~~~~~~~~~~~~~~~ + +Let's look at the ``DataLoader`` constructor args and go from there + +.. code:: python + + class DataLoader: + def __init__( + self, + dataset: Dataset[_T_co], + batch_size: Optional[int] = 1, + shuffle: Optional[bool] = None, + sampler: Union[Sampler, Iterable, None] = None, + batch_sampler: Union[Sampler[List], Iterable[List], None] = None, + num_workers: int = 0, + collate_fn: Optional[_collate_fn_t] = None, + pin_memory: bool = False, + drop_last: bool = False, + timeout: float = 0, + worker_init_fn: Optional[_worker_init_fn_t] = None, + multiprocessing_context=None, + generator=None, + *, + prefetch_factor: Optional[int] = None, + persistent_workers: bool = False, + pin_memory_device: str = "", + in_order: bool = True, + ): + ... + +As a referesher, here is roughly how dataloading works in ``torch.utils.data.DataLoader``: +``DataLoader`` begins by generating indices from a ``sampler`` and creates batches of `batch_size` indices. +If no sampler is provided, then a RandomSampler or SequentialSampler is created by default. +The indices are passed to ``Dataset.__getitem__()``, and then a ``collate_fn`` is applied to the batch +of samples. If ``num_workers > 0``, it will use multi-processing to create +subprocesses, and pass the batches of indices to the worker processes, who will then call ``Dataset.__getitem__()`` and apply ``collate_fn`` +before returning the batches to the main process. At that point, ``pin_memory`` may be applied to the tensors in the batch. + +Now let's look at what an equivalent implementation for DataLoader might look like, built with ``torchdata.nodes``. + +.. code:: python + + from typing import List, Callable + import torchdata.nodes as tn + from torch.utils.data import RandomSampler, SequentialSampler, default_collate, Dataset + + class MapAndCollate: + """A simple transform that takes a batch of indices, maps with dataset, and then applies + collate. + TODO: make this a standard utility in torchdata.nodes + """ + def __init__(self, dataset, collate_fn): + self.dataset = dataset + self.collate_fn = collate_fn + + def __call__(self, batch_of_indices: List[int]): + batch = [self.dataset[i] for i in batch_of_indices] + return self.collate_fn(batch) + + # To keep things simple, let's assume that the following args are provided by the caller + def NodesDataLoader( + dataset: Dataset, + batch_size: int, + shuffle: bool, + num_workers: int, + collate_fn: Callable | None, + pin_memory: bool, + drop_last: bool, + ): + # Assume we're working with a map-style dataset + assert hasattr(dataset, "__getitem__") and hasattr(dataset, "__len__") + # Start with a sampler, since caller did not provide one + sampler = RandomSampler(dataset) if shuffle else SequentialSampler(dataset) + # Sampler wrapper converts a Sampler to a BaseNode + node = tn.SamplerWrapper(sampler) + + # Now let's batch sampler indices together + node = tn.Batcher(node, batch_size=batch_size, drop_last=drop_last) + + # Create a Map Function that accepts a list of indices, applies getitem to it, and + # then collates them + map_and_collate = MapAndCollate(dataset, collate_fn or default_collate) + + # MapAndCollate is doing most of the heavy lifting, so let's parallelize it. We could + # choose process or thread workers. Note that if you're not using Free-Threaded + # Python (eg 3.13t) with -Xgil=0, then multi-threading might result in GIL contention, + # and slow down training. + node = tn.ParallelMapper( + node, + map_fn=map_and_collate, + num_workers=num_workers, + method="process", # Set this to "thread" for multi-threading + in_order=True, + ) + + # Optionally apply pin-memory, and we usually do some pre-fetching + if pin_memory: + node = tn.PinMemory(node) + node = tn.Prefetcher(node, prefetch_factor=num_workers * 2) + + # Note that node is an iterator, and once it's exhausted, you'll need to call .reset() + # on it to start a new Epoch. + # Insteaad, we wrap the node in a Loader, which is an iterable and handles reset. It + # also provides state_dict and load_state_dict methods. + return tn.Loader(node) + +Now let's test this out with a trivial dataset, and demonstrate how state management works. + +.. code:: python + + class SquaredDataset(Dataset): + def __init__(self, len: int): + self.len = len + def __len__(self): + return self.len + def __getitem__(self, i: int) -> int: + return i**2 + + loader = NodesDataLoader( + dataset=SquaredDataset(14), + batch_size=3, + shuffle=False, + num_workers=2, + collate_fn=None, + pin_memory=False, + drop_last=False, + ) + + batches = [] + for idx, batch in enumerate(loader): + if idx == 2: + state_dict = loader.state_dict() + # Saves the state_dict after batch 2 has been returned + batches.append(batch) + + loader.load_state_dict(state_dict) + batches_after_loading = list(loader) + print(batches[3:]) + # [tensor([ 81, 100, 121]), tensor([144, 169])] + print(batches_after_loading) + # [tensor([ 81, 100, 121]), tensor([144, 169])] + +Let's also compare this to torch.utils.data.DataLoader, as a sanity check. + +.. code:: python + + loaderv1 = torch.utils.data.DataLoader( + dataset=SquaredDataset(14), + batch_size=3, + shuffle=False, + num_workers=2, + collate_fn=None, + pin_memory=False, + drop_last=False, + persistent_workers=False, # Coming soon to torchdata.nodes! + ) + print(list(loaderv1)) + # [tensor([0, 1, 4]), tensor([ 9, 16, 25]), tensor([36, 49, 64]), tensor([ 81, 100, 121]), tensor([144, 169])] + print(batches) + # [tensor([0, 1, 4]), tensor([ 9, 16, 25]), tensor([36, 49, 64]), tensor([ 81, 100, 121]), tensor([144, 169])] + + +IterableDatasets +~~~~~~~~~~~~~~~~ + +Coming soon! While you can already plug your IterableDataset into an ``tn.IterableWrapper``, some functions like +``get_worker_info`` are not currently supported yet. However we believe that often, sharding work between +multi-process workers is not actually necessary, and you can keep some sort of indexing in the main process while +only parallelizing some of the heavier transforms, similar to how Map-style Datasets work above. diff --git a/docs/source/torchdata.nodes.rst b/docs/source/torchdata.nodes.rst index d5197c0e0..2a358b6a3 100644 --- a/docs/source/torchdata.nodes.rst +++ b/docs/source/torchdata.nodes.rst @@ -1,202 +1,6 @@ -torchdata.nodes -=============== +``torchdata.nodes`` (beta) +========================== -What is ``torchdata.nodes``? ----------------------------- - -``torchdata.nodes`` is a library of composable iterators (not -iterables!) that let you chain together common dataloading and pre-proc -operations. It follows a streaming programming model, although “sampler -+ Map-style” can still be configured if you desire. - -``torchdata.nodes`` adds more flexibility to the standard -``torch.utils.data`` offering, and introduces multi-threaded parallelism -in addition to multi-process (the only supported approach in -``torch.utils.data.DataLoader``), as well as first-class support for -mid-epoch checkpointing through a ``state_dict/load_state_dict`` -interface. - -``torchdata.nodes`` strives to include as many useful operators as -possible, however it’s designed to be extensible. New nodes are required -to subclass ``torchdata.nodes.BaseNode``, (which itself subclasses -``typing.Iterator``) and implement ``next()``, ``reset(initial_state)`` -and ``get_state()`` operations (notably, not ``__next__``, -``load_state_dict``, nor ``state_dict``) - -Getting started ---------------- - -Install torchdata with pip. - -.. code:: bash - - pip install torchdata>=0.10.0 - -Generator Example -~~~~~~~~~~~~~~~~~ - -Wrap a generator (or any iterable) to convert it to a BaseNode and get -started - -.. code:: python - - from torchdata.nodes import IterableWrapper, ParallelMapper, Loader - - node = IterableWrapper(range(10)) - node = ParallelMapper(node, map_fn=lambda x: x**2, num_workers=3, method="thread") - loader = Loader(node) - result = list(loader) - print(result) - # [0, 1, 4, 9, 16, 25, 36, 49, 64, 81] - -Sampler Example -~~~~~~~~~~~~~~~ - -Samplers are still supported, and you can use your existing -``torch.utils.data.Dataset``\ s - -.. code:: python - - import torch.utils.data - from torchdata.nodes import SamplerWrapper, ParallelMapper, Loader - - - class SquaredDataset(torch.utils.data.Dataset): - def __getitem__(self, i: int) -> int: - return i**2 - def __len__(self): - return 10 - - dataset = SquaredDataset() - sampler = RandomSampler(dataset) - - # For fine-grained control of iteration order, define your own sampler - node = SamplerWrapper(sampler) - # Simply apply dataset's __getitem__ as a map function to the indices generated from sampler - node = ParallelMapper(node, map_fn=dataset.__getitem__, num_workers=3, method="thread") - # Loader is used to convert a node (iterator) into an Iterable that may be reused for multi epochs - loader = Loader(node) - print(list(loader)) - # [25, 36, 9, 49, 0, 81, 4, 16, 64, 1] - print(list(loader)) - # [0, 4, 1, 64, 49, 25, 9, 16, 81, 36] - -What’s the point of ``torchdata.nodes``? ----------------------------------------- - -We get it, ``torch.utils.data`` just works for many many use cases. -However it definitely has a bunch of rough spots: - -Multiprocessing sucks -~~~~~~~~~~~~~~~~~~~~~ - -- You need to duplicate memory stored in your Dataset (because of - Python copy-on-read) -- IPC is slow over multiprocess queues and can introduce slow startup - times -- You’re forced to perform batching on the workers instead of - main-process to reduce IPC overhead, increasing peak memory. -- With GIL-releasing functions and Free-Threaded Python, - multi-threading may not be GIL-bound like it used to be. - -``torchdata.nodes`` enables both multi-threading and multi-processing so -you can choose what works best for your particular set up. Parallelism -is primarily configured in Mapper operators giving you flexibility in -the what, when, and how to parallelize. - -Map-style and random-access doesn’t scale -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Current map dataset approach is great for datasets that fit in memory, -but true random-access is not going to be very performant once your -dataset grows beyond memory limitations unless you jump through some -hoops with a special sampler. - -``torchdata.nodes`` follows a streaming data model, where operators are -Iterators that can be combined together to define a dataloading and -pre-proc pipeline. Samplers are still supported (see example above) and -can be combined with a Mapper to produce an Iterator - -Multi-Datasets do not fit well with the current implementation in ``torch.utils.data`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -The current Sampler (one per dataloader) concepts start to break down -when you start trying to combine multiple datasets. (For single -datasets, they’re a great abstraction and will continue to be -supported!) - -- For multi-datasets, consider this scenario: ``len(dsA): 10`` - ``len(dsB): 20``. Now we want to do round-robin (or sample uniformly) - between these two datasets to feed to our trainer. With just a single - sampler, how can you implement that strategy? Maybe a sampler that - emits tuples? What if you want to swap with RandomSampler, or - DistributedSampler? How will ``sampler.set_epoch`` work? - -``torchdata.nodes`` helps to address and scale multi-dataset dataloading -by only dealing with Iterators, thereby forcing samplers and datasets -together, focusing on composing smaller primitives nodes into a more -complex dataloading pipeline. - -IterableDataset + multiprocessing requires additional dataset sharding -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Dataset sharding is required for data-parallel training, which is fairly -reasonable. But what about sharding between dataloader workers? With -Map-style datasets, distribution of work between workers is handled by -the main process, which distributes sampler indices to workers. With -IterableDatasets, each worker needs to figure out (through -``torch.utils.data.get_worker_info``) what data it should be returning. - -Design choices --------------- - -No Generator BaseNodes -~~~~~~~~~~~~~~~~~~~~~~ - -See https://github.com/pytorch/data/pull/1362 for more thoughts. - -One difficult choice we made was to disallow Generators when defining a -new BaseNode implementation. However we dropped it and moved to an -Iterator-only foundation for a few reasons around state management: - -1. We require explicit state handling in BaseNode implementations. - Generators store state implicitly on the stack and we found that we - needed to jump through hoops and write very convoluted code to get - basic state working with Generators -2. End-of-iteration state dict: Iterables may feel more natural, however - a bunch of issues come up around state management. Consider the - end-of-iteration state dict. If you load this state_dict into your - iterable, should this represent the end-of-iteration or the start of - the next iteration? -3. Loading state: If you call load_state_dict() on an iterable, most - users would expect the next iterator requested from it to start with - the loaded state. However what if iter is called twice before - iteration begins? -4. Multiple Live Iterator problem: if you have one instance of an - Iterable, but two live iterators, what does it mean to call - state_dict() on the Iterable? In dataloading, this is very rare, - however we still need to work around it and make a bunch of - assumptions. Forcing devs that are implementing BaseNodes to reason - about these scenarios is, in our opinion, worse than disallowing - generators and Iterables. - -``torchdata.nodes.BaseNode`` implementations are Iterators. Iterators -define ``next()``, ``get_state()``, and ``reset(initial_state | None)``. -All re-initialization should be done in reset(), including initializing -with a particular state if one is passed. - -However, end-users are used to dealing with Iterables, for example, - -:: - - for epoch in range(5): - # Most frameworks and users don't expect to call loader.reset() - for batch in loader: - ... - sd = loader.state_dict() - # Loading sd should not throw StopIteration right away, but instead start at the next epoch - -To handle this we keep all of the assumptions and special end-of-epoch -handling in a single ``Loader`` class which takes any BaseNode and makes -it an Iterable, handling the reset() calls and end-of-epoch state_dict -loading. +.. automodule:: torchdata.nodes + :members: + :show-inheritance: diff --git a/docs/source/what_is_torchdata_nodes.rst b/docs/source/what_is_torchdata_nodes.rst new file mode 100644 index 000000000..842809c31 --- /dev/null +++ b/docs/source/what_is_torchdata_nodes.rst @@ -0,0 +1,165 @@ +What is ``torchdata.nodes`` (beta)? +=================================== + +``torchdata.nodes`` is a library of composable iterators (not +iterables!) that let you chain together common dataloading and pre-proc +operations. It follows a streaming programming model, although “sampler ++ Map-style” can still be configured if you desire. + +``torchdata.nodes`` adds more flexibility to the standard +``torch.utils.data`` offering, and introduces multi-threaded parallelism +in addition to multi-process (the only supported approach in +``torch.utils.data.DataLoader``), as well as first-class support for +mid-epoch checkpointing through a ``state_dict/load_state_dict`` +interface. + +``torchdata.nodes`` strives to include as many useful operators as +possible, however it’s designed to be extensible. New nodes are required +to subclass ``torchdata.nodes.BaseNode``, (which itself subclasses +``typing.Iterator``) and implement ``next()``, ``reset(initial_state)`` +and ``get_state()`` operations (notably, not ``__next__``, +``load_state_dict``, nor ``state_dict``) + +See :doc:`getting_started_with_torchdata_nodes` to get started + +Why ``torchdata.nodes``? +---------------------------------------- + +We get it, ``torch.utils.data`` just works for many many use cases. +However it definitely has a bunch of rough spots: + +Multiprocessing sucks +~~~~~~~~~~~~~~~~~~~~~ + +- You need to duplicate memory stored in your Dataset (because of + Python copy-on-read) +- IPC is slow over multiprocess queues and can introduce slow startup + times +- You’re forced to perform batching on the workers instead of + main-process to reduce IPC overhead, increasing peak memory. +- With GIL-releasing functions and Free-Threaded Python, + multi-threading may not be GIL-bound like it used to be. + +``torchdata.nodes`` enables both multi-threading and multi-processing so +you can choose what works best for your particular set up. Parallelism +is primarily configured in Mapper operators giving you flexibility in +the what, when, and how to parallelize. + +Map-style and random-access doesn’t scale +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Current map dataset approach is great for datasets that fit in memory, +but true random-access is not going to be very performant once your +dataset grows beyond memory limitations unless you jump through some +hoops with a special sampler. + +``torchdata.nodes`` follows a streaming data model, where operators are +Iterators that can be combined together to define a dataloading and +pre-proc pipeline. Samplers are still supported (see example above) and +can be combined with a Mapper to produce an Iterator + +Multi-Datasets do not fit well with the current implementation in ``torch.utils.data`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The current Sampler (one per dataloader) concepts start to break down +when you start trying to combine multiple datasets. (For single +datasets, they’re a great abstraction and will continue to be +supported!) + +- For multi-datasets, consider this scenario: ``len(dsA): 10`` + ``len(dsB): 20``. Now we want to do round-robin (or sample uniformly) + between these two datasets to feed to our trainer. With just a single + sampler, how can you implement that strategy? Maybe a sampler that + emits tuples? What if you want to swap with RandomSampler, or + DistributedSampler? How will ``sampler.set_epoch`` work? + +``torchdata.nodes`` helps to address and scale multi-dataset dataloading +by only dealing with Iterators, thereby forcing samplers and datasets +together, focusing on composing smaller primitives nodes into a more +complex dataloading pipeline. + +IterableDataset + multiprocessing requires additional dataset sharding +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Dataset sharding is required for data-parallel training, which is fairly +reasonable. But what about sharding between dataloader workers? With +Map-style datasets, distribution of work between workers is handled by +the main process, which distributes sampler indices to workers. With +IterableDatasets, each worker needs to figure out (through +``torch.utils.data.get_worker_info``) what data it should be returning. + +.. _how-does-nodes-perform: + +How does ``torchdata.nodes`` perform? +------------------------------------- + +We presented some results from an early version of ``torchdata.nodes`` +on a video-decoding benchmark at `PyTorch Conf 2024 `_ +where we showed that: + +* torchdata.nodes performs on-par or better with ``torch.utils.data.DataLoader`` + when using multi-processing (see :ref:`migrate-to-nodes-from-utils`) + +* With GIL python, torchdata.nodes with multi-threading performs better than + multi-processing in some scenarios, but makes features like GPU pre-proc + easier to perform which can boost + +We ran a benchmark loading the Imagenet dataset from disk, +and manage to saturate main-memory bandwidth with Free-Threaded Python (3.13t) +at a significantly lower CPU utilization than with multi-process workers +(blogpost expected eary 2025). See ``examples/nodes/imagenet_benchmark.py``. + + +Design choices +-------------- + +No Generator BaseNodes +~~~~~~~~~~~~~~~~~~~~~~ + +See https://github.com/pytorch/data/pull/1362 for more thoughts. + +One difficult choice we made was to disallow Generators when defining a +new BaseNode implementation. However we dropped it and moved to an +Iterator-only foundation for a few reasons around state management: + +1. We require explicit state handling in BaseNode implementations. + Generators store state implicitly on the stack and we found that we + needed to jump through hoops and write very convoluted code to get + basic state working with Generators +2. End-of-iteration state dict: Iterables may feel more natural, however + a bunch of issues come up around state management. Consider the + end-of-iteration state dict. If you load this state_dict into your + iterable, should this represent the end-of-iteration or the start of + the next iteration? +3. Loading state: If you call load_state_dict() on an iterable, most + users would expect the next iterator requested from it to start with + the loaded state. However what if iter is called twice before + iteration begins? +4. Multiple Live Iterator problem: if you have one instance of an + Iterable, but two live iterators, what does it mean to call + state_dict() on the Iterable? In dataloading, this is very rare, + however we still need to work around it and make a bunch of + assumptions. Forcing devs that are implementing BaseNodes to reason + about these scenarios is, in our opinion, worse than disallowing + generators and Iterables. + +``torchdata.nodes.BaseNode`` implementations are Iterators. Iterators +define ``next()``, ``get_state()``, and ``reset(initial_state | None)``. +All re-initialization should be done in reset(), including initializing +with a particular state if one is passed. + +However, end-users are used to dealing with Iterables, for example, + +.. code:: python + + for epoch in range(5): + # Most frameworks and users don't expect to call loader.reset() + for batch in loader: + ... + sd = loader.state_dict() + # Loading sd should not throw StopIteration right away, but instead start at the next epoch + +To handle this we keep all of the assumptions and special end-of-epoch +handling in a single ``Loader`` class which takes any BaseNode and makes +it an Iterable, handling the reset() calls and end-of-epoch state_dict +loading. diff --git a/examples/nodes/hf_datasets_nodes_mnist.ipynb b/examples/nodes/hf_datasets_nodes_mnist.ipynb new file mode 100644 index 000000000..92a478ad8 --- /dev/null +++ b/examples/nodes/hf_datasets_nodes_mnist.ipynb @@ -0,0 +1,182 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "1cb04783-4c83-43bc-91c6-b980d836de34", + "metadata": {}, + "source": [ + "### Loading and processing the MNIST dataset\n", + "In this example, we will load the MNIST dataset from Hugging Face, \n", + "use `torchdata.nodes` to process it and generate training batches." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "93026478-6dbd-4ac0-8507-360a3a2000c5", + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "# Load the mnist dataset from HuggingFace datasets and convert the format to \"torch\"\n", + "dataset = load_dataset(\"ylecun/mnist\").with_format(\"torch\")\n", + "\n", + "# Getting the train dataset\n", + "dataset = dataset[\"train\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "b0c46c4b-0194-4127-a218-e24ec54a3149", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from torch.utils.data import default_collate, RandomSampler, SequentialSampler\n", + "\n", + "torch.manual_seed(42)\n", + "\n", + "# Defining samplers\n", + "# Since datasets is a Map-style dataset, we can setup a sampler to shuffle the data\n", + "sampler = RandomSampler(dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "6f643c48-c6fb-4e8a-9461-fdf96b45b04b", + "metadata": {}, + "outputs": [], + "source": [ + "# Now we can set up some torchdata.nodes to create our pre-proc pipeline\n", + "from torchdata.nodes import MapStyleWrapper, ParallelMapper, Batcher, PinMemory, Loader\n", + "\n", + "# All torchdata.nodes.BaseNode implementations are Iterators.\n", + "# MapStyleWrapper creates an Iterator that combines sampler and dataset to create an iterator.\n", + "#\n", + "# Under the hood, MapStyleWrapper just does:\n", + "# > node = IterableWrapper(sampler)\n", + "# > node = Mapper(node, map_fn=dataset.__getitem__) # You can parallelize this with ParallelMapper\n", + "\n", + "node = MapStyleWrapper(map_dataset=dataset, sampler=sampler)\n", + "\n", + "# Now we want to transform the raw inputs. We can just use another Mapper with\n", + "# a custom map_fn to perform this. Using ParallelMapper allows us to use multiple\n", + "# threads (or processes) to parallelize this work and have it run in the background\n", + "# We need a mapper function to convert a dtype and also normalize\n", + "def map_fn(item):\n", + " image = item[\"image\"].to(torch.float32)/255\n", + " label = item[\"label\"]\n", + "\n", + " return {\"image\":image, \"label\":label}\n", + " \n", + "node = ParallelMapper(node, map_fn=map_fn, num_workers=2) # output items are Dict[str, tensor]\n", + "\n", + "\n", + "# Hyperparameters\n", + "batch_size = 2 \n", + "\n", + "# Next we batch the inputs, and then apply a collate_fn with another Mapper\n", + "# to stack the tensor. We use torch.utils.data.default_collate for this\n", + "node = Batcher(node, batch_size=batch_size) # output items are List[Dict[str, tensor]]\n", + "node = ParallelMapper(node, map_fn=default_collate, num_workers=2) # outputs are Dict[str, tensor]\n", + "\n", + "# we can optionally apply pin_memory to the batches\n", + "if torch.cuda.is_available():\n", + " node = PinMemory(node)\n", + "\n", + "# Since nodes are iterators, they need to be manually .reset() between epochs.\n", + "# Instead, we can wrap the root node in Loader to convert it to a more conventional Iterable.\n", + "loader = Loader(node)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "c97a79ba-e6b3-4ac7-a4c5-edc8f9c58ff4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'image': tensor([[[[0., 0., 0., ..., 0., 0., 0.],\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", + " ...,\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", + " [0., 0., 0., ..., 0., 0., 0.]]],\n", + "\n", + "\n", + " [[[0., 0., 0., ..., 0., 0., 0.],\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", + " ...,\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", + " [0., 0., 0., ..., 0., 0., 0.]]]]), 'label': tensor([1, 4])}\n", + "There are 2 samples in this batch\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOUAAAH4CAYAAAC19irnAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAUqElEQVR4nO3dbWyV9f3H8c+x3JQxBSmtDqKFWkjpUqezgwZLLN6sTsnSLui2LMPGhCUOXcdA0QdSdpOxTpkE8aaZE2x4BhbjBnHLYskyU1uJA4QJFEKHNA5a6lpYw013rv+Daf+ycl0t5fT008P7lfCA871+h98xvP0BV89pLAiCQABsXDXcGwBwIaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaI01dLSolgspmeffTZhz7ljxw7FYjHt2LEjYc+JxCPKBNq4caNisZh27tw53FsZEgcOHNDSpUs1d+5cpaenKxaLqaWlZbi3lXKIEgPW0NCgdevW6dSpU5o1a9ZwbydlESUG7Jvf/Kb+9a9/6YMPPtD3vve94d5OyiLKJDt37pxWrlyp2267TRMmTND48eM1b9481dfXh6557rnnlJ2drXHjxumOO+7Q3r17+1yzf/9+LVy4UJMmTVJ6eroKCwv15ptv9ruf7u5u7d+/X+3t7f1eO2nSJF199dX9XofLQ5RJ1tXVpVdeeUUlJSWqrq7WqlWr1NbWptLSUu3atavP9bW1tVq3bp2WLFmip556Snv37tWdd96p48eP916zb98+FRUV6cMPP9STTz6pNWvWaPz48SorK9PWrVsj99PU1KRZs2Zp/fr1iX6pGKRRw72BK821116rlpYWjRkzpvexxYsXKy8vT88//7x+97vfXXD9oUOH1NzcrKlTp0qS7r33Xs2ZM0fV1dX6zW9+I0mqrKzUjTfeqPfee09jx46VJP3whz9UcXGxVqxYofLy8iS9OiQCJ2WSpaWl9QYZj8fV0dGhnp4eFRYW6v333+9zfVlZWW+QkjR79mzNmTNH27dvlyR1dHTo7bff1oMPPqhTp06pvb1d7e3tOnnypEpLS9Xc3KzW1tbQ/ZSUlCgIAq1atSqxLxSDRpTD4LXXXtPNN9+s9PR0ZWRkKDMzU9u2bVNnZ2efa2fMmNHnsZkzZ/beijh06JCCINDTTz+tzMzMC35UVVVJkk6cODGkrweJxR9fk2zTpk2qqKhQWVmZHn/8cWVlZSktLU2rV6/W4cOHL/n54vG4JGn58uUqLS296DW5ubmXtWckF1Em2ZYtW5STk6O6ujrFYrHexz871f5Xc3Nzn8cOHjyoadOmSZJycnIkSaNHj9bdd9+d+A0j6fjja5KlpaVJkj7/eWWNjY1qaGi46PVvvPHGBX8nbGpqUmNjo77xjW9IkrKyslRSUqKamhp9/PHHfda3tbVF7udSbokgOTgph8Crr76qt956q8/jlZWVWrBggerq6lReXq77779fR44c0csvv6z8/HydPn26z5rc3FwVFxfrkUce0dmzZ7V27VplZGToiSee6L3mhRdeUHFxsQoKCrR48WLl5OTo+PHjamho0LFjx7R79+7QvTY1NWn+/Pmqqqrq9x97Ojs79fzzz0uS3nnnHUnS+vXrNXHiRE2cOFGPPvroQP7zoD8BEmbDhg2BpNAfH330URCPx4Nf/vKXQXZ2djB27Njg1ltvDf7whz8EDz30UJCdnd37XEeOHAkkBc8880ywZs2a4IYbbgjGjh0bzJs3L9i9e3efX/vw4cPBokWLguuvvz4YPXp0MHXq1GDBggXBli1beq+pr68PJAX19fV9Hquqqur39X22p4v9+PzecXliQcDnvgJO+DslYIYoATNECZghSsAMUQJmiBIwQ5SAmQF/Rc/nv04TwOAM5MsCOCkBM0QJmCFKwAxRAmaIEjBDlIAZ3uScQn7wgx+EzmpqaiLXTp8+PXTG9wtJLk5KwAxRAmaIEjBDlIAZogTMECVghigBM9ynTCHvvvtu6Ky/twzdfvvtoTPuUyYXJyVghigBM0QJmCFKwAxRAmaIEjDDLRFIkqZMmTLcW8CnOCkBM0QJmCFKwAxRAmaIEjBDlIAZogTMECVghigBM0QJmCFKwAxRAmaIEjBDlIAZogTMECVghigBM0QJmCFKwAxRAmaIEjBDlIAZogTMECVghigBM0QJmCFKwAxRAmaIEjBDlIAZogTMECVghigBM0QJmCFKwAxRAmaIEjAzarg3AA8HDhwY7i3gU5yUgBmiBMwQJWCGKAEzRAmYIUrADLdEUkheXt6g1+7ZsyeBO8Hl4KQEzBAlYIYoATNECZghSsAMUQJmiBIww33KFPLlL395uLeABOCkBMwQJWCGKAEzRAmYIUrADFECZrglkkK+/vWvD/cWkACclIAZogTMECVghigBM0QJmCFKwAxRAma4T5lCrrnmmtDZvn37Ite2trYmejsYJE5KwAxRAmaIEjBDlIAZogTMECVghlsiI8h1110XOc/KygqdtbS0RK49f/78YLaEIcBJCZghSsAMUQJmiBIwQ5SAGaIEzBAlYIb7lCNIRkbGoOcvvvhioreDIcJJCZghSsAMUQJmiBIwQ5SAGaIEzHBL5AqxefPm4d4CBoiTEjBDlIAZogTMECVghigBM0QJmCFKwAz3KUeQhx9+eNBrz5w5k8CdYChxUgJmiBIwQ5SAGaIEzBAlYIYoATPcEkmyoqKiyPnEiRNDZ+Xl5YP+defNmxc5X7JkSeiss7Mzcu2uXbtCZ9u3b49ce+7cucj5lYiTEjBDlIAZogTMECVghigBM0QJmCFKwEwsCIJgQBfGYkO9FyvZ2dmhs7Kyssi1CxcuDJ31d58yLS0tcj7S/Pvf/46cR/32+/3vfx+5dtGiRaGz//znP9EbGyYDyY2TEjBDlIAZogTMECVghigBM0QJmBnRb93q7/bBY489FjqrrKyMXDtlypTQ2ejRo6M3FqG5uTly3tjYGDr705/+NOhftz8nT54Mne3Zsydy7fz580NnCxYsiFw7ZsyY0Nmf//znyLXxeDxyPlJxUgJmiBIwQ5SAGaIEzBAlYIYoATNECZgZ0W/dys3NjZwfPHhw0M/d1dUVOnvooYci127bti101t+9tVS994b/4q1bwAhElIAZogTMECVghigBM0QJmBnRt0SAkYZbIsAIRJSAGaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaIEzBAlYGbUQC8MgmAo9wHgU5yUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKE21tLQoFovp2WefTdhz7tixQ7FYTDt27EjYcyLxiDKBNm7cqFgspp07dw73VpLinnvuUSwW06OPPjrcW0kpRIlBqaurU0NDw3BvIyURJS7ZmTNntGzZMq1YsWK4t5KSiDLJzp07p5UrV+q2227ThAkTNH78eM2bN0/19fWha5577jllZ2dr3LhxuuOOO7R3794+1+zfv18LFy7UpEmTlJ6ersLCQr355pv97qe7u1v79+9Xe3v7gF/Dr3/9a8XjcS1fvnzAazBwRJlkXV1deuWVV1RSUqLq6mqtWrVKbW1tKi0t1a5du/pcX1tbq3Xr1mnJkiV66qmntHfvXt155506fvx47zX79u1TUVGRPvzwQz355JNas2aNxo8fr7KyMm3dujVyP01NTZo1a5bWr18/oP0fPXpUv/rVr1RdXa1x48Zd0mvHAAVImA0bNgSSgvfeey/0mp6enuDs2bMXPPbJJ58E1113XfDwww/3PnbkyJFAUjBu3Ljg2LFjvY83NjYGkoKlS5f2PnbXXXcFBQUFwZkzZ3ofi8fjwdy5c4MZM2b0PlZfXx9ICurr6/s8VlVVNaDXuHDhwmDu3Lm9P5cULFmyZEBrMTCclEmWlpamMWPGSJLi8bg6OjrU09OjwsJCvf/++32uLysr09SpU3t/Pnv2bM2ZM0fbt2+XJHV0dOjtt9/Wgw8+qFOnTqm9vV3t7e06efKkSktL1dzcrNbW1tD9lJSUKAgCrVq1qt+919fX6/XXX9fatWsv7UXjkhDlMHjttdd08803Kz09XRkZGcrMzNS2bdvU2dnZ59oZM2b0eWzmzJlqaWmRJB06dEhBEOjpp59WZmbmBT+qqqokSSdOnLjsPff09OhHP/qRvv/97+trX/vaZT8fwg34G/wgMTZt2qSKigqVlZXp8ccfV1ZWltLS0rR69WodPnz4kp8vHo9LkpYvX67S0tKLXpObm3tZe5b++3fbAwcOqKampvd/CJ85deqUWlpalJWVpS984QuX/Wtd6YgyybZs2aKcnBzV1dUpFov1Pv7Zqfa/mpub+zx28OBBTZs2TZKUk5MjSRo9erTuvvvuxG/4U0ePHtX58+d1++2395nV1taqtrZWW7duVVlZ2ZDt4UpBlEmWlpYm6b/fWvCzKBsbG9XQ0KAbb7yxz/VvvPGGWltbe/9e2dTUpMbGRv34xz+WJGVlZamkpEQ1NTV67LHH9KUvfemC9W1tbcrMzAzdT3d3t44eParJkydr8uTJodd95zvf0S233NLn8fLyct13331avHix5syZE/naMTBEOQReffVVvfXWW30er6ys1IIFC1RXV6fy8nLdf//9OnLkiF5++WXl5+fr9OnTfdbk5uaquLhYjzzyiM6ePau1a9cqIyNDTzzxRO81L7zwgoqLi1VQUKDFixcrJydHx48fV0NDg44dO6bdu3eH7rWpqUnz589XVVVV5D/25OXlKS8v76Kz6dOnc0ImEFEOgZdeeumij1dUVKiiokL//Oc/VVNToz/+8Y/Kz8/Xpk2btHnz5ot+ofiiRYt01VVXae3atTpx4oRmz56t9evXX3Ai5ufna+fOnfrpT3+qjRs36uTJk8rKytKtt96qlStXDtXLxBCJBQHfohlwwi0RwAxRAmaIEjBDlIAZogTMECVghigBMwP+4oHPf50mgMEZyJcFcFICZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMyMGu4NXI7CwsLI+bJly0Jn3/3udxO9nZS1YsWKyPlVV4X/v3316tWJ3k7K46QEzBAlYIYoATNECZghSsAMUQJmiBIwM6LvU/b09ETO77nnntBZf/feqqurB7WnVNTd3R05LygoSNJOrgyclIAZogTMECVghigBM0QJmCFKwMyIviWya9euyHltbW3orLS0NHItt0T+36hR0b9NiouLk7STKwMnJWCGKAEzRAmYIUrADFECZogSMEOUgJkRfZ+yP11dXcO9hZSwefPmyPkvfvGL0NnkyZMj17a3tw9qT6mMkxIwQ5SAGaIEzBAlYIYoATNECZhJ6VsiSIxjx45Fzs+dOxc6e+CBByLXvvTSS4PaUyrjpATMECVghigBM0QJmCFKwAxRAmaIEjCT0vcpz58/HzqLxWJJ3Elqa2xsDJ3NnTs3ci33KfvipATMECVghigBM0QJmCFKwAxRAmZS+pbItm3bQmd33XVXEneS2t55553Q2U9+8pMk7iQ1cFICZogSMEOUgBmiBMwQJWCGKAEzRAmYiQVBEAzowhH4VqdbbrkldPaXv/wlcm1+fn7orL+PXLzSTJgwIXTW0tISufbaa69N8G68DSQ3TkrADFECZogSMEOUgBmiBMwQJWAmpd+6dfr06dDZqFHRL/3b3/526GzNmjWD3lMq6uzsDJ1dc801kWuLiopCZ+++++6g9zSScVICZogSMEOUgBmiBMwQJWCGKAEzRAmYSem3bkV5/fXXI+dRr/db3/pWoreTsuLxeOT8mWeeCZ2tWLEi0dsZdrx1CxiBiBIwQ5SAGaIEzBAlYIYoATMp/datKFu2bImcX2lvz0pLSwudXX311YN+3r/97W+R85kzZ4bOsrOzI9d2dXWFzj755JPojRnjpATMECVghigBM0QJmCFKwAxRAmaIEjBzxd6n/PjjjyPn119/fejspptuilx7+PDh0FleXl7k2nvvvTdyHuWGG24InX31q1+NXPvFL34xdDZ9+vTItVH3C6dNmxa5tqCgIHT2la98JXLtz3/+89DZhg0bItc646QEzBAlYIYoATNECZghSsAMUQJmrthPs+vvn+obGhpCZ/297eujjz4KnVVVVUWu7ejoCJ3t2bMncm3Uvvp7C1WU1tbWyHlbW1vobOXKlZFrKyoqQmc5OTmRa0ciPs0OGIGIEjBDlIAZogTMECVghigBM0QJmLli71P2Z9myZaGzn/3sZ5Frm5qaQmeVlZWRa1taWkJnUW+RcjVlypTI+c6dO0NnUR8/KUmnT58e1J6GE/cpgRGIKAEzRAmYIUrADFECZogSMHPFfppdf/7617+Gzj744IPItUuXLg2d9ff2q1TT3d0dOY/61MAHHnggcu1I/sS6KJyUgBmiBMwQJWCGKAEzRAmYIUrADFECZrhPGaKxsTF0VlRUlMSdpLaotwRmZ2cncSc+OCkBM0QJmCFKwAxRAmaIEjBDlIAZbolgWEV9utsAP2gx5XBSAmaIEjBDlIAZogTMECVghigBM0QJmOG7bmFIpaenR87//ve/h87+8Y9/RK6dP3/+oPY0nPiuW8AIRJSAGaIEzBAlYIYoATNECZjhrVsYUmfOnImc//a3vw2d9fddt1IVJyVghigBM0QJmCFKwAxRAmaIEjBDlIAZ7lNiWG3evDl0xn1KABaIEjBDlIAZogTMECVghigBM3yaHZBEfJodMAIRJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZgb8XbcG+EmUAC4TJyVghigBM0QJmCFKwAxRAmaIEjBDlIAZogTMECVg5v8AWWLQZxmV+6kAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Once we have the loader, we can get batches from it over multiple epochs, to train the model\n", + "# Let us look at one batch \n", + "import matplotlib.pyplot as plt\n", + "fig, axs = plt.subplots(2, figsize=(8, 6))\n", + "\n", + "batch = next(iter(loader))\n", + " \n", + "\n", + "print(batch)\n", + "print(f\"There are {len(batch)} samples in this batch\")\n", + "\n", + "# Since we used default_collate, each batch is a dictionary, with two keys: \"image\" and \"label\"\n", + "# The value of key \"image\" is a stacked tensor of images in the batch\n", + "# Similarly, the value of key \"label\" is a stacked tensor of labels in the batch\n", + "images = batch[\"image\"]\n", + "labels = batch[\"label\"]\n", + "\n", + "#let's also display the two items\n", + "for i in range(len(images)):\n", + " axs[i].imshow(images[i].squeeze(), cmap='gray')\n", + " axs[i].set_title(f\"Label: {labels[i]}\") \n", + " axs[i].set_axis_off()\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/nodes/hf_imdb_bert.ipynb b/examples/nodes/hf_imdb_bert.ipynb new file mode 100644 index 000000000..dd13cad59 --- /dev/null +++ b/examples/nodes/hf_imdb_bert.ipynb @@ -0,0 +1,161 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "d8513771-36ac-4d03-b890-35108bce2211", + "metadata": {}, + "source": [ + "### Loading and processing IMDB movie review dataset\n", + "In this example, we will load the IMDB dataset from Hugging Face, \n", + "use `torchdata.nodes` to process it and generate training batches." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "eb3b507c-2ad1-410d-a834-6847182de684", + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "from transformers import BertTokenizer, BertForSequenceClassification" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "089f1126-7125-4274-9d71-5c949ccc7bbd", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from torch.utils.data import default_collate, RandomSampler, SequentialSampler" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "2afac7d9-3d66-4195-8647-dc7034d306f2", + "metadata": {}, + "outputs": [], + "source": [ + "# Load IMDB dataset from huggingface datasets and select the \"train\" split\n", + "dataset = load_dataset(\"imdb\", streaming=False)\n", + "dataset = dataset[\"train\"]\n", + "# Since dataset is a Map-style dataset, we can setup a sampler to shuffle the data\n", + "# Please refer to the migration guide here https://pytorch.org/data/main/migrate_to_nodes_from_utils.html\n", + "# to migrate from torch.utils.data to torchdata.nodes\n", + "\n", + "sampler = RandomSampler(dataset)\n", + "# Use a standard bert tokenizer\n", + "tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")\n", + "# Now we can set up some torchdata.nodes to create our pre-proc pipeline" + ] + }, + { + "cell_type": "markdown", + "id": "09e08a47-573c-4d32-9a02-36cd8150db60", + "metadata": {}, + "source": [ + "All torchdata.nodes.BaseNode implementations are Iterators.\n", + "MapStyleWrapper creates an Iterator that combines sampler and dataset to create an iterator.\n", + "Under the hood, MapStyleWrapper just does:\n", + "```python\n", + "node = IterableWrapper(sampler)\n", + "node = Mapper(node, map_fn=dataset.__getitem__) # You can parallelize this with ParallelMapper\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "02af5479-ee69-41d8-ab2d-bf154b84bc15", + "metadata": {}, + "outputs": [], + "source": [ + "from torchdata.nodes import MapStyleWrapper, ParallelMapper, Batcher, PinMemory, Loader\n", + "node = MapStyleWrapper(map_dataset=dataset, sampler=sampler)\n", + "\n", + "# Now we want to transform the raw inputs. We can just use another Mapper with\n", + "# a custom map_fn to perform this. Using ParallelMapper allows us to use multiple\n", + "# threads (or processes) to parallelize this work and have it run in the background\n", + "max_len = 512\n", + "batch_size = 2\n", + "def bert_transform(item):\n", + " encoding = tokenizer.encode_plus(\n", + " item[\"text\"],\n", + " add_special_tokens=True,\n", + " max_length=max_len,\n", + " padding=\"max_length\",\n", + " truncation=True,\n", + " return_attention_mask=True,\n", + " return_tensors=\"pt\",\n", + " )\n", + " return {\n", + " \"input_ids\": encoding[\"input_ids\"].flatten(),\n", + " \"attention_mask\": encoding[\"attention_mask\"].flatten(),\n", + " \"labels\": torch.tensor(item[\"label\"], dtype=torch.long),\n", + " }\n", + "node = ParallelMapper(node, map_fn=bert_transform, num_workers=2) # output items are Dict[str, tensor]\n", + "\n", + "# Next we batch the inputs, and then apply a collate_fn with another Mapper\n", + "# to stack the tensors between. We use torch.utils.data.default_collate for this\n", + "node = Batcher(node, batch_size=batch_size) # output items are List[Dict[str, tensor]]\n", + "node = ParallelMapper(node, map_fn=default_collate, num_workers=2) # outputs are Dict[str, tensor]\n", + "\n", + "# we can optionally apply pin_memory to the batches\n", + "if torch.cuda.is_available():\n", + " node = PinMemory(node)\n", + "\n", + "# Since nodes are iterators, they need to be manually .reset() between epochs.\n", + "# We can wrap the root node in Loader to convert it to a more conventional Iterable.\n", + "loader = Loader(node)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "60fd54f3-62ef-47aa-a790-853cb4899f13", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'input_ids': tensor([[ 101, 1045, 2572, ..., 2143, 2000, 102],\n", + " [ 101, 2004, 1037, ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, ..., 1, 1, 1],\n", + " [1, 1, 1, ..., 0, 0, 0]]), 'labels': tensor([0, 1])}\n" + ] + } + ], + "source": [ + "# Inspect a batch\n", + "batch = next(iter(loader))\n", + "print(batch)\n", + "# In a batch we get three keys, as defined in the method `bert_transform`.\n", + "# Since the batch size is 2, two samples are stacked together for each key." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/nodes/imagenet_benchmark.py b/examples/nodes/imagenet_benchmark.py new file mode 100644 index 000000000..0d02f738b --- /dev/null +++ b/examples/nodes/imagenet_benchmark.py @@ -0,0 +1,242 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# From within the data directory run: +# > IMGNET_TRAIN=/path/to/imagenet/train +# > python examples/nodes/imagenet_benchmark.py --loader=process -d $IMGNET_TRAIN --max-steps 1000 --num-workers 4 +# +# For FT-python, you need python 3.13t and run as: +# > python -Xgil=0 examples/nodes/imagenet_benchmark.py --loader=process -d $IMGNET_TRAIN --max-steps 1000 --num-workers 4 +# +# Some example runs on Linux, with Python 3.13t below, using 4 workers +# ================================================================================ +# Baseline, with torch.utils.data.DataLoader: +# > python -Xgil=1 examples/nodes/imagenet_benchmark.py --loader=classic -d $IMGNET_TRAIN --max-steps 1000 --num-workers 4 +# 835.2034686705912 img/sec, 52.20021679191195 batches/sec +# +# torchdata.nodes with Multi-Processing: +# > python -Xgil=1 examples/nodes/imagenet_benchmark.py --loader=process -d $IMGNET_TRAIN --max-steps 1000 --num-workers 4 +# 905.5019281357543 img/sec, 56.59387050848464 batches/sec +# +# torchdata.nodes with Multi-Threading with the GIL: +# > python -Xgil=1 examples/nodes/imagenet_benchmark.py --loader=thread -d $IMGNET_TRAIN --max-steps 1000 --num-workers 4 +# 692.0924763926637 img/sec, 43.25577977454148 batches/sec +# +# torchdata.nodes with Multi-Threading with no GIL: +# > python -Xgil=0 examples/nodes/imagenet_benchmark.py --loader=thread -d $IMGNET_TRAIN --max-steps 1000 --num-workers 4 +# 922.3858393659006 img/sec, 57.649114960368784 batches/sec + +import argparse + +import os +import time +from typing import Any, Iterator + +import torch.utils.data +import torchdata.nodes as tn + +import torchvision.transforms.functional as F +from PIL import Image +from torch.utils.data import default_collate + + +class ImagenetTransform: + """Decode, transform, and crop to 224x224. + If called with a list of dicts, collates the results. + """ + + def __call__(self, data): + if isinstance(data, list): + return default_collate([self.transform_one(x) for x in data]) + else: + return self.transform_one(data) + + def transform_one(self, data): + img = Image.open(data["img_path"]).convert("RGB") + img_tensor = F.pil_to_tensor(img) + img_tensor = F.center_crop(img_tensor, [224, 224]) + data["img"] = img_tensor + return data + + +class ImagenetLister: + """Access imagenet data through either __getitem__, or an iterator. + If using an iterator, will loop forever, in order + """ + + def __init__(self, path: str): + self.path = path + self.img_labels = [] + self.img_paths = [] + for label in os.listdir(path): + for img_path in os.listdir(os.path.join(path, label)): + self.img_labels.append(label) + self.img_paths.append(os.path.join(path, label, img_path)) + + assert len(self.img_labels) == len(self.img_paths), ( + len(self.img_labels), + len(self.img_paths), + ) + + def __getitem__(self, i: int) -> dict: + data = {"img_path": self.img_paths[i]} + return data + + def __len__(self): + return len(self.img_labels) + + def __iter__(self) -> Iterator[dict]: + while True: # Loop forever + for i in range(len(self.img_labels)): + yield {"img_path": self.img_paths[i]} + + +class ImagenetDataset(torch.utils.data.Dataset): + """Classic DataLoader v1-style dataset (map style). Applies ImagenetTransform when + retrieving items. + """ + + def __init__(self, path: str): + self.imagenet_data = ImagenetLister(path) + self.tx = ImagenetTransform() + + def __len__(self): + return len(self.imagenet_data) + + def __getitem__(self, i: int) -> dict: + return self.tx(self.imagenet_data[i]) + + +def setup_classic(args): + dataset = ImagenetDataset(args.imagenet_dir) + assert args.in_order is False, "torch.utils.data.DataLoader does not support out-of-order iteration yet!" + loader = torch.utils.data.DataLoader( + dataset, + num_workers=args.num_workers, + batch_size=args.batch_size, + pin_memory=args.pin_memory, + shuffle=args.shuffle, + ) + return loader + + +def setup(args): + assert args.loader in ("thread", "process") + if args.shuffle: + dataset = ImagenetLister(args.imagenet_dir) + sampler = torch.utils.data.RandomSampler(dataset) + node = tn.MapStyleWrapper(map_dataset=dataset, sampler=sampler) + else: + node = tn.IterableWrapper(ImagenetLister(args.imagenet_dir)) + + node = tn.Batcher(node, batch_size=args.batch_size) + node = tn.ParallelMapper( + node, + map_fn=ImagenetTransform(), + num_workers=args.num_workers, + method=args.loader, + ) + if args.pin_memory: + node = tn.PinMemory(node) + node = tn.Prefetcher(node, prefetch_factor=2) + + return tn.Loader(node) + + +def run_benchmark(args): + print(f"Running benchmark with {args=}...") + loader: Any + if args.loader == "classic": + loader = setup_classic(args) + elif args.loader in ("thread", "process"): + loader = setup(args) + else: + raise ValueError(f"Unknown loader {args.loader}") + + start = time.perf_counter() + it = iter(loader) + create_iter_dt = time.perf_counter() - start + print(f"create iter took {create_iter_dt} seconds") + + start = time.perf_counter() + if args.warmup_steps: + for i in range(args.warmup_steps): + next(it) + print(f"{args.warmup_steps} warmup steps took {time.perf_counter() - start} seconds") + warmup_dt = time.perf_counter() - start + + i: int = 0 + progress_freq = 100 + last_reported: float = time.perf_counter() + start = time.perf_counter() + for i in range(args.max_steps): + if i % progress_freq == 0 or time.perf_counter() - last_reported > 5.0: + print(f"{i} / {args.max_steps}, {time.perf_counter() - start} seconds elapsed") + last_reported = time.perf_counter() + next(it) + if time.perf_counter() - start > args.max_duration: + print(f"reached {args.max_duration=}") + break + + iter_time = time.perf_counter() - start + print( + "=" * 80 + "\n" + f"{args=}\n" + f"Benchmark complete, {i} steps took {iter_time} seconds, " + f"for a total of {i * args.batch_size} images\n" + f"{i * args.batch_size / iter_time} img/sec, {i / iter_time} batches/sec\n" + f"{create_iter_dt=}, {warmup_dt=}, {sum((create_iter_dt, warmup_dt, iter_time))=}", + ) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--loader", + default="thread", + choices=["thread", "process", "classic"], + help="Whether to use multi-threaded parallelism, multi-process parallelism, or the classic torch.utils.data.DataLoader (multi-process only)", + ) + parser.add_argument( + "--num-workers", + type=int, + default=4, + help="Number of workers to parallelize with", + ) + parser.add_argument("--batch-size", type=int, default=16, help="Batch size for dataloading") + parser.add_argument("--in-order", type=bool, default=False, help="Whether to enforce ordering") + parser.add_argument("--shuffle", type=bool, default=False, help="Whether to shuffle the data") + parser.add_argument( + "--max-steps", + type=int, + default=10000, + help="Maximum number of batches to load for the benchmark", + ) + parser.add_argument( + "--max-duration", + type=int, + default=60, + help="Stop after this many seconds of benchmarking, if max-steps is not reached", + ) + parser.add_argument( + "--warmup-steps", + type=int, + default=0, + help="Number of warmup steps to take before starting timing", + ) + parser.add_argument( + "--pin-memory", + type=bool, + default=False, + help="Number of workers to parallelize with", + ) + parser.add_argument("--imagenet-dir", "-d", type=str, required=True) + args = parser.parse_args() + run_benchmark(args) + + +if __name__ == "__main__": + main() diff --git a/examples/nodes/multi_dataset_weighted_sampling.ipynb b/examples/nodes/multi_dataset_weighted_sampling.ipynb new file mode 100644 index 000000000..0a775552c --- /dev/null +++ b/examples/nodes/multi_dataset_weighted_sampling.ipynb @@ -0,0 +1,100 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "79a14c63-a085-493f-8db9-6af3e1d744b5", + "metadata": {}, + "source": [ + "### `MultiNodeWeightedSampler` example\n", + "In this notebook, we will explore the usage of `MultiNodeWeightedSampler` in `torchdata.nodes`.\n", + "\n", + "`MultiNodeWeightedSampler` allows us to sample with a probability from multiple datsets. We will make three datasets, and then see how does the composition of the output depends on the weights defined in the `MultiNodeWeightedSampler`." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "0b283748-9b3f-4b9e-bbc5-db0791f4d900", + "metadata": {}, + "outputs": [], + "source": [ + "from torchdata.nodes import Mapper, MultiNodeWeightedSampler, IterableWrapper, Loader\n", + "import collections\n", + "\n", + "# defining a simple map_fn as a place holder example\n", + "def map_fn(item):\n", + " return {\"x\":item}\n", + "\n", + "\n", + "def constant_stream(value: int):\n", + " while True:\n", + " yield value\n", + "\n", + "# First, we create a dictionary of three datasets, with each dataset converted into BaseNode using the IterableWrapper\n", + "num_datasets = 3\n", + "datasets = {\n", + " \"ds0\": IterableWrapper(constant_stream(0)),\n", + " \"ds1\": IterableWrapper(constant_stream(1)),\n", + " \"ds2\": IterableWrapper(constant_stream(2)),\n", + "}\n", + "\n", + "# Next, we have to define weights for sampling from a particular dataset\n", + "weights = {\"ds0\": 0.5, \"ds1\": 0.25, \"ds2\": 0.25}\n", + "\n", + "# Finally we instatiate the MultiNodeWeightedSampler to sample from our datasets\n", + "multi_node_sampler = MultiNodeWeightedSampler(datasets, weights)\n", + "\n", + "# Since nodes are iterators, they need to be manually .reset() between epochs.\n", + "# We can wrap the root node in Loader to convert it to a more conventional Iterable.\n", + "loader = Loader(multi_node_sampler)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "77784ba3-b917-4083-aed4-dba2374110d5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "fractions = {0: 0.49791, 2: 0.25067, 1: 0.25142}\n", + "The original weights were = {'ds0': 0.5, 'ds1': 0.25, 'ds2': 0.25}\n" + ] + } + ], + "source": [ + "# Let's take a look at the output for 100k numbers, compute the fraction of each dataset in that batch\n", + "# and compare the batch composition with our given weights\n", + "n = 100000\n", + "it = iter(loader)\n", + "samples = [next(it) for _ in range(n)]\n", + "fractions = {k: v/len(samples) for k, v in collections.Counter(samples).items()}\n", + "print(f\"fractions = {fractions}\")\n", + "print(f\"The original weights were = {weights}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/nodes/torchata_nodes_basics.ipynb b/examples/nodes/torchata_nodes_basics.ipynb new file mode 100644 index 000000000..97fe01b78 --- /dev/null +++ b/examples/nodes/torchata_nodes_basics.ipynb @@ -0,0 +1,485 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "ca0ddf27-b42d-4004-9f0c-4f43e55a245b", + "metadata": {}, + "source": [ + "### `torchdata.nodes` Basics" + ] + }, + { + "cell_type": "markdown", + "id": "f0ff94f5-fc60-4a7e-a3e5-0d7a4bfa2ddd", + "metadata": {}, + "source": [ + "All torchdata.nodes.BaseNode implementations are Iterators, adhering to the following API:\n", + "```Python\n", + "class BaseNode(Iterator[T]):\n", + " def reset(self, initial_state: Optional[Dict[str, Any]] = None) -> None:\n", + " \"\"\"Resets the node to its initial state or a specified state.\"\"\"\n", + " ...\n", + " def __next__(self) -> T:\n", + " \"\"\"Returns the next value in the sequence.\"\"\"\n", + " ...\n", + " def get_state(self) -> Dict[str, Any]:\n", + " \"\"\"Returns a dictionary representing the current state of the node.\"\"\"\n", + " ...\n", + "```\n", + "This standardized interface enables seamless chaining of iterators, allowing for flexible, efficient, and composable data processing pipelines." + ] + }, + { + "cell_type": "markdown", + "id": "5542e957-f40c-4624-9e3d-b4130d1fd03a", + "metadata": {}, + "source": [ + "Let's see the functionalities of `torchdata.nodes` through the help of a very simple example." + ] + }, + { + "cell_type": "markdown", + "id": "b5168248-7c08-419d-a9a5-f831bd7b46e7", + "metadata": {}, + "source": [ + "#### IterableWrapper" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "f42819f7-1019-48ca-ad23-00b38c72d63e", + "metadata": {}, + "outputs": [], + "source": [ + "from torchdata.nodes import IterableWrapper\n", + "# This Wrapper converts any Iterable in to a BaseNode.\n", + "\n", + "dataset = range(10) # creating a very simple dataset, and then converting it into a BaseNode\n", + "source = IterableWrapper(dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "8bb313fe-5cad-4c5f-b7f3-4bd27b0645a3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0\n", + "1\n", + "2\n", + "3\n", + "4\n", + "5\n", + "6\n", + "7\n", + "8\n", + "9\n" + ] + } + ], + "source": [ + "# Let's take a look at the items in the node\n", + "for item in source:\n", + " print(item)" + ] + }, + { + "cell_type": "markdown", + "id": "247845d5-ef4b-4bc3-92c3-93c55972578a", + "metadata": {}, + "source": [ + "#### Integrating with torch.data Dataloaders and Samplers" + ] + }, + { + "cell_type": "markdown", + "id": "97fb43c8-16c0-4c87-bb3a-ff6a0e126764", + "metadata": {}, + "source": [ + "We can also use `torch.data.utils` style dataloaders and samplers, and then wrap them into nodes.\n", + "Please refer to this [migration guide](https://pytorch.org/data/main/migrate_to_nodes_from_utils.html) to migrate from torch.utils.data to torchdata.nodes\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "a14e03a0-29f7-4824-8860-75f1f2a91533", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6\n", + "9\n", + "2\n", + "1\n", + "7\n", + "5\n", + "4\n", + "3\n", + "8\n", + "0\n" + ] + } + ], + "source": [ + "from torchdata.nodes import MapStyleWrapper\n", + "from torch.utils.data import RandomSampler\n", + "\n", + "sampler = RandomSampler(dataset)\n", + "node = MapStyleWrapper(map_dataset=dataset, sampler=sampler)\n", + "\n", + "for item in node:\n", + " print(item)" + ] + }, + { + "cell_type": "markdown", + "id": "73f6e1ce-6e33-49cd-b0bb-c49fee0e7cd6", + "metadata": {}, + "source": [ + "#### Map" + ] + }, + { + "cell_type": "markdown", + "id": "2a5246cc-6f40-4942-b514-df76c3946eee", + "metadata": {}, + "source": [ + "We can use the Mapper class, to apply a transformation defined using the `map_fn`" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "71a07a13-5d6f-46d4-b316-102698eb13c5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0\n", + "1\n", + "4\n", + "9\n", + "16\n", + "25\n", + "36\n", + "49\n", + "64\n", + "81\n" + ] + } + ], + "source": [ + "from torchdata.nodes import Mapper\n", + "node = Mapper(source, map_fn = lambda x : x**2)\n", + "for item in node:\n", + " print(item)" + ] + }, + { + "cell_type": "markdown", + "id": "0701898c-6c72-4065-aee9-23a6b516876d", + "metadata": {}, + "source": [ + "It can also be executed in parallel, using the multi threading/processing approaches, depending on the defined `method`" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "0771661d-f0dc-4d35-a5af-abab84d7efd3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0\n", + "1\n", + "4\n", + "9\n", + "16\n", + "25\n", + "36\n", + "49\n", + "64\n", + "81\n" + ] + } + ], + "source": [ + "from torchdata.nodes import ParallelMapper\n", + "mapper = ParallelMapper(source, map_fn = lambda x : x**2, num_workers =2, method = \"thread\")\n", + "for item in mapper:\n", + " print(item)" + ] + }, + { + "cell_type": "markdown", + "id": "b517ebcf-9f92-43bf-b332-3875933ea244", + "metadata": {}, + "source": [ + "#### Batch" + ] + }, + { + "cell_type": "markdown", + "id": "6156bcb4-891f-484a-bf06-cbf39e282175", + "metadata": {}, + "source": [ + "A BaseNode can be passed into a Batcher, to get batches of size `batch_size`.\n", + "By default, `drop_last` is True, meaning if the last batch has a size smaller than the `batch_size`, it is not produced." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "dc40c637-90ed-4c09-9961-3b0d488fc6d8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[0, 1, 2, 3]\n", + "[4, 5, 6, 7]\n" + ] + } + ], + "source": [ + "from torchdata.nodes import Batcher\n", + "batcher = Batcher(source, batch_size = 4)\n", + "for batch in batcher:\n", + " print(batch)" + ] + }, + { + "cell_type": "markdown", + "id": "5487e01d-6245-4940-a0df-c32aee0849cf", + "metadata": {}, + "source": [ + "We can make `drop_last = False` to produce the last batch" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "092d8e5b-aeb2-4de1-9b13-8b0dc0af31ce", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[0, 1, 2, 3]\n", + "[4, 5, 6, 7]\n", + "[8, 9]\n" + ] + } + ], + "source": [ + "batcher = Batcher(source, batch_size = 4, drop_last = False)\n", + "for batch in batcher:\n", + " print(batch)" + ] + }, + { + "cell_type": "markdown", + "id": "e11543b9-c2dd-4861-af24-0eb3a1f3ddb9", + "metadata": {}, + "source": [ + "If we try to use this batcher over multiple epochs, we will need to reset it after every epoch" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "d89e36a9-cc9c-4b55-acba-dbcaf8d53407", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch = 0 Batch = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]\n", + "Epoch = 1 Batch = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]\n" + ] + } + ], + "source": [ + "batcher = Batcher(source, batch_size = 10)\n", + "num_epochs = 2\n", + "\n", + "for epoch in range(num_epochs):\n", + " for batch in batcher:\n", + " print(f\"Epoch = {epoch}\", f\" Batch = {batch}\")\n", + " batcher.reset()\n", + " \n", + "# This is one extra step than traditional dataloader, we can actually wrap the batcher in a Loader to skip that\n", + "# Let's look at Loader in the next cell" + ] + }, + { + "cell_type": "markdown", + "id": "d2232823-0927-432b-9171-fb2bf38a7205", + "metadata": {}, + "source": [ + "#### Loader\n", + "As you can see, we get a batch in every epoch, without even needing to reset the loader!!" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "4f8a1830-49a2-4dfd-b17b-acaca7ac2e98", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch = 0 Batch = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]\n", + "Epoch = 1 Batch = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]\n" + ] + } + ], + "source": [ + "from torchdata.nodes import Loader\n", + "batcher = Batcher(source, batch_size = 10)\n", + "loader = Loader(batcher)\n", + "\n", + "for epoch in range(num_epochs):\n", + " for batch in loader:\n", + " print(f\"Epoch = {epoch}\", f\" Batch = {batch}\")" + ] + }, + { + "cell_type": "markdown", + "id": "860142c7-a39b-441b-b2b2-1cbd81a7a0f3", + "metadata": {}, + "source": [ + "#### SamplerWrapper" + ] + }, + { + "cell_type": "markdown", + "id": "18134e7d-d0e8-4096-a674-23e715fa16a3", + "metadata": {}, + "source": [ + "As mentioned earlier, we can use `torch.data.utils` samplers using `MapStyleWrapper`.\n", + "Alternatively, we can employ the `SamplerWrapper`, which converts a `Sampler` into a `BaseNode`. `SamplerWrapper` differs from `IterableWrapper` because it will track the number of epochs, and call the sampler's `set_epoch` method if it is implemented." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "f794c4d9-b3ed-4b7b-8c85-131d495e6605", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch = 0 Batch = [2, 9, 0, 5, 8, 6, 7, 3, 1, 4]\n", + "Epoch = 1 Batch = [2, 3, 6, 8, 5, 0, 1, 9, 7, 4]\n" + ] + } + ], + "source": [ + "from torchdata.nodes import SamplerWrapper\n", + "\n", + "sampler = RandomSampler(dataset)\n", + "node = SamplerWrapper(sampler)\n", + "batcher = Batcher(node, batch_size = 10)\n", + "loader = Loader(batcher)\n", + "for epoch in range(num_epochs):\n", + " \n", + " for batch in loader:\n", + " print(f\"Epoch = {node.epoch}\", f\" Batch = {batch}\")" + ] + }, + { + "cell_type": "markdown", + "id": "3e956793-df2d-4d0f-94b0-daa7fb860bf5", + "metadata": {}, + "source": [ + "`torchadata` nodes are composable, thus, many BaseNodes type nodes can be chained together for desired transformations" + ] + }, + { + "cell_type": "markdown", + "id": "5f1a6291-d688-45eb-bf2e-56153419513e", + "metadata": {}, + "source": [ + "#### Chaining multiple operations together\n", + "\n", + "Base nodes are iterators, and are designed to be chained together to create more complex dataloading graphs." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "cc4bcd30-9a76-4a1e-aa38-d7db8b9a8c2e", + "metadata": {}, + "outputs": [], + "source": [ + "sampler = RandomSampler(dataset)\n", + "node = MapStyleWrapper(map_dataset=dataset, sampler=sampler)\n", + "node = Mapper(node, map_fn = lambda x : x**3)\n", + "node = Batcher(node, batch_size = 4, drop_last = False)\n", + "loader = Loader(node)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "798f64af-102d-4b24-8327-b47d4d0fc4f2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[64, 216, 27, 729]\n", + "[0, 8, 343, 1]\n", + "[512, 125]\n" + ] + } + ], + "source": [ + "for batch in loader:\n", + " print(batch)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/test/smoke_test/smoke_test.py b/test/smoke_test/smoke_test.py index 4e55db28f..33aff5e6e 100644 --- a/test/smoke_test/smoke_test.py +++ b/test/smoke_test/smoke_test.py @@ -6,6 +6,7 @@ def stateful_dataloader_test(): + from torchdata.nodes import Loader from torchdata.stateful_dataloader import StatefulDataLoader diff --git a/torchdata/nodes/adapters.py b/torchdata/nodes/adapters.py index be392c72e..99402daa2 100644 --- a/torchdata/nodes/adapters.py +++ b/torchdata/nodes/adapters.py @@ -25,9 +25,11 @@ class IterableWrapper(BaseNode[T]): If iterable implements the Stateful Protocol, it will be saved and restored with its state_dict/load_state_dict methods. - If the iterator resulting from iter(iterable) is Stateful it is IGNORED. + Args: + iterable (Iterable[T]): Iterable to convert to BaseNode. IterableWrapper calls iter() on it. - :param iterable: Iterable to wrap. IterableWrapper calls iter() on it. + :warning: Note the distinction between state_dict/load_state_dict defined on Iterable, vs Iterator. + Only the Iterable's state_dict/load_state_dict are used. """ NUM_YIELDED_KEY = "_num_yielded" @@ -77,8 +79,9 @@ def MapStyleWrapper(map_dataset: Mapping[K, T], sampler: Sampler[K]) -> BaseNode """Thin Wrapper that converts any MapDataset in to a torchdata.node If you want parallelism, copy this and replace Mapper with ParallelMapper. - :param map_dataset: Mapping[K, T] - Apply map_dataset.__getitem__ to the outputs of sampler. - :param sampler: Sampler[K] + Args: + map_dataset (Mapping[K, T]): - Apply map_dataset.__getitem__ to the outputs of sampler. + sampler (Sampler[K]): """ sampler_node: SamplerWrapper[K] = SamplerWrapper(sampler) mapper_node = Mapper(sampler_node, map_dataset.__getitem__) @@ -91,9 +94,10 @@ class SamplerWrapper(BaseNode[T]): IterableWrapper except it includes a hook to call set_epoch on the sampler, if it supports it. - :param sampler: Sampler - to wrap. - :param initial_epoch: int - initial epoch to set on the sampler - :param epoch_updater: Optional[Callable[[int], int]] = None - callback to update epoch at start of new iteration. It's called at the beginning of each iterator request, except the first one. + Args: + sampler (Sampler): Sampler to wrap. + initial_epoch (int): initial epoch to set on the sampler + epoch_updater (Optional[Callable[[int], int]] = None): callback to update epoch at start of new iteration. It's called at the beginning of each iterator request, except the first one. """ NUM_YIELDED_KEY = "_num_yielded" diff --git a/torchdata/nodes/base_node.py b/torchdata/nodes/base_node.py index 30dedae64..33ab65043 100644 --- a/torchdata/nodes/base_node.py +++ b/torchdata/nodes/base_node.py @@ -14,54 +14,58 @@ class BaseNode(Iterator[T]): - """BaseNodes are iterators. They have the following **public** interface: - - * reset(initial_state: Optional[dict] = None) - resets iterator to either initial_state or beginning if None is passed - * state_dict() -> Dict[str, Any] - returns a state_dict that may be passed to reset() at some point in the future - * __next__() -> T - users should call next(my_instance) on the iterator in order to iterate through it. - - Base nodes also work in for loops as usual, if they are wrapped with an iter. - They can also be used directly, eg when composing BaseNodes, with a slight modification eg: - ```python - node = MyBaseNodeImpl() - loader = Loader(node) - # loader also supports state_dict() and load_state_dict() - for epoch in range(5): - for idx, batch in enumerate(loader): - ... - - # or if using node directly: - node = MyBaseNodeImpl() - for epoch in range(5): - node.reset() - for idx, batch in enumerate(loader): - ... - ``` - - Subclasses of base node must implement the following methods: - - * __init__() - must call super().__init__() - * reset(initial_state: Optional[dict]=None) - As above. Reset is a good place to put expensive - initialization, as it will be lazily called when next() or state_dict() is called. - Must call super().reset(initial_state) - * next() -> T - logic for returning the next value in the sequence, or throw StopIteration - * get_state(self) -> dict: returns a dictionary representing state that may be passed to reset() + """BaseNodes are the base class for creating composable dataloading DAGs in ``torchdata.nodes``. + Most end-users will not iterate over a BaseNode instance directly, but instead + wrap it in a :class:`torchdata.nodes.Loader` which converts the DAG into a more familiar Iterable. + + .. code-block:: python + + node = MyBaseNodeImpl() + loader = Loader(node) + # loader supports state_dict() and load_state_dict() + + for epoch in range(5): + for idx, batch in enumerate(loader): + ... + + # or if using node directly: + node = MyBaseNodeImpl() + for epoch in range(5): + node.reset() + for idx, batch in enumerate(loader): + ... """ def __init__(self, *args, **kwargs): + """Subclasses must implement this method and call super().__init__(*args, **kwargs)""" self.__initialized = False def __iter__(self): return self def reset(self, initial_state: Optional[dict] = None): + """Resets the iterator to the beginning, or to the state passed in by initial_state. + + Reset is a good place to put expensive initialization, as it will be lazily called when next() or state_dict() is called. + Subclasses must call ``super().reset(initial_state)``. + + Args: + initial_state: Optional[dict] - a state dict to pass to the node. If None, reset to the beginning. + """ + self.__initialized = True def get_state(self) -> Dict[str, Any]: + """Subclasses must implement this method, instead of state_dict(). Should only be called by BaseNode. + :return: Dict[str, Any] - a state dict that may be passed to reset() at some point in the future + """ raise NotImplementedError(type(self)) def next(self) -> T: + """Subclasses must implement this method, instead of ``__next``. Should only be called by BaseNode. + :return: T - the next value in the sequence, or throw StopIteration + """ raise NotImplementedError(type(self)) def __next__(self): @@ -78,6 +82,9 @@ def __next__(self): return self.next() def state_dict(self) -> Dict[str, Any]: + """Get a state_dict for this BaseNode. + :return: Dict[str, Any] - a state dict that may be passed to reset() at some point in the future. + """ try: self.__initialized except AttributeError: diff --git a/torchdata/nodes/batch.py b/torchdata/nodes/batch.py index b16d8d89d..184608907 100644 --- a/torchdata/nodes/batch.py +++ b/torchdata/nodes/batch.py @@ -15,13 +15,10 @@ class Batcher(BaseNode[List[T]]): If drop_last is True, the last batch will be dropped if it is smaller than batch_size. If drop_last is False, the last batch will be returned even if it is smaller than batch_size. - Parameters: + Args: source (BaseNode[T]): The source node to batch the data from. batch_size (int): The size of the batch. drop_last (bool): Whether to drop the last batch if it is smaller than batch_size. Default is True. - - Attributes: - SOURCE_KEY (str): The key for the source node in the state dict. """ SOURCE_KEY = "source" diff --git a/torchdata/nodes/loader.py b/torchdata/nodes/loader.py index fe0dacb56..7543cb4f8 100644 --- a/torchdata/nodes/loader.py +++ b/torchdata/nodes/loader.py @@ -4,12 +4,12 @@ class Loader(Generic[T]): - """Wraps the root node (iterator) and provides a stateful iterable interface. + """Wraps the root BaseNode (an iterator) and provides a stateful iterable interface. The state of the last-returned iterator is returned by the state_dict() method, and can be loaded using the load_state_dict() method. - Parameters: + Args: root (BaseNode[T]): The root node of the data pipeline. restart_on_stop_iteration (bool): Whether to restart the iterator when it reaches the end. Default is True """ @@ -49,9 +49,21 @@ def __iter__(self): return self._it def load_state_dict(self, state_dict: Dict[str, Any]): + """Loads a state_dict which will be used to initialize the next iter() requested + from this loader. + + Args: + state_dict (Dict[str, Any]): The state_dict to load. Should be generated from a call to state_dict(). + """ self._next_iter_state_dict = state_dict def state_dict(self) -> Dict[str, Any]: + """Returns a state_dict which can be passed to load_state_dict() in the future to + resume iteration. + + The state_dict will come from the iterator returned by the most recent call to iter(). + If no iterator has been created, a new iterator will be created and the state_dict returned from it. + """ if self._it is None: iter(self) self._iter_for_state_dict = True @@ -65,7 +77,7 @@ class LoaderIterator(BaseNode[T]): the iterator is exhausted or on a reset call. We look one step ahead to determine if the iterator is exhausted. The state of the iterator is saved in the state_dict() method, and can be loaded on reset calls. - Parameters: + Args: loader (Loader[T]): The loader object that contains the root node. """ diff --git a/torchdata/nodes/map.py b/torchdata/nodes/map.py index 1a1a6a900..2b110cd0d 100644 --- a/torchdata/nodes/map.py +++ b/torchdata/nodes/map.py @@ -39,6 +39,12 @@ def Queue(self, *args, **kwargs): def Mapper(source: BaseNode[X], map_fn: Callable[[X], T]) -> "ParallelMapper[T]": + """Returns a :class:`ParallelMapper` node with num_workers=0, which will execute map_fn in the current process/thread. + + Args: + source (BaseNode[X]): The source node to map over. + map_fn (Callable[[X], T]): The function to apply to each item from the source node. + """ return ParallelMapper( source=source, map_fn=map_fn, @@ -279,7 +285,7 @@ class ParallelMapper(BaseNode[T]): If in_order is true, the iterator will return items in the order from which they arrive from source's iterator, potentially blocking even if other items are available. - Parameters: + Args: source (BaseNode[X]): The source node to map over. map_fn (Callable[[X], T]): The function to apply to each item from the source node. num_workers (int): The number of workers to use for parallel processing. diff --git a/torchdata/nodes/pin_memory.py b/torchdata/nodes/pin_memory.py index 5acf2021e..96069e50b 100644 --- a/torchdata/nodes/pin_memory.py +++ b/torchdata/nodes/pin_memory.py @@ -97,7 +97,7 @@ def _put( class PinMemory(BaseNode[T]): """Pins the data of the underlying node to a device. This is backed by torch.utils.data._utils.pin_memory._pin_memory_loop. - Parameters: + Args: source (BaseNode[T]): The source node to pin the data from. pin_memory_device (str): The device to pin the data to. Default is "". snapshot_frequency (int): The frequency at which to snapshot the state of the source node. Default is diff --git a/torchdata/nodes/prefetch.py b/torchdata/nodes/prefetch.py index 2f22d33fa..c35a593d2 100644 --- a/torchdata/nodes/prefetch.py +++ b/torchdata/nodes/prefetch.py @@ -16,7 +16,7 @@ class Prefetcher(BaseNode[T]): """Prefetches data from the source node and stores it in a queue. - Parameters: + Args: source (BaseNode[T]): The source node to prefetch data from. prefetch_factor (int): The number of items to prefetch ahead of time. snapshot_frequency (int): The frequency at which to snapshot the state of the source node. Default is diff --git a/torchdata/nodes/samplers/multi_node_weighted_sampler.py b/torchdata/nodes/samplers/multi_node_weighted_sampler.py index c89ae4942..f4f4c7b4a 100644 --- a/torchdata/nodes/samplers/multi_node_weighted_sampler.py +++ b/torchdata/nodes/samplers/multi_node_weighted_sampler.py @@ -31,14 +31,13 @@ class MultiNodeWeightedSampler(BaseNode[T]): - WEIGHTED_SAMPLER_STATE_KEY: The state of the weighted sampler. We support multiple stopping criteria: - - CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED: Cycle through the source nodes until all datasets - are exhausted. This is the default behavior. + - CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED: Cycle through the source nodes until all datasets are exhausted. This is the default behavior. - FIRST_DATASET_EXHAUSTED: Stop when the first dataset is exhausted. - ALL_DATASETS_EXHAUSTED: Stop when all datasets are exhausted. On complete exhaustion of the source nodes, the node will raise StopIteration. - Parameters: + Args: source_nodes (Mapping[str, BaseNode[T]]): A dictionary of source nodes. weights (Dict[str, float]): A dictionary of weights for each source node. stop_criteria (str): The stopping criteria. Default is CYCLE_UNTIL_ALL_DATASETS_EXHAUST @@ -208,7 +207,7 @@ class _WeightedSampler: - g_rank_state: The state of the random number generator for the rank. - offset: The offset of the batch of indices. - Parameters: + Args: weights (Dict[str, float]): A dictionary of weights for each source node. seed (int): The seed for the random number generator. rank (int): The rank of the current process. diff --git a/torchdata/nodes/samplers/stop_criteria.py b/torchdata/nodes/samplers/stop_criteria.py index ea128ce06..187d9ac60 100644 --- a/torchdata/nodes/samplers/stop_criteria.py +++ b/torchdata/nodes/samplers/stop_criteria.py @@ -10,11 +10,11 @@ class StopCriteria: Stopping criteria for the dataset samplers. 1) CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED: Stop once the last unseen dataset is exhausted. - All datasets are seen at least once. In certain cases, some datasets may be - seen more than once when there are still non-exhausted datasets. + All datasets are seen at least once. In certain cases, some datasets may be + seen more than once when there are still non-exhausted datasets. 2) ALL_DATASETS_EXHAUSTED: Stop once all have the datasets are exhausted. Each - dataset is seen exactly once. No wraparound or restart will be performed. + dataset is seen exactly once. No wraparound or restart will be performed. 3) FIRST_DATASET_EXHAUSTED: Stop when the first dataset is exhausted. """ diff --git a/torchdata/nodes/types.py b/torchdata/nodes/types.py index ed11f580b..57a505b6f 100644 --- a/torchdata/nodes/types.py +++ b/torchdata/nodes/types.py @@ -10,6 +10,8 @@ @runtime_checkable class Stateful(Protocol): + """Protocol for objects implementing both ``state_dict()`` and ``load_state_dict(state_dict: Dict[str, Any])``""" + def state_dict(self) -> Dict[str, Any]: ...