Skip to content

Commit 7202377

Browse files
committed
update test_filter
update filter state management update filter add filter delete unnecessary files
1 parent 6112828 commit 7202377

File tree

3 files changed

+106
-87
lines changed

3 files changed

+106
-87
lines changed

test/nodes/run_filter.py

+25-4
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,37 @@
44
from torchdata.nodes.loader import Loader
55
from torchdata.nodes.prefetch import Prefetcher
66
from torchdata.nodes.samplers.stop_criteria import StopCriteria
7+
from utils import MockSource, run_test_save_load_state, StatefulRangeNode
78

8-
a = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
99

10-
node = IterableWrapper(a)
1110

11+
a = list(range(60))
12+
base_node = IterableWrapper(a)
1213

1314
def is_even(x):
1415
return x % 2 == 0
16+
17+
node = Filter(base_node, is_even, num_workers=2)
1518

19+
print(node.get_state())
20+
for _ in range(28):
21+
print(next(node))
22+
print(node.get_state())
1623

17-
filtered_node = Filter(node, is_even, num_workers=2)
18-
for item in filtered_node:
24+
state = node.get_state()
25+
node.reset()
26+
27+
print(node.get_state())
28+
29+
for _ in range(2):
30+
print(next(node))
31+
32+
del node
33+
node = Filter(base_node, is_even, num_workers=2)
34+
print("state to be loaded", state)
35+
print("state before reset", node.get_state())
36+
node.reset(state)
37+
print(node.get_state())
38+
39+
for item in node:
1940
print(item)

test/nodes/test_filter.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515

1616
from torchdata.nodes.base_node import BaseNode
1717
from torchdata.nodes.filter import Filter
18+
from torchdata.nodes.batch import Batcher
1819

1920
from .utils import MockSource, run_test_save_load_state, StatefulRangeNode
20-
21+
from torchdata.nodes.samplers.multi_node_weighted_sampler import MultiNodeWeightedSampler
2122

2223
class TestFilter(TestCase):
2324
def _test_filter(self, num_workers, in_order, method):
@@ -52,25 +53,37 @@ def test_filter_parallel_threads(self):
5253
def test_filter_parallel_process(self):
5354
self._test_filter(num_workers=4, in_order=True, method="process")
5455

