1
1
from typing import Any , Callable , Dict , Iterator , Literal , Optional , TypeVar
2
+
2
3
from torchdata .nodes .base_node import BaseNode
3
4
from torchdata .nodes .map import ParallelMapper
5
+
4
6
T = TypeVar ("T" , covariant = True )
5
7
8
+
6
9
class Filter (BaseNode [T ]):
7
10
"""
8
11
A node that filters data samples based on a given predicate.
@@ -16,6 +19,7 @@ class Filter(BaseNode[T]):
16
19
max_concurrent (Optional[int]): The maximum number of items to process at once. Default is None.
17
20
snapshot_frequency (int): The frequency at which to snapshot the state of the source node. Default is 1.
18
21
"""
22
+
19
23
def __init__ (
20
24
self ,
21
25
source : BaseNode [T ],
@@ -49,25 +53,30 @@ def __init__(
49
53
)
50
54
else :
51
55
self ._it = _InlineFilterIter (source = self .source , predicate = self .predicate )
56
+
52
57
def reset (self , initial_state : Optional [Dict [str , Any ]] = None ) -> None :
53
58
"""Resets the filter node to its initial state."""
54
59
super ().reset (initial_state )
55
60
if self ._it is not None :
56
61
self ._it .reset (initial_state )
62
+
57
63
def next (self ) -> T :
58
64
"""Returns the next filtered item."""
59
65
return next (self ._it )
66
+
60
67
def get_state (self ) -> Dict [str , Any ]:
61
68
"""Returns the current state of the filter node."""
62
69
return self ._it .get_state ()
63
70
71
+
64
72
class _InlineFilterIter (Iterator [T ]):
65
73
"""
66
74
An iterator that filters data samples inline.
67
75
Args:
68
76
source (BaseNode[T]): The source node providing data samples.
69
77
predicate (Callable[[T], bool]): A function that takes a data sample and returns a boolean indicating whether to include it.
70
78
"""
79
+
71
80
SOURCE_KEY = "source"
72
81
73
82
def __init__ (self , source : BaseNode [T ], predicate : Callable [[T ], bool ]) -> None :
@@ -84,6 +93,7 @@ def reset(self, initial_state: Optional[Dict[str, Any]] = None) -> None:
84
93
def __iter__ (self ) -> Iterator [T ]:
85
94
"""Returns the iterator object itself."""
86
95
return self
96
+
87
97
def __next__ (self ) -> T :
88
98
"""Returns the next filtered item."""
89
99
while True :
@@ -93,10 +103,12 @@ def __next__(self) -> T:
93
103
return item
94
104
except StopIteration :
95
105
raise
106
+
96
107
def get_state (self ) -> Dict [str , Any ]:
97
108
"""Returns the current state of the inline filter iterator."""
98
109
return {self .SOURCE_KEY : self .source .state_dict ()}
99
110
111
+
100
112
class _ParallelFilterIter (Iterator [T ]):
101
113
"""
102
114
An iterator that filters data samples in parallel.
@@ -110,7 +122,9 @@ class _ParallelFilterIter(Iterator[T]):
110
122
max_concurrent (Optional[int]): The maximum number of concurrent tasks.
111
123
snapshot_frequency (int): The frequency at which to take snapshots.
112
124
"""
125
+
113
126
MAPPER_KEY = "mapper"
127
+
114
128
def __init__ (
115
129
self ,
116
130
source : BaseNode [T ],
@@ -141,15 +155,18 @@ def __init__(
141
155
max_concurrent = self .max_concurrent ,
142
156
snapshot_frequency = self .snapshot_frequency ,
143
157
)
158
+
144
159
def reset (self , initial_state : Optional [Dict [str , Any ]] = None ) -> None :
145
160
"""Resets the parallel filter iterator to its initial state."""
146
161
if initial_state :
147
162
self .mapper .reset (initial_state [self .MAPPER_KEY ])
148
163
else :
149
164
self .mapper .reset ()
165
+
150
166
def __iter__ (self ) -> Iterator [T ]:
151
167
"""Returns the iterator object itself."""
152
168
return self
169
+
153
170
def __next__ (self ) -> T :
154
171
"""Returns the next filtered item."""
155
172
while True :
@@ -161,6 +178,7 @@ def __next__(self) -> T:
161
178
def get_state (self ) -> Dict [str , Any ]:
162
179
"""Returns the current state of the parallel filter iterator."""
163
180
return {self .MAPPER_KEY : self .mapper .get_state ()}
181
+
164
182
def __del__ (self ):
165
183
# Clean up resources when the iterator is deleted
166
184
del self .mapper
0 commit comments