|
| 1 | +.. _migrate-to-nodes-from-utils: |
| 2 | + |
| 3 | +Migrating to ``torchdata.nodes`` from ``torch.utils.data`` |
| 4 | +========================================================== |
| 5 | + |
| 6 | +This guide is intended to help people familiar with ``torch.utils.data``, or |
| 7 | +:class:`~torchdata.stateful_dataloader.StatefulDataLoader`, |
| 8 | +to get started with ``torchdata.nodes``, and provide a starting ground for defining |
| 9 | +your own dataloading pipelines. |
| 10 | + |
| 11 | +We'll demonstrate how to achieve the most common DataLoader features, re-use existing samplers and datasets, |
| 12 | +and load/save dataloader state. It performs at least as well as ``DataLoader`` and ``StatefulDataLoader``, |
| 13 | +see :ref:`how-does-nodes-perform`. |
| 14 | + |
| 15 | +Map-Style Datasets |
| 16 | +~~~~~~~~~~~~~~~~~~ |
| 17 | + |
| 18 | +Let's look at the ``DataLoader`` constructor args and go from there |
| 19 | + |
| 20 | +.. code:: python |
| 21 | +
|
| 22 | + class DataLoader: |
| 23 | + def __init__( |
| 24 | + self, |
| 25 | + dataset: Dataset[_T_co], |
| 26 | + batch_size: Optional[int] = 1, |
| 27 | + shuffle: Optional[bool] = None, |
| 28 | + sampler: Union[Sampler, Iterable, None] = None, |
| 29 | + batch_sampler: Union[Sampler[List], Iterable[List], None] = None, |
| 30 | + num_workers: int = 0, |
| 31 | + collate_fn: Optional[_collate_fn_t] = None, |
| 32 | + pin_memory: bool = False, |
| 33 | + drop_last: bool = False, |
| 34 | + timeout: float = 0, |
| 35 | + worker_init_fn: Optional[_worker_init_fn_t] = None, |
| 36 | + multiprocessing_context=None, |
| 37 | + generator=None, |
| 38 | + *, |
| 39 | + prefetch_factor: Optional[int] = None, |
| 40 | + persistent_workers: bool = False, |
| 41 | + pin_memory_device: str = "", |
| 42 | + in_order: bool = True, |
| 43 | + ): |
| 44 | + ... |
| 45 | +
|
| 46 | +As a referesher, here is roughly how dataloading works in ``torch.utils.data.DataLoader``: |
| 47 | +``DataLoader`` begins by generating indices from a ``sampler`` and creates batches of `batch_size` indices. |
| 48 | +If no sampler is provided, then a RandomSampler or SequentialSampler is created by default. |
| 49 | +The indices are passed to ``Dataset.__getitem__()``, and then a ``collate_fn`` is applied to the batch |
| 50 | +of samples. If ``num_workers > 0``, it will use multi-processing to create |
| 51 | +subprocesses, and pass the batches of indices to the worker processes, who will then call ``Dataset.__getitem__()`` and apply ``collate_fn`` |
| 52 | +before returning the batches to the main process. At that point, ``pin_memory`` may be applied to the tensors in the batch. |
| 53 | + |
| 54 | +Now let's look at what an equivalent implementation for DataLoader might look like, built with ``torchdata.nodes``. |
| 55 | + |
| 56 | +.. code:: python |
| 57 | +
|
| 58 | + from typing import List, Callable |
| 59 | + import torchdata.nodes as tn |
| 60 | + from torch.utils.data import RandomSampler, SequentialSampler, default_collate, Dataset |
| 61 | +
|
| 62 | + class MapAndCollate: |
| 63 | + """A simple transform that takes a batch of indices, maps with dataset, and then applies |
| 64 | + collate. |
| 65 | + TODO: make this a standard utility in torchdata.nodes |
| 66 | + """ |
| 67 | + def __init__(self, dataset, collate_fn): |
| 68 | + self.dataset = dataset |
| 69 | + self.collate_fn = collate_fn |
| 70 | +
|
| 71 | + def __call__(self, batch_of_indices: List[int]): |
| 72 | + batch = [self.dataset[i] for i in batch_of_indices] |
| 73 | + return self.collate_fn(batch) |
| 74 | +
|
| 75 | + # To keep things simple, let's assume that the following args are provided by the caller |
| 76 | + def NodesDataLoader( |
| 77 | + dataset: Dataset, |
| 78 | + batch_size: int, |
| 79 | + shuffle: bool, |
| 80 | + num_workers: int, |
| 81 | + collate_fn: Callable | None, |
| 82 | + pin_memory: bool, |
| 83 | + drop_last: bool, |
| 84 | + ): |
| 85 | + # Assume we're working with a map-style dataset |
| 86 | + assert hasattr(dataset, "__getitem__") and hasattr(dataset, "__len__") |
| 87 | + # Start with a sampler, since caller did not provide one |
| 88 | + sampler = RandomSampler(dataset) if shuffle else SequentialSampler(dataset) |
| 89 | + # Sampler wrapper converts a Sampler to a BaseNode |
| 90 | + node = tn.SamplerWrapper(sampler) |
| 91 | +
|
| 92 | + # Now let's batch sampler indices together |
| 93 | + node = tn.Batcher(node, batch_size=batch_size, drop_last=drop_last) |
| 94 | +
|
| 95 | + # Create a Map Function that accepts a list of indices, applies getitem to it, and |
| 96 | + # then collates them |
| 97 | + map_and_collate = MapAndCollate(dataset, collate_fn or default_collate) |
| 98 | +
|
| 99 | + # MapAndCollate is doing most of the heavy lifting, so let's parallelize it. We could |
| 100 | + # choose process or thread workers. Note that if you're not using Free-Threaded |
| 101 | + # Python (eg 3.13t) with -Xgil=0, then multi-threading might result in GIL contention, |
| 102 | + # and slow down training. |
| 103 | + node = tn.ParallelMapper( |
| 104 | + node, |
| 105 | + map_fn=map_and_collate, |
| 106 | + num_workers=num_workers, |
| 107 | + method="process", # Set this to "thread" for multi-threading |
| 108 | + in_order=True, |
| 109 | + ) |
| 110 | +
|
| 111 | + # Optionally apply pin-memory, and we usually do some pre-fetching |
| 112 | + if pin_memory: |
| 113 | + node = tn.PinMemory(node) |
| 114 | + node = tn.Prefetcher(node, prefetch_factor=num_workers * 2) |
| 115 | +
|
| 116 | + # Note that node is an iterator, and once it's exhausted, you'll need to call .reset() |
| 117 | + # on it to start a new Epoch. |
| 118 | + # Insteaad, we wrap the node in a Loader, which is an iterable and handles reset. It |
| 119 | + # also provides state_dict and load_state_dict methods. |
| 120 | + return tn.Loader(node) |
| 121 | +
|
| 122 | +Now let's test this out with a trivial dataset, and demonstrate how state management works. |
| 123 | + |
| 124 | +.. code:: python |
| 125 | +
|
| 126 | + class SquaredDataset(Dataset): |
| 127 | + def __init__(self, len: int): |
| 128 | + self.len = len |
| 129 | + def __len__(self): |
| 130 | + return self.len |
| 131 | + def __getitem__(self, i: int) -> int: |
| 132 | + return i**2 |
| 133 | +
|
| 134 | + loader = NodesDataLoader( |
| 135 | + dataset=SquaredDataset(14), |
| 136 | + batch_size=3, |
| 137 | + shuffle=False, |
| 138 | + num_workers=2, |
| 139 | + collate_fn=None, |
| 140 | + pin_memory=False, |
| 141 | + drop_last=False, |
| 142 | + ) |
| 143 | +
|
| 144 | + batches = [] |
| 145 | + for idx, batch in enumerate(loader): |
| 146 | + if idx == 2: |
| 147 | + state_dict = loader.state_dict() |
| 148 | + # Saves the state_dict after batch 2 has been returned |
| 149 | + batches.append(batch) |
| 150 | +
|
| 151 | + loader.load_state_dict(state_dict) |
| 152 | + batches_after_loading = list(loader) |
| 153 | + print(batches[3:]) |
| 154 | + # [tensor([ 81, 100, 121]), tensor([144, 169])] |
| 155 | + print(batches_after_loading) |
| 156 | + # [tensor([ 81, 100, 121]), tensor([144, 169])] |
| 157 | +
|
| 158 | +Let's also compare this to torch.utils.data.DataLoader, as a sanity check. |
| 159 | + |
| 160 | +.. code:: python |
| 161 | +
|
| 162 | + loaderv1 = torch.utils.data.DataLoader( |
| 163 | + dataset=SquaredDataset(14), |
| 164 | + batch_size=3, |
| 165 | + shuffle=False, |
| 166 | + num_workers=2, |
| 167 | + collate_fn=None, |
| 168 | + pin_memory=False, |
| 169 | + drop_last=False, |
| 170 | + persistent_workers=False, # Coming soon to torchdata.nodes! |
| 171 | + ) |
| 172 | + print(list(loaderv1)) |
| 173 | + # [tensor([0, 1, 4]), tensor([ 9, 16, 25]), tensor([36, 49, 64]), tensor([ 81, 100, 121]), tensor([144, 169])] |
| 174 | + print(batches) |
| 175 | + # [tensor([0, 1, 4]), tensor([ 9, 16, 25]), tensor([36, 49, 64]), tensor([ 81, 100, 121]), tensor([144, 169])] |
| 176 | +
|
| 177 | +
|
| 178 | +IterableDatasets |
| 179 | +~~~~~~~~~~~~~~~~ |
| 180 | + |
| 181 | +Coming soon! While you can already plug your IterableDataset into an ``tn.IterableWrapper``, some functions like |
| 182 | +``get_worker_info`` are not currently supported yet. However we believe that often, sharding work between |
| 183 | +multi-process workers is not actually necessary, and you can keep some sort of indexing in the main process while |
| 184 | +only parallelizing some of the heavier transforms, similar to how Map-style Datasets work above. |
0 commit comments