Skip to content

Commit

Permalink
Add unbatcher node (#1416)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewkho authored Dec 30, 2024
1 parent 3f866a8 commit 88c7b96
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 3 deletions.
26 changes: 25 additions & 1 deletion test/nodes/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import TestCase
from torchdata.nodes.batch import Batcher
from torchdata.nodes.batch import Batcher, Unbatcher

from .utils import MockSource, run_test_save_load_state

Expand Down Expand Up @@ -48,3 +48,27 @@ def test_save_load_state_fast_forward(self, midpoint: int, drop_last: bool):
src = MockSource(num_samples=20)
node = Batcher(src, batch_size=batch_size, drop_last=drop_last)
run_test_save_load_state(self, node, midpoint)


class TestUnbatcher(TestCase):
def test_unbatcher(self) -> None:
batch_size = 6
n = 20
src = MockSource(num_samples=n)
node = Batcher(src, batch_size=batch_size, drop_last=False)
node = Unbatcher(node)

results = list(node)
self.assertEqual(len(results), n)
for i in range(n):
self.assertEqual(results[i]["step"], i)
self.assertEqual(results[i]["test_tensor"], torch.tensor([i]))
self.assertEqual(results[i]["test_str"], f"str_{i}")

@parameterized.expand(itertools.product([0, 2], [True, False]))
def test_save_load_state_fast_forward(self, midpoint: int, drop_last: bool):
batch_size = 6
src = MockSource(num_samples=20)
node = Batcher(src, batch_size=batch_size, drop_last=drop_last)
node = Unbatcher(node)
run_test_save_load_state(self, node, midpoint)
3 changes: 2 additions & 1 deletion torchdata/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from .adapters import IterableWrapper, MapStyleWrapper, SamplerWrapper
from .base_node import BaseNode, T
from .batch import Batcher
from .batch import Batcher, Unbatcher
from .loader import Loader
from .map import Mapper, ParallelMapper
from .pin_memory import PinMemory
Expand All @@ -31,6 +31,7 @@
"Stateful",
"StopCriteria",
"T",
"Unbatcher",
]

assert sorted(__all__) == __all__
55 changes: 54 additions & 1 deletion torchdata/nodes/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Sequence

from torchdata.nodes.base_node import BaseNode, T

Expand Down Expand Up @@ -56,3 +56,56 @@ def next(self) -> List[T]:

def get_state(self) -> Dict[str, Any]:
return {self.SOURCE_KEY: self.source.state_dict()}


class Unbatcher(BaseNode[T]):
"""Unbatcher will flatten batches pulled from source, and
yields elements in sequential order when next() is called on it.
Args:
source (BaseNode[T]): The source node to pull batches from.
"""

SOURCE_KEY = "source"
BATCH_IDX_KEY = "batch_idx"

def __init__(self, source: BaseNode[Sequence[T]]):
super().__init__(self)
self.source = source

def reset(self, initial_state: Optional[Dict[str, Any]] = None):
super().reset(initial_state)
if initial_state is not None:
self.source.reset(initial_state[self.SOURCE_KEY])
self._cached_state_dict = initial_state[self.SOURCE_KEY]
try:
self._batch = next(self.source)
self._batch_idx = initial_state[self.BATCH_IDX_KEY]
except StopIteration:
# next(self.source) will be called upon subsequent self.next() call
# and raise StopIteration in the correct place.
self._batch = []
self._batch_idx = 0
else:
self.source.reset()
self._batch = []
self._cached_state_dict = None
self._batch_idx = 0

def next(self) -> T:
while self._batch_idx >= len(self._batch):
self._cached_state_dict = self.source.state_dict()
self._batch = next(self.source)
self._batch_idx = 0

self._batch_idx += 1
return self._batch[self._batch_idx - 1]

def get_state(self) -> Dict[str, Any]:
if self._cached_state_dict is None:
self._cached_state_dict = self.source.state_dict()

return {
self.SOURCE_KEY: self._cached_state_dict,
self.BATCH_IDX_KEY: self._batch_idx,
}

0 comments on commit 88c7b96

Please sign in to comment.