Skip to content

Commit a1508c4

Browse files
committed
run precommit
1 parent 7202377 commit a1508c4

File tree

3 files changed

+41
-19
lines changed

3 files changed

+41
-19
lines changed

test/nodes/run_filter.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@
77
from utils import MockSource, run_test_save_load_state, StatefulRangeNode
88

99

10-
1110
a = list(range(60))
1211
base_node = IterableWrapper(a)
1312

13+
1414
def is_even(x):
1515
return x % 2 == 0
16-
16+
17+
1718
node = Filter(base_node, is_even, num_workers=2)
1819

1920
print(node.get_state())

test/nodes/test_filter.py

+20-17
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,20 @@
1414
from torch.testing._internal.common_utils import IS_WINDOWS, TEST_CUDA, TestCase
1515

1616
from torchdata.nodes.base_node import BaseNode
17-
from torchdata.nodes.filter import Filter
1817
from torchdata.nodes.batch import Batcher
18+
from torchdata.nodes.filter import Filter
19+
from torchdata.nodes.samplers.multi_node_weighted_sampler import MultiNodeWeightedSampler
1920

2021
from .utils import MockSource, run_test_save_load_state, StatefulRangeNode
21-
from torchdata.nodes.samplers.multi_node_weighted_sampler import MultiNodeWeightedSampler
22+
2223

