|
| 1 | +torchdata.nodes |
| 2 | +=============== |
| 3 | + |
| 4 | +What is ``torchdata.nodes``? |
| 5 | +---------------------------- |
| 6 | + |
| 7 | +``torchdata.nodes`` is a library of composable iterators (not |
| 8 | +iterables!) that let you chain together common dataloading and pre-proc |
| 9 | +operations. It follows a streaming programming model, although “sampler |
| 10 | ++ Map-style” can still be configured if you desire. |
| 11 | + |
| 12 | +``torchdata.nodes`` adds more flexibility to the standard |
| 13 | +``torch.utils.data`` offering, and introduces multi-threaded parallelism |
| 14 | +in addition to multi-process (the only supported approach in |
| 15 | +``torch.utils.data.DataLoader``), as well as first-class support for |
| 16 | +mid-epoch checkpointing through a ``state_dict/load_state_dict`` |
| 17 | +interface. |
| 18 | + |
| 19 | +``torchdata.nodes`` strives to include as many useful operators as |
| 20 | +possible, however it’s designed to be extensible. New nodes are required |
| 21 | +to subclass ``torchdata.nodes.BaseNode``, (which itself subclasses |
| 22 | +``typing.Iterator``) and implement ``next()``, ``reset(initial_state)`` |
| 23 | +and ``get_state()`` operations (notably, not ``__next__``, |
| 24 | +``load_state_dict``, nor ``state_dict``) |
| 25 | + |
| 26 | +Getting started |
| 27 | +--------------- |
| 28 | + |
| 29 | +Install torchdata with pip. |
| 30 | + |
| 31 | +.. code:: bash |
| 32 | +
|
| 33 | + pip install torchdata>=0.10.0 |
| 34 | +
|
| 35 | +Generator Example |
| 36 | +~~~~~~~~~~~~~~~~~ |
| 37 | + |
| 38 | +Wrap a generator (or any iterable) to convert it to a BaseNode and get |
| 39 | +started |
| 40 | + |
| 41 | +.. code:: python |
| 42 | +
|
| 43 | + from torchdata.nodes import IterableWrapper, ParallelMapper, Loader |
| 44 | +
|
| 45 | + node = IterableWrapper(range(10)) |
| 46 | + node = ParallelMapper(node, map_fn=lambda x: x**2, num_workers=3, method="thread") |
| 47 | + loader = Loader(node) |
| 48 | + result = list(loader) |
| 49 | + print(result) |
| 50 | + # [0, 1, 4, 9, 16, 25, 36, 49, 64, 81] |
| 51 | +
|
| 52 | +Sampler Example |
| 53 | +~~~~~~~~~~~~~~~ |
| 54 | + |
| 55 | +Samplers are still supported, and you can use your existing |
| 56 | +``torch.utils.data.Dataset``\ s |
| 57 | + |
| 58 | +.. code:: python |
| 59 | +
|
| 60 | + import torch.utils.data |
| 61 | + from torchdata.nodes import SamplerWrapper, ParallelMapper, Loader |
| 62 | +
|
| 63 | +
|
| 64 | + class SquaredDataset(torch.utils.data.Dataset): |
| 65 | + def __getitem__(self, i: int) -> int: |
| 66 | + return i**2 |
| 67 | + def __len__(self): |
| 68 | + return 10 |
| 69 | +
|
| 70 | + dataset = SquaredDataset() |
| 71 | + sampler = RandomSampler(dataset) |
| 72 | +
|
| 73 | + # For fine-grained control of iteration order, define your own sampler |
| 74 | + node = SamplerWrapper(sampler) |
| 75 | + # Simply apply dataset's __getitem__ as a map function to the indices generated from sampler |
| 76 | + node = ParallelMapper(node, map_fn=dataset.__getitem__, num_workers=3, method="thread") |
| 77 | + # Loader is used to convert a node (iterator) into an Iterable that may be reused for multi epochs |
| 78 | + loader = Loader(node) |
| 79 | + print(list(loader)) |
| 80 | + # [25, 36, 9, 49, 0, 81, 4, 16, 64, 1] |
| 81 | + print(list(loader)) |
| 82 | + # [0, 4, 1, 64, 49, 25, 9, 16, 81, 36] |
| 83 | +
|
| 84 | +What’s the point of ``torchdata.nodes``? |
| 85 | +---------------------------------------- |
| 86 | + |
| 87 | +We get it, ``torch.utils.data`` just works for many many use cases. |
| 88 | +However it definitely has a bunch of rough spots: |
| 89 | + |
| 90 | +Multiprocessing sucks |
| 91 | +~~~~~~~~~~~~~~~~~~~~~ |
| 92 | + |
| 93 | +- You need to duplicate memory stored in your Dataset (because of |
| 94 | + Python copy-on-read) |
| 95 | +- IPC is slow over multiprocess queues and can introduce slow startup |
| 96 | + times |
| 97 | +- You’re forced to perform batching on the workers instead of |
| 98 | + main-process to reduce IPC overhead, increasing peak memory. |
| 99 | +- With GIL-releasing functions and Free-Threaded Python, |
| 100 | + multi-threading may not be GIL-bound like it used to be. |
| 101 | + |
| 102 | +``torchdata.nodes`` enables both multi-threading and multi-processing so |
| 103 | +you can choose what works best for your particular set up. Parallelism |
| 104 | +is primarily configured in Mapper operators giving you flexibility in |
| 105 | +the what, when, and how to parallelize. |
| 106 | + |
| 107 | +Map-style and random-access doesn’t scale |
| 108 | +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 109 | + |
| 110 | +Current map dataset approach is great for datasets that fit in memory, |
| 111 | +but true random-access is not going to be very performant once your |
| 112 | +dataset grows beyond memory limitations unless you jump through some |
| 113 | +hoops with a special sampler. |
| 114 | + |
| 115 | +``torchdata.nodes`` follows a streaming data model, where operators are |
| 116 | +Iterators that can be combined together to define a dataloading and |
| 117 | +pre-proc pipeline. Samplers are still supported (see example above) and |
| 118 | +can be combined with a Mapper to produce an Iterator |
| 119 | + |
| 120 | +Multi-Datasets do not fit well with the current implementation in ``torch.utils.data`` |
| 121 | +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 122 | + |
| 123 | +The current Sampler (one per dataloader) concepts start to break down |
| 124 | +when you start trying to combine multiple datasets. (For single |
| 125 | +datasets, they’re a great abstraction and will continue to be |
| 126 | +supported!) |
| 127 | + |
| 128 | +- For multi-datasets, consider this scenario: ``len(dsA): 10`` |
| 129 | + ``len(dsB): 20``. Now we want to do round-robin (or sample uniformly) |
| 130 | + between these two datasets to feed to our trainer. With just a single |
| 131 | + sampler, how can you implement that strategy? Maybe a sampler that |
| 132 | + emits tuples? What if you want to swap with RandomSampler, or |
| 133 | + DistributedSampler? How will ``sampler.set_epoch`` work? |
| 134 | + |
| 135 | +``torchdata.nodes`` helps to address and scale multi-dataset dataloading |
| 136 | +by only dealing with Iterators, thereby forcing samplers and datasets |
| 137 | +together, focusing on composing smaller primitives nodes into a more |
| 138 | +complex dataloading pipeline. |
| 139 | + |
| 140 | +IterableDataset + multiprocessing requires additional dataset sharding |
| 141 | +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 142 | + |
| 143 | +Dataset sharding is required for data-parallel training, which is fairly |
| 144 | +reasonable. But what about sharding between dataloader workers? With |
| 145 | +Map-style datasets, distribution of work between workers is handled by |
| 146 | +the main process, which distributes sampler indices to workers. With |
| 147 | +IterableDatasets, each worker needs to figure out (through |
| 148 | +``torch.utils.data.get_worker_info``) what data it should be returning. |
| 149 | + |
| 150 | +Design choices |
| 151 | +-------------- |
| 152 | + |
| 153 | +No Generator BaseNodes |
| 154 | +~~~~~~~~~~~~~~~~~~~~~~ |
| 155 | + |
| 156 | +See https://github.com/pytorch/data/pull/1362 for more thoughts. |
| 157 | + |
| 158 | +One difficult choice we made was to disallow Generators when defining a |
| 159 | +new BaseNode implementation. However we dropped it and moved to an |
| 160 | +Iterator-only foundation for a few reasons around state management: |
| 161 | + |
| 162 | +1. We require explicit state handling in BaseNode implementations. |
| 163 | + Generators store state implicitly on the stack and we found that we |
| 164 | + needed to jump through hoops and write very convoluted code to get |
| 165 | + basic state working with Generators |
| 166 | +2. End-of-iteration state dict: Iterables may feel more natural, however |
| 167 | + a bunch of issues come up around state management. Consider the |
| 168 | + end-of-iteration state dict. If you load this state_dict into your |
| 169 | + iterable, should this represent the end-of-iteration or the start of |
| 170 | + the next iteration? |
| 171 | +3. Loading state: If you call load_state_dict() on an iterable, most |
| 172 | + users would expect the next iterator requested from it to start with |
| 173 | + the loaded state. However what if iter is called twice before |
| 174 | + iteration begins? |
| 175 | +4. Multiple Live Iterator problem: if you have one instance of an |
| 176 | + Iterable, but two live iterators, what does it mean to call |
| 177 | + state_dict() on the Iterable? In dataloading, this is very rare, |
| 178 | + however we still need to work around it and make a bunch of |
| 179 | + assumptions. Forcing devs that are implementing BaseNodes to reason |
| 180 | + about these scenarios is, in our opinion, worse than disallowing |
| 181 | + generators and Iterables. |
| 182 | + |
| 183 | +``torchdata.nodes.BaseNode`` implementations are Iterators. Iterators |
| 184 | +define ``next()``, ``get_state()``, and ``reset(initial_state | None)``. |
| 185 | +All re-initialization should be done in reset(), including initializing |
| 186 | +with a particular state if one is passed. |
| 187 | + |
| 188 | +However, end-users are used to dealing with Iterables, for example, |
| 189 | + |
| 190 | +:: |
| 191 | + |
| 192 | + for epoch in range(5): |
| 193 | + # Most frameworks and users don't expect to call loader.reset() |
| 194 | + for batch in loader: |
| 195 | + ... |
| 196 | + sd = loader.state_dict() |
| 197 | + # Loading sd should not throw StopIteration right away, but instead start at the next epoch |
| 198 | + |
| 199 | +To handle this we keep all of the assumptions and special end-of-epoch |
| 200 | +handling in a single ``Loader`` class which takes any BaseNode and makes |
| 201 | +it an Iterable, handling the reset() calls and end-of-epoch state_dict |
| 202 | +loading. |
0 commit comments