Skip to content

Commit ba48bac

Browse files
NivekTSvenDS9
andauthored
Single key option for Slicer and doc improvements (#1041) (#1060)
Summary: Single key option for Slicer and doc improvements ### Changes - Enable Slicer to also work for a single key + functional test - Fix typos in doc - Add laion-example to examples page Pull Request resolved: #1041 Reviewed By: NivekT Differential Revision: D43622504 Pulled By: ejguan fbshipit-source-id: b656082598f4a790dc457dddb0213a1a180239fd Co-authored-by: SvenDS9 <[email protected]>
1 parent b57545f commit ba48bac

File tree

7 files changed

+47
-20
lines changed

7 files changed

+47
-20
lines changed

docs/source/dp_tutorial.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ into the ``DataLoader``. For detailed documentation related to ``DataLoader``,
7272
please visit `this PyTorch Core page <https://pytorch.org/docs/stable/data.html#single-and-multi-process-data-loading>`_.
7373

7474

75-
Please refer to `this page <dlv2_tutorial.html>`_ about using ``DataPipe`` with ``DataLoader2``.
75+
Please refer to :doc:`this page <dlv2_tutorial>` about using ``DataPipe`` with ``DataLoader2``.
7676

7777

7878
For this example, we will first have a helper function that generates some CSV files with random label and data.

docs/source/examples.rst

+7
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,13 @@ semantic classes. Here is a
7575
<https://github.com/tcapelle/torchdata/blob/main/01_Camvid_segmentation_with_datapipes.ipynb>`_
7676
created by our community.
7777

78+
laion2B-en-joined
79+
^^^^^^^^^^^^^^^^^^^^^^
80+
The `laion2B-en-joined dataset <https://huggingface.co/datasets/laion/laion2B-en-joined>`_ is a subset of the `LAION-5B dataset <https://laion.ai/blog/laion-5b/>`_ containing english captions, URls pointing to images,
81+
and other metadata. It contains around 2.32 billion entries.
82+
Currently (February 2023) around 86% of the URLs still point to valid images. Here is a `DataPipe implementation of laion2B-en-joined
83+
<https://github.com/pytorch/data/blob/main/examples/vision/laion5b.py>`_ that filters out unsafe images and images with watermarks and loads the images from the URLs.
84+
7885
Additional Datasets in TorchVision
7986
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
8087
In a separate PyTorch domain library `TorchVision <https://github.com/pytorch/vision>`_, you will find some of the most

docs/source/reading_service.rst

+5-5
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ Features
1313
Dynamic Sharding
1414
^^^^^^^^^^^^^^^^
1515

16-
Dynamic sharding is achieved by ``MultiProcessingReadingService`` and ``DistributedReadingService`` to shard the pipeline based on the information of corresponding multiprocessing and distributed workers. And, TorchData offers two types of ``DataPipe`` letting users to define the sharding place within the pipeline.
16+
Dynamic sharding is achieved by ``MultiProcessingReadingService`` and ``DistributedReadingService`` to shard the pipeline based on the information of corresponding multiprocessing and distributed workers. And, TorchData offers two types of ``DataPipe`` letting users define the sharding place within the pipeline.
1717

1818
- ``sharding_filter`` (:class:`ShardingFilter`): When the pipeline is replicable, each distributed/multiprocessing worker loads data from its own replica of the ``DataPipe`` graph, while skipping samples that do not belong to the corresponding worker at the point where ``sharding_filter`` is placed.
1919

20-
- ``sharding_round_robin_dispatch`` (:class:`ShardingRoundRobinDispatcher`): When there is any ``sharding_round_robin_dispatch`` ``DataPipe`` in the pipeline, that branch (i.e. all DataPipes prior to ``sharding_round_robin_dispatch``) will be treated as a non-replicable branch (in the context of multiprocessing). A single dispatching process will be created to load data from the non-replicable branch and distributed data to the subsequent worker processes.
20+
- ``sharding_round_robin_dispatch`` (:class:`ShardingRoundRobinDispatcher`): When there is any ``sharding_round_robin_dispatch`` ``DataPipe`` in the pipeline, that branch (i.e. all DataPipes prior to ``sharding_round_robin_dispatch``) will be treated as a non-replicable branch (in the context of multiprocessing). A single dispatching process will be created to load data from the non-replicable branch and distribute data to the subsequent worker processes.
2121

2222
The following is an example of having two types of sharding strategies in the pipeline.
2323

@@ -116,21 +116,21 @@ When multiprocessing takes place, the graph becomes:
116116
end [shape=box];
117117
}
118118

119-
``Client`` in the graph is a ``DataPipe`` that send request and receive response from multiprocessing queues.
119+
``Client`` in the graph is a ``DataPipe`` that sends a request and receives a response from multiprocessing queues.
120120

121121
.. module:: torchdata.dataloader2
122122

123123
Determinism
124124
^^^^^^^^^^^^
125125

