Skip to content

Commit 6112828

Browse files
committed
add tests
1 parent 3789545 commit 6112828

File tree

3 files changed

+175
-94
lines changed

3 files changed

+175
-94
lines changed

test/nodes/run_filter.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from torchdata.nodes.adapters import IterableWrapper
2+
from torchdata.nodes.batch import Batcher
3+
from torchdata.nodes.filter import Filter
4+
from torchdata.nodes.loader import Loader
5+
from torchdata.nodes.prefetch import Prefetcher
6+
from torchdata.nodes.samplers.stop_criteria import StopCriteria
7+
8+
a = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
9+
10+
node = IterableWrapper(a)
11+
12+
13+
def is_even(x):
14+
return x % 2 == 0
15+
16+
17+
filtered_node = Filter(node, is_even, num_workers=2)
18+
for item in filtered_node:
19+
print(item)

test/nodes/test_filter.py

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import itertools
8+
import unittest
9+
from typing import List
10+
11+
import torch
12+
13+
from parameterized import parameterized
14+
from torch.testing._internal.common_utils import IS_WINDOWS, TEST_CUDA, TestCase
15+
16+
from torchdata.nodes.base_node import BaseNode
17+
from torchdata.nodes.filter import Filter
18+
19+
from .utils import MockSource, run_test_save_load_state, StatefulRangeNode
20+
21+
22+
class TestFilter(TestCase):
23+
def _test_filter(self, num_workers, in_order, method):
24+
n = 100
25+
predicate = lambda x: x["test_tensor"] % 2 == 0 # Filter even numbers
26+
src = MockSource(num_samples=n)
27+
node = Filter(
28+
source=src,
29+
predicate=predicate,
30+
num_workers=num_workers,
31+
in_order=in_order,
32+
method=method,
33+
)
34+
35+
results: List[int] = []
36+
for item in node:
37+
results.append(item)
38+
39+
expected_results = [
40+
{"step": i, "test_tensor": torch.tensor([i]), "test_str": f"str_{i}"}
41+
for i in range(n)
42+
if i % 2 == 0
43+
]
44+
self.assertEqual(results, expected_results)
45+
46+
def test_filter_inline(self):
47+
self._test_filter(num_workers=0, in_order=True, method="thread")
48+
49+
def test_filter_parallel_threads(self):
50+
self._test_filter(num_workers=4, in_order=True, method="thread")
51+
52+
def test_filter_parallel_process(self):
53+
self._test_filter(num_workers=4, in_order=True, method="process")
54+
55+
@parameterized.expand(
56+
itertools.product(
57+
[0], # , 7, 13],
58+
[True], # TODO: define and fix in_order = False
59+
[0], # , 1, 9], # TODO: define and fix in_order = False
60+
)
61+
)
62+
def test_save_load_state_thread(
63+
self, midpoint: int, in_order: bool, snapshot_frequency: int
64+
):
65+
method = "thread"
66+
n = 100
67+
predicate = lambda x: True
68+
src = StatefulRangeNode(n=n)
69+
70+
node = Filter(
71+
source=src,
72+
predicate=predicate,
73+
num_workers=4,
74+
in_order=in_order,
75+
method=method,
76+
snapshot_frequency=snapshot_frequency,
77+
)
78+
node.reset()
79+
run_test_save_load_state(self, node, midpoint)

torchdata/nodes/filter.py

+77-94
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@
44
#
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
7-
from typing import Any, Callable, Dict, Iterator, Optional
7+
from typing import Any, Callable, Dict, Iterator, Literal, Optional, TypeVar
88

9-
from torchdata.nodes.base_node import BaseNode, T
9+
from torchdata.nodes.base_node import BaseNode
10+
from torchdata.nodes.map import Mapper, ParallelMapper
11+
12+
T = TypeVar("T", covariant=True)
1013

1114

