From 4c2507f2e9d4b9eb06669048242ebc5f4aa2703c Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Fri, 8 Nov 2024 11:34:44 +0000 Subject: [PATCH] 2024-11-08 nightly release (1fcd5e1d022c3967634b30a3afb45ea7b661a76a) --- packaging/pre_build_script_linux.sh | 2 +- test/nodes/test_adapters.py | 69 +++++++--- test/nodes/test_base_node.py | 31 +++++ test/nodes/test_batch.py | 16 ++- test/nodes/test_map.py | 63 ++++++++- test/nodes/test_pin_memory.py | 22 +++- test/nodes/test_prefetch.py | 18 ++- test/nodes/test_snapshot_store.py | 53 ++++++++ test/nodes/utils.py | 55 +++++++- test/requirements.txt | 2 +- torchdata/nodes/__init__.py | 8 +- torchdata/nodes/_apply_udf.py | 4 +- torchdata/nodes/_populate_queue.py | 61 +++++---- torchdata/nodes/adapters.py | 80 ++++++------ torchdata/nodes/base_node.py | 93 ++++++++++++-- torchdata/nodes/batch.py | 14 +- torchdata/nodes/map.py | 190 ++++++++++++++++++++++------ torchdata/nodes/pin_memory.py | 55 +++++++- torchdata/nodes/prefetch.py | 25 +++- torchdata/nodes/snapshot_store.py | 50 ++++++++ torchdata/nodes/types.py | 17 +++ 21 files changed, 754 insertions(+), 174 deletions(-) create mode 100644 test/nodes/test_base_node.py create mode 100644 test/nodes/test_snapshot_store.py create mode 100644 torchdata/nodes/snapshot_store.py create mode 100644 torchdata/nodes/types.py diff --git a/packaging/pre_build_script_linux.sh b/packaging/pre_build_script_linux.sh index 2a74a5642..aeecf07ce 100644 --- a/packaging/pre_build_script_linux.sh +++ b/packaging/pre_build_script_linux.sh @@ -2,7 +2,7 @@ set -ex source packaging/manylinux/python_helper.sh -yum -y install ninja-build zlib-static +yum -y install ninja-build zlib # Docker path is /__w by default export WORKSPACE="/__w" # Install static OpenSSL/libcrypto library diff --git a/test/nodes/test_adapters.py b/test/nodes/test_adapters.py index 31139cb59..83e2e6c2d 100644 --- a/test/nodes/test_adapters.py +++ b/test/nodes/test_adapters.py @@ -4,14 +4,40 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import testslide -from torch.utils.data import IterableDataset, RandomSampler -from torchdata.nodes.adapters import IterableWrapper, MapStyleWrapper, ToIterableDataset +from typing import Any, Dict, Iterator -from .utils import DummyIterableDataset, DummyMapDataset, MockSource +from parameterized import parameterized +from torch.testing._internal.common_utils import TestCase +from torch.utils.data import RandomSampler +from torchdata.nodes.adapters import IterableWrapper, MapStyleWrapper -class TestIterableWrapper(testslide.TestCase): +from torchdata.nodes.types import Stateful + +from .utils import DummyIterableDataset, DummyMapDataset, run_test_save_load_state + + +class _StatefulRange(Stateful): + def __init__(self, n: int) -> None: + self.n = n + self._num_yielded = 0 + self._next_start = 0 + + def __iter__(self) -> Iterator[int]: + self._num_yielded = self._next_start # Reset for next iter call + self._next_start = 0 + for i in range(self._num_yielded, self.n): + self._num_yielded += 1 + yield i + + def state_dict(self) -> Dict[str, Any]: + return {"_num_yielded": self._num_yielded} + + def load_state_dict(self, state_dict: Dict[str, Any]): + self._next_start = state_dict["_num_yielded"] + + +class TestIterableWrapper(TestCase): def test_iterable(self): n = 20 node = IterableWrapper(range(n)) @@ -44,11 +70,19 @@ def test_iterable_dataset(self): self.assertEqual(row["test_tensor"].item(), i) self.assertEqual(row["test_str"], f"str_{i}") + @parameterized.expand([0, 5]) + def test_save_load_state_fast_forward(self, midpoint: int): + run_test_save_load_state(self, IterableWrapper(range(10)), midpoint) + + @parameterized.expand([0, 5]) + def test_save_load_state_stateful(self, midpoint: int): + run_test_save_load_state(self, IterableWrapper(_StatefulRange(10)), midpoint) -class TestMapStyle(testslide.TestCase): + +class TestMapStyle(TestCase): def test_default_sampler(self): n = 20 - node = MapStyleWrapper(DummyMapDataset(n)) + node = MapStyleWrapper(DummyMapDataset(n), sampler=range(n)) for epoch in range(2): result = list(node) self.assertEqual(len(result), n) @@ -89,17 +123,14 @@ def test_dict(self): self.assertEqual(row["test_tensor"].item(), i) self.assertEqual(row["test_str"], f"str_{i}") + @parameterized.expand([0, 7]) + def test_save_load_state_fast_forward(self, midpoint: int): + n = 20 + node = MapStyleWrapper(DummyMapDataset(n), sampler=range(n)) + run_test_save_load_state(self, node, midpoint) -class TestToIterableDataset(testslide.TestCase): - def test_to_iterable_dataset(self): + @parameterized.expand([0, 7]) + def test_save_load_state_stateful(self, midpoint: int): n = 20 - node = MockSource(n) - iterable_ds = ToIterableDataset(node) - self.assertIsInstance(iterable_ds, IterableDataset) - for epoch in range(2): - result = list(iterable_ds) - self.assertEqual(len(result), n) - for i, row in enumerate(result): - self.assertEqual(row["step"], i) - self.assertEqual(row["test_tensor"].item(), i) - self.assertEqual(row["test_str"], f"str_{i}") + node = MapStyleWrapper(DummyMapDataset(n), sampler=_StatefulRange(n)) + run_test_save_load_state(self, node, midpoint) diff --git a/test/nodes/test_base_node.py b/test/nodes/test_base_node.py new file mode 100644 index 000000000..6c5f8761a --- /dev/null +++ b/test/nodes/test_base_node.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torch.testing._internal.common_utils import TestCase +from torchdata.nodes.adapters import IterableWrapper +from torchdata.nodes.base_node import BaseNodeIterator + +from .utils import run_test_save_load_state + + +class TestBaseNode(TestCase): + def test_started_finished(self) -> None: + x = IterableWrapper(range(10)) + for _ in range(3): # test multi-epoch + it = iter(x) + self.assertIsInstance(it, BaseNodeIterator) + self.assertFalse(it.started()) + self.assertFalse(it.finished()) + + for _ in it: + self.assertTrue(it.started()) + self.assertFalse(it.finished()) + + self.assertTrue(it.started()) + self.assertTrue(it.finished()) + + def test_save_load_state(self): + run_test_save_load_state(self, IterableWrapper(range(10)), 5) diff --git a/test/nodes/test_batch.py b/test/nodes/test_batch.py index 1d6c0ade8..1e2dd7fa8 100644 --- a/test/nodes/test_batch.py +++ b/test/nodes/test_batch.py @@ -4,14 +4,17 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import testslide +import itertools + import torch +from parameterized import parameterized +from torch.testing._internal.common_utils import TestCase from torchdata.nodes.batch import Batcher -from .utils import MockSource +from .utils import MockSource, run_test_save_load_state -class TestBatcher(testslide.TestCase): +class TestBatcher(TestCase): def test_batcher(self) -> None: batch_size = 6 src = MockSource(num_samples=20) @@ -38,3 +41,10 @@ def test_batcher_drop_last_false(self) -> None: self.assertEqual(results[i][j]["step"], i * batch_size + j) self.assertEqual(results[i][j]["test_tensor"], torch.tensor([i * batch_size + j])) self.assertEqual(results[i][j]["test_str"], f"str_{i * batch_size + j}") + + @parameterized.expand(itertools.product([0, 2], [True, False])) + def test_save_load_state_fast_forward(self, midpoint: int, drop_last: bool): + batch_size = 6 + src = MockSource(num_samples=20) + node = Batcher(src, batch_size=batch_size, drop_last=drop_last) + run_test_save_load_state(self, node, midpoint) diff --git a/test/nodes/test_map.py b/test/nodes/test_map.py index 4e48a415b..a45c9b62b 100644 --- a/test/nodes/test_map.py +++ b/test/nodes/test_map.py @@ -4,20 +4,23 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import itertools + import unittest from typing import List -import testslide -from torch.testing._internal.common_utils import IS_WINDOWS, TEST_CUDA +from parameterized import parameterized +from torch.testing._internal.common_utils import IS_WINDOWS, TEST_CUDA, TestCase from torchdata.nodes.batch import Batcher + from torchdata.nodes.map import Mapper, ParallelMapper from torchdata.nodes.pin_memory import PinMemory from torchdata.nodes.prefetch import Prefetcher -from .utils import MockSource, RandomSleepUdf, udf_raises +from .utils import MockSource, RandomSleepUdf, run_test_save_load_state, udf_raises -class TestMap(testslide.TestCase): +class TestMap(TestCase): def _test_exception_handling_mapper(self, pin_memory, method): batch_size = 6 multiprocessing_context = None if IS_WINDOWS else "forkserver" @@ -104,3 +107,55 @@ def test_in_order_process(self): def test_out_of_order_process(self): self._test_map(False, "process") + + @parameterized.expand( + itertools.product( + [0, 7, 13], + [True], # TODO: define and fix in_order = False + [0, 1, 9], # TODO: define and fix in_order = False + ) + ) + def test_save_load_state_thread(self, midpoint: int, in_order: bool, snapshot_frequency: int): + method = "thread" + batch_size = 6 + n = 80 + multiprocessing_context = None if IS_WINDOWS else "forkserver" + src = MockSource(num_samples=n) + node = Batcher(src, batch_size=batch_size, drop_last=False) + node = ParallelMapper( + node, + RandomSleepUdf(), + num_workers=4, + in_order=in_order, + method=method, + multiprocessing_context=multiprocessing_context, + snapshot_frequency=snapshot_frequency, + ) + node = Prefetcher(node, prefetch_factor=2) + run_test_save_load_state(self, node, midpoint) + + @parameterized.expand( + itertools.product( + [0, 7, 13], + [True], # TODO: define and fix in_order = False + [0, 1, 9], # TODO: define and fix in_order = False + ) + ) + def test_save_load_state_process(self, midpoint: int, in_order: bool, snapshot_frequency: int): + method = "process" + batch_size = 6 + n = 80 + multiprocessing_context = None if IS_WINDOWS else "forkserver" + src = MockSource(num_samples=n) + node = Batcher(src, batch_size=batch_size, drop_last=False) + node = ParallelMapper( + node, + RandomSleepUdf(), + num_workers=4, + in_order=in_order, + method=method, + multiprocessing_context=multiprocessing_context, + snapshot_frequency=snapshot_frequency, + ) + node = Prefetcher(node, prefetch_factor=2) + run_test_save_load_state(self, node, midpoint) diff --git a/test/nodes/test_pin_memory.py b/test/nodes/test_pin_memory.py index f92ae1769..3931786a8 100644 --- a/test/nodes/test_pin_memory.py +++ b/test/nodes/test_pin_memory.py @@ -4,23 +4,25 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import itertools import unittest -import testslide import torch -from torch.testing._internal.common_utils import TEST_CUDA +from parameterized import parameterized + +from torch.testing._internal.common_utils import TEST_CUDA, TestCase from torchdata.nodes.batch import Batcher from torchdata.nodes.map import Mapper from torchdata.nodes.pin_memory import PinMemory from torchdata.nodes.prefetch import Prefetcher -from .utils import Collate, IterInitError, MockSource +from .utils import Collate, IterInitError, MockSource, run_test_save_load_state @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") -class TestPinMemory(testslide.TestCase): +class TestPinMemory(TestCase): def test_pin_memory(self) -> None: batch_size = 6 src = MockSource(num_samples=20) @@ -62,3 +64,15 @@ def test_iter_init_error(self): with self.assertRaisesRegex(ValueError, "Iter Init Error"): list(root) + + @parameterized.expand(itertools.product([0, 7, 33], [0, 1, 9])) + def test_save_load_state_stateful(self, midpoint: int, snapshot_frequency: int): + batch_size = 6 + n = 200 + node = MockSource(num_samples=n) + node = Batcher(node, batch_size=batch_size, drop_last=False) + node = Mapper(node, Collate()) + node = PinMemory(node, snapshot_frequency=snapshot_frequency) + node = Prefetcher(node, prefetch_factor=8) + + run_test_save_load_state(self, node, midpoint) diff --git a/test/nodes/test_prefetch.py b/test/nodes/test_prefetch.py index 82d7d9545..2e9820c67 100644 --- a/test/nodes/test_prefetch.py +++ b/test/nodes/test_prefetch.py @@ -4,15 +4,18 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import testslide +import itertools + import torch +from parameterized import parameterized +from torch.testing._internal.common_utils import TestCase from torchdata.nodes.batch import Batcher from torchdata.nodes.prefetch import Prefetcher -from .utils import IterInitError, MockSource +from .utils import IterInitError, MockSource, run_test_save_load_state -class TestPrefetcher(testslide.TestCase): +class TestPrefetcher(TestCase): def test_prefetcher(self) -> None: batch_size = 6 src = MockSource(num_samples=20) @@ -35,3 +38,12 @@ def test_iter_init_error(self): with self.assertRaisesRegex(ValueError, "Iter Init Error"): list(root) + + @parameterized.expand(itertools.product([0, 7, 32], [0, 1, 9])) + def test_save_load_state_stateful(self, midpoint: int, snapshot_frequency: int): + batch_size = 6 + n = 200 + src = MockSource(num_samples=n) + node = Batcher(src, batch_size=batch_size, drop_last=False) + node = Prefetcher(node, prefetch_factor=8, snapshot_frequency=snapshot_frequency) + run_test_save_load_state(self, node, midpoint) diff --git a/test/nodes/test_snapshot_store.py b/test/nodes/test_snapshot_store.py new file mode 100644 index 000000000..2e82f56ac --- /dev/null +++ b/test/nodes/test_snapshot_store.py @@ -0,0 +1,53 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torch.testing._internal.common_utils import TestCase +from torchdata.nodes.adapters import IterableWrapper +from torchdata.nodes.base_node import BaseNodeIterator +from torchdata.nodes.snapshot_store import DequeSnapshotStore + +from .utils import run_test_save_load_state + + +class TestDequeSnapshotStore(TestCase): + def test_snapshot_store(self) -> None: + store = DequeSnapshotStore() + store.append({"a": 1}, 0) + store.append({"a": 2}, 10) + + self.assertEqual(len(store._deque), 2) + + val = store.pop_version(0) + self.assertEqual(val, {"a": 1}) + self.assertEqual(len(store._deque), 1) + val = store.pop_version(1) + self.assertIsNone(val) + self.assertEqual(len(store._deque), 1) + val = store.pop_version(7) + self.assertIsNone(val) + self.assertEqual(len(store._deque), 1) + val = store.pop_version(10) + self.assertEqual(val, {"a": 2}) + self.assertEqual(len(store._deque), 0) + + val = store.pop_version(11) + self.assertIsNone(val) + self.assertEqual(len(store._deque), 0) + + with self.assertRaisesRegex(ValueError, "is not strictly greater than"): + store.append({"a": 3}, 3) + + self.assertEqual(len(store._deque), 0) + + with self.assertRaisesRegex(ValueError, "is not strictly greater than"): + store.append({"a": 4}, 10) + self.assertEqual(len(store._deque), 0) + + store.append({"a": 4}, 11) + store.append({"a": 5}, 19) + val = store.pop_version(19) + self.assertEqual(val, {"a": 5}) + self.assertEqual(len(store._deque), 0) diff --git a/test/nodes/utils.py b/test/nodes/utils.py index bfaf18327..cfe70206d 100644 --- a/test/nodes/utils.py +++ b/test/nodes/utils.py @@ -6,21 +6,26 @@ import random import time -from typing import Iterator +from typing import Any, Dict, Iterator, Optional import torch -from torchdata.nodes import BaseNode +from torchdata.nodes.adapters import IterableWrapper +from torchdata.nodes.base_node import BaseNode -class MockSource(BaseNode[dict]): +class MockGenerator: def __init__(self, num_samples: int) -> None: self.num_samples = num_samples - def iterator(self) -> Iterator[dict]: + def __iter__(self): for i in range(self.num_samples): yield {"step": i, "test_tensor": torch.tensor([i]), "test_str": f"str_{i}"} +def MockSource(num_samples: int) -> BaseNode[dict]: + return IterableWrapper(MockGenerator(num_samples)) + + def udf_raises(item): raise ValueError("test exception") @@ -46,9 +51,12 @@ class IterInitError(BaseNode[int]): def __init__(self, msg: str = "Iter Init Error") -> None: self.msg = msg - def iterator(self) -> Iterator[int]: + def iterator(self, initial_state: Optional[Dict[str, Any]]) -> Iterator[int]: raise ValueError(self.msg) + def get_state(self) -> Dict[str, Any]: + return {} + class DummyIterableDataset(torch.utils.data.IterableDataset): def __init__(self, num_samples: int) -> None: @@ -68,3 +76,40 @@ def __len__(self) -> int: def __getitem__(self, i: int) -> dict: return {"step": i, "test_tensor": torch.tensor([i]), "test_str": f"str_{i}"} + + +def run_test_save_load_state(test, x: BaseNode, midpoint: int): + # Test before iter call + initial_state_dict = x.state_dict() + it = iter(x) + results = [] + for _ in range(midpoint): + results.append(next(it)) + state_dict = x.state_dict() + for val in it: + results.append(val) + + state_dict_0_end = x.state_dict() + + # store epoch 1's results + results_1 = list(x) + + x.load_state_dict(state_dict) + results_after = list(x) + test.assertEqual(results_after, results[midpoint:]) + + # Test for second epoch after resume + results_after_1 = list(x) + test.assertEqual(results_after_1, results_1) + + # Test initialize from beginning after resume + x.load_state_dict(initial_state_dict) + full_results = list(x) + test.assertEqual(full_results, results) + full_results_1 = list(x) + test.assertEqual(full_results_1, results_1) + + # Test restoring from end of epoch 0 + x.load_state_dict(state_dict_0_end) + results_after_dict_0 = list(x) + test.assertEqual(results_after_dict_0, results_1) diff --git a/test/requirements.txt b/test/requirements.txt index 904437976..bbd7270d9 100644 --- a/test/requirements.txt +++ b/test/requirements.txt @@ -1,5 +1,4 @@ pytest -testslide expecttest fsspec numpy<2 @@ -8,3 +7,4 @@ graphviz adlfs awscli>=1.27.66 psutil +parameterized diff --git a/torchdata/nodes/__init__.py b/torchdata/nodes/__init__.py index 3e7a42535..a1c376c96 100644 --- a/torchdata/nodes/__init__.py +++ b/torchdata/nodes/__init__.py @@ -10,14 +10,20 @@ from .map import Mapper, ParallelMapper from .pin_memory import PinMemory from .prefetch import Prefetcher +from .types import Stateful __all__ = [ "BaseNode", "Batcher", + "IterableWrapper", + "MapStyleWrapper", "Mapper", - "Prefetcher", "ParallelMapper", "PinMemory", + "Prefetcher", + "Stateful", "T", ] + +assert sorted(__all__) == __all__ diff --git a/torchdata/nodes/_apply_udf.py b/torchdata/nodes/_apply_udf.py index 019ed7b44..ae272b4db 100644 --- a/torchdata/nodes/_apply_udf.py +++ b/torchdata/nodes/_apply_udf.py @@ -41,9 +41,9 @@ def _apply_udf( continue if isinstance(item, ExceptionWrapper): - out_q.put((item, idx)) + out_q.put((item, idx), block=False) elif isinstance(item, StopIteration): - out_q.put((item, idx)) + out_q.put((item, idx), block=False) else: try: y = udf(item) diff --git a/torchdata/nodes/_populate_queue.py b/torchdata/nodes/_populate_queue.py index c7e288c7c..a3329baf9 100644 --- a/torchdata/nodes/_populate_queue.py +++ b/torchdata/nodes/_populate_queue.py @@ -6,40 +6,32 @@ import queue import threading -from dataclasses import dataclass -from typing import Iterable - -from torchdata.nodes.exception_wrapper import ExceptionWrapper, StartupExceptionWrapper - -from .constants import QUEUE_TIMEOUT +from typing import Any, Dict, Optional, Union +import torch.multiprocessing as mp -@dataclass -class _MonotonicIndex: - initial: int = 0 +from torchdata.nodes.base_node import BaseNode - def __post_init__(self): - self._idx = self.initial +from torchdata.nodes.exception_wrapper import ExceptionWrapper, StartupExceptionWrapper +from torchdata.nodes.snapshot_store import MonotonicIndex, SnapshotStore - def get(self) -> int: - idx = self._idx - self._idx += 1 - return idx +from .constants import QUEUE_TIMEOUT def _populate_queue( - source: Iterable, - q: queue.Queue, + source: BaseNode, + q: Union[queue.Queue, mp.Queue], + snapshot_store: SnapshotStore, + snapshot_frequency: int, semaphore: threading.BoundedSemaphore, stop_event: threading.Event, - add_index: bool = False, ): """_populate_queue calls `iter(source)` to get an iterator `it`, waits for semaphore.acquire, and puts its outputs onto q. It never releases the sempahore. It continues to put items on the q as long as it can acquire the sempahore, stop_event is not set, and StopIteration has not been thrown by the `it`. - If add_index = True, this function will always put tuples of (x, idx) on the q where idx + This function will always put tuples of (x, idx) on the q where idx starts from 0 and is monotonically increasing. x may be the output of next(it), StopIteration, or an ExceptionWrapper. @@ -52,32 +44,39 @@ def _populate_queue( """ # Include a monotonic index starting from 0 to each item in the queue - idx = _MonotonicIndex() + idx = MonotonicIndex() - def _put(item, block: bool = True): - if add_index: - q.put((item, idx.get()), block=block) - else: - q.put(item, block=block) + def _put(item, block: bool = True, snapshot: Optional[Dict[str, Any]] = None): + _idx = idx.get() + if snapshot: + snapshot_store.append(snapshot=snapshot, version=_idx) + q.put((item, _idx), block=block, timeout=1.0 if block else None) try: + assert ( + isinstance(snapshot_frequency, int) and snapshot_frequency >= 0 + ), f"snapshot_frequency must be non-negative integer! Got {snapshot_frequency}" src_iter = iter(source) except Exception: e = StartupExceptionWrapper(where="in _populate_queue startup for device") - _put(e) + _put(e, block=False) return + yielded = 0 while not stop_event.is_set(): if not semaphore.acquire(blocking=True, timeout=QUEUE_TIMEOUT): continue try: item = next(src_iter) # FIXME: This may hang! + yielded += 1 + snapshot = None + if snapshot_frequency > 0 and yielded % snapshot_frequency == 0: + snapshot = source.state_dict() + _put(item, block=False, snapshot=snapshot) except StopIteration as e: - _put(e) + _put(e, block=False) break except Exception: item = ExceptionWrapper(where="in _populate_queue") - try: - _put(item, block=False) # Semaphore should prevent this from throwing - except queue.Full: - raise RuntimeError("Queue should not be full") + _put(item, block=False) + break diff --git a/torchdata/nodes/adapters.py b/torchdata/nodes/adapters.py index e07e38192..05f640ffc 100644 --- a/torchdata/nodes/adapters.py +++ b/torchdata/nodes/adapters.py @@ -5,12 +5,14 @@ # LICENSE file in the root directory of this source tree. -from typing import Generic, Iterable, Iterator, Mapping, Optional, Sized, TypeVar - -from torch.utils.data import IterableDataset, Sampler, SequentialSampler +from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, TypeVar from torchdata.nodes.base_node import BaseNode, T +from .map import Mapper + +from .types import Stateful + K = TypeVar("K", covariant=True) @@ -18,47 +20,53 @@ class IterableWrapper(BaseNode[T]): """Thin Wrapper that converts any Iterable (including torch.utils.data.IterableDataset) in to a BaseNode. + If iterable implements the Stateful Protocol, it will be saved and restored with its + state_dict/load_state_dict methods. + + If the iterator resulting from iter(iterable) is Stateful it is IGNORED. + :param iterable: Iterable to wrap. IterableWrapper calls iter() on it. """ - iterable: Iterable[T] + NUM_YIELDED_KEY = "_num_yielded" + ITERABLE_KEY = "iterable" def __init__(self, iterable: Iterable[T]): self.iterable = iterable - - def iterator(self) -> Iterator[T]: - return iter(self.iterable) - - -class MapStyleWrapper(BaseNode[T], Generic[K, T]): - """Thin Wrapper that converts any Mapping[K, T] into a BaseNode[T]. - If no sampler is provided, a SequentialSampler is used and requires dataset to be Sized. - - Note that if your map_style lookup is expensive, you might want - to use __to_be_named_dataloader_drop_in__ instead which can take advantage - of process- or thread-based parallelism. - """ - - dataset: Mapping[K, T] - sampler: Sampler[K] - - def __init__(self, dataset: Mapping[K, T], sampler: Optional[Sampler[K]] = None): - self.dataset = dataset - if sampler is None: - if not isinstance(self.dataset, Sized): - raise ValueError("If dataset does not implement __len__, you must pass a sampler!") - self.sampler = SequentialSampler(self.dataset) # type: ignore + self._num_yielded = 0 + + def iterator(self, initial_state: Optional[Dict[str, Any]]) -> Iterator[T]: + if initial_state is not None: + self._num_yielded = initial_state[self.NUM_YIELDED_KEY] + if isinstance(self.iterable, Stateful): + self.iterable.load_state_dict(initial_state[self.ITERABLE_KEY]) + it = iter(self.iterable) + else: + it = iter(self.iterable) + # Naively fast-forwarding + for _ in range(self._num_yielded): + next(it) else: - self.sampler = sampler + it = iter(self.iterable) + + for item in it: + self._num_yielded += 1 + yield item - def iterator(self) -> Iterator[T]: - for key in self.sampler: - yield self.dataset[key] + def get_state(self) -> Dict[str, Any]: + state_dict: Dict[str, Any] = {self.NUM_YIELDED_KEY: self._num_yielded} + if isinstance(self.iterable, Stateful): + state_dict[self.ITERABLE_KEY] = self.iterable.state_dict() + return state_dict -class ToIterableDataset(IterableDataset[T]): - def __init__(self, base_node: BaseNode[T]): - self.base_node = base_node +def MapStyleWrapper(map_dataset: Mapping[K, T], sampler: Iterable[K]) -> BaseNode[T]: + """Thin Wrapper that converts any MapDataset in to a torchdata.node + If you want parallelism, copy this and replace Mapper with ParallelMapper. - def __iter__(self) -> Iterator[T]: - return iter(self.base_node) + :param map_dataset: Mapping to wrap. + :param sampler: Optional[Iterable]. + """ + sampler_node = IterableWrapper(sampler) + mapper_node = Mapper(sampler_node, map_dataset.__getitem__) + return mapper_node diff --git a/torchdata/nodes/base_node.py b/torchdata/nodes/base_node.py index e72825aa9..96b4f683b 100644 --- a/torchdata/nodes/base_node.py +++ b/torchdata/nodes/base_node.py @@ -4,33 +4,106 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Iterable, Iterator, TypeVar +import logging +from typing import Any, Dict, Iterable, Iterator, Optional, TypeVar + +logger = logging.getLogger(__name__) T = TypeVar("T", covariant=True) +class BaseNodeIterator(Iterator[T]): + def state_dict(self) -> Dict[str, Any]: + raise NotImplementedError() + + def started(self) -> bool: + raise NotImplementedError() + + def finished(self) -> bool: + raise NotImplementedError() + + class BaseNode(Iterable[T]): - def iterator(self) -> Iterator[T]: + __it: Optional[BaseNodeIterator[T]] = None # holds pointer to last iter() requested + __initial_state: Optional[Dict[str, Any]] = None + + def iterator(self, initial_state: Optional[dict]) -> Iterator[T]: """Override this method to implement the iterator. Iterators are expected to raise StopIteration to signal end of iteration, so they can be used in for loops. Generators just need to return, as usual. + + initial_state will be passed if `load_state_dict(initial_state)` was called + on this node before the __iter__ is requested, otherwise None will be passed """ raise NotImplementedError() - def __iter__(self) -> "_EagerIter[T]": - return _EagerIter(self) + def get_state(self) -> Dict[str, Any]: + """Return a dictionary that can be passed to iterator(...) which + can be used to initialize iterator at a certain state. + """ + raise NotImplementedError() + + def __iter__(self) -> BaseNodeIterator[T]: + if self.__it is not None and not self.__it.started(): + # Only create a new iter if the last requested one did not start + return self.__it + + if self.__initial_state is not None: + self.__it = _EagerIter(self, self.__initial_state) + self.__initial_state = None + if not self.__it.has_next(): + self.__it = _EagerIter(self, self.__initial_state) + else: + self.__it = _EagerIter(self, self.__initial_state) + return self.__it + + def state_dict(self) -> Dict[str, Any]: + return self.get_state() + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self.__initial_state = state_dict -class _EagerIter(Iterator[T]): + +class _EagerIter(BaseNodeIterator[T]): """ Basic iterator which will runs next-calls eagerly """ - def __init__(self, parent: BaseNode[T]): - self.parent = parent - self.it = self.parent.iterator() + def __init__(self, base_node: BaseNode[T], initial_state: Optional[Dict[str, Any]]): + self.base_node = base_node + self._started = False + self._finished = False + if initial_state is not None: + self._it = self.base_node.iterator(initial_state) + else: + self._it = self.base_node.iterator(None) + + self._next_val: Optional[T] = None + + def __next__(self) -> T: + self._started = True + if self._next_val is not None: + val = self._next_val + self._next_val = None + return val + try: + return next(self._it) + except StopIteration: + self._finished = True + raise + + def has_next(self) -> bool: + if self._next_val is None: + try: + self._next_val = next(self._it) + except StopIteration: + pass + return self._next_val is not None + + def started(self) -> bool: + return self._started - def __next__(self): - return next(self.it) + def finished(self) -> bool: + return self._finished diff --git a/torchdata/nodes/batch.py b/torchdata/nodes/batch.py index 482368135..e4cbbd924 100644 --- a/torchdata/nodes/batch.py +++ b/torchdata/nodes/batch.py @@ -4,18 +4,23 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Iterator, List +from typing import Any, Dict, Iterator, List, Optional -from torchdata.nodes import BaseNode, T +from torchdata.nodes.base_node import BaseNode, T class Batcher(BaseNode[List[T]]): + SOURCE_KEY = "source" + def __init__(self, source: BaseNode[T], batch_size: int, drop_last: bool = True): self.source = source self.batch_size = batch_size self.drop_last = drop_last - def iterator(self) -> Iterator[List[T]]: + def iterator(self, initial_state: Optional[Dict[str, Any]]) -> Iterator[List[T]]: + if initial_state is not None: + self.source.load_state_dict(initial_state[self.SOURCE_KEY]) + batch = [] for item in self.source: batch.append(item) @@ -25,3 +30,6 @@ def iterator(self) -> Iterator[List[T]]: if len(batch) and not self.drop_last: yield batch + + def get_state(self) -> Dict[str, Any]: + return {self.SOURCE_KEY: self.source.state_dict()} diff --git a/torchdata/nodes/map.py b/torchdata/nodes/map.py index 8e83e90b8..4751dbe6b 100644 --- a/torchdata/nodes/map.py +++ b/torchdata/nodes/map.py @@ -9,8 +9,9 @@ from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Protocol, TypeVar, Union import torch.multiprocessing as mp -from torchdata.nodes import BaseNode, T +from torchdata.nodes.base_node import BaseNode, T from torchdata.nodes.exception_wrapper import ExceptionWrapper, StartupExceptionWrapper +from torchdata.nodes.snapshot_store import DequeSnapshotStore, SnapshotStore from ._apply_udf import _apply_udf @@ -35,6 +36,8 @@ def Queue(self, *args, **kwargs): class Mapper(BaseNode[T]): + SOURCE_KEY = "source" + def __init__( self, source: BaseNode[X], @@ -43,10 +46,15 @@ def __init__( self.source = source self.map_fn = map_fn - def iterator(self) -> Iterator[T]: + def iterator(self, initial_state: Optional[Dict[str, Any]]) -> Iterator[T]: + if initial_state is not None: + self.source.load_state_dict(initial_state[self.SOURCE_KEY]) for item in self.source: yield self.map_fn(item) + def get_state(self) -> Dict[str, Any]: + return {self.SOURCE_KEY: self.source.state_dict()} + def _sort_worker(in_q: Union[queue.Queue, mp.Queue], out_q: queue.Queue, stop_event: threading.Event): buffer: Dict[int, Any] = {} @@ -94,6 +102,8 @@ def __init__( method: Literal["thread", "process"], mp_context: _MultiprocessContext, max_concurrent: Optional[int], + snapshot_frequency: int, + initial_state: Optional[Dict[str, Any]], ): self.source = source self.map_fn = map_fn @@ -101,6 +111,7 @@ def __init__( self.in_order = in_order self.method = method self.mp_context = mp_context + self.snapshot_frequency = snapshot_frequency self._in_q: Union[queue.Queue, mp.Queue] = queue.Queue() if method == "thread" else mp_context.Queue() self._intermed_q: Union[queue.Queue, mp.Queue] = queue.Queue() if method == "thread" else mp_context.Queue() @@ -112,12 +123,29 @@ def __init__( self._stop = threading.Event() self._mp_stop = mp_context.Event() + self._steps_since_snapshot = 0 + fast_forward = 0 + if initial_state is not None: + self._snapshot = initial_state["snapshot"] + fast_forward = initial_state["steps_since_snapshot"] + self.source.load_state_dict(self._snapshot) + else: + self._snapshot = self.source.state_dict() + self._snapshot_store = DequeSnapshotStore() + self._read_thread = threading.Thread( target=_populate_queue, - args=(self.source, self._in_q, self._sem, self._stop, True), + args=( + self.source, + self._in_q, + self._snapshot_store, + self.snapshot_frequency, + self._sem, + self._stop, + ), daemon=True, ) - self._map_threads: List[Union[threading.Thread, mp.Process]] = [] + self._workers: List[Union[threading.Thread, mp.Process]] = [] for worker_id in range(self.num_workers): args = ( worker_id, @@ -126,10 +154,10 @@ def __init__( self.map_fn, self._stop if self.method == "thread" else self._mp_stop, ) - self._map_threads.append( - threading.Thread(target=_apply_udf, args=args) + self._workers.append( + threading.Thread(target=_apply_udf, args=args, daemon=True) if self.method == "thread" - else mp_context.Process(target=_apply_udf, args=args) + else mp_context.Process(target=_apply_udf, args=args, daemon=True) ) self._sort_q: queue.Queue = queue.Queue() self._sort_thread = threading.Thread( @@ -143,11 +171,14 @@ def __init__( self._out_q = self._sort_q self._read_thread.start() - for t in self._map_threads: + for t in self._workers: t.start() if self.in_order: self._sort_thread.start() + for _ in range(fast_forward): + next(self) + def __iter__(self) -> Iterator[T]: return self @@ -174,9 +205,22 @@ def __next__(self) -> T: if not isinstance(item, StartupExceptionWrapper): self._sem.release() item.reraise() - else: - self._sem.release() - return item + + self._steps_since_snapshot += 1 + self._sem.release() + self._maybe_update_snapshot(idx) + return item + + def get_state(self) -> Dict[str, Any]: + return { + "snapshot": self._snapshot, + "steps_since_snapshot": self._steps_since_snapshot, + } + + def _maybe_update_snapshot(self, idx: int): + if (snapshot := self._snapshot_store.pop_version(idx)) is not None: + self._snapshot = snapshot + self._steps_since_snapshot = 0 def __del__(self): self._shutdown() @@ -185,12 +229,12 @@ def _shutdown(self): self._stop.set() self._mp_stop.set() if self._read_thread.is_alive(): - self._read_thread.join(timeout=QUEUE_TIMEOUT) + self._read_thread.join(timeout=QUEUE_TIMEOUT * 5) if self._sort_thread.is_alive(): - self._sort_thread.join(timeout=QUEUE_TIMEOUT) - for t in self._map_threads: + self._sort_thread.join(timeout=QUEUE_TIMEOUT * 5) + for t in self._workers: if t.is_alive(): - t.join(timeout=QUEUE_TIMEOUT) + t.join(timeout=QUEUE_TIMEOUT * 5) class ParallelMapper(BaseNode[T]): @@ -216,6 +260,7 @@ def __init__( method: Literal["thread", "process"] = "thread", multiprocessing_context: Optional[str] = None, max_concurrent: Optional[int] = None, + snapshot_frequency: int = 1, ): assert method in ["thread", "process"] self.source = source @@ -232,8 +277,11 @@ def __init__( if not isinstance(max_concurrent, int) and max_concurrent > num_workers: raise ValueError(f"{max_concurrent=} should be >= {num_workers=}!") self.max_concurrent = max_concurrent + self.snapshot_frequency = snapshot_frequency + self._it: Optional[_ParallelMapperIter[T]] = None + self._iter_for_state_dict: bool = False - def iterator(self) -> Iterator[T]: + def _get_iterator(self, initial_state: Optional[Dict[str, Any]]) -> _ParallelMapperIter[T]: return _ParallelMapperIter( source=self.source, map_fn=self.map_fn, @@ -242,10 +290,36 @@ def iterator(self) -> Iterator[T]: method=self.method, mp_context=self._mp_context, max_concurrent=self.max_concurrent, + snapshot_frequency=self.snapshot_frequency, + initial_state=initial_state, ) - -_WorkerType = Callable[[BaseNode, queue.Queue, threading.BoundedSemaphore, threading.Event], None] + def iterator(self, initial_state: Optional[Dict[str, Any]]) -> Iterator[T]: + if self._iter_for_state_dict: + self._iter_for_state_dict = False + else: + self._it = self._get_iterator(initial_state) + assert self._it is not None + return self._it + + def get_state(self) -> Dict[str, Any]: + if self._it is None: + self._it = self._get_iterator(None) + self._iter_for_state_dict = True + return self._it.get_state() + + +_WorkerType = Callable[ + [ + BaseNode, + queue.Queue, + SnapshotStore, + int, + threading.BoundedSemaphore, + threading.Event, + ], + None, +] class _SingleThreadedMapper(Iterator[T]): @@ -275,52 +349,86 @@ class _SingleThreadedMapper(Iterator[T]): processed in _populate_queue, in the _q, or about to be returned by an in-flight next() call. """ - def __init__(self, source: BaseNode[T], prefetch_factor: int, worker: _WorkerType): + def __init__( + self, + source: BaseNode[T], + prefetch_factor: int, + worker: _WorkerType, + snapshot_frequency: int, + initial_state: Optional[Dict[str, Any]], + ): self.source = source self.prefetch_factor = prefetch_factor self.worker = worker + self.snapshot_frequency = snapshot_frequency self._q: queue.Queue = queue.Queue() self._sem = threading.BoundedSemaphore(value=prefetch_factor) self._stop_event = threading.Event() + self._steps_since_snapshot = 0 + fast_forward = 0 + if initial_state is not None: + self._snapshot = initial_state["snapshot"] + fast_forward = initial_state["steps_since_snapshot"] + self.source.load_state_dict(self._snapshot) + else: + self._snapshot = self.source.state_dict() + self._snapshot_store = DequeSnapshotStore() self._thread = threading.Thread( target=self.worker, - args=(self.source, self._q, self._sem, self._stop_event), + args=( + self.source, + self._q, + self._snapshot_store, + self.snapshot_frequency, + self._sem, + self._stop_event, + ), daemon=True, ) self._thread.start() - self._stopped = False + for _ in range(fast_forward): + next(self) def __iter__(self) -> Iterator[T]: return self def __next__(self) -> T: - if self._stopped: - raise StopIteration() - while True: + if self._stop_event.is_set(): + raise StopIteration() try: - item = self._q.get(block=True, timeout=QUEUE_TIMEOUT) - break + item, idx = self._q.get(block=True, timeout=QUEUE_TIMEOUT) except queue.Empty: continue - if isinstance(item, StopIteration): - self._stopped = True - self._sem.release() - self._stop_event.set() - raise item - elif isinstance(item, ExceptionWrapper): - self._stopped = True - if not isinstance(item, StartupExceptionWrapper): - # We don't need to release for startup exceptions + if isinstance(item, StopIteration): self._sem.release() - self._stop_event.set() - item.reraise() - else: - self._sem.release() - return item + self._stop_event.set() + raise item + elif isinstance(item, ExceptionWrapper): + if not isinstance(item, StartupExceptionWrapper): + # We don't need to release for startup exceptions + self._sem.release() + self._stop_event.set() + item.reraise() + else: + self._sem.release() + self._steps_since_snapshot += 1 + self._maybe_update_snapshot(idx) + return item + + def get_state(self) -> Dict[str, Any]: + return { + "snapshot": self._snapshot, + "steps_since_snapshot": self._steps_since_snapshot, + } + + def _maybe_update_snapshot(self, idx: int): + if (snapshot := self._snapshot_store.pop_version(idx)) is not None: + self._snapshot = snapshot + self._steps_since_snapshot = 0 def __del__(self): self._shutdown() @@ -328,4 +436,4 @@ def __del__(self): def _shutdown(self): self._stop_event.set() if self._thread.is_alive(): - self._thread.join(timeout=QUEUE_TIMEOUT) + self._thread.join(timeout=QUEUE_TIMEOUT * 5) diff --git a/torchdata/nodes/pin_memory.py b/torchdata/nodes/pin_memory.py index c7d0b57e2..7dd04653c 100644 --- a/torchdata/nodes/pin_memory.py +++ b/torchdata/nodes/pin_memory.py @@ -8,21 +8,24 @@ import queue import threading -from typing import Iterator, Optional, Union +from typing import Any, Dict, Iterator, Optional, Union import torch import torch.multiprocessing from torch.utils.data._utils.pin_memory import pin_memory -from torchdata.nodes import BaseNode, T +from torchdata.nodes.base_node import BaseNode, T from torchdata.nodes.exception_wrapper import ExceptionWrapper, StartupExceptionWrapper from torchdata.nodes.map import _SingleThreadedMapper +from torchdata.nodes.snapshot_store import MonotonicIndex, SnapshotStore def _pin_memory_loop( source: BaseNode, q: queue.Queue, + snapshot_store: SnapshotStore, + snapshot_frequency: int, semaphore: threading.BoundedSemaphore, stop_event: threading.Event, device_id: Union[int, str], @@ -33,6 +36,15 @@ def _pin_memory_loop( # This setting is thread local, and prevents the copy in pin_memory from # consuming all CPU cores. + + idx = MonotonicIndex() + + def _put(item, block: bool = True, snapshot: Optional[Dict[str, Any]] = None): + _idx = idx.get() + if snapshot: + snapshot_store.append(snapshot=snapshot, version=_idx) + q.put((item, _idx), block=block) + try: torch.set_num_threads(1) @@ -46,26 +58,34 @@ def _pin_memory_loop( custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name()) custom_device_mod.set_device(device_id) + assert ( + isinstance(snapshot_frequency, int) and snapshot_frequency >= 0 + ), f"snapshot_frequency must be non-negative integer! Got {snapshot_frequency}" src_iter = iter(source) except Exception: e = StartupExceptionWrapper(where=f"in _pin_memory_loop startup for device {device_id}") - q.put(e) + _put(e, block=False) return + yielded = 0 while not stop_event.is_set(): if not semaphore.acquire(blocking=True, timeout=0.1): continue try: item = next(src_iter) item = pin_memory(item, device) - q.put(item, block=False) + yielded += 1 + snapshot = None + if snapshot_frequency > 0 and yielded % snapshot_frequency == 0: + snapshot = source.state_dict() + _put(item, block=False, snapshot=snapshot) except StopIteration as e: item = e - q.put(item, block=False) + _put(item, block=False) break except Exception: item = ExceptionWrapper(where=f"in _pin_memory_loop for device {device_id}") - q.put(item, block=False) + _put(item, block=False) break @@ -74,8 +94,10 @@ def __init__( self, source: BaseNode[T], pin_memory_device: str = "", + snapshot_frequency: int = 1, ): self.source = source + self.snapshot_frequency = snapshot_frequency self._pin_memory = torch.cuda.is_available() if len(pin_memory_device) == 0: self._pin_memory_device = None @@ -90,7 +112,10 @@ def __init__( else: self._current_device = torch.cuda.current_device() - def iterator(self) -> Iterator[T]: + self._it: Optional[_SingleThreadedMapper[T]] = None + self._iter_for_state_dict: bool = False + + def _get_iterator(self, initial_state: Optional[Dict[str, Any]]) -> _SingleThreadedMapper[T]: return _SingleThreadedMapper( source=self.source, prefetch_factor=1, @@ -99,4 +124,20 @@ def iterator(self) -> Iterator[T]: device_id=self._current_device, device=self._pin_memory_device, ), + snapshot_frequency=self.snapshot_frequency, + initial_state=initial_state, ) + + def iterator(self, initial_state: Optional[Dict[str, Any]]) -> Iterator[T]: + if self._iter_for_state_dict: + self._iter_for_state_dict = False + else: + self._it = self._get_iterator(initial_state) + assert self._it is not None + return self._it + + def get_state(self) -> Dict[str, Any]: + if self._it is None: + self._it = self._get_iterator(None) + self._iter_for_state_dict = True + return self._it.get_state() diff --git a/torchdata/nodes/prefetch.py b/torchdata/nodes/prefetch.py index a2715bbd4..0127d9938 100644 --- a/torchdata/nodes/prefetch.py +++ b/torchdata/nodes/prefetch.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Iterator +from typing import Any, Dict, Iterator, Optional from torchdata.nodes import BaseNode, T @@ -14,13 +14,32 @@ class Prefetcher(BaseNode[T]): - def __init__(self, source: BaseNode[T], prefetch_factor: int): + def __init__(self, source: BaseNode[T], prefetch_factor: int, snapshot_frequency: int = 1): self.source = source self.prefetch_factor = prefetch_factor + self.snapshot_frequency = snapshot_frequency + self._it: Optional[_SingleThreadedMapper[T]] = None + self._iter_for_state_dict: bool = False - def iterator(self) -> Iterator[T]: + def _get_iterator(self, initial_state: Optional[Dict[str, Any]]) -> _SingleThreadedMapper[T]: return _SingleThreadedMapper( source=self.source, prefetch_factor=self.prefetch_factor, worker=_populate_queue, + snapshot_frequency=self.snapshot_frequency, + initial_state=initial_state, ) + + def iterator(self, initial_state: Optional[Dict[str, Any]]) -> Iterator[T]: + if self._iter_for_state_dict: + self._iter_for_state_dict = False + else: + self._it = self._get_iterator(initial_state) + assert self._it is not None + return self._it + + def get_state(self) -> Dict[str, Any]: + if self._it is None: + self._it = self._get_iterator(None) + self._iter_for_state_dict = True + return self._it.get_state() diff --git a/torchdata/nodes/snapshot_store.py b/torchdata/nodes/snapshot_store.py new file mode 100644 index 000000000..e92001778 --- /dev/null +++ b/torchdata/nodes/snapshot_store.py @@ -0,0 +1,50 @@ +from collections import deque +from dataclasses import dataclass +from typing import Any, Optional, Protocol + + +@dataclass +class MonotonicIndex: + initial: int = 0 + + def __post_init__(self): + self._idx = self.initial + + def get(self) -> int: + idx = self._idx + self._idx += 1 + return idx + + +class SnapshotStore(Protocol): + """Protocol for passing snapshot state around between threads and processes""" + + def append(self, snapshot: Any, version: int): + ... + + def pop_version(self, version: int) -> Optional[Any]: + ... + + +class DequeSnapshotStore(SnapshotStore): + """A snapshot store that uses a deque to store snapshots""" + + def __init__(self, max_size: Optional[int] = None) -> None: + self._deque: deque = deque(maxlen=max_size) + self._max_version: int = -1 + + def append(self, snapshot: Any, version: int) -> None: + if version <= self._max_version: + raise ValueError(f"{version=} is not strictly greater than {self._max_version=}") + self._max_version = version + self._deque.append((version, snapshot)) + + def pop_version(self, version: int) -> Optional[Any]: + ver, val = None, None + while self._deque and version >= self._deque[0][0]: + ver, val = self._deque.popleft() + + if ver == version: + return val + else: + return None diff --git a/torchdata/nodes/types.py b/torchdata/nodes/types.py new file mode 100644 index 000000000..ed11f580b --- /dev/null +++ b/torchdata/nodes/types.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Any, Dict, Protocol, runtime_checkable + + +@runtime_checkable +class Stateful(Protocol): + def state_dict(self) -> Dict[str, Any]: + ... + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + ...