Skip to content

Commit 7f97f48

Browse files
authored
Cherry pick commits for 0.10.1 (#1400)
1 parent dabbdcc commit 7f97f48

21 files changed

+1673
-258
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
Getting Started With ``torchdata.nodes`` (beta)
2+
===============================================
3+
4+
Install torchdata with pip.
5+
6+
.. code:: bash
7+
8+
pip install torchdata>=0.10.0
9+
10+
Generator Example
11+
~~~~~~~~~~~~~~~~~
12+
13+
Wrap a generator (or any iterable) to convert it to a BaseNode and get started
14+
15+
.. code:: python
16+
17+
from torchdata.nodes import IterableWrapper, ParallelMapper, Loader
18+
19+
node = IterableWrapper(range(10))
20+
node = ParallelMapper(node, map_fn=lambda x: x**2, num_workers=3, method="thread")
21+
loader = Loader(node)
22+
result = list(loader)
23+
print(result)
24+
# [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
25+
26+
Sampler Example
27+
~~~~~~~~~~~~~~~
28+
29+
Samplers are still supported, and you can use your existing
30+
``torch.utils.data.Dataset``\'s. See :ref:`migrate-to-nodes-from-utils` for an in-depth
31+
example.
32+
33+
.. code:: python
34+
35+
import torch.utils.data
36+
from torchdata.nodes import SamplerWrapper, ParallelMapper, Loader
37+
38+
39+
class SquaredDataset(torch.utils.data.Dataset):
40+
def __getitem__(self, i: int) -> int:
41+
return i**2
42+
def __len__(self):
43+
return 10
44+
45+
dataset = SquaredDataset()
46+
sampler = RandomSampler(dataset)
47+
48+
# For fine-grained control of iteration order, define your own sampler
49+
node = SamplerWrapper(sampler)
50+
# Simply apply dataset's __getitem__ as a map function to the indices generated from sampler
51+
node = ParallelMapper(node, map_fn=dataset.__getitem__, num_workers=3, method="thread")
52+
# Loader is used to convert a node (iterator) into an Iterable that may be reused for multi epochs
53+
loader = Loader(node)
54+
print(list(loader))
55+
# [25, 36, 9, 49, 0, 81, 4, 16, 64, 1]
56+
print(list(loader))
57+
# [0, 4, 1, 64, 49, 25, 9, 16, 81, 36]

docs/source/index.rst

+8-1
Original file line numberDiff line numberDiff line change
@@ -31,19 +31,26 @@ Features described in this documentation are classified by release status:
3131
binary distributions like PyPI or Conda, except sometimes behind run-time
3232
flags, and are at an early stage for feedback and testing.
3333

34+
.. toctree::
35+
:maxdepth: 2
36+
:caption: Developer Notes:
37+
38+
what_is_torchdata_nodes.rst
3439

3540
.. toctree::
3641
:maxdepth: 2
3742
:caption: API Reference:
3843

39-
torchdata.stateful_dataloader.rst
4044
torchdata.nodes.rst
45+
torchdata.stateful_dataloader.rst
4146

4247

4348
.. toctree::
4449
:maxdepth: 2
4550
:caption: Tutorial and Examples:
4651

52+
getting_started_with_torchdata_nodes.rst
53+
migrate_to_nodes_from_utils.rst
4754
stateful_dataloader_tutorial.rst
4855

4956

+184
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
.. _migrate-to-nodes-from-utils:
2+
3+
Migrating to ``torchdata.nodes`` from ``torch.utils.data``
4+
==========================================================
5+
6+
This guide is intended to help people familiar with ``torch.utils.data``, or
7+
:class:`~torchdata.stateful_dataloader.StatefulDataLoader`,
8+
to get started with ``torchdata.nodes``, and provide a starting ground for defining
9+
your own dataloading pipelines.
10+
11+
We'll demonstrate how to achieve the most common DataLoader features, re-use existing samplers and datasets,
12+
and load/save dataloader state. It performs at least as well as ``DataLoader`` and ``StatefulDataLoader``,
13+
see :ref:`how-does-nodes-perform`.
14+
15+
Map-Style Datasets
16+
~~~~~~~~~~~~~~~~~~
17+
18+
Let's look at the ``DataLoader`` constructor args and go from there
19+
20+
.. code:: python
21+
22+
class DataLoader:
23+
def __init__(
24+
self,
25+
dataset: Dataset[_T_co],
26+
batch_size: Optional[int] = 1,
27+
shuffle: Optional[bool] = None,
28+
sampler: Union[Sampler, Iterable, None] = None,
29+
batch_sampler: Union[Sampler[List], Iterable[List], None] = None,
30+
num_workers: int = 0,
31+
collate_fn: Optional[_collate_fn_t] = None,
32+
pin_memory: bool = False,
33+
drop_last: bool = False,
34+
timeout: float = 0,
35+
worker_init_fn: Optional[_worker_init_fn_t] = None,
36+
multiprocessing_context=None,
37+
generator=None,
38+
*,
39+
prefetch_factor: Optional[int] = None,
40+
persistent_workers: bool = False,
41+
pin_memory_device: str = "",
42+
in_order: bool = True,
43+
):
44+
...
45+
46+
As a referesher, here is roughly how dataloading works in ``torch.utils.data.DataLoader``:
47+
``DataLoader`` begins by generating indices from a ``sampler`` and creates batches of `batch_size` indices.
48+
If no sampler is provided, then a RandomSampler or SequentialSampler is created by default.
49+
The indices are passed to ``Dataset.__getitem__()``, and then a ``collate_fn`` is applied to the batch
50+
of samples. If ``num_workers > 0``, it will use multi-processing to create
51+
subprocesses, and pass the batches of indices to the worker processes, who will then call ``Dataset.__getitem__()`` and apply ``collate_fn``
52+
before returning the batches to the main process. At that point, ``pin_memory`` may be applied to the tensors in the batch.
53+
54+
Now let's look at what an equivalent implementation for DataLoader might look like, built with ``torchdata.nodes``.
55+
56+
.. code:: python
57+
58+
from typing import List, Callable
59+
import torchdata.nodes as tn
60+
from torch.utils.data import RandomSampler, SequentialSampler, default_collate, Dataset
61+
62+
class MapAndCollate:
63+
"""A simple transform that takes a batch of indices, maps with dataset, and then applies
64+
collate.
65+
TODO: make this a standard utility in torchdata.nodes
66+
"""
67+
def __init__(self, dataset, collate_fn):
68+
self.dataset = dataset
69+
self.collate_fn = collate_fn
70+
71+
def __call__(self, batch_of_indices: List[int]):
72+
batch = [self.dataset[i] for i in batch_of_indices]
73+
return self.collate_fn(batch)
74+
75+
# To keep things simple, let's assume that the following args are provided by the caller
76+
def NodesDataLoader(
77+
dataset: Dataset,
78+
batch_size: int,
79+
shuffle: bool,
80+
num_workers: int,
81+
collate_fn: Callable | None,
82+
pin_memory: bool,
83+
drop_last: bool,
84+
):
85+
# Assume we're working with a map-style dataset
86+
assert hasattr(dataset, "__getitem__") and hasattr(dataset, "__len__")
87+
# Start with a sampler, since caller did not provide one
88+
sampler = RandomSampler(dataset) if shuffle else SequentialSampler(dataset)
89+
# Sampler wrapper converts a Sampler to a BaseNode
90+
node = tn.SamplerWrapper(sampler)
91+
92+
# Now let's batch sampler indices together
93+
node = tn.Batcher(node, batch_size=batch_size, drop_last=drop_last)
94+
95+
# Create a Map Function that accepts a list of indices, applies getitem to it, and
96+
# then collates them
97+
map_and_collate = MapAndCollate(dataset, collate_fn or default_collate)
98+
99+
# MapAndCollate is doing most of the heavy lifting, so let's parallelize it. We could
100+
# choose process or thread workers. Note that if you're not using Free-Threaded
101+
# Python (eg 3.13t) with -Xgil=0, then multi-threading might result in GIL contention,
102+
# and slow down training.
103+
node = tn.ParallelMapper(
104+
node,
105+
map_fn=map_and_collate,
106+
num_workers=num_workers,
107+
method="process", # Set this to "thread" for multi-threading
108+
in_order=True,
109+
)
110+
111+
# Optionally apply pin-memory, and we usually do some pre-fetching
112+
if pin_memory:
113+
node = tn.PinMemory(node)
114+
node = tn.Prefetcher(node, prefetch_factor=num_workers * 2)
115+
116+
# Note that node is an iterator, and once it's exhausted, you'll need to call .reset()
117+
# on it to start a new Epoch.
118+
# Insteaad, we wrap the node in a Loader, which is an iterable and handles reset. It
119+
# also provides state_dict and load_state_dict methods.
120+
return tn.Loader(node)
121+
122+
Now let's test this out with a trivial dataset, and demonstrate how state management works.
123+
124+
.. code:: python
125+
126+
class SquaredDataset(Dataset):
127+
def __init__(self, len: int):
128+
self.len = len
129+
def __len__(self):
130+
return self.len
131+
def __getitem__(self, i: int) -> int:
132+
return i**2
133+
134+
loader = NodesDataLoader(
135+
dataset=SquaredDataset(14),
136+
batch_size=3,
137+
shuffle=False,
138+
num_workers=2,
139+
collate_fn=None,
140+
pin_memory=False,
141+
drop_last=False,
142+
)
143+
144+
batches = []
145+
for idx, batch in enumerate(loader):
146+
if idx == 2:
147+
state_dict = loader.state_dict()
148+
# Saves the state_dict after batch 2 has been returned
149+
batches.append(batch)
150+
151+
loader.load_state_dict(state_dict)
152+
batches_after_loading = list(loader)
153+
print(batches[3:])
154+
# [tensor([ 81, 100, 121]), tensor([144, 169])]
155+
print(batches_after_loading)
156+
# [tensor([ 81, 100, 121]), tensor([144, 169])]
157+
158+
Let's also compare this to torch.utils.data.DataLoader, as a sanity check.
159+
160+
.. code:: python
161+
162+
loaderv1 = torch.utils.data.DataLoader(
163+
dataset=SquaredDataset(14),
164+
batch_size=3,
165+
shuffle=False,
166+
num_workers=2,
167+
collate_fn=None,
168+
pin_memory=False,
169+
drop_last=False,
170+
persistent_workers=False, # Coming soon to torchdata.nodes!
171+
)
172+
print(list(loaderv1))
173+
# [tensor([0, 1, 4]), tensor([ 9, 16, 25]), tensor([36, 49, 64]), tensor([ 81, 100, 121]), tensor([144, 169])]
174+
print(batches)
175+
# [tensor([0, 1, 4]), tensor([ 9, 16, 25]), tensor([36, 49, 64]), tensor([ 81, 100, 121]), tensor([144, 169])]
176+
177+
178+
IterableDatasets
179+
~~~~~~~~~~~~~~~~
180+
181+
Coming soon! While you can already plug your IterableDataset into an ``tn.IterableWrapper``, some functions like
182+
``get_worker_info`` are not currently supported yet. However we believe that often, sharding work between
183+
multi-process workers is not actually necessary, and you can keep some sort of indexing in the main process while
184+
only parallelizing some of the heavier transforms, similar to how Map-style Datasets work above.

0 commit comments

Comments
 (0)