Skip to content

Commit 73d3aa9

Browse files
ejguanfacebook-github-bot
authored andcommitted
Officially graduate ProtypeMPRS to MPRS (#1009)
Summary: Pull Request resolved: #1009 Fixes: #932 - Convert all references from `ProtypeMPRS` to `MPRS` - Remove usage of `MPRS` Reviewed By: wenleix, NivekT Differential Revision: D43245136 fbshipit-source-id: 1fd67f0e9d55a984209b8e58edecdb6070d919cd
1 parent b43c07d commit 73d3aa9

File tree

10 files changed

+39
-152
lines changed

10 files changed

+39
-152
lines changed

benchmarks/cloud/aws_s3.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import pandas as pd
1616
import psutil
17-
from torchdata.dataloader2 import DataLoader2, PrototypeMultiProcessingReadingService
17+
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService
1818
from torchdata.datapipes.iter import IterableWrapper
1919

2020

@@ -64,7 +64,7 @@ def check_and_output_speed(prefix: str, create_dp_fn: Callable, n_prefetch: int,
6464
dp = create_dp_fn()
6565

6666
rs_type = "DataLoader2 w/ tar archives"
67-
new_rs = PrototypeMultiProcessingReadingService(
67+
new_rs = MultiProcessingReadingService(
6868
num_workers=n_workers, worker_prefetch_cnt=n_prefetch, main_prefetch_cnt=n_prefetch
6969
)
7070
dl: DataLoader2 = DataLoader2(dp, reading_service=new_rs)

docs/source/dataloader2.rst

-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ ReadingService
3232

3333
DistributedReadingService
3434
MultiProcessingReadingService
35-
PrototypeMultiProcessingReadingService
3635
SequentialReadingService
3736

3837
Each ``ReadingServices`` would take the ``DataPipe`` graph and rewrite it to achieve a few features like dynamic sharding, sharing random seeds and snapshoting for multi-/distributed processes. For more detail about those features, please refer to `the documentation <reading_service.html>`_.

docs/source/dlv2_tutorial.rst

+3-3
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ Here is an example of a ``DataPipe`` graph:
2424
Multiprocessing
2525
----------------
2626

27-
``PrototypeMultiProcessingReadingService`` handles multiprocessing sharding at the point of ``sharding_filter`` and synchronizes the seeds across worker processes.
27+
``MultiProcessingReadingService`` handles multiprocessing sharding at the point of ``sharding_filter`` and synchronizes the seeds across worker processes.
2828

2929
.. code:: python
3030
31-
rs = PrototypeMultiProcessingReadingService(num_workers=4)
31+
rs = MultiProcessingReadingService(num_workers=4)
3232
dl = DataLoader2(datapipe, reading_service=rs)
3333
for epoch in range(10):
3434
dl.seed(epoch)
@@ -58,7 +58,7 @@ Multiprocessing + Distributed
5858

5959
.. code:: python
6060
61-
mp_rs = PrototypeMultiProcessingReadingService(num_workers=4)
61+
mp_rs = MultiProcessingReadingService(num_workers=4)
6262
dist_rs = DistributedReadingService()
6363
rs = SequentialReadingService(dist_rs, mp_rs)
6464

docs/source/reading_service.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Features
1111
Dynamic Sharding
1212
^^^^^^^^^^^^^^^^
1313

14-
Dynamic sharding is achieved by ``PrototypeMultiProcessingReadingService`` and ``DistributedReadingService`` to shard the pipeline based on the information of corresponding multiprocessing and distributed workers. And, TorchData offers two types of ``DataPipe`` letting users to define the sharding place within the pipeline.
14+
Dynamic sharding is achieved by ``MultiProcessingReadingService`` and ``DistributedReadingService`` to shard the pipeline based on the information of corresponding multiprocessing and distributed workers. And, TorchData offers two types of ``DataPipe`` letting users to define the sharding place within the pipeline.
1515

1616
- ``sharding_filter``: When the pipeline is replicable, each distributed/multiprocessing worker loads data from one replica of the ``DataPipe`` graph, and skip the data not blonged to the corresponding worker at the place of ``sharding_filter``.
1717

@@ -121,7 +121,7 @@ Determinism
121121

122122
In ``DataLoader2``, a ``SeedGenerator`` becomes a single source of randomness and each ``ReadingService`` would access to it via ``initialize_iteration()`` and generate corresponding random seeds for random ``DataPipe`` operations.
123123

124-
In order to make sure that the Dataset shards are mutually exclusive and collectively exhaunsitve on multiprocessing processes and distributed nodes, ``PrototypeMultiProcessingReadingService`` and ``DistributedReadingService`` would help ``DataLoader2`` to synchronize random states for any random ``DataPipe`` operation prior to ``sharding_filter`` or ``sharding_round_robin_dispatch``. For the remaining ``DataPipe`` operations after sharding, unique random states are generated based on the distributed rank and worker process id by each ``ReadingService``, in order to perform different random transformations.
124+
In order to make sure that the Dataset shards are mutually exclusive and collectively exhaunsitve on multiprocessing processes and distributed nodes, ``MultiProcessingReadingService`` and ``DistributedReadingService`` would help ``DataLoader2`` to synchronize random states for any random ``DataPipe`` operation prior to ``sharding_filter`` or ``sharding_round_robin_dispatch``. For the remaining ``DataPipe`` operations after sharding, unique random states are generated based on the distributed rank and worker process id by each ``ReadingService``, in order to perform different random transformations.
125125

126126
Graph Mode
127127
^^^^^^^^^^^

test/dataloader2/test_dataloader2.py

+14-56
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
DataLoader2,
2828
DistributedReadingService,
2929
MultiProcessingReadingService,
30-
PrototypeMultiProcessingReadingService,
3130
ReadingServiceInterface,
3231
SequentialReadingService,
3332
)
@@ -120,16 +119,6 @@ def test_dataloader2_reading_service(self) -> None:
120119
self.assertEqual(batch, expected_batch)
121120
expected_batch += 1
122121

123-
def test_dataloader2_multi_process_reading_service(self) -> None:
124-
test_data_pipe = IterableWrapper(range(3))
125-
reading_service = MultiProcessingReadingService()
126-
data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service)
127-
128-
expected_batch = 0
129-
for batch in iter(data_loader):
130-
self.assertEqual(batch, expected_batch)
131-
expected_batch += 1
132-
133122
def test_dataloader2_load_state_dict(self) -> None:
134123
test_data_pipe = IterableWrapper(range(3))
135124
reading_service = TestReadingService()
@@ -165,7 +154,7 @@ def test_dataloader2_iterates_correctly(self) -> None:
165154
None,
166155
TestReadingService(),
167156
MultiProcessingReadingService(num_workers=4),
168-
PrototypeMultiProcessingReadingService(num_workers=4, worker_prefetch_cnt=0),
157+
MultiProcessingReadingService(num_workers=4, worker_prefetch_cnt=0),
169158
]
170159
for reading_service in reading_services:
171160
data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service)
@@ -232,18 +221,10 @@ def _get_no_reading_service():
232221
def _get_mp_reading_service():
233222
return MultiProcessingReadingService(num_workers=2)
234223

