Skip to content

Commit 1cc566a

Browse files
author
Mathias Burger
committed
Migrate docstrings to doctests
1 parent 1472157 commit 1cc566a

File tree

6 files changed

+225
-75
lines changed

6 files changed

+225
-75
lines changed

torchdata/dataloader2/adapter.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ class Shuffle(Adapter):
6767
dp = IterableWrapper(range(size)).shuffle()
6868
dl = DataLoader2(dp, [Shuffle(False)])
6969
assert list(range(size)) == list(dl)
70-
7170
"""
7271

7372
def __init__(self, enable=True):
@@ -86,7 +85,19 @@ class CacheTimeout(Adapter):
8685
timeout: int - amount of seconds parallel processes will wait for cached files to appear.
8786
8887
Example:
89-
>>> dl = DataLoader2(dp, [CacheTimeout(600)])
88+
89+
.. testsetup::
90+
91+
from torchdata.datapipes.iter import IterableWrapper
92+
from torchdata.dataloader2 import DataLoader2
93+
from torchdata.dataloader2.adapter import CacheTimeout
94+
95+
size = 12
96+
97+
.. testcode::
98+
99+
dp = IterableWrapper(range(size)).shuffle()
100+
dl = DataLoader2(dp, [CacheTimeout(600)])
90101
"""
91102

92103
def __init__(self, timeout=None):

torchdata/datapipes/iter/load/fsspec.py

+46-12
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,16 @@ class FSSpecFileListerIterDataPipe(IterDataPipe[str]):
4747
e.g. host, port, username, password, etc.
4848
4949
Example:
50-
>>> from torchdata.datapipes.iter import FSSpecFileLister
51-
>>> datapipe = FSSpecFileLister(root=dir_path)
50+
51+
.. testsetup::
52+
53+
dir_path = "path"
54+
55+
.. testcode::
56+
57+
from torchdata.datapipes.iter import FSSpecFileLister
58+
59+
datapipe = FSSpecFileLister(root=dir_path)
5260
"""
5361

5462
def __init__(
@@ -127,9 +135,17 @@ class FSSpecFileOpenerIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
127135
e.g. host, port, username, password, etc.
128136
129137
Example:
130-
>>> from torchdata.datapipes.iter import FSSpecFileLister
131-
>>> datapipe = FSSpecFileLister(root=dir_path)
132-
>>> file_dp = datapipe.open_files_by_fsspec()
138+
139+
.. testsetup::
140+
141+
dir_path = "path"
142+
143+
.. testcode::
144+
145+
from torchdata.datapipes.iter import FSSpecFileLister
146+
147+
datapipe = FSSpecFileLister(root=dir_path)
148+
file_dp = datapipe.open_files_by_fsspec()
133149
"""
134150

