Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Filter node #1427

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions test/nodes/run_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#This is a local file for testing and will be deleted in the future.
from torchdata.nodes.adapters import IterableWrapper
from torchdata.nodes.batch import Batcher
from torchdata.nodes.filter import Filter
from torchdata.nodes.loader import Loader
from torchdata.nodes.prefetch import Prefetcher
from torchdata.nodes.samplers.stop_criteria import StopCriteria
from utils import MockSource, run_test_save_load_state, StatefulRangeNode


a = list(range(60))
base_node = IterableWrapper(a)


def is_even(x):
return x % 2 == 0


node = Filter(base_node, is_even, num_workers=2)

print(node.get_state())
for _ in range(28):
print(next(node))
print(node.get_state())

state = node.get_state()
node.reset()

print(node.get_state())

for _ in range(2):
print(next(node))

del node
node = Filter(base_node, is_even, num_workers=2)
print("state to be loaded", state)
print("state before reset", node.get_state())
node.reset(state)
print(node.get_state())

for item in node:
print(item)
95 changes: 95 additions & 0 deletions test/nodes/test_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import itertools
import unittest
from typing import List

import torch

from parameterized import parameterized
from torch.testing._internal.common_utils import IS_WINDOWS, TEST_CUDA, TestCase

from torchdata.nodes.base_node import BaseNode
from torchdata.nodes.batch import Batcher
from torchdata.nodes.filter import Filter
from torchdata.nodes.samplers.multi_node_weighted_sampler import MultiNodeWeightedSampler

from .utils import MockSource, run_test_save_load_state, StatefulRangeNode


class TestFilter(TestCase):
def _test_filter(self, num_workers, in_order, method):
n = 100

def predicate(x):
return x["test_tensor"] % 2 == 0

src = MockSource(num_samples=n)
node = Filter(
source=src,
predicate=predicate,
num_workers=num_workers,
in_order=in_order,
method=method,
)

results: List[int] = []
for item in node:
results.append(item)

expected_results = [
{"step": i, "test_tensor": torch.tensor([i]), "test_str": f"str_{i}"} for i in range(n) if i % 2 == 0
]
self.assertEqual(results, expected_results)

def test_filter_inline(self):
self._test_filter(num_workers=0, in_order=True, method="thread")

def test_filter_parallel_threads(self):
self._test_filter(num_workers=4, in_order=True, method="thread")

def test_filter_parallel_process(self):
self._test_filter(num_workers=4, in_order=True, method="process")

@parameterized.expand([100, 200, 300])
def test_filter_batcher(self, n):
src = StatefulRangeNode(n=n)
node = Batcher(src, batch_size=2)

def predicate(x):
return (x[0]["i"] + x[1]["i"]) % 3 == 0