235-
@staticmethod
236-
def _get_proto_reading_service():
237-
return PrototypeMultiProcessingReadingService(num_workers=2)
238-
239224
@staticmethod
240225
def _get_mp_reading_service_zero_workers():
241226
return MultiProcessingReadingService(num_workers=0)
242227

243-
@staticmethod
244-
def _get_proto_reading_service_zero_workers():
245-
return PrototypeMultiProcessingReadingService(num_workers=0)
246-
247228
def _collect_data(self, datapipe, reading_service_gen):
248229
dl: DataLoader2 = DataLoader2(datapipe, reading_service=reading_service_gen())
249230
result = []
@@ -265,9 +246,7 @@ def test_dataloader2_batch_collate(self) -> None:
265246

266247
reading_service_generators = (
267248
self._get_mp_reading_service,
268-
self._get_proto_reading_service,
269249
self._get_mp_reading_service_zero_workers,
270-
self._get_proto_reading_service_zero_workers,
271250
)
272251
for reading_service_gen in reading_service_generators:
273252
actual = self._collect_data(dp, reading_service_gen=reading_service_gen)
@@ -279,27 +258,6 @@ def test_dataloader2_shuffle(self) -> None:
279258
pass
280259

281260

282-
class DataLoader2IntegrationTest(TestCase):
283-
@staticmethod
284-
def _get_mp_reading_service():
285-
return MultiProcessingReadingService(num_workers=2)
286-
287-
def test_lazy_load(self):
288-
source_dp = IterableWrapper([(i, i) for i in range(10)])
289-
map_dp = source_dp.to_map_datapipe()
290-
291-
reading_service_generators = (self._get_mp_reading_service,)
292-
for reading_service_gen in reading_service_generators:
293-
dl: DataLoader2 = DataLoader2(datapipe=map_dp, reading_service=reading_service_gen())
294-
# Lazy loading
295-
dp = dl.datapipe
296-
self.assertTrue(dp._map is None)
297-
it = iter(dl)
298-
self.assertEqual(list(it), list(range(10)))
299-
# Lazy loading in multiprocessing
300-
self.assertTrue(map_dp._map is None)
301-
302-
303261
@unittest.skipIf(
304262
TEST_WITH_TSAN,
305263
"Fails with TSAN with the following error: starting new threads after multi-threaded "
@@ -382,7 +340,7 @@ def is_replicable(self):
382340
return False
383341

384342

385-
class PrototypeMultiProcessingReadingServiceTest(TestCase):
343+
class MultiProcessingReadingServiceTest(TestCase):
386344
@staticmethod
387345
def _worker_init_fn(datapipe, worker_info):
388346
datapipe = datapipe.sharding_filter()
@@ -403,7 +361,7 @@ def _worker_reset_fn(datapipe, worker_info, worker_seed_generator: SeedGenerator
403361
def test_worker_fns(self, ctx):
404362
dp: IterDataPipe = IterableWrapper(range(100)).batch(2).shuffle()
405363

406-
rs = PrototypeMultiProcessingReadingService(
364+
rs = MultiProcessingReadingService(
407365
num_workers=2,
408366
multiprocessing_context=ctx,
409367
worker_init_fn=self._worker_init_fn,
@@ -448,7 +406,7 @@ def _assert_deterministic_dl_res(dl, exp):
448406
sf_dp = single_br_dp.sharding_filter()
449407
replace_dp(graph, single_br_dp, sf_dp)
450408
dl = DataLoader2(
451-
end_dp, reading_service=PrototypeMultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx)
409+
end_dp, reading_service=MultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx)
452410
)
453411
# Determinism and dynamic sharding
454412
# _assert_deterministic_dl_res(dl, [i * 4 for i in range(10)])
@@ -462,7 +420,7 @@ def _assert_deterministic_dl_res(dl, exp):
462420
sf_dp = map_dp.sharding_filter()
463421
replace_dp(graph, map_dp, sf_dp)
464422
dl = DataLoader2(
465-
end_dp, reading_service=PrototypeMultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx)
423+
end_dp, reading_service=MultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx)
466424
)
467425
# Determinism for non-replicable pipeline
468426
_assert_deterministic_dl_res(dl, [i * 4 for i in range(10)])
@@ -476,7 +434,7 @@ def _assert_deterministic_dl_res(dl, exp):
476434
round_robin_dispatcher = ShardingRoundRobinDispatcher(map_dp, SHARDING_PRIORITIES.MULTIPROCESSING)
477435
replace_dp(graph, map_dp, round_robin_dispatcher)
478436
dl = DataLoader2(
479-
end_dp, reading_service=PrototypeMultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx)
437+
end_dp, reading_service=MultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx)
480438
)
481439
# Determinism for non-replicable pipeline
482440
_assert_deterministic_dl_res(dl, [i * 4 for i in range(10)])
@@ -518,7 +476,7 @@ def _assert_deterministic_dl_res(dl, exp1, exp2):
518476
replace_dp(graph, branch1_dp, sf1_dp)
519477
replace_dp(graph, branch2_dp, sf2_dp)
520478
dl = DataLoader2(
521-
end_dp, reading_service=PrototypeMultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx)
479+
end_dp, reading_service=MultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx)
522480
)
523481
# Determinism and dynamic sharding
524482
_assert_deterministic_dl_res(dl, [i * 2 for i in range(10)], list(range(10)))
@@ -533,7 +491,7 @@ def _assert_deterministic_dl_res(dl, exp1, exp2):
533491
sf_dp = branch2_dp.sharding_filter()
534492
replace_dp(graph, branch2_dp, sf_dp)
535493
dl = DataLoader2(
536-
end_dp, reading_service=PrototypeMultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx)
494+
end_dp, reading_service=MultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx)
537495
)
538496
# Determinism for non-replicable pipeline
539497
_assert_deterministic_dl_res(dl, [i * 2 for i in range(10)], list(range(10)))
@@ -547,7 +505,7 @@ def _assert_deterministic_dl_res(dl, exp1, exp2):
547505
non_replicable_dp2 = ShardingRoundRobinDispatcher(branch2_dp, SHARDING_PRIORITIES.MULTIPROCESSING)
548506
replace_dp(graph, branch2_dp, non_replicable_dp2)
549507
dl = DataLoader2(
550-
end_dp, reading_service=PrototypeMultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx)
508+
end_dp, reading_service=MultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx)
551509
)
552510
# Determinism for non-replicable pipeline
553511
_assert_deterministic_dl_res(dl, [i * 2 for i in range(10)], list(range(10)))
@@ -558,7 +516,7 @@ def test_multi_worker_determinism(self, ctx):
558516
dp = dp.shuffle().sharding_filter()
559517
dp = dp.batch(2)
560518