56+
@parameterized.expand([100, 200, 300])
57+
def test_filter_batcher(self, n):
58+
src = StatefulRangeNode(n=n)
59+
node = Batcher(src, batch_size=2)
60+
predicate = lambda x : (x[0]["i"]+x[1]["i"])%3==0
61+
node = Filter(node, predicate, num_workers=2)
62+
results = list(node)
63+
self.assertEqual(len(results), n//6)
64+
65+
66+
67+
5568
@parameterized.expand(
5669
itertools.product(
57-
[0], # , 7, 13],
70+
[10, 20 , 40],
5871
[True], # TODO: define and fix in_order = False
59-
[0], # , 1, 9], # TODO: define and fix in_order = False
72+
[1, 2, 4],
6073
)
6174
)
6275
def test_save_load_state_thread(
6376
self, midpoint: int, in_order: bool, snapshot_frequency: int
6477
):
6578
method = "thread"
6679
n = 100
67-
predicate = lambda x: True
80+
predicate = lambda x: x["i"]%2==0
6881
src = StatefulRangeNode(n=n)
6982

7083
node = Filter(
7184
source=src,
7285
predicate=predicate,
73-
num_workers=4,
86+
num_workers=1,
7487
in_order=in_order,
7588
method=method,
7689
snapshot_frequency=snapshot_frequency,

torchdata/nodes/filter.py

+63-78
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
1-
# Copyright (c) Meta Platforms, Inc. and affiliates.
2-
#
3-
# All rights reserved.
4-
#
5-
# This source code is licensed under the BSD-style license found in the
6-
# LICENSE file in the root directory of this source tree.
71
from typing import Any, Callable, Dict, Iterator, Literal, Optional, TypeVar
8-
92
from torchdata.nodes.base_node import BaseNode
10-
from torchdata.nodes.map import Mapper, ParallelMapper
11-
3+
from torchdata.nodes.map import ParallelMapper
124
T = TypeVar("T", covariant=True)
135

14-
156
class Filter(BaseNode[T]):
7+
"""
8+
A node that filters data samples based on a given predicate.
9+
Args:
10+
source (BaseNode[T]): The source node providing data samples.
11+
predicate (Callable[[T], bool]): A function that takes a data sample and returns a boolean indicating whether to include it.
12+
num_workers (int): The number of worker processes to use for parallel filtering. Defaults to 0.
13+
in_order (bool): Whether to return items in the order from which they arrive from. Default is True.
14+
method (Literal["thread", "process"]): The method to use for parallel processing. Default is "thread".
15+
multiprocessing_context (Optional[str]): The multiprocessing context to use for parallel processing. Default is None.
16+
max_concurrent (Optional[int]): The maximum number of items to process at once. Default is None.
17+
snapshot_frequency (int): The frequency at which to snapshot the state of the source node. Default is 1.
18+
"""
1619
def __init__(
1720
self,
1821
source: BaseNode[T],
@@ -33,12 +36,6 @@ def __init__(
3336
self.multiprocessing_context = multiprocessing_context
3437
self.max_concurrent = max_concurrent
3538
self.snapshot_frequency = snapshot_frequency
36-
self._it: Optional[Iterator[T]] = None
37-
38-
def reset(self, initial_state: Optional[Dict[str, Any]] = None):
39-
super().reset(initial_state)
40-
if self._it is not None:
41-
del self._it
4239
if self.num_workers > 0:
4340
self._it = _ParallelFilterIter(
4441
source=self.source,
@@ -49,54 +46,60 @@ def reset(self, initial_state: Optional[Dict[str, Any]] = None):
4946
multiprocessing_context=self.multiprocessing_context,
5047
max_concurrent=self.max_concurrent,
5148
snapshot_frequency=self.snapshot_frequency,
52-
initial_state=initial_state,
5349
)
54-
5550
else:
56-
self._it = _InlineFilterIter(
57-
source=self.source,
58-
predicate=self.predicate,
59-
initial_state=initial_state,
60-
)
61-
62-
def next(self):
63-
return next(self._it) # type: ignore[arg-type]
64-
51+
self._it = _InlineFilterIter(source=self.source, predicate=self.predicate)
52+
def reset(self, initial_state: Optional[Dict[str, Any]] = None) -> None:
53+
"""Resets the filter node to its initial state."""
54+
super().reset(initial_state)
55+
if self._it is not None:
56+
self._it.reset(initial_state)
57+
def next(self) -> T:
58+
"""Returns the next filtered item."""
59+
return next(self._it)
6560
def get_state(self) -> Dict[str, Any]:
66-
return self._it.get_state() # type: ignore[union-attr]
67-
61+
"""Returns the current state of the filter node."""
62+
return self._it.get_state()
6863

6964
class _InlineFilterIter(Iterator[T]):
70-
def __init__(
71-
self,
72-
source: BaseNode[T],
73-
predicate: Callable[[T], bool],
74-
initial_state: Optional[Dict[str, Any]] = None,
75-
):
65+
"""
66+
An iterator that filters data samples inline.
67+
Args:
68+
source (BaseNode[T]): The source node providing data samples.
69+
predicate (Callable[[T], bool]): A function that takes a data sample and returns a boolean indicating whether to include it.
70+
"""
71+
SOURCE_KEY = "source"
72+
73+
def __init__(self, source: BaseNode[T], predicate: Callable[[T], bool]) -> None:
7674
self.source = source
7775
self.predicate = predicate
78-
if initial_state is not None:
79-
self.source.reset(initial_state["source"])
76+
77+
def reset(self, initial_state: Optional[Dict[str, Any]] = None) -> None:
78+
"""Resets the inline filter iterator to its initial state."""
79+
if initial_state:
80+
self.source.reset(initial_state[self.SOURCE_KEY])
8081
else:
8182
self.source.reset()
8283

8384
def __iter__(self) -> Iterator[T]:
85+
"""Returns the iterator object itself."""
8486
return self
85-
8687
def __next__(self) -> T:
88+
"""Returns the next filtered item."""
8789
while True:
88-
item = next(self.source)
89-
if self.predicate(item):
90-
return item
91-
90+
try:
91+
item = next(self.source)
92+
if self.predicate(item):
93+
return item
94+
except StopIteration:
95+
raise
9296
def get_state(self) -> Dict[str, Any]:
93-
return {"source": self.source.state_dict()}
94-
97+
"""Returns the current state of the inline filter iterator."""
98+
return {self.SOURCE_KEY: self.source.state_dict()}
9599

96100
class _ParallelFilterIter(Iterator[T]):
97101
"""
98102
An iterator that filters data samples in parallel.
99-
100103
Args:
101104
source (BaseNode[T]): The source node providing data samples.
102105
predicate (Callable[[T], bool]): A function that takes a data sample and returns a boolean indicating whether to include it.
@@ -106,9 +109,8 @@ class _ParallelFilterIter(Iterator[T]):
106109
multiprocessing_context (Optional[str]): The multiprocessing context to use.
107110
max_concurrent (Optional[int]): The maximum number of concurrent tasks.
108111
snapshot_frequency (int): The frequency at which to take snapshots.
109-
initial_state (Optional[Dict[str, Any]]): The initial state to start with.
110112
"""
111-
113+
MAPPER_KEY = "mapper"
112114
def __init__(
113115
self,
114116
source: BaseNode[T],
@@ -119,7 +121,6 @@ def __init__(
119121
multiprocessing_context: Optional[str],
120122
max_concurrent: Optional[int],
121123
snapshot_frequency: int,
122-
initial_state: Optional[Dict[str, Any]] = None,
123124
):
124125
self.source = source
125126
self.predicate = predicate
@@ -140,42 +141,26 @@ def __init__(
140141
max_concurrent=self.max_concurrent,
141142
snapshot_frequency=self.snapshot_frequency,
142143
)
143-
if initial_state is not None:
144-
self.mapper.reset(initial_state)
145-
144+
def reset(self, initial_state: Optional[Dict[str, Any]] = None) -> None:
145+
"""Resets the parallel filter iterator to its initial state."""
146+
if initial_state:
147+
self.mapper.reset(initial_state[self.MAPPER_KEY])
148+
else:
149+
self.mapper.reset()
146150
def __iter__(self) -> Iterator[T]:
147-
"""
148-
Returns the iterator object itself.
149-
150-
Returns:
151-
Iterator[T]: The iterator object itself.
152-
"""
151+
"""Returns the iterator object itself."""
153152
return self
154-
155153
def __next__(self) -> T:
156-
"""
157-
Returns the next filtered data sample.
158-
159-
Returns:
160-
T: The next filtered data sample.
161-
"""
154+
"""Returns the next filtered item."""
162155
while True:
163-
try:
164-
item, passed_predicate = next(self.mapper)
165-
if passed_predicate:
166-
return item
167-
except StopIteration:
168-
raise
169-
170-
def get_state(self) -> Dict[str, Any]:
171-
"""
172-
Returns the current state of the parallel filter iterator.
173156

174-
Returns:
175-
Dict[str, Any]: The current state of the parallel filter iterator.
176-
"""
177-
return self.mapper.get_state()
157+
item, passed_predicate = next(self.mapper)
158+
if passed_predicate:
159+
return item
178160

161+
def get_state(self) -> Dict[str, Any]:
162+
"""Returns the current state of the parallel filter iterator."""
163+
return {self.MAPPER_KEY: self.mapper.get_state()}
179164
def __del__(self):
180165
# Clean up resources when the iterator is deleted
181166
del self.mapper

0 commit comments

Comments
 (0)