Skip to content

Commit b57545f

Browse files
NivekTejguan
authored andcommitted
Removing delegation for 'pause', 'limit', and 'resume' (#1011)
Summary: Pull Request resolved: #1011 Test Plan: Imported from OSS Reviewed By: mingyuzh, ejguan Differential Revision: D43251818 Pulled By: NivekT fbshipit-source-id: 5c34e1a71438308366b473c5d4d075a8158088f1
1 parent 6ace7a4 commit b57545f

File tree

2 files changed

+18
-12
lines changed

2 files changed

+18
-12
lines changed

torchdata/dataloader2/dataloader2.py

+11-12
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,6 @@ def resume(self) -> None:
100100
Restarts the threads within ``DataLoader2`` and allows it to yield additional batches.
101101
"""
102102
self.dataloader._resume()
103-
if self.dataloader._datapipe_iter and hasattr(self.dataloader._datapipe_iter, "resume"):
104-
self.dataloader._datapipe_iter.resume() # type: ignore[attr-defined]
105103

106104
def limit(self, num_batches: Optional[int]) -> None:
107105
"""
@@ -120,8 +118,7 @@ def limit(self, num_batches: Optional[int]) -> None:
120118
"""
121119
self.limit_counter = 0
122120
self.limit_threshold = num_batches
123-
if self.dataloader._datapipe_iter and hasattr(self.dataloader._datapipe_iter, "limit"):
124-
self.dataloader._datapipe_iter.limit(num_batches) # type: ignore[attr-defined]
121+
self.dataloader._limit(num_batches)
125122

126123
def __getattr__(self, name):
127124
"""
@@ -339,11 +336,8 @@ def _pause(self):
339336
if hasattr(self.reading_service, "_pause"):
340337
self._is_paused = True
341338
self.reading_service._pause()
342-
# TODO: the condition should be `else` once `self._datapipe_iter.pause/limit()` is no longer used
343-
elif self._datapipe_iter is None or not (
344-
hasattr(self._datapipe_iter, "limit") or hasattr(self._datapipe_iter, "pause")
345-
):
346-
warnings.warn("ReadingService doesn't support pause.")
339+
else:
340+
warnings.warn("ReadingService doesn't support `pause`.")
347341

348342
def _resume(self):
349343
if hasattr(self.reading_service, "_resume"):
@@ -352,6 +346,11 @@ def _resume(self):
352346
else:
353347
self.reading_service._resume()
354348
self._is_paused = False
355-
# TODO: the condition should be `else` once `self._datapipe_iter.resume()` is no longer used
356-
elif self._datapipe_iter is None or not hasattr(self._datapipe_iter, "resume"):
357-
warnings.warn("ReadingService doesn't support resume.")
349+
else:
350+
warnings.warn("ReadingService doesn't support `resume`.")
351+
352+
def _limit(self, num_batches: Optional[int]) -> None:
353+
if hasattr(self.reading_service, "_limit"):
354+
self.reading_service._limit(num_batches)
355+
else:
356+
warnings.warn("ReadingService doesn't support `limit`.")

torchdata/dataloader2/reading_service.py

+7
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,13 @@ def _resume(self):
406406
if self.main_prefetch_cnt > 0 and self.num_workers > 0:
407407
self._main_prefetch_datapipe.resume() # type: ignore[union-attr]
408408

409+
def _limit(self, num_batches: Optional[int]) -> None:
410+
"""
411+
For this ReadingService, `DataLoader2Iterator` and `DataLoader2` should sufficiently handle
412+
the limit operation, such that nothing needs to be done here.
413+
"""
414+
pass
415+
409416

410417
class DistributedReadingService(ReadingServiceInterface):
411418
r"""

0 commit comments

Comments
 (0)