561-
rs = PrototypeMultiProcessingReadingService(
519+
rs = MultiProcessingReadingService(
562520
num_workers=2,
563521
multiprocessing_context=ctx,
564522
)
@@ -589,7 +547,7 @@ def test_dispatching_worker_determinism(self, ctx):
589547
dp = dp.shuffle().sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING)
590548
dp = dp.batch(2)
591549

592-
rs = PrototypeMultiProcessingReadingService(
550+
rs = MultiProcessingReadingService(
593551
num_workers=2,
594552
multiprocessing_context=ctx,
595553
)
@@ -625,7 +583,7 @@ def test_non_replicable_datapipe(self, ctx) -> None:
625583
dp = dp.batch(2)
626584
non_rep_dp = NonReplicableDataPipe(dp)
627585

628-
rs = PrototypeMultiProcessingReadingService(
586+
rs = MultiProcessingReadingService(
629587
num_workers=2,
630588
multiprocessing_context=ctx,
631589
)
@@ -775,7 +733,7 @@ def _make_dispatching_dp(data_length):
775733

776734
@staticmethod
777735
def _make_rs(num_workers, ctx):
778-
mp_rs = PrototypeMultiProcessingReadingService(
736+
mp_rs = MultiProcessingReadingService(
779737
num_workers=num_workers,
780738
multiprocessing_context=ctx,
781739
)
@@ -850,7 +808,7 @@ def test_sequential_reading_service_dispatching_dp(self, ctx):
850808
self.assertNotEqual(result[1][rank][1], result[3][rank][1])
851809

852810

853-
instantiate_parametrized_tests(PrototypeMultiProcessingReadingServiceTest)
811+
instantiate_parametrized_tests(MultiProcessingReadingServiceTest)
854812
instantiate_parametrized_tests(SequentialReadingServiceTest)
855813

856814

test/dataloader2/test_random.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import torch
1616

1717
from torch.testing._internal.common_utils import instantiate_parametrized_tests, IS_WINDOWS, parametrize
18-
from torchdata.dataloader2 import DataLoader2, PrototypeMultiProcessingReadingService
18+
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService
1919
from torchdata.dataloader2.graph.settings import set_graph_random_seed
2020
from torchdata.dataloader2.random import SeedGenerator
2121
from torchdata.datapipes.iter import IterableWrapper
@@ -40,7 +40,7 @@ def test_proto_rs_determinism(self, num_workers):
4040

4141
data_source = IterableWrapper(exp)
4242
dp = data_source.shuffle().sharding_filter().map(_random_fn)
43-
rs = PrototypeMultiProcessingReadingService(num_workers=num_workers)
43+
rs = MultiProcessingReadingService(num_workers=num_workers)
4444
dl = DataLoader2(dp, reading_service=rs)
4545

4646
# No seed

test/test_distributed.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from torch.testing._internal.common_utils import instantiate_parametrized_tests, parametrize
2424
from torch.utils.data import DataLoader
2525

26-
from torchdata.dataloader2 import DataLoader2, DistributedReadingService, PrototypeMultiProcessingReadingService
26+
from torchdata.dataloader2 import DataLoader2, DistributedReadingService
2727
from torchdata.datapipes.iter import IterableWrapper
2828
from torchdata.datapipes.iter.util.distributed import PrefetchTimeoutError
2929

test/test_graph.py

+1-10
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from torch.utils.data import IterDataPipe
1717
from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES
1818

19-
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService, ReadingServiceInterface
19+
from torchdata.dataloader2 import DataLoader2, ReadingServiceInterface
2020
from torchdata.dataloader2.graph import find_dps, list_dps, remove_dp, replace_dp, traverse_dps
2121
from torchdata.dataloader2.graph.utils import _find_replicable_branches
2222
from torchdata.dataloader2.random import SeedGenerator
@@ -254,15 +254,6 @@ def test_reading_service(self) -> None:
254254

255255
self.assertEqual(res, list(dl))
256256

257-
@unittest.skipIf(IS_WINDOWS, "Fork is required for lambda")
258-
def test_multiprocessing_reading_service(self) -> None:
259-
_, (*_, dp) = self._get_datapipes() # pyre-ignore
260-
rs = MultiProcessingReadingService(2, persistent_workers=True, multiprocessing_context="fork")
261-
dl = DataLoader2(dp, reading_service=rs)
262-
d1 = list(dl)
263-
d2 = list(dl)
264-
self.assertEqual(d1, d2)
265-
266257

267258
def insert_round_robin_sharding(graph, datapipe):
268259
dispatch_dp = ShardingRoundRobinDispatcher(datapipe, SHARDING_PRIORITIES.MULTIPROCESSING)

0 commit comments

Comments
 (0)