From a1508c441fbba29d0f120797fb180c25be1629ad Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Wed, 22 Jan 2025 15:30:49 -0800 Subject: [PATCH] run precommit --- test/nodes/run_filter.py | 5 +++-- test/nodes/test_filter.py | 37 ++++++++++++++++++++----------------- torchdata/nodes/filter.py | 18 ++++++++++++++++++ 3 files changed, 41 insertions(+), 19 deletions(-) diff --git a/test/nodes/run_filter.py b/test/nodes/run_filter.py index a41a39101..ba0ea1ff9 100644 --- a/test/nodes/run_filter.py +++ b/test/nodes/run_filter.py @@ -7,13 +7,14 @@ from utils import MockSource, run_test_save_load_state, StatefulRangeNode - a = list(range(60)) base_node = IterableWrapper(a) + def is_even(x): return x % 2 == 0 - + + node = Filter(base_node, is_even, num_workers=2) print(node.get_state()) diff --git a/test/nodes/test_filter.py b/test/nodes/test_filter.py index 0244e7ea2..3c758403d 100644 --- a/test/nodes/test_filter.py +++ b/test/nodes/test_filter.py @@ -14,16 +14,20 @@ from torch.testing._internal.common_utils import IS_WINDOWS, TEST_CUDA, TestCase from torchdata.nodes.base_node import BaseNode -from torchdata.nodes.filter import Filter from torchdata.nodes.batch import Batcher +from torchdata.nodes.filter import Filter +from torchdata.nodes.samplers.multi_node_weighted_sampler import MultiNodeWeightedSampler from .utils import MockSource, run_test_save_load_state, StatefulRangeNode -from torchdata.nodes.samplers.multi_node_weighted_sampler import MultiNodeWeightedSampler + class TestFilter(TestCase): def _test_filter(self, num_workers, in_order, method): n = 100 - predicate = lambda x: x["test_tensor"] % 2 == 0 # Filter even numbers + + def predicate(x): + return x["test_tensor"] % 2 == 0 + src = MockSource(num_samples=n) node = Filter( source=src, @@ -38,9 +42,7 @@ def _test_filter(self, num_workers, in_order, method): results.append(item) expected_results = [ - {"step": i, "test_tensor": torch.tensor([i]), "test_str": f"str_{i}"} - for i in range(n) - if i % 2 == 0 + {"step": i, "test_tensor": torch.tensor([i]), "test_str": f"str_{i}"} for i in range(n) if i % 2 == 0 ] self.assertEqual(results, expected_results) @@ -57,27 +59,28 @@ def test_filter_parallel_process(self): def test_filter_batcher(self, n): src = StatefulRangeNode(n=n) node = Batcher(src, batch_size=2) - predicate = lambda x : (x[0]["i"]+x[1]["i"])%3==0 - node = Filter(node, predicate, num_workers=2) - results = list(node) - self.assertEqual(len(results), n//6) - + def predicate(x): + return (x[0]["i"] + x[1]["i"]) % 3 == 0 + node = Filter(node, predicate, num_workers=2) + results = list(node) + self.assertEqual(len(results), n // 6) @parameterized.expand( itertools.product( - [10, 20 , 40], + [10, 20, 40], [True], # TODO: define and fix in_order = False - [1, 2, 4], + [1, 2, 4], ) ) - def test_save_load_state_thread( - self, midpoint: int, in_order: bool, snapshot_frequency: int - ): + def test_save_load_state_thread(self, midpoint: int, in_order: bool, snapshot_frequency: int): method = "thread" n = 100 - predicate = lambda x: x["i"]%2==0 + + def predicate(x): + return x["i"] % 2 == 0 + src = StatefulRangeNode(n=n) node = Filter( diff --git a/torchdata/nodes/filter.py b/torchdata/nodes/filter.py index 7f48ab8ac..0c5035ea1 100644 --- a/torchdata/nodes/filter.py +++ b/torchdata/nodes/filter.py @@ -1,8 +1,11 @@ from typing import Any, Callable, Dict, Iterator, Literal, Optional, TypeVar + from torchdata.nodes.base_node import BaseNode from torchdata.nodes.map import ParallelMapper + T = TypeVar("T", covariant=True) + class Filter(BaseNode[T]): """ A node that filters data samples based on a given predicate. @@ -16,6 +19,7 @@ class Filter(BaseNode[T]): max_concurrent (Optional[int]): The maximum number of items to process at once. Default is None. snapshot_frequency (int): The frequency at which to snapshot the state of the source node. Default is 1. """ + def __init__( self, source: BaseNode[T], @@ -49,18 +53,22 @@ def __init__( ) else: self._it = _InlineFilterIter(source=self.source, predicate=self.predicate) + def reset(self, initial_state: Optional[Dict[str, Any]] = None) -> None: """Resets the filter node to its initial state.""" super().reset(initial_state) if self._it is not None: self._it.reset(initial_state) + def next(self) -> T: """Returns the next filtered item.""" return next(self._it) + def get_state(self) -> Dict[str, Any]: """Returns the current state of the filter node.""" return self._it.get_state() + class _InlineFilterIter(Iterator[T]): """ An iterator that filters data samples inline. @@ -68,6 +76,7 @@ class _InlineFilterIter(Iterator[T]): source (BaseNode[T]): The source node providing data samples. predicate (Callable[[T], bool]): A function that takes a data sample and returns a boolean indicating whether to include it. """ + SOURCE_KEY = "source" def __init__(self, source: BaseNode[T], predicate: Callable[[T], bool]) -> None: @@ -84,6 +93,7 @@ def reset(self, initial_state: Optional[Dict[str, Any]] = None) -> None: def __iter__(self) -> Iterator[T]: """Returns the iterator object itself.""" return self + def __next__(self) -> T: """Returns the next filtered item.""" while True: @@ -93,10 +103,12 @@ def __next__(self) -> T: return item except StopIteration: raise + def get_state(self) -> Dict[str, Any]: """Returns the current state of the inline filter iterator.""" return {self.SOURCE_KEY: self.source.state_dict()} + class _ParallelFilterIter(Iterator[T]): """ An iterator that filters data samples in parallel. @@ -110,7 +122,9 @@ class _ParallelFilterIter(Iterator[T]): max_concurrent (Optional[int]): The maximum number of concurrent tasks. snapshot_frequency (int): The frequency at which to take snapshots. """ + MAPPER_KEY = "mapper" + def __init__( self, source: BaseNode[T], @@ -141,15 +155,18 @@ def __init__( max_concurrent=self.max_concurrent, snapshot_frequency=self.snapshot_frequency, ) + def reset(self, initial_state: Optional[Dict[str, Any]] = None) -> None: """Resets the parallel filter iterator to its initial state.""" if initial_state: self.mapper.reset(initial_state[self.MAPPER_KEY]) else: self.mapper.reset() + def __iter__(self) -> Iterator[T]: """Returns the iterator object itself.""" return self + def __next__(self) -> T: """Returns the next filtered item.""" while True: @@ -161,6 +178,7 @@ def __next__(self) -> T: def get_state(self) -> Dict[str, Any]: """Returns the current state of the parallel filter iterator.""" return {self.MAPPER_KEY: self.mapper.get_state()} + def __del__(self): # Clean up resources when the iterator is deleted del self.mapper