Skip to content

Commit ad8a71c

Browse files
authored
add new torchdata.nodes doc file (#1390) (#1392)
* add new torchdata.nodes doc file * update main readme.md, move content of nodes readme.md to torchdata.nodes.rst * main readme * udpate readme
1 parent 31d3a1a commit ad8a71c

File tree

3 files changed

+210
-0
lines changed

3 files changed

+210
-0
lines changed

README.md

+7
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@ provides state_dict and load_state_dict functionality. See
3535
examples
3636
[in this Colab notebook](https://colab.research.google.com/drive/1tonoovEd7Tsi8EW8ZHXf0v3yHJGwZP8M?usp=sharing).
3737

38+
## torchdata.nodes
39+
40+
torchdata.nodes is a library of composable iterators (not iterables!) that let you chain together common dataloading and
41+
pre-proc operations. It follows a streaming programming model, although "sampler + Map-style" can still be configured if
42+
you desire. See [torchdata.nodes main page](torchdata/nodes) for more details. Stay tuned for tutorial on
43+
torchdata.nodes coming soon!
44+
3845
## Installation
3946

4047
### Version Compatibility

docs/source/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ Features described in this documentation are classified by release status:
3737
:caption: API Reference:
3838

3939
torchdata.stateful_dataloader.rst
40+
torchdata.nodes.rst
4041

4142

4243
.. toctree::

docs/source/torchdata.nodes.rst

+202
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
torchdata.nodes
2+
===============
3+
4+
What is ``torchdata.nodes``?
5+
----------------------------
6+
7+
``torchdata.nodes`` is a library of composable iterators (not
8+
iterables!) that let you chain together common dataloading and pre-proc
9+
operations. It follows a streaming programming model, although “sampler
10+
+ Map-style” can still be configured if you desire.
11+
12+
``torchdata.nodes`` adds more flexibility to the standard
13+
``torch.utils.data`` offering, and introduces multi-threaded parallelism
14+
in addition to multi-process (the only supported approach in
15+
``torch.utils.data.DataLoader``), as well as first-class support for
16+
mid-epoch checkpointing through a ``state_dict/load_state_dict``
17+
interface.
18+
19+
``torchdata.nodes`` strives to include as many useful operators as
20+
possible, however it’s designed to be extensible. New nodes are required
21+
to subclass ``torchdata.nodes.BaseNode``, (which itself subclasses
22+
``typing.Iterator``) and implement ``next()``, ``reset(initial_state)``
23+
and ``get_state()`` operations (notably, not ``__next__``,
24+
``load_state_dict``, nor ``state_dict``)
25+
26+
Getting started
27+
---------------
28+
29+
Install torchdata with pip.
30+
31+
.. code:: bash
32+
33+
pip install torchdata>=0.10.0
34+
35+
Generator Example
36+
~~~~~~~~~~~~~~~~~
37+
38+
Wrap a generator (or any iterable) to convert it to a BaseNode and get
39+
started
40+
41+
.. code:: python
42+
43+
from torchdata.nodes import IterableWrapper, ParallelMapper, Loader
44+
45+
node = IterableWrapper(range(10))
46+
node = ParallelMapper(node, map_fn=lambda x: x**2, num_workers=3, method="thread")
47+
loader = Loader(node)
48+
result = list(loader)
49+
print(result)
50+
# [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
51+
52+
Sampler Example
53+
~~~~~~~~~~~~~~~
54+
55+
Samplers are still supported, and you can use your existing
56+
``torch.utils.data.Dataset``\ s
57+
58+
.. code:: python
59+
60+
import torch.utils.data
61+
from torchdata.nodes import SamplerWrapper, ParallelMapper, Loader
62+
63+
64+
class SquaredDataset(torch.utils.data.Dataset):
65+
def __getitem__(self, i: int) -> int:
66+
return i**2
67+
def __len__(self):
68+
return 10
69+
70+
dataset = SquaredDataset()
71+
sampler = RandomSampler(dataset)
72+
73+
# For fine-grained control of iteration order, define your own sampler
74+
node = SamplerWrapper(sampler)
75+
# Simply apply dataset's __getitem__ as a map function to the indices generated from sampler
76+
node = ParallelMapper(node, map_fn=dataset.__getitem__, num_workers=3, method="thread")
77+
# Loader is used to convert a node (iterator) into an Iterable that may be reused for multi epochs
78+
loader = Loader(node)
79+
print(list(loader))
80+
# [25, 36, 9, 49, 0, 81, 4, 16, 64, 1]
81+
print(list(loader))
82+
# [0, 4, 1, 64, 49, 25, 9, 16, 81, 36]
83+
84+
What’s the point of ``torchdata.nodes``?
85+
----------------------------------------
86+
87+
We get it, ``torch.utils.data`` just works for many many use cases.
88+
However it definitely has a bunch of rough spots:
89+
90+
Multiprocessing sucks
91+
~~~~~~~~~~~~~~~~~~~~~
92+
93+
- You need to duplicate memory stored in your Dataset (because of
94+
Python copy-on-read)
95+
- IPC is slow over multiprocess queues and can introduce slow startup
96+
times
97+
- You’re forced to perform batching on the workers instead of
98+
main-process to reduce IPC overhead, increasing peak memory.
99+
- With GIL-releasing functions and Free-Threaded Python,
100+
multi-threading may not be GIL-bound like it used to be.
101+
102+
``torchdata.nodes`` enables both multi-threading and multi-processing so
103+
you can choose what works best for your particular set up. Parallelism
104+
is primarily configured in Mapper operators giving you flexibility in
105+
the what, when, and how to parallelize.
106+
107+
Map-style and random-access doesn’t scale
108+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
109+
110+
Current map dataset approach is great for datasets that fit in memory,
111+
but true random-access is not going to be very performant once your
112+
dataset grows beyond memory limitations unless you jump through some
113+
hoops with a special sampler.
114+
115+
``torchdata.nodes`` follows a streaming data model, where operators are
116+
Iterators that can be combined together to define a dataloading and
117+
pre-proc pipeline. Samplers are still supported (see example above) and
118+
can be combined with a Mapper to produce an Iterator
119+
120+
Multi-Datasets do not fit well with the current implementation in ``torch.utils.data``
121+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
122+
123+
The current Sampler (one per dataloader) concepts start to break down
124+
when you start trying to combine multiple datasets. (For single
125+
datasets, they’re a great abstraction and will continue to be
126+
supported!)
127+
128+
- For multi-datasets, consider this scenario: ``len(dsA): 10``
129+
``len(dsB): 20``. Now we want to do round-robin (or sample uniformly)
130+
between these two datasets to feed to our trainer. With just a single
131+
sampler, how can you implement that strategy? Maybe a sampler that
132+
emits tuples? What if you want to swap with RandomSampler, or
133+
DistributedSampler? How will ``sampler.set_epoch`` work?
134+
135+
``torchdata.nodes`` helps to address and scale multi-dataset dataloading
136+
by only dealing with Iterators, thereby forcing samplers and datasets
137+
together, focusing on composing smaller primitives nodes into a more
138+
complex dataloading pipeline.
139+
140+
IterableDataset + multiprocessing requires additional dataset sharding
141+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
142+
143+
Dataset sharding is required for data-parallel training, which is fairly
144+
reasonable. But what about sharding between dataloader workers? With
145+
Map-style datasets, distribution of work between workers is handled by
146+
the main process, which distributes sampler indices to workers. With
147+
IterableDatasets, each worker needs to figure out (through
148+
``torch.utils.data.get_worker_info``) what data it should be returning.
149+
150+
Design choices
151+
--------------
152+
153+
No Generator BaseNodes
154+
~~~~~~~~~~~~~~~~~~~~~~
155+
156+
See https://github.com/pytorch/data/pull/1362 for more thoughts.
157+
158+
One difficult choice we made was to disallow Generators when defining a
159+
new BaseNode implementation. However we dropped it and moved to an
160+
Iterator-only foundation for a few reasons around state management:
161+
162+
1. We require explicit state handling in BaseNode implementations.
163+
Generators store state implicitly on the stack and we found that we
164+
needed to jump through hoops and write very convoluted code to get
165+
basic state working with Generators
166+
2. End-of-iteration state dict: Iterables may feel more natural, however
167+
a bunch of issues come up around state management. Consider the
168+
end-of-iteration state dict. If you load this state_dict into your
169+
iterable, should this represent the end-of-iteration or the start of
170+
the next iteration?
171+
3. Loading state: If you call load_state_dict() on an iterable, most
172+
users would expect the next iterator requested from it to start with
173+
the loaded state. However what if iter is called twice before
174+
iteration begins?
175+
4. Multiple Live Iterator problem: if you have one instance of an
176+
Iterable, but two live iterators, what does it mean to call
177+
state_dict() on the Iterable? In dataloading, this is very rare,
178+
however we still need to work around it and make a bunch of
179+
assumptions. Forcing devs that are implementing BaseNodes to reason
180+
about these scenarios is, in our opinion, worse than disallowing
181+
generators and Iterables.
182+
183+
``torchdata.nodes.BaseNode`` implementations are Iterators. Iterators
184+
define ``next()``, ``get_state()``, and ``reset(initial_state | None)``.
185+
All re-initialization should be done in reset(), including initializing
186+
with a particular state if one is passed.
187+
188+
However, end-users are used to dealing with Iterables, for example,
189+
190+
::
191+
192+
for epoch in range(5):
193+
# Most frameworks and users don't expect to call loader.reset()
194+
for batch in loader:
195+
...
196+
sd = loader.state_dict()
197+
# Loading sd should not throw StopIteration right away, but instead start at the next epoch
198+
199+
To handle this we keep all of the assumptions and special end-of-epoch
200+
handling in a single ``Loader`` class which takes any BaseNode and makes
201+
it an Iterable, handling the reset() calls and end-of-epoch state_dict
202+
loading.

0 commit comments

Comments
 (0)