135151
def __init__(
@@ -169,13 +185,31 @@ class FSSpecSaverIterDataPipe(IterDataPipe[str]):
169185
170186
171187
Example:
172-
>>> from torchdata.datapipes.iter import IterableWrapper
173-
>>> def filepath_fn(name: str) -> str:
174-
>>> return dir_path + name
175-
>>> name_to_data = {"1.txt": b"DATA1", "2.txt": b"DATA2", "3.txt": b"DATA3"}
176-
>>> source_dp = IterableWrapper(sorted(name_to_data.items()))
177-
>>> fsspec_saver_dp = source_dp.save_by_fsspec(filepath_fn=filepath_fn, mode="wb")
178-
>>> res_file_paths = list(fsspec_saver_dp)
188+
189+
.. testsetup::
190+
191+
file_prefix = "file"
192+
193+
.. testcode::
194+
195+
from torchdata.datapipes.iter import IterableWrapper
196+
197+
198+
def filepath_fn(name: str) -> str:
199+
return file_prefix + name
200+
201+
202+
name_to_data = {"1.txt": b"DATA1", "2.txt": b"DATA2", "3.txt": b"DATA3"}
203+
source_dp = IterableWrapper(sorted(name_to_data.items()))
204+
fsspec_saver_dp = source_dp.save_by_fsspec(filepath_fn=filepath_fn, mode="wb")
205+
res_file_paths = list(fsspec_saver_dp)
206+
207+
.. testcleanup::
208+
209+
import os
210+
211+
for name in name_to_data.keys():
212+
os.remove(file_prefix + name)
179213
"""
180214

181215
def __init__(

torchdata/datapipes/iter/load/iopath.py

+46-12
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,16 @@ class IoPathFileListerIterDataPipe(IterDataPipe[str]):
5555
S3 URL is supported only with ``iopath``>=0.1.9.
5656
5757
Example:
58-
>>> from torchdata.datapipes.iter import IoPathFileLister
59-
>>> datapipe = IoPathFileLister(root=S3URL)
58+
59+
.. testsetup::
60+
61+
s3_url = "path"
62+
63+
.. testcode::
64+
65+
from torchdata.datapipes.iter import IoPathFileLister
66+
67+
datapipe = IoPathFileLister(root=s3_url)
6068
"""
6169

6270
def __init__(
@@ -113,9 +121,17 @@ class IoPathFileOpenerIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
113121
S3 URL is supported only with `iopath`>=0.1.9.
114122
115123
Example:
116-
>>> from torchdata.datapipes.iter import IoPathFileLister
117-
>>> datapipe = IoPathFileLister(root=S3URL)
118-
>>> file_dp = datapipe.open_files_by_iopath()
124+
125+
.. testsetup::
126+
127+
s3_url = "path"
128+
129+
.. testcode::
130+
131+
from torchdata.datapipes.iter import IoPathFileLister
132+
133+
datapipe = IoPathFileLister(root=s3_url)
134+
file_dp = datapipe.open_files_by_iopath()
119135
"""
120136

121137
def __init__(self, source_datapipe: IterDataPipe[str], mode: str = "r", pathmgr=None) -> None:
@@ -161,13 +177,31 @@ class IoPathSaverIterDataPipe(IterDataPipe[str]):
161177
S3 URL is supported only with `iopath`>=0.1.9.
162178
163179
Example:
164-
>>> from torchdata.datapipes.iter import IterableWrapper
165-
>>> def filepath_fn(name: str) -> str:
166-
>>> return S3URL + name
167-
>>> name_to_data = {"1.txt": b"DATA1", "2.txt": b"DATA2", "3.txt": b"DATA3"}
168-
>>> source_dp = IterableWrapper(sorted(name_to_data.items()))
169-
>>> iopath_saver_dp = source_dp.save_by_iopath(filepath_fn=filepath_fn, mode="wb")
170-
>>> res_file_paths = list(iopath_saver_dp)
180+
181+
.. testsetup::
182+
183+
s3_url = "url"
184+
185+
.. testcode::
186+
187+
from torchdata.datapipes.iter import IterableWrapper
188+
189+
190+
def filepath_fn(name: str) -> str:
191+
return s3_url + name
192+
193+
194+
name_to_data = {"1.txt": b"DATA1", "2.txt": b"DATA2", "3.txt": b"DATA3"}
195+
source_dp = IterableWrapper(sorted(name_to_data.items()))
196+
iopath_saver_dp = source_dp.save_by_iopath(filepath_fn=filepath_fn, mode="wb")
197+
res_file_paths = list(iopath_saver_dp)
198+
199+
.. testcleanup::
200+
201+
import os
202+
203+
for file in ["1.txt", "1.txt.lock", "2.txt", "2.txt.lock", "3.txt", "3.txt.lock"]:
204+
os.remove(s3_url + file)
171205
"""
172206

173207
def __init__(

torchdata/datapipes/iter/load/online.py

+61-32
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,25 @@ class HTTPReaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
5454
**kwargs: a Dictionary to pass optional arguments that requests takes. For the full list check out https://docs.python-requests.org/en/master/api/
5555
5656
Example:
57-
>>> from torchdata.datapipes.iter import IterableWrapper, HttpReader
58-
>>> file_url = "https://raw.githubusercontent.com/pytorch/data/main/LICENSE"
59-
>>> query_params = {"auth" : ("fake_username", "fake_password"), "allow_redirects" : True}
60-
>>> timeout = 120
61-
>>> http_reader_dp = HttpReader(IterableWrapper([file_url]), timeout=timeout, query_params)
62-
>>> reader_dp = http_reader_dp.readlines()
63-
>>> it = iter(reader_dp)
64-
>>> path, line = next(it)
65-
>>> path
66-
https://raw.githubusercontent.com/pytorch/data/main/LICENSE
67-
>>> line
68-
b'BSD 3-Clause License'
57+
58+
.. testcode::
59+
60+
from torchdata.datapipes.iter import IterableWrapper, HttpReader
61+
62+
file_url = "https://raw.githubusercontent.com/pytorch/data/main/LICENSE"
63+
query_params = {"auth" : ("fake_username", "fake_password"), "allow_redirects" : True}
64+
timeout = 120
65+
http_reader_dp = HttpReader(IterableWrapper([file_url]), timeout=timeout, **query_params)
66+
reader_dp = http_reader_dp.readlines()
67+
it = iter(reader_dp)
68+
path, line = next(it)
69+
print((path, line))
70+
71+
Output:
72+
73+
.. testoutput::
74+
75+
('https://raw.githubusercontent.com/pytorch/data/main/LICENSE', b'BSD 3-Clause License')
6976
"""
7077

7178
def __init__(
@@ -154,16 +161,31 @@ class GDriveReaderDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
154161
**kwargs: a Dictionary to pass optional arguments that requests takes. For the full list check out https://docs.python-requests.org/en/master/api/
155162
156163
Example:
157-
>>> from torchdata.datapipes.iter import IterableWrapper, GDriveReader
158-
>>> gdrive_file_url = "https://drive.google.com/uc?export=download&id=SomeIDToAGDriveFile"
159-
>>> gdrive_reader_dp = GDriveReader(IterableWrapper([gdrive_file_url]))
160-
>>> reader_dp = gdrive_reader_dp.readlines()
161-
>>> it = iter(reader_dp)
162-
>>> path, line = next(it)
163-
>>> path
164-
https://drive.google.com/uc?export=download&id=SomeIDToAGDriveFile
165-
>>> line
166-
<First line from the GDrive File>
164+
165+
.. testsetup::
166+
167+
from torchdata.datapipes.iter import GDriveReader
168+
169+
GDriveReader.readlines = lambda self: [
170+
("https://drive.google.com/uc?export=download&id=SomeIDToAGDriveFile", b"<First line from the GDrive File>")
171+
]
172+
173+
.. testcode::
174+
175+
from torchdata.datapipes.iter import IterableWrapper, GDriveReader
176+
177+
gdrive_file_url = "https://drive.google.com/uc?export=download&id=SomeIDToAGDriveFile"
178+
gdrive_reader_dp = GDriveReader(IterableWrapper([gdrive_file_url]))
179+
reader_dp = gdrive_reader_dp.readlines()
180+
it = iter(reader_dp)
181+
path, line = next(it)
182+
print((path, line))
183+
184+
Output:
185+
186+
.. testoutput::
187+
188+
('https://drive.google.com/uc?export=download&id=SomeIDToAGDriveFile', b'<First line from the GDrive File>')
167189
"""
168190
source_datapipe: IterDataPipe[str]
169191

@@ -207,16 +229,23 @@ class OnlineReaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
207229
**kwargs: a Dictionary to pass optional arguments that requests takes. For the full list check out https://docs.python-requests.org/en/master/api/
208230
209231
Example:
210-
>>> from torchdata.datapipes.iter import IterableWrapper, OnlineReader
211-
>>> file_url = "https://raw.githubusercontent.com/pytorch/data/main/LICENSE"
212-
>>> online_reader_dp = OnlineReader(IterableWrapper([file_url]))
213-
>>> reader_dp = online_reader_dp.readlines()
214-
>>> it = iter(reader_dp)
215-
>>> path, line = next(it)
216-
>>> path
217-
https://raw.githubusercontent.com/pytorch/data/main/LICENSE
218-
>>> line
219-
b'BSD 3-Clause License'
232+
233+
.. testcode::
234+
235+
from torchdata.datapipes.iter import IterableWrapper, OnlineReader
236+
237+
file_url = "https://raw.githubusercontent.com/pytorch/data/main/LICENSE"
238+
online_reader_dp = OnlineReader(IterableWrapper([file_url]))
239+
reader_dp = online_reader_dp.readlines()
240+
it = iter(reader_dp)
241+
path, line = next(it)
242+
print((path, line))
243+
244+
Output:
245+
246+
.. testoutput::
247+
248+
('https://raw.githubusercontent.com/pytorch/data/main/LICENSE', b'BSD 3-Clause License')
220249
"""
221250
source_datapipe: IterDataPipe[str]
222251

torchdata/datapipes/iter/load/s3io.py

+53-17
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,33 @@ class S3FileListerIterDataPipe(IterDataPipe[str]):
4040
region: region for access files (inferred from credentials by default)
4141
4242
Example:
43-
>>> from torchdata.datapipes.iter import IterableWrapper, S3FileLister
44-
>>> s3_prefixes = IterableWrapper(['s3://bucket-name/folder/', ...])
45-
>>> dp_s3_urls = S3FileLister(s3_prefixes)
46-
>>> for d in dp_s3_urls:
47-
... pass
43+
44+
.. testsetup::
45+
46+
from unittest import mock
47+
from torchdata.datapipes.iter import IterableWrapper, S3FileLister
48+
49+
file_lister_patch = mock.patch.object(S3FileLister, "__iter__", return_value=iter([]))
50+
file_lister_patch.start()
51+
52+
.. testcode::
53+
54+
from torchdata.datapipes.iter import IterableWrapper, S3FileLister
55+
56+
s3_prefixes = IterableWrapper(['s3://bucket-name/folder/', ...])
57+
58+
dp_s3_urls = S3FileLister(s3_prefixes)
59+
for d in dp_s3_urls:
60+
pass
61+
4862
# Functional API
49-
>>> dp_s3_urls = s3_prefixes.list_files_by_s3(request_timeout_ms=100)
50-
>>> for d in dp_s3_urls:
51-
... pass
63+
dp_s3_urls = s3_prefixes.list_files_by_s3(request_timeout_ms=100)
64+
for d in dp_s3_urls:
65+
pass
66+
67+
.. testcleanup::
68+
69+
file_lister_patch.stop()
5270
"""
5371

5472
def __init__(
@@ -108,20 +126,38 @@ class S3FileLoaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
108126
multi_part_download: flag to split each chunk into small packets and download those packets in parallel (enabled by default)
109127
110128
Example:
111-
>>> from torchdata.datapipes.iter import IterableWrapper, S3FileLoader
112-
>>> dp_s3_urls = IterableWrapper(['s3://bucket-name/folder/', ...]).list_files_by_s3()
129+
130+
.. testsetup::
131+
132+
from unittest import mock
133+
from torchdata.datapipes.iter import S3FileLister
134+
135+
file_lister_patch = mock.patch.object(S3FileLister, "__iter__", return_value=iter([]))
136+
file_lister_patch.start()
137+
138+
.. testcode::
139+
140+
from torchdata.datapipes.iter import IterableWrapper, S3FileLoader
141+
142+
dp_s3_urls = IterableWrapper(['s3://bucket-name/folder/', ...]).list_files_by_s3()
113143
# In order to make sure data are shuffled and sharded in the
114144
# distributed environment, `shuffle` and `sharding_filter`
115145
# are required. For detail, please check our tutorial in:
116146
# https://pytorch.org/data/main/tutorial.html#working-with-dataloader
117-
>>> sharded_s3_urls = dp_s3_urls.shuffle().sharding_filter()
118-
>>> dp_s3_files = S3FileLoader(sharded_s3_urls)
119-
>>> for url, fd in dp_s3_files: # Start loading data
120-
... data = fd.read()
147+
sharded_s3_urls = dp_s3_urls.shuffle().sharding_filter()
148+
149+
dp_s3_files = S3FileLoader(sharded_s3_urls)
150+
for url, fd in dp_s3_files: # Start loading data
151+
data = fd.read()
152+
121153
# Functional API
122-
>>> dp_s3_files = sharded_s3_urls.load_files_by_s3(buffer_size=256)
123-
>>> for url, fd in dp_s3_files:
124-
... data = fd.read()
154+
dp_s3_files = sharded_s3_urls.load_files_by_s3(buffer_size=256)
155+
for url, fd in dp_s3_files:
156+
data = fd.read()
157+
158+
.. testcleanup::
159+
160+
file_lister_patch.stop()
125161
"""
126162

127163
def __init__(

0 commit comments

Comments
 (0)