node = Filter(node, predicate, num_workers=2)
results = list(node)
self.assertEqual(len(results), n // 6)

@parameterized.expand(
itertools.product(
[10, 20, 40],
[True], # TODO: define and fix in_order = False
[1, 2, 4],
)
)
def test_save_load_state_thread(self, midpoint: int, in_order: bool, snapshot_frequency: int):
method = "thread"
n = 100

def predicate(x):
return x["i"] % 2 == 0

src = StatefulRangeNode(n=n)

node = Filter(
source=src,
predicate=predicate,
num_workers=1,
in_order=in_order,
method=method,
snapshot_frequency=snapshot_frequency,
)
node.reset()
run_test_save_load_state(self, node, midpoint)
180 changes: 180 additions & 0 deletions torchdata/nodes/filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
from typing import Any, Callable, Dict, Iterator, Literal, Optional, TypeVar

from torchdata.nodes.base_node import BaseNode
from torchdata.nodes.map import ParallelMapper

T = TypeVar("T", covariant=True)


class Filter(BaseNode[T]):
"""
A node that filters data samples based on a given predicate.
Args:
source (BaseNode[T]): The source node providing data samples.
predicate (Callable[[T], bool]): A function that takes a data sample and returns a boolean indicating whether to include it.
num_workers (int): The number of worker processes to use for parallel filtering. Defaults to 0.
in_order (bool): Whether to return items in the order from which they arrive from. Default is True.
method (Literal["thread", "process"]): The method to use for parallel processing. Default is "thread".
multiprocessing_context (Optional[str]): The multiprocessing context to use for parallel processing. Default is None.
max_concurrent (Optional[int]): The maximum number of items to process at once. Default is None.
snapshot_frequency (int): The frequency at which to snapshot the state of the source node. Default is 1.
"""

def __init__(
self,
source: BaseNode[T],
predicate: Callable[[T], bool],
num_workers: int = 0,
in_order: bool = True,
method: Literal["thread", "process"] = "thread",
multiprocessing_context: Optional[str] = None,
max_concurrent: Optional[int] = None,
snapshot_frequency: int = 1,
):
super().__init__()
self.source = source
self.predicate = predicate
self.num_workers = num_workers
self.in_order = in_order
self.method = method
self.multiprocessing_context = multiprocessing_context
self.max_concurrent = max_concurrent
self.snapshot_frequency = snapshot_frequency
if self.num_workers > 0:
self._it = _ParallelFilterIter(
source=self.source,
predicate=self.predicate,
num_workers=self.num_workers,
in_order=self.in_order,
method=self.method,
multiprocessing_context=self.multiprocessing_context,
max_concurrent=self.max_concurrent,
snapshot_frequency=self.snapshot_frequency,
)
else:
self._it = _InlineFilterIter(source=self.source, predicate=self.predicate)

def reset(self, initial_state: Optional[Dict[str, Any]] = None) -> None:
"""Resets the filter node to its initial state."""
super().reset(initial_state)
if self._it is not None:
self._it.reset(initial_state)

def next(self) -> T:
"""Returns the next filtered item."""
return next(self._it)

def get_state(self) -> Dict[str, Any]:
"""Returns the current state of the filter node."""
return self._it.get_state()


class _InlineFilterIter(Iterator[T]):
"""
An iterator that filters data samples inline.
Args:
source (BaseNode[T]): The source node providing data samples.
predicate (Callable[[T], bool]): A function that takes a data sample and returns a boolean indicating whether to include it.
"""

SOURCE_KEY = "source"

def __init__(self, source: BaseNode[T], predicate: Callable[[T], bool]) -> None:
self.source = source
self.predicate = predicate

def reset(self, initial_state: Optional[Dict[str, Any]] = None) -> None:
"""Resets the inline filter iterator to its initial state."""
if initial_state:
self.source.reset(initial_state[self.SOURCE_KEY])
else:
self.source.reset()

def __iter__(self) -> Iterator[T]:
"""Returns the iterator object itself."""
return self

def __next__(self) -> T:
"""Returns the next filtered item."""
while True:
try:
item = next(self.source)
if self.predicate(item):
return item
except StopIteration:
raise

def get_state(self) -> Dict[str, Any]:
"""Returns the current state of the inline filter iterator."""
return {self.SOURCE_KEY: self.source.state_dict()}


class _ParallelFilterIter(Iterator[T]):
"""
An iterator that filters data samples in parallel.
Args:
source (BaseNode[T]): The source node providing data samples.
predicate (Callable[[T], bool]): A function that takes a data sample and returns a boolean indicating whether to include it.
num_workers (int): The number of worker processes to use for parallel filtering.
in_order (bool): Whether to preserve the order of data samples.
method (Literal["thread", "process"]): The method to use for parallelization.
multiprocessing_context (Optional[str]): The multiprocessing context to use.
max_concurrent (Optional[int]): The maximum number of concurrent tasks.
snapshot_frequency (int): The frequency at which to take snapshots.
"""

MAPPER_KEY = "mapper"

def __init__(
self,
source: BaseNode[T],
predicate: Callable[[T], bool],
num_workers: int,
in_order: bool,
method: Literal["thread", "process"],
multiprocessing_context: Optional[str],
max_concurrent: Optional[int],
snapshot_frequency: int,
):
self.source = source
self.predicate = predicate
self.num_workers = num_workers
self.in_order = in_order
self.method = method
self.multiprocessing_context = multiprocessing_context
self.max_concurrent = max_concurrent
self.snapshot_frequency = snapshot_frequency
# Create a ParallelMapper to filter items in parallel
self.mapper = ParallelMapper(
source=self.source,
map_fn=lambda x: (x, self.predicate(x)),
num_workers=self.num_workers,
in_order=self.in_order,
method=self.method,
multiprocessing_context=self.multiprocessing_context,
max_concurrent=self.max_concurrent,
snapshot_frequency=self.snapshot_frequency,
)

def reset(self, initial_state: Optional[Dict[str, Any]] = None) -> None:
"""Resets the parallel filter iterator to its initial state."""
if initial_state:
self.mapper.reset(initial_state[self.MAPPER_KEY])
else:
self.mapper.reset()

def __iter__(self) -> Iterator[T]:
"""Returns the iterator object itself."""
return self

def __next__(self) -> T:
"""Returns the next filtered item."""
while True:

item, passed_predicate = next(self.mapper)
if passed_predicate:
return item

def get_state(self) -> Dict[str, Any]:
"""Returns the current state of the parallel filter iterator."""
return {self.MAPPER_KEY: self.mapper.get_state()}
Loading