Skip to content

Commit 327e225

Browse files
authored
Add Filter (#1454)
* adding filter node and its test * test fixed * add more tests. add num_yielded * linting is done ~
1 parent 83d09ff commit 327e225

File tree

3 files changed

+243
-0
lines changed

3 files changed

+243
-0
lines changed

test/nodes/test_filter.py

+166
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
import itertools
2+
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import TestCase
5+
from torchdata.nodes import Batcher, Filter
6+
from torchdata.nodes.adapters import IterableWrapper
7+
8+
from .utils import MockSource, run_test_save_load_state, StatefulRangeNode
9+
10+
11+
class TestFilter(TestCase):
12+
def test_filter_basic(self) -> None:
13+
# Test with a simple range
14+
source = IterableWrapper(range(10))
15+
node = Filter(source, lambda x: x % 2 == 0) # Keep even numbers
16+
17+
results = list(node)
18+
self.assertEqual(results, [0, 2, 4, 6, 8])
19+
20+
# Verify counters
21+
self.assertEqual(node._num_yielded, 5) # 5 even numbers were yielded
22+
self.assertEqual(node._num_filtered, 5) # 5 odd numbers were filtered out
23+
24+
# Test with a different predicate
25+
source = IterableWrapper(range(10))
26+
node = Filter(source, lambda x: x > 5) # Keep numbers greater than 5
27+
28+
results = list(node)
29+
self.assertEqual(results, [6, 7, 8, 9])
30+
31+
# Verify counters
32+
self.assertEqual(node._num_yielded, 4) # 4 numbers > 5 were yielded
33+
self.assertEqual(node._num_filtered, 6) # 6 numbers <= 5 were filtered out
34+
35+
def test_filter_with_mock_source(self) -> None:
36+
num_samples = 20
37+
source = MockSource(num_samples=num_samples)
38+
node = Filter(source, lambda x: x["step"] % 3 == 0) # Keep items where step is divisible by 3
39+
40+
# Test multi epoch
41+
for epoch in range(2):
42+
node.reset()
43+
results = list(node)
44+
expected_steps = [i for i in range(num_samples) if i % 3 == 0]
45+
self.assertEqual(len(results), len(expected_steps))
46+
47+
# Verify counters after each epoch
48+
self.assertEqual(node._num_yielded, len(expected_steps))
49+
self.assertEqual(node._num_filtered, num_samples - len(expected_steps))
50+
51+
for i, result in enumerate(results):
52+
expected_step = expected_steps[i]
53+
self.assertEqual(result["step"], expected_step)
54+
self.assertEqual(result["test_tensor"].item(), expected_step)
55+
self.assertEqual(result["test_str"], f"str_{expected_step}")
56+
57+
def test_filter_empty_result(self) -> None:
58+
source = IterableWrapper(range(10))
59+
node = Filter(source, lambda x: x > 100) # No items will pass this filter
60+
61+
results = list(node)
62+
self.assertEqual(results, [])
63+
64+
# Verify counters when no items pass the filter
65+
self.assertEqual(node._num_yielded, 0) # No items were yielded
66+
self.assertEqual(node._num_filtered, 10) # All 10 items were filtered out
67+
68+
@parameterized.expand(itertools.product([0, 3, 7]))
69+
def test_save_load_state(self, midpoint: int):
70+
n = 50
71+
source = StatefulRangeNode(n=n)
72+
node = Filter(source, lambda x: x["i"] % 3 == 0) # Keep items where 'i' is divisible by 3
73+
run_test_save_load_state(self, node, midpoint)
74+
75+
def test_filter_reset_state(self) -> None:
76+
source = IterableWrapper(range(10))
77+
node = Filter(source, lambda x: x % 2 == 0)
78+
79+
# Consume first two items
80+
self.assertEqual(next(node), 0)
81+
self.assertEqual(next(node), 2)
82+
83+
# Check counters after consuming two items
84+
self.assertEqual(node._num_yielded, 2) # 2 even numbers were yielded
85+
self.assertEqual(node._num_filtered, 1) # 1 odd number was filtered out
86+
87+
# Get state and reset
88+
state = node.state_dict()
89+
node.reset(state)
90+
91+
# Counters should be preserved after reset with state
92+
self.assertEqual(node._num_yielded, 2)
93+
self.assertEqual(node._num_filtered, 1)
94+
95+
# Should continue from where we left off
96+
self.assertEqual(next(node), 4)
97+
self.assertEqual(next(node), 6)
98+
self.assertEqual(next(node), 8)
99+
100+
# Counters should be updated after consuming more items
101+
self.assertEqual(node._num_yielded, 5) # Total of 5 even numbers yielded
102+
self.assertEqual(node._num_filtered, 4) # Total of 4 odd numbers filtered out
103+
104+
# Should raise StopIteration after all items are consumed
105+
with self.assertRaises(StopIteration):
106+
next(node)
107+
108+
def test_filter_with_batcher(self) -> None:
109+
# Test Filter node with Batcher
110+
111+
# Create a source with numbers 0-19
112+
source = IterableWrapper(range(20))
113+
114+
# Batch into groups of 4
115+
batch_node = Batcher(source, batch_size=4)
116+
117+
# Filter to keep only batches where the sum is divisible by 10
118+
filter_node = Filter(batch_node, lambda batch: sum(batch) % 10 == 0)
119+
120+
# Let's calculate the expected batches and their sums
121+
# Batch 1: [0, 1, 2, 3] -> sum = 6
122+
# Batch 2: [4, 5, 6, 7] -> sum = 22
123+
# Batch 3: [8, 9, 10, 11] -> sum = 38
124+
# Batch 4: [12, 13, 14, 15] -> sum = 54
125+
# Batch 5: [16, 17, 18, 19] -> sum = 70
126+
# Batches with sum divisible by 10: Batch 5 (70)
127+
128+
results = list(filter_node)
129+
130+
# We expect only one batch to pass the filter (sum divisible by 10)
131+
self.assertEqual(len(results), 1)
132+
self.assertEqual(results[0], [16, 17, 18, 19]) # sum = 70
133+
134+
# Check that the filter node tracked both filtered and yielded items
135+
self.assertEqual(filter_node._num_yielded, 1) # 1 batch was yielded
136+
self.assertEqual(filter_node._num_filtered, 4) # 4 batches were filtered out
137+
138+
# Verify total number of batches processed
139+
self.assertEqual(filter_node._num_yielded + filter_node._num_filtered, 5) # Total of 5 batches
140+
141+
def test_counter_reset(self) -> None:
142+
# Test that counters are properly reset
143+
source = IterableWrapper(range(10))
144+
node = Filter(source, lambda x: x % 2 == 0)
145+
146+
# Consume all items
147+
list(node)
148+
149+
# Verify counters after first pass
150+
self.assertEqual(node._num_yielded, 5)
151+
self.assertEqual(node._num_filtered, 5)
152+
153+
# Reset without state
154+
node.reset()
155+
156+
# Counters should be reset to 0
157+
self.assertEqual(node._num_yielded, 0)
158+
self.assertEqual(node._num_filtered, 0)
159+
160+
# Consume some items
161+
next(node) # 0
162+
next(node) # 2
163+
164+
# Verify counters after partial consumption
165+
self.assertEqual(node._num_yielded, 2)
166+
self.assertEqual(node._num_filtered, 1)

torchdata/nodes/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .adapters import IterableWrapper, MapStyleWrapper, SamplerWrapper
88
from .base_node import BaseNode, T
99
from .batch import Batcher, Unbatcher
10+
from .filter import Filter
1011
from .loader import Loader
1112
from .map import Mapper, ParallelMapper
1213
from .pin_memory import PinMemory

torchdata/nodes/filter.py

+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from typing import Any, Callable, Dict, Optional, TypeVar
2+
3+
from torchdata.nodes import BaseNode
4+
5+
6+
T = TypeVar("T")
7+
8+
9+
class Filter(BaseNode[T]):
10+
"""Node that filters items from source node based on predicate function.
11+
12+
This node applies a filter function to each item from the source node and only yields
13+
items that satisfy the condition (when filter_fn returns True). It keeps track of both
14+
the number of items that were filtered out (rejected) and the number of items that were
15+
yielded (accepted).
16+
17+
Args:
18+
source_node (BaseNode[T]): The source node to filter items from.
19+
filter_fn (Callable[[T], bool]): A function that takes an item and returns True if the item
20+
should be included, False otherwise.
21+
"""
22+
23+
SOURCE_KEY = "source"
24+
NUM_FILTERED_KEY = "num_filtered"
25+
NUM_YIELDED_KEY = "num_yielded"
26+
27+
def __init__(self, source_node: BaseNode[T], filter_fn: Callable[[T], bool]):
28+
super().__init__()
29+
self.source = source_node
30+
self.filter_fn = filter_fn
31+
self._num_filtered = 0 # Count of items that did NOT pass the filter
32+
self._num_yielded = 0 # Count of items that DID pass the filter
33+
34+
def reset(self, initial_state: Optional[Dict[str, Any]] = None):
35+
"""Reset the node to its initial state or to the provided state.
36+
37+
Args:
38+
initial_state: Optional state dictionary to restore from.
39+
"""
40+
super().reset(initial_state)
41+
if initial_state is not None:
42+
self.source.reset(initial_state.get(self.SOURCE_KEY))
43+
self._num_filtered = initial_state.get(self.NUM_FILTERED_KEY, 0)
44+
self._num_yielded = initial_state.get(self.NUM_YIELDED_KEY, 0)
45+
else:
46+
self.source.reset(None)
47+
self._num_filtered = 0
48+
self._num_yielded = 0
49+
50+
def next(self) -> T:
51+
"""Get the next item that passes the filter.
52+
53+
Returns:
54+
The next item that satisfies the filter condition.
55+
56+
Raises:
57+
StopIteration: If there are no more items in the source node.
58+
"""
59+
while True:
60+
item = next(self.source)
61+
if self.filter_fn(item):
62+
self._num_yielded += 1
63+
return item
64+
self._num_filtered += 1
65+
66+
def get_state(self) -> Dict[str, Any]:
67+
"""Get the current state of the node.
68+
69+
Returns:
70+
A dictionary containing the state of the source node and counters.
71+
"""
72+
return {
73+
self.SOURCE_KEY: self.source.state_dict(),
74+
self.NUM_FILTERED_KEY: self._num_filtered,
75+
self.NUM_YIELDED_KEY: self._num_yielded,
76+
}

0 commit comments

Comments
 (0)