126-
In ``DataLoader2``, a ``SeedGenerator`` becomes a single source of randomness and each ``ReadingService`` would access to it via ``initialize_iteration()`` and generate corresponding random seeds for random ``DataPipe`` operations.
126+
In ``DataLoader2``, a ``SeedGenerator`` becomes a single source of randomness and each ``ReadingService`` would access it via ``initialize_iteration()`` and generate corresponding random seeds for random ``DataPipe`` operations.
127127

128128
In order to make sure that the Dataset shards are mutually exclusive and collectively exhaustive on multiprocessing processes and distributed nodes, ``MultiProcessingReadingService`` and ``DistributedReadingService`` would help :class:`DataLoader2` to synchronize random states for any random ``DataPipe`` operation prior to ``sharding_filter`` or ``sharding_round_robin_dispatch``. For the remaining ``DataPipe`` operations after sharding, unique random states are generated based on the distributed rank and worker process id by each ``ReadingService``, in order to perform different random transformations.
129129

130130
Graph Mode
131131
^^^^^^^^^^^
132132

133-
This also allows easier transition of data-preprocessing pipeline from research to production. After the ``DataPipe`` graph is created and validated with the ``ReadingServices``, a different ``ReadingService`` that configures and connects to the production service/infra such as ``AIStore`` can be provided to :class:`DataLoader2` as a drop-in replacement. The ``ReadingService`` could potentially search the graph, and find ``DataPipe`` operations that can be delegated to the production service/infra, then modify the graph correspondingly to achieve higher-performant execution.
133+
This also allows easier transition of data-preprocessing pipeline from research to production. After the ``DataPipe`` graph is created and validated with the ``ReadingServices``, a different ``ReadingService`` that configures and connects to the production service/infrastructure such as ``AIStore`` can be provided to :class:`DataLoader2` as a drop-in replacement. The ``ReadingService`` could potentially search the graph, and find ``DataPipe`` operations that can be delegated to the production service/infrastructure, then modify the graph correspondingly to achieve higher-performant execution.
134134

135135
Extend ReadingService
136136
----------------------

docs/source/torchdata.datapipes.utils.rst

+3-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ DataPipe Graph Visualization
1515

1616
to_graph
1717

18-
Commond Utility Functions
18+
Common Utility Functions
1919
--------------------------------------
2020
.. currentmodule:: torchdata.datapipes.utils
2121

@@ -47,7 +47,8 @@ For documentation related to DataLoader, please refer to the
4747
``torch.utils.data`` `documentation <https://pytorch.org/docs/stable/data.html>`_. Or, more specifically, the
4848
`DataLoader API section <https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader>`_.
4949

50-
DataLoader v2 is currently in development. You should see an update here by mid-2022.
50+
DataLoader v2 is currently in development. For more information please refer to :doc:`dataloader2`.
51+
5152

5253
Sampler
5354
-------------------------------------

test/test_iterdatapipe.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -1230,7 +1230,7 @@ def test_slice_iterdatapipe(self):
12301230
slice_dp = input_dp.slice(0, 2, 2)
12311231
self.assertEqual([(0,), (3,), (6,)], list(slice_dp))
12321232

1233-
# Functional Test: filter with list of indices for tuple
1233+
# Functional Test: slice with list of indices for tuple
12341234
slice_dp = input_dp.slice([0, 1])
12351235
self.assertEqual([(0, 1), (3, 4), (6, 7)], list(slice_dp))
12361236

@@ -1245,14 +1245,18 @@ def test_slice_iterdatapipe(self):
12451245
slice_dp = input_dp.slice(0, 2)
12461246
self.assertEqual([[0, 1], [3, 4], [6, 7]], list(slice_dp))
12471247

1248-
# Functional Test: filter with list of indices for list
1248+
# Functional Test: slice with list of indices for list
12491249
slice_dp = input_dp.slice(0, 2)
12501250
self.assertEqual([[0, 1], [3, 4], [6, 7]], list(slice_dp))
12511251

12521252
# dict tests
12531253
input_dp = IterableWrapper([{"a": 1, "b": 2, "c": 3}, {"a": 3, "b": 4, "c": 5}, {"a": 5, "b": 6, "c": 7}])
12541254

1255-
# Functional Test: filter with list of indices for dict
1255+
# Functional Test: slice with key for dict
1256+
slice_dp = input_dp.slice("a")
1257+
self.assertEqual([{"a": 1}, {"a": 3}, {"a": 5}], list(slice_dp))
1258+
1259+
# Functional Test: slice with list of keys for dict
12561260
slice_dp = input_dp.slice(["a", "b"])
12571261
self.assertEqual([{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}], list(slice_dp))
12581262

torchdata/datapipes/iter/transform/callable.py