1215
class Filter(BaseNode[T]):
@@ -16,7 +19,7 @@ def __init__(
1619
predicate: Callable[[T], bool],
1720
num_workers: int = 0,
1821
in_order: bool = True,
19-
method: str = "thread",
22+
method: Literal["thread", "process"] = "thread",
2023
multiprocessing_context: Optional[str] = None,
2124
max_concurrent: Optional[int] = None,
2225
snapshot_frequency: int = 1,
@@ -37,29 +40,24 @@ def reset(self, initial_state: Optional[Dict[str, Any]] = None):
3740
if self._it is not None:
3841
del self._it
3942
if self.num_workers > 0:
40-
self._parallel_reset(initial_state)
41-
else:
42-
self._inline_reset(initial_state)
43-
44-
def _inline_reset(self, initial_state: Optional[Dict[str, Any]]):
45-
self._it = _InlineFilterIter(
46-
source=self.source,
47-
predicate=self.predicate,
48-
initial_state=initial_state,
49-
)
43+
self._it = _ParallelFilterIter(
44+
source=self.source,
45+
predicate=self.predicate,
46+
num_workers=self.num_workers,
47+
in_order=self.in_order,
48+
method=self.method,
49+
multiprocessing_context=self.multiprocessing_context,
50+
max_concurrent=self.max_concurrent,
51+
snapshot_frequency=self.snapshot_frequency,
52+
initial_state=initial_state,
53+
)
5054

51-
def _parallel_reset(self, initial_state: Optional[Dict[str, Any]]):
52-
self._it = _ParallelFilterIter(
53-
source=self.source,
54-
predicate=self.predicate,
55-
num_workers=self.num_workers,
56-
in_order=self.in_order,
57-
method=self.method,
58-
multiprocessing_context=self.multiprocessing_context,
59-
max_concurrent=self.max_concurrent,
60-
snapshot_frequency=self.snapshot_frequency,
61-
initial_state=initial_state,
62-
)
55+
else:
56+
self._it = _InlineFilterIter(
57+
source=self.source,
58+
predicate=self.predicate,
59+
initial_state=initial_state,
60+
)
6361

6462
def next(self):
6563
return next(self._it) # type: ignore[arg-type]
@@ -96,17 +94,32 @@ def get_state(self) -> Dict[str, Any]:
9694

9795

9896
class _ParallelFilterIter(Iterator[T]):
97+
"""
98+
An iterator that filters data samples in parallel.
99+
100+
Args:
101+
source (BaseNode[T]): The source node providing data samples.
102+
predicate (Callable[[T], bool]): A function that takes a data sample and returns a boolean indicating whether to include it.
103+
num_workers (int): The number of worker processes to use for parallel filtering.
104+
in_order (bool): Whether to preserve the order of data samples.
105+
method (Literal["thread", "process"]): The method to use for parallelization.
106+
multiprocessing_context (Optional[str]): The multiprocessing context to use.
107+
max_concurrent (Optional[int]): The maximum number of concurrent tasks.
108+
snapshot_frequency (int): The frequency at which to take snapshots.
109+
initial_state (Optional[Dict[str, Any]]): The initial state to start with.
110+
"""
111+
99112
def __init__(
100113
self,
101114
source: BaseNode[T],
102115
predicate: Callable[[T], bool],
103116
num_workers: int,
104117
in_order: bool,
105-
method: str,
118+
method: Literal["thread", "process"],
106119
multiprocessing_context: Optional[str],
107120
max_concurrent: Optional[int],
108121
snapshot_frequency: int,
109-
initial_state: Optional[Dict[str, Any]],
122+
initial_state: Optional[Dict[str, Any]] = None,
110123
):
111124
self.source = source
112125
self.predicate = predicate
@@ -116,83 +129,53 @@ def __init__(
116129
self.multiprocessing_context = multiprocessing_context
117130
self.max_concurrent = max_concurrent
118131
self.snapshot_frequency = snapshot_frequency
119-
self._in_q: queue.Queue = queue.Queue()
120-
self._out_q: queue.Queue = queue.Queue()
121-
self._sem = threading.BoundedSemaphore(value=max_concurrent or 2 * num_workers)
122-
self._stop_event = threading.Event()
123-
self._workers: list[threading.Thread] = []
124-
for _ in range(num_workers):
125-
t = threading.Thread(
126-
target=self._filter_worker,
127-
args=(self._in_q, self._out_q, self.predicate),
128-
daemon=True,
129-
)
130-
t.start()
131-
self._workers.append(t)
132-
self._populate_queue_thread = threading.Thread(
133-
target=_populate_queue,
134-
args=(
135-
self.source,
136-
self._in_q,
137-
QueueSnapshotStore(),
138-
snapshot_frequency,
139-
self._sem,
140-
self._stop_event,
141-
),
142-
daemon=True,
132+
# Create a ParallelMapper to filter items in parallel
133+
self.mapper = ParallelMapper(
134+
source=self.source,
135+
map_fn=lambda x: (x, self.predicate(x)),
136+
num_workers=self.num_workers,
137+
in_order=self.in_order,
138+
method=self.method,
139+
multiprocessing_context=self.multiprocessing_context,
140+
max_concurrent=self.max_concurrent,
141+
snapshot_frequency=self.snapshot_frequency,
143142
)
144-
145-
self._populate_queue_thread.start()
146143
if initial_state is not None:
147-
self.source.reset(initial_state["source"])
148-
else:
149-
self.source.reset()
150-
151-
def _filter_worker(
152-
self, in_q: queue.Queue, out_q: queue.Queue, predicate: Callable[[T], bool]
153-
) -> None:
154-
while True:
155-
try:
156-
item = in_q.get(block=True, timeout=0.1)
157-
except queue.Empty:
158-
if self._stop_event.is_set():
159-
break
160-
continue
161-
if isinstance(item, StopIteration):
162-
out_q.put(item)
163-
break
164-
elif predicate(item):
165-
out_q.put(item)
166-
self._sem.release()
144+
self.mapper.reset(initial_state)
167145

168146
def __iter__(self) -> Iterator[T]:
147+
"""
148+
Returns the iterator object itself.
149+
150+
Returns:
151+
Iterator[T]: The iterator object itself.
152+
"""
169153
return self
170154

171155
def __next__(self) -> T:
156+
"""
157+
Returns the next filtered data sample.
158+
159+
Returns:
160+
T: The next filtered data sample.
161+
"""
172162
while True:
173163
try:
174-
item = self._out_q.get(block=True, timeout=0.1)
175-
except queue.Empty:
176-
if self._stop_event.is_set():
177-
raise StopIteration()
178-
continue
179-
if isinstance(item, StopIteration):
180-
raise item
181-
return item
164+
item, passed_predicate = next(self.mapper)
165+
if passed_predicate:
166+
return item
167+
except StopIteration:
168+
raise
182169

183170
def get_state(self) -> Dict[str, Any]:
184-
return {"source": self.source.state_dict()}
171+
"""
172+
Returns the current state of the parallel filter iterator.
173+
174+
Returns:
175+
Dict[str, Any]: The current state of the parallel filter iterator.
176+
"""
177+
return self.mapper.get_state()
185178

186179
def __del__(self):
187-
self._shutdown()
188-
189-
def _shutdown(self):
190-
self._stop_event.set()
191-
if (
192-
hasattr(self, "_populate_queue_thread")
193-
and self._populate_queue_thread.is_alive()
194-
):
195-
self._populate_queue_thread.join(timeout=0.5)
196-
for t in self._workers:
197-
if t.is_alive():
198-
t.join(timeout=0.5)
180+
# Clean up resources when the iterator is deleted
181+
del self.mapper

0 commit comments

Comments
 (0)