4
4
#
5
5
# This source code is licensed under the BSD-style license found in the
6
6
# LICENSE file in the root directory of this source tree.
7
- from typing import Any , Callable , Dict , Iterator , Optional
7
+ from typing import Any , Callable , Dict , Iterator , Literal , Optional , TypeVar
8
8
9
- from torchdata .nodes .base_node import BaseNode , T
9
+ from torchdata .nodes .base_node import BaseNode
10
+ from torchdata .nodes .map import Mapper , ParallelMapper
11
+
12
+ T = TypeVar ("T" , covariant = True )
10
13
11
14
12
15
class Filter (BaseNode [T ]):
@@ -16,7 +19,7 @@ def __init__(
16
19
predicate : Callable [[T ], bool ],
17
20
num_workers : int = 0 ,
18
21
in_order : bool = True ,
19
- method : str = "thread" ,
22
+ method : Literal [ "thread" , "process" ] = "thread" ,
20
23
multiprocessing_context : Optional [str ] = None ,
21
24
max_concurrent : Optional [int ] = None ,
22
25
snapshot_frequency : int = 1 ,
@@ -37,29 +40,24 @@ def reset(self, initial_state: Optional[Dict[str, Any]] = None):
37
40
if self ._it is not None :
38
41
del self ._it
39
42
if self .num_workers > 0 :
40
- self ._parallel_reset (initial_state )
41
- else :
42
- self ._inline_reset (initial_state )
43
-
44
- def _inline_reset (self , initial_state : Optional [Dict [str , Any ]]):
45
- self ._it = _InlineFilterIter (
46
- source = self .source ,
47
- predicate = self .predicate ,
48
- initial_state = initial_state ,
49
- )
43
+ self ._it = _ParallelFilterIter (
44
+ source = self .source ,
45
+ predicate = self .predicate ,
46
+ num_workers = self .num_workers ,
47
+ in_order = self .in_order ,
48
+ method = self .method ,
49
+ multiprocessing_context = self .multiprocessing_context ,
50
+ max_concurrent = self .max_concurrent ,
51
+ snapshot_frequency = self .snapshot_frequency ,
52
+ initial_state = initial_state ,
53
+ )
50
54
51
- def _parallel_reset (self , initial_state : Optional [Dict [str , Any ]]):
52
- self ._it = _ParallelFilterIter (
53
- source = self .source ,
54
- predicate = self .predicate ,
55
- num_workers = self .num_workers ,
56
- in_order = self .in_order ,
57
- method = self .method ,
58
- multiprocessing_context = self .multiprocessing_context ,
59
- max_concurrent = self .max_concurrent ,
60
- snapshot_frequency = self .snapshot_frequency ,
61
- initial_state = initial_state ,
62
- )
55
+ else :
56
+ self ._it = _InlineFilterIter (
57
+ source = self .source ,
58
+ predicate = self .predicate ,
59
+ initial_state = initial_state ,
60
+ )
63
61
64
62
def next (self ):
65
63
return next (self ._it ) # type: ignore[arg-type]
@@ -96,17 +94,32 @@ def get_state(self) -> Dict[str, Any]:
96
94
97
95
98
96
class _ParallelFilterIter (Iterator [T ]):
97
+ """
98
+ An iterator that filters data samples in parallel.
99
+
100
+ Args:
101
+ source (BaseNode[T]): The source node providing data samples.
102
+ predicate (Callable[[T], bool]): A function that takes a data sample and returns a boolean indicating whether to include it.
103
+ num_workers (int): The number of worker processes to use for parallel filtering.
104
+ in_order (bool): Whether to preserve the order of data samples.
105
+ method (Literal["thread", "process"]): The method to use for parallelization.
106
+ multiprocessing_context (Optional[str]): The multiprocessing context to use.
107
+ max_concurrent (Optional[int]): The maximum number of concurrent tasks.
108
+ snapshot_frequency (int): The frequency at which to take snapshots.
109
+ initial_state (Optional[Dict[str, Any]]): The initial state to start with.
110
+ """
111
+
99
112
def __init__ (
100
113
self ,
101
114
source : BaseNode [T ],
102
115
predicate : Callable [[T ], bool ],
103
116
num_workers : int ,
104
117
in_order : bool ,
105
- method : str ,
118
+ method : Literal [ "thread" , "process" ] ,
106
119
multiprocessing_context : Optional [str ],
107
120
max_concurrent : Optional [int ],
108
121
snapshot_frequency : int ,
109
- initial_state : Optional [Dict [str , Any ]],
122
+ initial_state : Optional [Dict [str , Any ]] = None ,
110
123
):
111
124
self .source = source
112
125
self .predicate = predicate
@@ -116,83 +129,53 @@ def __init__(
116
129
self .multiprocessing_context = multiprocessing_context
117
130
self .max_concurrent = max_concurrent
118
131
self .snapshot_frequency = snapshot_frequency
119
- self ._in_q : queue .Queue = queue .Queue ()
120
- self ._out_q : queue .Queue = queue .Queue ()
121
- self ._sem = threading .BoundedSemaphore (value = max_concurrent or 2 * num_workers )
122
- self ._stop_event = threading .Event ()
123
- self ._workers : list [threading .Thread ] = []
124
- for _ in range (num_workers ):
125
- t = threading .Thread (
126
- target = self ._filter_worker ,
127
- args = (self ._in_q , self ._out_q , self .predicate ),
128
- daemon = True ,
129
- )
130
- t .start ()
131
- self ._workers .append (t )
132
- self ._populate_queue_thread = threading .Thread (
133
- target = _populate_queue ,
134
- args = (
135
- self .source ,
136
- self ._in_q ,
137
- QueueSnapshotStore (),
138
- snapshot_frequency ,
139
- self ._sem ,
140
- self ._stop_event ,
141
- ),
142
- daemon = True ,
132
+ # Create a ParallelMapper to filter items in parallel
133
+ self .mapper = ParallelMapper (
134
+ source = self .source ,
135
+ map_fn = lambda x : (x , self .predicate (x )),
136
+ num_workers = self .num_workers ,
137
+ in_order = self .in_order ,
138
+ method = self .method ,
139
+ multiprocessing_context = self .multiprocessing_context ,
140
+ max_concurrent = self .max_concurrent ,
141
+ snapshot_frequency = self .snapshot_frequency ,
143
142
)
144
-
145
- self ._populate_queue_thread .start ()
146
143
if initial_state is not None :
147
- self .source .reset (initial_state ["source" ])
148
- else :
149
- self .source .reset ()
150
-
151
- def _filter_worker (
152
- self , in_q : queue .Queue , out_q : queue .Queue , predicate : Callable [[T ], bool ]
153
- ) -> None :
154
- while True :
155
- try :
156
- item = in_q .get (block = True , timeout = 0.1 )
157
- except queue .Empty :
158
- if self ._stop_event .is_set ():
159
- break
160
- continue
161
- if isinstance (item , StopIteration ):
162
- out_q .put (item )
163
- break
164
- elif predicate (item ):
165
- out_q .put (item )
166
- self ._sem .release ()
144
+ self .mapper .reset (initial_state )
167
145
168
146
def __iter__ (self ) -> Iterator [T ]:
147
+ """
148
+ Returns the iterator object itself.
149
+
150
+ Returns:
151
+ Iterator[T]: The iterator object itself.
152
+ """
169
153
return self
170
154
171
155
def __next__ (self ) -> T :
156
+ """
157
+ Returns the next filtered data sample.
158
+
159
+ Returns:
160
+ T: The next filtered data sample.
161
+ """
172
162
while True :
173
163
try :
174
- item = self ._out_q .get (block = True , timeout = 0.1 )
175
- except queue .Empty :
176
- if self ._stop_event .is_set ():
177
- raise StopIteration ()
178
- continue
179
- if isinstance (item , StopIteration ):
180
- raise item
181
- return item
164
+ item , passed_predicate = next (self .mapper )
165
+ if passed_predicate :
166
+ return item
167
+ except StopIteration :
168
+ raise
182
169
183
170
def get_state (self ) -> Dict [str , Any ]:
184
- return {"source" : self .source .state_dict ()}
171
+ """
172
+ Returns the current state of the parallel filter iterator.
173
+
174
+ Returns:
175
+ Dict[str, Any]: The current state of the parallel filter iterator.
176
+ """
177
+ return self .mapper .get_state ()
185
178
186
179
def __del__ (self ):
187
- self ._shutdown ()
188
-
189
- def _shutdown (self ):
190
- self ._stop_event .set ()
191
- if (
192
- hasattr (self , "_populate_queue_thread" )
193
- and self ._populate_queue_thread .is_alive ()
194
- ):
195
- self ._populate_queue_thread .join (timeout = 0.5 )
196
- for t in self ._workers :
197
- if t .is_alive ():
198
- t .join (timeout = 0.5 )
180
+ # Clean up resources when the iterator is deleted
181
+ del self .mapper
0 commit comments