+23-8
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,15 @@ def _no_op_fn(*args):
2626
class BatchMapperIterDataPipe(IterDataPipe[T_co]):
2727
r"""
2828
Combines elements from the source DataPipe to batches and applies a function
29-
over each batch, then flattens the outpus to a single, unnested IterDataPipe
29+
over each batch, then flattens the outputs to a single, unnested IterDataPipe
3030
(functional name: ``map_batches``).
3131
3232
Args:
3333
datapipe: Source IterDataPipe
3434
fn: The function to be applied to each batch of data
3535
batch_size: The size of batch to be aggregated from ``datapipe``
3636
input_col: Index or indices of data which ``fn`` is applied, such as:
37+
3738
- ``None`` as default to apply ``fn`` to the data directly.
3839
- Integer(s) is used for list/tuple.
3940
- Key(s) is used for dict.
@@ -114,6 +115,7 @@ class FlatMapperIterDataPipe(IterDataPipe[T_co]):
114115
datapipe: Source IterDataPipe
115116
fn: the function to be applied to each element in the DataPipe, the output must be a Sequence
116117
input_col: Index or indices of data which ``fn`` is applied, such as:
118+
117119
- ``None`` as default to apply ``fn`` to the data directly.
118120
- Integer(s) is/are used for list/tuple.
119121
- Key(s) is/are used for dict.
@@ -133,9 +135,9 @@ class FlatMapperIterDataPipe(IterDataPipe[T_co]):
133135
[1, 2, 3, 4, 5, 6]
134136
"""
135137
datapipe: IterDataPipe
136-
fn: Callable
138+
fn: Optional[Callable]
137139

138-
def __init__(self, datapipe: IterDataPipe, fn: Callable = None, input_col=None) -> None:
140+
def __init__(self, datapipe: IterDataPipe, fn: Optional[Callable] = None, input_col=None) -> None:
139141
self.datapipe = datapipe
140142

141143
if fn is None:
@@ -147,12 +149,12 @@ def __init__(self, datapipe: IterDataPipe, fn: Callable = None, input_col=None)
147149

148150
def _apply_fn(self, data):
149151
if self.input_col is None:
150-
return self.fn(data)
152+
return self.fn(data) # type: ignore[misc]
151153
elif isinstance(self.input_col, (list, tuple)):
152154
args = tuple(data[col] for col in self.input_col)
153-
return self.fn(*args)
155+
return self.fn(*args) # type: ignore[misc]
154156
else:
155-
return self.fn(data[self.input_col])
157+
return self.fn(data[self.input_col]) # type: ignore[misc]
156158

157159
def __iter__(self) -> Iterator[T_co]:
158160
for d in self.datapipe:
@@ -171,6 +173,9 @@ class DropperIterDataPipe(IterDataPipe[T_co]):
171173
datapipe: IterDataPipe with columns to be dropped
172174
indices: a single column index to be dropped or a list of indices
173175
176+
- Integer(s) is/are used for list/tuple.
177+
- Key(s) is/are used for dict.
178+
174179
Example:
175180
>>> from torchdata.datapipes.iter import IterableWrapper, ZipperMapDataPipe
176181
>>> dp1 = IterableWrapper(range(5))
@@ -237,8 +242,13 @@ class SliceIterDataPipe(IterDataPipe[T_co]):
237242
Args:
238243
datapipe: IterDataPipe with iterable elements
239244
index: a single start index for the slice or a list of indices to be returned instead of a start/stop slice
240-
stop: the slice stop. ignored if index is a list
241-
step: step to be taken from start to stop. ignored if index is a list
245+
246+
- Integer(s) is/are used for list/tuple.
247+
- Key(s) is/are used for dict.
248+
249+
250+
stop: the slice stop. ignored if index is a list or if element is a dict
251+
step: step to be taken from start to stop. ignored if index is a list or if element is a dict
242252
243253
Example:
244254
>>> from torchdata.datapipes.iter import IterableWrapper
@@ -285,6 +295,8 @@ def __iter__(self) -> Iterator[T_co]:
285295
elif isinstance(old_item, dict):
286296
if isinstance(self.index, list):
287297
new_item = {k: v for (k, v) in old_item.items() if k in self.index} # type: ignore[assignment]
298+
elif self.index in old_item.keys():
299+
new_item = {self.index: old_item.get(self.index)} # type: ignore[assignment]
288300
else:
289301
new_item = old_item # type: ignore[assignment]
290302
warnings.warn(
@@ -329,6 +341,9 @@ class FlattenIterDataPipe(IterDataPipe[T_co]):
329341
datapipe: IterDataPipe with iterable elements
330342
indices: a single index/key for the item to flatten from an iterator item or a list of indices/keys to be flattened
331343
344+
- Integer(s) is/are used for list/tuple.
345+
- Key(s) is/are used for dict.
346+
332347
Example:
333348
>>> from torchdata.datapipes.iter import IterableWrapper
334349
>>> dp = IterableWrapper([(0, 10, (100, 1000)), (1, 11, (111, 1001)), (2, 12, (122, 1002)), (3, 13, (133, 1003)), (4, 14, (144, 1004))])

torchdata/datapipes/map/util/unzipper.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class UnZipperMapDataPipe(MapDataPipe):
3131
an integer from 0 to sequence_length - 1)
3232
3333
Example:
34-
>>> from torchdata.datapipes.iter import SequenceWrapper
34+
>>> from torchdata.datapipes.map import SequenceWrapper
3535
>>> source_dp = SequenceWrapper([(i, i + 10, i + 20) for i in range(3)])
3636
>>> dp1, dp2, dp3 = source_dp.unzip(sequence_length=3)
3737
>>> list(dp1)

0 commit comments

Comments
 (0)