Skip to content

Commit 3789545

Browse files
committed
initial_commit
1 parent 4ec4548 commit 3789545

File tree

1 file changed

+198
-0
lines changed

1 file changed

+198
-0
lines changed

torchdata/nodes/filter.py

+198
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
from typing import Any, Callable, Dict, Iterator, Optional
8+
9+
from torchdata.nodes.base_node import BaseNode, T
10+
11+
12+
class Filter(BaseNode[T]):
13+
def __init__(
14+
self,
15+
source: BaseNode[T],
16+
predicate: Callable[[T], bool],
17+
num_workers: int = 0,
18+
in_order: bool = True,
19+
method: str = "thread",
20+
multiprocessing_context: Optional[str] = None,
21+
max_concurrent: Optional[int] = None,
22+
snapshot_frequency: int = 1,
23+
):
24+
super().__init__()
25+
self.source = source
26+
self.predicate = predicate
27+
self.num_workers = num_workers
28+
self.in_order = in_order
29+
self.method = method
30+
self.multiprocessing_context = multiprocessing_context
31+
self.max_concurrent = max_concurrent
32+
self.snapshot_frequency = snapshot_frequency
33+
self._it: Optional[Iterator[T]] = None
34+
35+
def reset(self, initial_state: Optional[Dict[str, Any]] = None):
36+
super().reset(initial_state)
37+
if self._it is not None:
38+
del self._it
39+
if self.num_workers > 0:
40+
self._parallel_reset(initial_state)
41+
else:
42+
self._inline_reset(initial_state)
43+
44+
def _inline_reset(self, initial_state: Optional[Dict[str, Any]]):
45+
self._it = _InlineFilterIter(
46+
source=self.source,
47+
predicate=self.predicate,
48+
initial_state=initial_state,
49+
)
50+
51+
def _parallel_reset(self, initial_state: Optional[Dict[str, Any]]):
52+
self._it = _ParallelFilterIter(
53+
source=self.source,
54+
predicate=self.predicate,
55+
num_workers=self.num_workers,
56+
in_order=self.in_order,
57+
method=self.method,
58+
multiprocessing_context=self.multiprocessing_context,
59+
max_concurrent=self.max_concurrent,
60+
snapshot_frequency=self.snapshot_frequency,
61+
initial_state=initial_state,
62+
)
63+
64+
def next(self):
65+
return next(self._it) # type: ignore[arg-type]
66+
67+
def get_state(self) -> Dict[str, Any]:
68+
return self._it.get_state() # type: ignore[union-attr]
69+
70+
71+
class _InlineFilterIter(Iterator[T]):
72+
def __init__(
73+
self,
74+
source: BaseNode[T],
75+
predicate: Callable[[T], bool],
76+
initial_state: Optional[Dict[str, Any]] = None,
77+
):
78+
self.source = source
79+
self.predicate = predicate
80+
if initial_state is not None:
81+
self.source.reset(initial_state["source"])
82+
else:
83+
self.source.reset()
84+
85+
def __iter__(self) -> Iterator[T]:
86+
return self
87+
88+
def __next__(self) -> T:
89+
while True:
90+
item = next(self.source)
91+
if self.predicate(item):
92+
return item
93+
94+
def get_state(self) -> Dict[str, Any]:
95+
return {"source": self.source.state_dict()}
96+
97+
98+
class _ParallelFilterIter(Iterator[T]):
99+
def __init__(
100+
self,
101+
source: BaseNode[T],
102+
predicate: Callable[[T], bool],
103+
num_workers: int,
104+
in_order: bool,
105+
method: str,
106+
multiprocessing_context: Optional[str],
107+
max_concurrent: Optional[int],
108+
snapshot_frequency: int,
109+
initial_state: Optional[Dict[str, Any]],
110+
):
111+
self.source = source
112+
self.predicate = predicate
113+
self.num_workers = num_workers
114+
self.in_order = in_order
115+
self.method = method
116+
self.multiprocessing_context = multiprocessing_context
117+
self.max_concurrent = max_concurrent
118+
self.snapshot_frequency = snapshot_frequency
119+
self._in_q: queue.Queue = queue.Queue()
120+
self._out_q: queue.Queue = queue.Queue()
121+
self._sem = threading.BoundedSemaphore(value=max_concurrent or 2 * num_workers)
122+
self._stop_event = threading.Event()
123+
self._workers: list[threading.Thread] = []
124+
for _ in range(num_workers):
125+
t = threading.Thread(
126+
target=self._filter_worker,
127+
args=(self._in_q, self._out_q, self.predicate),
128+
daemon=True,
129+
)
130+
t.start()
131+
self._workers.append(t)
132+
self._populate_queue_thread = threading.Thread(
133+
target=_populate_queue,
134+
args=(
135+
self.source,
136+
self._in_q,
137+
QueueSnapshotStore(),
138+
snapshot_frequency,
139+
self._sem,
140+
self._stop_event,
141+
),
142+
daemon=True,
143+
)
144+
145+
self._populate_queue_thread.start()
146+
if initial_state is not None:
147+
self.source.reset(initial_state["source"])
148+
else:
149+
self.source.reset()
150+
151+
def _filter_worker(
152+
self, in_q: queue.Queue, out_q: queue.Queue, predicate: Callable[[T], bool]
153+
) -> None:
154+
while True:
155+
try:
156+
item = in_q.get(block=True, timeout=0.1)
157+
except queue.Empty:
158+
if self._stop_event.is_set():
159+
break
160+
continue
161+
if isinstance(item, StopIteration):
162+
out_q.put(item)
163+
break
164+
elif predicate(item):
165+
out_q.put(item)
166+
self._sem.release()
167+
168+
def __iter__(self) -> Iterator[T]:
169+
return self
170+
171+
def __next__(self) -> T:
172+
while True:
173+
try:
174+
item = self._out_q.get(block=True, timeout=0.1)
175+
except queue.Empty:
176+
if self._stop_event.is_set():
177+
raise StopIteration()
178+
continue
179+
if isinstance(item, StopIteration):
180+
raise item
181+
return item
182+
183+
def get_state(self) -> Dict[str, Any]:
184+
return {"source": self.source.state_dict()}
185+
186+
def __del__(self):
187+
self._shutdown()
188+
189+
def _shutdown(self):
190+
self._stop_event.set()
191+
if (
192+
hasattr(self, "_populate_queue_thread")
193+
and self._populate_queue_thread.is_alive()
194+
):
195+
self._populate_queue_thread.join(timeout=0.5)
196+
for t in self._workers:
197+
if t.is_alive():
198+
t.join(timeout=0.5)

0 commit comments

Comments
 (0)