2324
class TestFilter(TestCase):
2425
def _test_filter(self, num_workers, in_order, method):
2526
n = 100
26-
predicate = lambda x: x["test_tensor"] % 2 == 0 # Filter even numbers
27+
28+
def predicate(x):
29+
return x["test_tensor"] % 2 == 0
30+
2731
src = MockSource(num_samples=n)
2832
node = Filter(
2933
source=src,
@@ -38,9 +42,7 @@ def _test_filter(self, num_workers, in_order, method):
3842
results.append(item)
3943

4044
expected_results = [
41-
{"step": i, "test_tensor": torch.tensor([i]), "test_str": f"str_{i}"}
42-
for i in range(n)
43-
if i % 2 == 0
45+
{"step": i, "test_tensor": torch.tensor([i]), "test_str": f"str_{i}"} for i in range(n) if i % 2 == 0
4446
]
4547
self.assertEqual(results, expected_results)
4648

@@ -57,27 +59,28 @@ def test_filter_parallel_process(self):
5759
def test_filter_batcher(self, n):
5860
src = StatefulRangeNode(n=n)
5961
node = Batcher(src, batch_size=2)
60-
predicate = lambda x : (x[0]["i"]+x[1]["i"])%3==0
61-
node = Filter(node, predicate, num_workers=2)
62-
results = list(node)
63-
self.assertEqual(len(results), n//6)
64-
6562

63+
def predicate(x):
64+
return (x[0]["i"] + x[1]["i"]) % 3 == 0
6665

66+
node = Filter(node, predicate, num_workers=2)
67+
results = list(node)
68+
self.assertEqual(len(results), n // 6)
6769

6870
@parameterized.expand(
6971
itertools.product(
70-
[10, 20 , 40],
72+
[10, 20, 40],
7173
[True], # TODO: define and fix in_order = False
72-
[1, 2, 4],
74+
[1, 2, 4],
7375
)
7476
)
75-
def test_save_load_state_thread(
76-
self, midpoint: int, in_order: bool, snapshot_frequency: int
77-
):
77+
def test_save_load_state_thread(self, midpoint: int, in_order: bool, snapshot_frequency: int):
7878
method = "thread"
7979
n = 100
80-
predicate = lambda x: x["i"]%2==0
80+
81+
def predicate(x):
82+
return x["i"] % 2 == 0
83+
8184
src = StatefulRangeNode(n=n)
8285

8386
node = Filter(

torchdata/nodes/filter.py

+18
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
from typing import Any, Callable, Dict, Iterator, Literal, Optional, TypeVar
2+
23
from torchdata.nodes.base_node import BaseNode
34
from torchdata.nodes.map import ParallelMapper
5+
46
T = TypeVar("T", covariant=True)
57

8+
69
class Filter(BaseNode[T]):
710
"""
811
A node that filters data samples based on a given predicate.
@@ -16,6 +19,7 @@ class Filter(BaseNode[T]):
1619
max_concurrent (Optional[int]): The maximum number of items to process at once. Default is None.
1720
snapshot_frequency (int): The frequency at which to snapshot the state of the source node. Default is 1.
1821
"""
22+
1923
def __init__(
2024
self,
2125
source: BaseNode[T],
@@ -49,25 +53,30 @@ def __init__(
4953
)
5054
else:
5155
self._it = _InlineFilterIter(source=self.source, predicate=self.predicate)
56+
5257
def reset(self, initial_state: Optional[Dict[str, Any]] = None) -> None:
5358
"""Resets the filter node to its initial state."""
5459
super().reset(initial_state)
5560
if self._it is not None:
5661
self._it.reset(initial_state)
62+
5763
def next(self) -> T:
5864
"""Returns the next filtered item."""
5965
return next(self._it)
66+
6067
def get_state(self) -> Dict[str, Any]:
6168
"""Returns the current state of the filter node."""
6269
return self._it.get_state()
6370

71+
6472
class _InlineFilterIter(Iterator[T]):
6573
"""
6674
An iterator that filters data samples inline.
6775
Args:
6876
source (BaseNode[T]): The source node providing data samples.
6977
predicate (Callable[[T], bool]): A function that takes a data sample and returns a boolean indicating whether to include it.
7078
"""
79+
7180
SOURCE_KEY = "source"
7281

7382
def __init__(self, source: BaseNode[T], predicate: Callable[[T], bool]) -> None:
@@ -84,6 +93,7 @@ def reset(self, initial_state: Optional[Dict[str, Any]] = None) -> None:
8493
def __iter__(self) -> Iterator[T]:
8594
"""Returns the iterator object itself."""
8695
return self
96+
8797
def __next__(self) -> T:
8898
"""Returns the next filtered item."""
8999
while True:
@@ -93,10 +103,12 @@ def __next__(self) -> T:
93103
return item
94104
except StopIteration:
95105
raise
106+
96107
def get_state(self) -> Dict[str, Any]:
97108
"""Returns the current state of the inline filter iterator."""
98109
return {self.SOURCE_KEY: self.source.state_dict()}
99110

111+
100112
class _ParallelFilterIter(Iterator[T]):
101113
"""
102114
An iterator that filters data samples in parallel.
@@ -110,7 +122,9 @@ class _ParallelFilterIter(Iterator[T]):
110122
max_concurrent (Optional[int]): The maximum number of concurrent tasks.
111123
snapshot_frequency (int): The frequency at which to take snapshots.
112124
"""
125+
113126
MAPPER_KEY = "mapper"
127+
114128
def __init__(
115129
self,
116130
source: BaseNode[T],
@@ -141,15 +155,18 @@ def __init__(
141155
max_concurrent=self.max_concurrent,
142156
snapshot_frequency=self.snapshot_frequency,
143157
)
158+
144159
def reset(self, initial_state: Optional[Dict[str, Any]] = None) -> None:
145160
"""Resets the parallel filter iterator to its initial state."""
146161
if initial_state:
147162
self.mapper.reset(initial_state[self.MAPPER_KEY])
148163
else:
149164
self.mapper.reset()
165+
150166
def __iter__(self) -> Iterator[T]:
151167
"""Returns the iterator object itself."""
152168
return self
169+
153170
def __next__(self) -> T:
154171
"""Returns the next filtered item."""
155172
while True:
@@ -161,6 +178,7 @@ def __next__(self) -> T:
161178
def get_state(self) -> Dict[str, Any]:
162179
"""Returns the current state of the parallel filter iterator."""
163180
return {self.MAPPER_KEY: self.mapper.get_state()}
181+
164182
def __del__(self):
165183
# Clean up resources when the iterator is deleted
166184
del self.mapper

0 commit comments

Comments
 (0)