1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # All rights reserved.
4
- #
5
- # This source code is licensed under the BSD-style license found in the
6
- # LICENSE file in the root directory of this source tree.
7
1
from typing import Any , Callable , Dict , Iterator , Literal , Optional , TypeVar
8
-
9
2
from torchdata .nodes .base_node import BaseNode
10
- from torchdata .nodes .map import Mapper , ParallelMapper
11
-
3
+ from torchdata .nodes .map import ParallelMapper
12
4
T = TypeVar ("T" , covariant = True )
13
5
14
-
15
6
class Filter (BaseNode [T ]):
7
+ """
8
+ A node that filters data samples based on a given predicate.
9
+ Args:
10
+ source (BaseNode[T]): The source node providing data samples.
11
+ predicate (Callable[[T], bool]): A function that takes a data sample and returns a boolean indicating whether to include it.
12
+ num_workers (int): The number of worker processes to use for parallel filtering. Defaults to 0.
13
+ in_order (bool): Whether to return items in the order from which they arrive from. Default is True.
14
+ method (Literal["thread", "process"]): The method to use for parallel processing. Default is "thread".
15
+ multiprocessing_context (Optional[str]): The multiprocessing context to use for parallel processing. Default is None.
16
+ max_concurrent (Optional[int]): The maximum number of items to process at once. Default is None.
17
+ snapshot_frequency (int): The frequency at which to snapshot the state of the source node. Default is 1.
18
+ """
16
19
def __init__ (
17
20
self ,
18
21
source : BaseNode [T ],
@@ -33,12 +36,6 @@ def __init__(
33
36
self .multiprocessing_context = multiprocessing_context
34
37
self .max_concurrent = max_concurrent
35
38
self .snapshot_frequency = snapshot_frequency
36
- self ._it : Optional [Iterator [T ]] = None
37
-
38
- def reset (self , initial_state : Optional [Dict [str , Any ]] = None ):
39
- super ().reset (initial_state )
40
- if self ._it is not None :
41
- del self ._it
42
39
if self .num_workers > 0 :
43
40
self ._it = _ParallelFilterIter (
44
41
source = self .source ,
@@ -49,54 +46,60 @@ def reset(self, initial_state: Optional[Dict[str, Any]] = None):
49
46
multiprocessing_context = self .multiprocessing_context ,
50
47
max_concurrent = self .max_concurrent ,
51
48
snapshot_frequency = self .snapshot_frequency ,
52
- initial_state = initial_state ,
53
49
)
54
-
55
50
else :
56
- self ._it = _InlineFilterIter (
57
- source = self . source ,
58
- predicate = self . predicate ,
59
- initial_state = initial_state ,
60
- )
61
-
62
- def next (self ):
63
- return next ( self . _it ) # type: ignore[arg-type]
64
-
51
+ self ._it = _InlineFilterIter (source = self . source , predicate = self . predicate )
52
+ def reset ( self , initial_state : Optional [ Dict [ str , Any ]] = None ) -> None :
53
+ """Resets the filter node to its initial state."""
54
+ super (). reset ( initial_state )
55
+ if self . _it is not None :
56
+ self . _it . reset ( initial_state )
57
+ def next (self ) -> T :
58
+ """Returns the next filtered item."""
59
+ return next ( self . _it )
65
60
def get_state (self ) -> Dict [str , Any ]:
66
- return self . _it . get_state () # type: ignore[union-attr]
67
-
61
+ """Returns the current state of the filter node."""
62
+ return self . _it . get_state ()
68
63
69
64
class _InlineFilterIter (Iterator [T ]):
70
- def __init__ (
71
- self ,
72
- source : BaseNode [T ],
73
- predicate : Callable [[T ], bool ],
74
- initial_state : Optional [Dict [str , Any ]] = None ,
75
- ):
65
+ """
66
+ An iterator that filters data samples inline.
67
+ Args:
68
+ source (BaseNode[T]): The source node providing data samples.
69
+ predicate (Callable[[T], bool]): A function that takes a data sample and returns a boolean indicating whether to include it.
70
+ """
71
+ SOURCE_KEY = "source"
72
+
73
+ def __init__ (self , source : BaseNode [T ], predicate : Callable [[T ], bool ]) -> None :
76
74
self .source = source
77
75
self .predicate = predicate
78
- if initial_state is not None :
79
- self .source .reset (initial_state ["source" ])
76
+
77
+ def reset (self , initial_state : Optional [Dict [str , Any ]] = None ) -> None :
78
+ """Resets the inline filter iterator to its initial state."""
79
+ if initial_state :
80
+ self .source .reset (initial_state [self .SOURCE_KEY ])
80
81
else :
81
82
self .source .reset ()
82
83
83
84
def __iter__ (self ) -> Iterator [T ]:
85
+ """Returns the iterator object itself."""
84
86
return self
85
-
86
87
def __next__ (self ) -> T :
88
+ """Returns the next filtered item."""
87
89
while True :
88
- item = next (self .source )
89
- if self .predicate (item ):
90
- return item
91
-
90
+ try :
91
+ item = next (self .source )
92
+ if self .predicate (item ):
93
+ return item
94
+ except StopIteration :
95
+ raise
92
96
def get_state (self ) -> Dict [str , Any ]:
93
- return { "source" : self . source . state_dict ()}
94
-
97
+ """Returns the current state of the inline filter iterator."""
98
+ return { self . SOURCE_KEY : self . source . state_dict ()}
95
99
96
100
class _ParallelFilterIter (Iterator [T ]):
97
101
"""
98
102
An iterator that filters data samples in parallel.
99
-
100
103
Args:
101
104
source (BaseNode[T]): The source node providing data samples.
102
105
predicate (Callable[[T], bool]): A function that takes a data sample and returns a boolean indicating whether to include it.
@@ -106,9 +109,8 @@ class _ParallelFilterIter(Iterator[T]):
106
109
multiprocessing_context (Optional[str]): The multiprocessing context to use.
107
110
max_concurrent (Optional[int]): The maximum number of concurrent tasks.
108
111
snapshot_frequency (int): The frequency at which to take snapshots.
109
- initial_state (Optional[Dict[str, Any]]): The initial state to start with.
110
112
"""
111
-
113
+ MAPPER_KEY = "mapper"
112
114
def __init__ (
113
115
self ,
114
116
source : BaseNode [T ],
@@ -119,7 +121,6 @@ def __init__(
119
121
multiprocessing_context : Optional [str ],
120
122
max_concurrent : Optional [int ],
121
123
snapshot_frequency : int ,
122
- initial_state : Optional [Dict [str , Any ]] = None ,
123
124
):
124
125
self .source = source
125
126
self .predicate = predicate
@@ -140,42 +141,26 @@ def __init__(
140
141
max_concurrent = self .max_concurrent ,
141
142
snapshot_frequency = self .snapshot_frequency ,
142
143
)
143
- if initial_state is not None :
144
- self .mapper .reset (initial_state )
145
-
144
+ def reset (self , initial_state : Optional [Dict [str , Any ]] = None ) -> None :
145
+ """Resets the parallel filter iterator to its initial state."""
146
+ if initial_state :
147
+ self .mapper .reset (initial_state [self .MAPPER_KEY ])
148
+ else :
149
+ self .mapper .reset ()
146
150
def __iter__ (self ) -> Iterator [T ]:
147
- """
148
- Returns the iterator object itself.
149
-
150
- Returns:
151
- Iterator[T]: The iterator object itself.
152
- """
151
+ """Returns the iterator object itself."""
153
152
return self
154
-
155
153
def __next__ (self ) -> T :
156
- """
157
- Returns the next filtered data sample.
158
-
159
- Returns:
160
- T: The next filtered data sample.
161
- """
154
+ """Returns the next filtered item."""
162
155
while True :
163
- try :
164
- item , passed_predicate = next (self .mapper )
165
- if passed_predicate :
166
- return item
167
- except StopIteration :
168
- raise
169
-
170
- def get_state (self ) -> Dict [str , Any ]:
171
- """
172
- Returns the current state of the parallel filter iterator.
173
156
174
- Returns:
175
- Dict[str, Any]: The current state of the parallel filter iterator.
176
- """
177
- return self .mapper .get_state ()
157
+ item , passed_predicate = next (self .mapper )
158
+ if passed_predicate :
159
+ return item
178
160
161
+ def get_state (self ) -> Dict [str , Any ]:
162
+ """Returns the current state of the parallel filter iterator."""
163
+ return {self .MAPPER_KEY : self .mapper .get_state ()}
179
164
def __del__ (self ):
180
165
# Clean up resources when the iterator is deleted
181
166
del self .mapper
0 commit comments