Skip to content

Commit

Permalink
run precommit
Browse files Browse the repository at this point in the history
add coment

clean up
  • Loading branch information
ramanishsingh committed Jan 22, 2025
1 parent 7202377 commit d6e6356
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 22 deletions.
6 changes: 4 additions & 2 deletions test/nodes/run_filter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#This is a local file for testing and will be deleted in the future.
from torchdata.nodes.adapters import IterableWrapper
from torchdata.nodes.batch import Batcher
from torchdata.nodes.filter import Filter
Expand All @@ -7,13 +8,14 @@
from utils import MockSource, run_test_save_load_state, StatefulRangeNode



a = list(range(60))
base_node = IterableWrapper(a)


def is_even(x):
return x % 2 == 0



node = Filter(base_node, is_even, num_workers=2)

print(node.get_state())
Expand Down
37 changes: 20 additions & 17 deletions test/nodes/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,20 @@
from torch.testing._internal.common_utils import IS_WINDOWS, TEST_CUDA, TestCase

from torchdata.nodes.base_node import BaseNode
from torchdata.nodes.filter import Filter
from torchdata.nodes.batch import Batcher
from torchdata.nodes.filter import Filter
from torchdata.nodes.samplers.multi_node_weighted_sampler import MultiNodeWeightedSampler

from .utils import MockSource, run_test_save_load_state, StatefulRangeNode
from torchdata.nodes.samplers.multi_node_weighted_sampler import MultiNodeWeightedSampler


class TestFilter(TestCase):
def _test_filter(self, num_workers, in_order, method):
n = 100
predicate = lambda x: x["test_tensor"] % 2 == 0 # Filter even numbers

def predicate(x):
return x["test_tensor"] % 2 == 0

src = MockSource(num_samples=n)
node = Filter(
source=src,
Expand All @@ -38,9 +42,7 @@ def _test_filter(self, num_workers, in_order, method):
results.append(item)

expected_results = [
{"step": i, "test_tensor": torch.tensor([i]), "test_str": f"str_{i}"}
for i in range(n)
if i % 2 == 0
{"step": i, "test_tensor": torch.tensor([i]), "test_str": f"str_{i}"} for i in range(n) if i % 2 == 0
]
self.assertEqual(results, expected_results)

Expand All @@ -57,27 +59,28 @@ def test_filter_parallel_process(self):
def test_filter_batcher(self, n):
src = StatefulRangeNode(n=n)
node = Batcher(src, batch_size=2)
predicate = lambda x : (x[0]["i"]+x[1]["i"])%3==0
node = Filter(node, predicate, num_workers=2)
results = list(node)
self.assertEqual(len(results), n//6)


def predicate(x):
return (x[0]["i"] + x[1]["i"]) % 3 == 0

node = Filter(node, predicate, num_workers=2)
results = list(node)
self.assertEqual(len(results), n // 6)

@parameterized.expand(
itertools.product(
[10, 20 , 40],
[10, 20, 40],
[True], # TODO: define and fix in_order = False
[1, 2, 4],
[1, 2, 4],
)
)
def test_save_load_state_thread(
self, midpoint: int, in_order: bool, snapshot_frequency: int
):
def test_save_load_state_thread(self, midpoint: int, in_order: bool, snapshot_frequency: int):
method = "thread"
n = 100
predicate = lambda x: x["i"]%2==0

def predicate(x):
return x["i"] % 2 == 0

src = StatefulRangeNode(n=n)

node = Filter(
Expand Down
20 changes: 17 additions & 3 deletions torchdata/nodes/filter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import Any, Callable, Dict, Iterator, Literal, Optional, TypeVar

from torchdata.nodes.base_node import BaseNode
from torchdata.nodes.map import ParallelMapper

T = TypeVar("T", covariant=True)


class Filter(BaseNode[T]):
"""
A node that filters data samples based on a given predicate.
Expand All @@ -16,6 +19,7 @@ class Filter(BaseNode[T]):
max_concurrent (Optional[int]): The maximum number of items to process at once. Default is None.
snapshot_frequency (int): The frequency at which to snapshot the state of the source node. Default is 1.
"""

def __init__(
self,
source: BaseNode[T],
Expand Down Expand Up @@ -49,25 +53,30 @@ def __init__(
)
else:
self._it = _InlineFilterIter(source=self.source, predicate=self.predicate)

def reset(self, initial_state: Optional[Dict[str, Any]] = None) -> None:
"""Resets the filter node to its initial state."""
super().reset(initial_state)
if self._it is not None:
self._it.reset(initial_state)

def next(self) -> T:
"""Returns the next filtered item."""
return next(self._it)

def get_state(self) -> Dict[str, Any]:
"""Returns the current state of the filter node."""
return self._it.get_state()


class _InlineFilterIter(Iterator[T]):
"""
An iterator that filters data samples inline.
Args:
source (BaseNode[T]): The source node providing data samples.
predicate (Callable[[T], bool]): A function that takes a data sample and returns a boolean indicating whether to include it.
"""

SOURCE_KEY = "source"

def __init__(self, source: BaseNode[T], predicate: Callable[[T], bool]) -> None:
Expand All @@ -84,6 +93,7 @@ def reset(self, initial_state: Optional[Dict[str, Any]] = None) -> None:
def __iter__(self) -> Iterator[T]:
"""Returns the iterator object itself."""
return self

def __next__(self) -> T:
"""Returns the next filtered item."""
while True:
Expand All @@ -93,10 +103,12 @@ def __next__(self) -> T:
return item
except StopIteration:
raise

def get_state(self) -> Dict[str, Any]:
"""Returns the current state of the inline filter iterator."""
return {self.SOURCE_KEY: self.source.state_dict()}


class _ParallelFilterIter(Iterator[T]):
"""
An iterator that filters data samples in parallel.
Expand All @@ -110,7 +122,9 @@ class _ParallelFilterIter(Iterator[T]):
max_concurrent (Optional[int]): The maximum number of concurrent tasks.
snapshot_frequency (int): The frequency at which to take snapshots.
"""

MAPPER_KEY = "mapper"

def __init__(
self,
source: BaseNode[T],
Expand Down Expand Up @@ -141,15 +155,18 @@ def __init__(
max_concurrent=self.max_concurrent,
snapshot_frequency=self.snapshot_frequency,
)

def reset(self, initial_state: Optional[Dict[str, Any]] = None) -> None:
"""Resets the parallel filter iterator to its initial state."""
if initial_state:
self.mapper.reset(initial_state[self.MAPPER_KEY])
else:
self.mapper.reset()

def __iter__(self) -> Iterator[T]:
"""Returns the iterator object itself."""
return self

def __next__(self) -> T:
"""Returns the next filtered item."""
while True:
Expand All @@ -161,6 +178,3 @@ def __next__(self) -> T:
def get_state(self) -> Dict[str, Any]:
"""Returns the current state of the parallel filter iterator."""
return {self.MAPPER_KEY: self.mapper.get_state()}
def __del__(self):
# Clean up resources when the iterator is deleted
del self.mapper

0 comments on commit d6e6356

Please sign in to comment.