Skip to content

Commit 265d317

Browse files
authored
Cherry pick docs changes from main branch (#1304)
* fix docs, make sure docs build (#1302) * fix docs, make sure docs build * adding stateful dataloader docs * Add stateful dataloader tutorial docs (#1303)
1 parent ba35881 commit 265d317

File tree

7 files changed

+241
-30
lines changed

7 files changed

+241
-30
lines changed

docs/Makefile

+27-9
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,41 @@
11
# Minimal makefile for Sphinx documentation
22
#
33

4-
# You can set these variables from the command line, and also
5-
# from the environment for the first two.
6-
SPHINXOPTS ?=
7-
SPHINXBUILD ?= sphinx-build
4+
ifneq ($(EXAMPLES_PATTERN),)
5+
EXAMPLES_PATTERN_OPTS := -D sphinx_gallery_conf.filename_pattern="$(EXAMPLES_PATTERN)"
6+
endif
7+
8+
# You can set these variables from the command line.
9+
SPHINXOPTS = -W -j auto $(EXAMPLES_PATTERN_OPTS)
10+
SPHINXBUILD = sphinx-build
11+
SPHINXPROJ = torchdata
812
SOURCEDIR = source
913
BUILDDIR = build
1014

1115
# Put it first so that "make" without argument is like "make help".
1216
help:
1317
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
1418

15-
doctest: html
16-
$(SPHINXBUILD) -b doctest $(SPHINXOPTS) "$(SOURCEDIR)" "$(BUILDDIR)"/doctest
17-
@echo "Testing of doctests in the sources finished, look at the " \
18-
"results in $(BUILDDIR)/doctest/output.txt."
19+
docset: html
20+
doc2dash --name $(SPHINXPROJ) --icon $(SOURCEDIR)/_static/img/pytorch-logo-flame.png --enable-js --online-redirect-url http://pytorch.org/data/ --force $(BUILDDIR)/html/
21+
22+
# Manually fix because Zeal doesn't deal well with `icon.png`-only at 2x resolution.
23+
cp $(SPHINXPROJ).docset/icon.png $(SPHINXPROJ).docset/[email protected]
24+
convert $(SPHINXPROJ).docset/[email protected] -resize 16x16 $(SPHINXPROJ).docset/icon.png
25+
26+
html-noplot: # Avoids running the gallery examples, which may take time
27+
$(SPHINXBUILD) -D plot_gallery=0 -b html "${SOURCEDIR}" "$(BUILDDIR)"/html
28+
@echo
29+
@echo "Build finished. The HTML pages are in $(BUILDDIR)/html."
30+
31+
clean:
32+
rm -rf $(BUILDDIR)/*
33+
rm -rf $(SOURCEDIR)/generated_examples/ # sphinx-gallery
34+
rm -rf $(SOURCEDIR)/gen_modules/ # sphinx-gallery
35+
rm -rf $(SOURCEDIR)/sg_execution_times.rst # sphinx-gallery
36+
rm -rf $(SOURCEDIR)/generated/ # autosummary
1937

20-
.PHONY: help doctest Makefile
38+
.PHONY: help Makefile docset
2139

2240
# Catch-all target: route all unknown targets to Sphinx using the new
2341
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).

docs/source/dp_tutorial.rst

+11-11
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ Accessing AWS S3 with ``fsspec`` DataPipes
321321
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
322322

323323
This requires the installation of the libraries ``fsspec``
324-
(`documentation <https://filesystem-spec.readthedocs.io/en/latest/>`_) and ``s3fs``
324+
(`documentation <https://filesystem-spec.readthedocs.io/en/latest/>`__) and ``s3fs``
325325
(`s3fs GitHub repo <https://github.com/fsspec/s3fs>`_).
326326

327327
You can list out the files within a S3 bucket directory by passing a path that starts
@@ -363,7 +363,7 @@ is also available for writing data to cloud.
363363
Accessing Google Cloud Storage (GCS) with ``fsspec`` DataPipes
364364
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
365365
This requires the installation of the libraries ``fsspec``
366-
(`documentation <https://filesystem-spec.readthedocs.io/en/latest/>`_) and ``gcsfs``
366+
(`documentation <https://filesystem-spec.readthedocs.io/en/latest/>`__) and ``gcsfs``
367367
(`gcsfs GitHub repo <https://github.com/fsspec/gcsfs>`_).
368368

369369
You can list out the files within a GCS bucket directory by specifying a path that starts
@@ -400,11 +400,11 @@ Accessing Azure Blob storage with ``fsspec`` DataPipes
400400
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
401401

402402
This requires the installation of the libraries ``fsspec``
403-
(`documentation <https://filesystem-spec.readthedocs.io/en/latest/>`_) and ``adlfs``
403+
(`documentation <https://filesystem-spec.readthedocs.io/en/latest/>`__) and ``adlfs``
404404
(`adlfs GitHub repo <https://github.com/fsspec/adlfs>`_).
405-
You can access data in Azure Data Lake Storage Gen2 by providing URIs staring with ``abfs://``.
405+
You can access data in Azure Data Lake Storage Gen2 by providing URIs staring with ``abfs://``.
406406
For example,
407-
`FSSpecFileLister <generated/torchdata.datapipes.iter.FSSpecFileLister.html>`_ (``.list_files_by_fsspec(...)``)
407+
`FSSpecFileLister <generated/torchdata.datapipes.iter.FSSpecFileLister.html>`_ (``.list_files_by_fsspec(...)``)
408408
can be used to list files in a directory in a container:
409409

410410
.. code:: python
@@ -430,11 +430,11 @@ directory ``curated/covid-19/ecdc_cases/latest``, belonging to account ``pandemi
430430
.open_files_by_fsspec(account_name='pandemicdatalake') \
431431
.parse_csv()
432432
print(list(dp)[:3])
433-
# [['date_rep', 'day', ..., 'iso_country', 'daterep'],
433+
# [['date_rep', 'day', ..., 'iso_country', 'daterep'],
434434
# ['2020-12-14', '14', ..., 'AF', '2020-12-14'],
435435
# ['2020-12-13', '13', ..., 'AF', '2020-12-13']]
436436
437-
If necessary, you can also access data in Azure Data Lake Storage Gen1 by using URIs staring with
437+
If necessary, you can also access data in Azure Data Lake Storage Gen1 by using URIs staring with
438438
``adl://`` and ``abfs://``, as described in `README of adlfs repo <https://github.com/fsspec/adlfs/blob/main/README.md>`_
439439

440440
Accessing Azure ML Datastores with ``fsspec`` DataPipes
@@ -446,11 +446,11 @@ An Azure ML datastore is a *reference* to an existing storage account on Azure.
446446
- Authentication is automatically handled - both *credential-based* access (service principal/SAS/key) and *identity-based* access (Azure Active Directory/managed identity) are supported. When using credential-based authentication, you do not need to expose secrets in your code.
447447

448448
This requires the installation of the library ``azureml-fsspec``
449-
(`documentation <https://learn.microsoft.com/python/api/azureml-fsspec/?view=azure-ml-py>`_).
449+
(`documentation <https://learn.microsoft.com/python/api/azureml-fsspec/?view=azure-ml-py>`__).
450450

451-
You can access data in an Azure ML datastore by providing URIs staring with ``azureml://``.
451+
You can access data in an Azure ML datastore by providing URIs staring with ``azureml://``.
452452
For example,
453-
`FSSpecFileLister <generated/torchdata.datapipes.iter.FSSpecFileLister.html>`_ (``.list_files_by_fsspec(...)``)
453+
`FSSpecFileLister <generated/torchdata.datapipes.iter.FSSpecFileLister.html>`_ (``.list_files_by_fsspec(...)``)
454454
can be used to list files in a directory in a container:
455455

456456
.. code:: python
@@ -470,7 +470,7 @@ can be used to list files in a directory in a container:
470470
471471
dp = IterableWrapper([uri]).list_files_by_fsspec()
472472
print(list(dp))
473-
# ['azureml:///<sub_id>/resourcegroups/<rg_name>/workspaces/<ws_name>/datastores/<datastore>/paths/<folder>/file1.txt',
473+
# ['azureml:///<sub_id>/resourcegroups/<rg_name>/workspaces/<ws_name>/datastores/<datastore>/paths/<folder>/file1.txt',
474474
# 'azureml:///<sub_id>/resourcegroups/<rg_name>/workspaces/<ws_name>/datastores/<datastore>/paths/<folder>/file2.txt', ...]
475475
476476
You can also open files using `FSSpecFileOpener <generated/torchdata.datapipes.iter.FSSpecFileOpener.html>`_

docs/source/index.rst

+2
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ Features described in this documentation are classified by release status:
3636
:maxdepth: 2
3737
:caption: API Reference:
3838

39+
torchdata.stateful_dataloader.rst
3940
torchdata.datapipes.iter.rst
4041
torchdata.datapipes.map.rst
4142
torchdata.datapipes.utils.rst
@@ -47,6 +48,7 @@ Features described in this documentation are classified by release status:
4748
:maxdepth: 2
4849
:caption: Tutorial and Examples:
4950

51+
stateful_dataloader_tutorial.rst
5052
dp_tutorial.rst
5153
dlv2_tutorial.rst
5254
examples.rst
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
Stateful DataLoader Tutorial
2+
============================
3+
4+
Saving and loading state
5+
------------------------
6+
7+
Stateful DataLoader adds the ``load_state_dict``, ``state_dict`` methods to the ``torch.utils.data.DataLoader``. State fetch and set can be done as follows:
8+
9+
.. code:: python
10+
11+
from torchdata.stateful_dataloader import StatefulDataLoader
12+
13+
dataloader = StatefulDataLoader(dataset, num_workers=2)
14+
for i, batch in enumerate(dataloader):
15+
...
16+
if i == 10:
17+
state_dict = dataloader.state_dict()
18+
break
19+
20+
# Training run resumes with the previous checkpoint
21+
dataloader = StatefulDataLoader(dataset, num_workers=2)
22+
# Resume state with DataLoader
23+
dataloader.load_state_dict(state_dict)
24+
for i, batch in enumerate(dataloader):
25+
...
26+
27+
Saving Custom State with Map-Style Datasets
28+
-------------------------------------------
29+
30+
For efficient resuming of `Map-style datasets <https://pytorch.org/docs/stable/data.html#map-style-datasets>`_, you can resume iteration by defining ``state_dict`` / ``load_state_dict`` methods in your sampler. If your dataset has worker-specific state (eg RNG transform state) you can add ``state_dict`` / ``load_state_dict`` methods to your dataset.
31+
32+
.. code:: python
33+
34+
from typing import *
35+
import torch
36+
import torch.utils.data
37+
from torchdata.stateful_dataloader import StatefulDataLoader
38+
39+
# If you are using the default RandomSampler and BatchSampler in torch.utils.data, they are patched when you import torchdata.stateful_dataloader so that defining, a custom sampler here is unnecessary
40+
class MySampler(torch.utils.data.Sampler[int]):
41+
def __init__(self, high: int, seed: int, limit: int):
42+
self.seed, self.high, self.limit = seed, high, limit
43+
self.g = torch.Generator()
44+
self.g.manual_seed(self.seed)
45+
self.i = 0
46+
47+
def __iter__(self):
48+
while self.i < self.limit:
49+
val = int(torch.randint(high=self.high, size=(1,), generator=self.g))
50+
self.i += 1
51+
yield val
52+
53+
def load_state_dict(self, state_dict: Dict[str, Any]):
54+
self.i = state_dict["i"]
55+
self.g.set_state(state_dict["rng"])
56+
57+
def state_dict(self) -> Dict[str, Any]:
58+
return {"i": self.i, "rng": self.g.get_state()}
59+
60+
# Optional: save dataset random transform state
61+
class NoisyRange(torch.utils.data.Dataset):
62+
def __init__(self, high: int, mean: float, std: float):
63+
self.high, self.mean, self.std = high, torch.tensor([float(mean)]), float(std)
64+
65+
def __len__(self):
66+
return self.high
67+
68+
def __getitem__(self, idx: int) -> float:
69+
if not (0 <= idx < self.high):
70+
raise IndexError()
71+
x = torch.normal(self.mean, self.std)
72+
noise = x.item()
73+
return idx + noise
74+
75+
def load_state_dict(self, state_dict):
76+
torch.set_rng_state(state_dict["rng"])
77+
78+
def state_dict(self):
79+
return {"rng": torch.get_rng_state()}
80+
81+
# Test both single/multiprocess dataloading
82+
for num_workers in [0, 2]:
83+
print(f"{num_workers=}")
84+
dl = StatefulDataLoader(NoisyRange(5, 1, 1), sampler=MySampler(5, 1, 10),
85+
batch_size=2, drop_last=False, num_workers=num_workers)
86+
87+
batches = []
88+
for i, batch in enumerate(dl):
89+
batches.append(batch)
90+
if i == 2:
91+
sd = dl.state_dict()
92+
93+
dl.load_state_dict(sd)
94+
batches2 = list(dl)
95+
96+
print(batches[3:])
97+
print(batches2)
98+
99+
"""
100+
Output:
101+
num_workers=0
102+
[tensor([-0.4526, 3.7948], dtype=torch.float64), tensor([6.5494, 3.0470], dtype=torch.float64)]
103+
[tensor([-0.4526, 3.7948], dtype=torch.float64), tensor([6.5494, 3.0470], dtype=torch.float64)]
104+
num_workers=2
105+
[tensor([3.7412, 1.2438], dtype=torch.float64), tensor([4.4807, 4.0036], dtype=torch.float64)]
106+
[tensor([3.7412, 1.2438], dtype=torch.float64), tensor([4.4807, 4.0036], dtype=torch.float64)]
107+
"""
108+
109+
Saving Custom State with Iterable-Style Datasets
110+
------------------------------------------------
111+
112+
Tracking iteration order with `Iterable-style datasets <https://pytorch.org/docs/stable/data.html#iterable-style-datasets>`_ requires state from each worker-level instance of the dataset to be captured. You can define ``state_dict`` / ``load_state_dict`` methods on your dataset which capture worker-level state. :class:`StatefulDataLoader` will handle aggregation across workers and distribution back to the workers. Calling ``load_state_dict`` requires :class:`StatefulDataLoader`` to have same ``num_workers`` as those of the provided ``state_dict``.
113+
114+
.. code:: python
115+
116+
from typing import *
117+
import torch
118+
import torch.utils.data
119+
from torchdata.stateful_dataloader import StatefulDataLoader
120+
121+
122+
class MyIterableDataset(torch.utils.data.IterableDataset):
123+
def __init__(self, high: int, seed: int):
124+
self.high, self.seed = high, seed
125+
self.g = torch.Generator()
126+
self.i = 0
127+
128+
def __iter__(self):
129+
worker_info = torch.utils.data.get_worker_info()
130+
if worker_info is not None:
131+
worker_id = worker_info.id
132+
num_workers = worker_info.num_workers
133+
else:
134+
worker_id = 0
135+
num_workers = 1
136+
self.g.manual_seed(self.seed)
137+
arr = torch.randperm(self.high, generator=self.g)
138+
arr = arr[worker_id:self.high:num_workers]
139+
for idx in range(self.i, len(arr)):
140+
self.i += 1
141+
yield arr[idx]
142+
self.i = 0
143+
144+
def state_dict(self):
145+
return {"i": self.i}
146+
147+
def load_state_dict(self, state_dict):
148+
self.i = state_dict["i"]
149+
150+
# Test both single/multiprocess dataloading
151+
for num_workers in [0, 2]:
152+
print(f"{num_workers=}")
153+
dl = StatefulDataLoader(
154+
MyIterableDataset(12, 0), batch_size=2, drop_last=False,
155+
num_workers=num_workers)
156+
157+
batches = []
158+
for i, batch in enumerate(dl):
159+
batches.append(batch)
160+
if i == 2:
161+
sd = dl.state_dict()
162+
163+
dl.load_state_dict(sd)
164+
batches2 = list(dl)
165+
166+
print(batches[3:])
167+
print(batches2)
168+
169+
"""
170+
Output:
171+
num_workers=0
172+
[tensor([ 2, 10]), tensor([3, 1]), tensor([11, 6])]
173+
[tensor([ 2, 10]), tensor([3, 1]), tensor([11, 6])]
174+
num_workers=2
175+
[tensor([ 4, 10]), tensor([ 3, 11]), tensor([1, 6])]
176+
[tensor([ 4, 10]), tensor([ 3, 11]), tensor([1, 6])]
177+
"""
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
:tocdepth: 3
2+
3+
Stateful DataLoader
4+
===================
5+
6+
.. automodule:: torchdata.stateful_dataloader
7+
8+
StatefulDataLoader is a drop-in replacement for `torch.utils.data.DataLoader <https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader>`_ which offers ``state_dict`` / ``load_state_dict`` methods for handling mid-epoch checkpointing which operate on the previous/next iterator requested from the dataloader (resp.).
9+
10+
By default, the state includes the number of batches yielded and uses this to naively fast-forward the sampler (map-style) or the dataset (iterable-style). However if the sampler and/or dataset include ``state_dict`` / ``load_state_dict`` methods, then it will call them during its own ``state_dict`` / ``load_state_dict`` calls. Under the hood, :class:`StatefulDataLoader` handles aggregation and distribution of state across multiprocess workers (but not across ranks).
11+
12+
.. autoclass:: StatefulDataLoader
13+
:members:

torchdata/dataloader2/reading_service.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def __new__(cls, *args, **kwargs):
149149

150150
class InProcessReadingService(ReadingServiceInterface):
151151
r"""
152-
Default ReadingService to serve the ``DataPipe` graph in the main process,
152+
Default ReadingService to serve the ``DataPipe`` graph in the main process,
153153
and apply graph settings like determinism control to the graph.
154154
155155
Args:

torchdata/stateful_dataloader/stateful_dataloader.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,12 @@
9292

9393
class StatefulDataLoader(DataLoader[_T_co]):
9494
r"""
95-
This is a drop in replacement for :class:`~torch.utils.data.DataLoader`
95+
This is a drop in replacement for ``torch.utils.data.DataLoader``
9696
that implements state_dict and load_state_dict methods, enabling mid-epoch
9797
checkpointing.
9898
99-
All arguments are identical to :class:`~torch.utils.data.DataLoader`, with
100-
a new kwarg: `snapshot_every_n_steps: Optional[int] = `.
101-
See :py:mod:`torch.utils.data` documentation page for more details.
99+
All arguments are identical to ``torch.utils.data.DataLoader``, with
100+
a new kwarg: ``snapshot_every_n_steps``.
102101
103102
Args:
104103
dataset (Dataset): dataset from which to load the data.
@@ -148,11 +147,13 @@ class StatefulDataLoader(DataLoader[_T_co]):
148147
maintain the workers `Dataset` instances alive. (default: ``False``)
149148
pin_memory_device (str, optional): the device to :attr:`pin_memory` to if ``pin_memory`` is
150149
``True``.
150+
snapshot_every_n_steps (int, optional): Defines how often the state is
151+
transferred from the dataloader workers to the dataloader. By default, it is set to ``1``, i.e., state is transferred every step. If the state is large, this value can be increased (and ideally set to the frequency of training checkpointing) to reduce the overhead of transferring state every step.
151152
152153
153154
.. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn`
154155
cannot be an unpicklable object, e.g., a lambda function. See
155-
:ref:`multiprocessing-best-practices` on more details related
156+
`multiprocessing-best-practices <https://pytorch.org/docs/stable/notes/multiprocessing.html#multiprocessing-best-practices>`_ on more details related
156157
to multiprocessing in PyTorch.
157158
158159
.. warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used.
@@ -169,12 +170,12 @@ class StatefulDataLoader(DataLoader[_T_co]):
169170
dropped when :attr:`drop_last` is set. Unfortunately, PyTorch can not detect such
170171
cases in general.
171172
172-
See `Dataset Types`_ for more details on these two types of datasets and how
173+
See `Dataset Types <https://pytorch.org/docs/stable/data.html>`_ for more details on these two types of datasets and how
173174
:class:`~torch.utils.data.IterableDataset` interacts with
174-
`Multi-process data loading`_.
175+
`Multi-process data loading <https://pytorch.org/docs/stable/data.html#multi-process-data-loading>`_.
175176
176-
.. warning:: See :ref:`reproducibility`, and :ref:`dataloader-workers-random-seed`, and
177-
:ref:`data-loading-randomness` notes for random seed related questions.
177+
.. warning:: See `Reproducibility <https://pytorch.org/docs/stable/notes/randomness.html#reproducibility>`_, and `Dataloader-workers-random-seed <https://pytorch.org/docs/stable/notes/faq.html#dataloader-workers-random-seed>`_, and
178+
`Data-loading-randomness <https://pytorch.org/docs/stable/data.html#data-loading-randomness>`_ notes for random seed related questions.
178179
179180
.. _multiprocessing context:
180181
https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods

0 commit comments

Comments
 (0)