Skip to content

Commit

Permalink
Don't subclass IterableDataset, and add two-directional adapters inst…
Browse files Browse the repository at this point in the history
…ead (#1349)

* add basic iterable and map style adapters

* Add ToIterableDataset adapter and drop that inheritance

* fix test
  • Loading branch information
andrewkho authored Oct 28, 2024
1 parent 7fdd0e9 commit 7b0de83
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 5 deletions.
105 changes: 105 additions & 0 deletions test/nodes/test_adapters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import testslide
from torch.utils.data import IterableDataset, RandomSampler
from torchdata.nodes.adapters import IterableWrapper, MapStyleWrapper, ToIterableDataset

from .utils import DummyIterableDataset, DummyMapDataset, MockSource


class TestIterableWrapper(testslide.TestCase):
def test_iterable(self):
n = 20
node = IterableWrapper(range(n))
for epoch in range(2):
result = list(node)
self.assertEqual(len(result), n)
for i, j in enumerate(result):
self.assertEqual(j, i)

def test_generator(self):
n = 20
node = IterableWrapper(f"str_{i}" for i in range(n))
result = list(node)
self.assertEqual(len(result), n)
for i, j in enumerate(result):
self.assertEqual(j, f"str_{i}")

# Second time iter is called on generator will raise StopIteration
result = list(node)
self.assertEqual(len(result), 0)

def test_iterable_dataset(self):
n = 20
node = IterableWrapper(DummyIterableDataset(n))
for epoch in range(2):
result = list(node)
self.assertEqual(len(result), n)
for i, row in enumerate(result):
self.assertEqual(row["step"], i)
self.assertEqual(row["test_tensor"].item(), i)
self.assertEqual(row["test_str"], f"str_{i}")


class TestMapStyle(testslide.TestCase):
def test_default_sampler(self):
n = 20
node = MapStyleWrapper(DummyMapDataset(n))
for epoch in range(2):
result = list(node)
self.assertEqual(len(result), n)
for i, row in enumerate(result):
self.assertEqual(row["step"], i)
self.assertEqual(row["test_tensor"].item(), i)
self.assertEqual(row["test_str"], f"str_{i}")

def test_random_sampler(self):
n = 20
ds = DummyMapDataset(n)
node = MapStyleWrapper(ds, sampler=RandomSampler(ds))
results = []
for epoch in range(2):
result = list(node)
results.append(result)
self.assertEqual(len(result), n)
self.assertEqual({row["step"] for row in result}, set(range(n)))
self.assertEqual({row["test_tensor"].item() for row in result}, set(range(n)))
self.assertEqual(
{row["test_str"] for row in result},
{f"str_{i}" for i in range(n)},
)

self.assertNotEqual(results[0], results[1]) # Should have different values per epoch

def test_dict(self):
n = 20
orig_ds = DummyMapDataset(n)
d = {f"i{i}": orig_ds[i] for i in range(n)}
sampler = list(d.keys())
node = MapStyleWrapper(d, sampler=sampler)
for epoch in range(2):
result = list(node)
self.assertEqual(len(result), n)
for i, row in enumerate(result):
self.assertEqual(row["step"], i)
self.assertEqual(row["test_tensor"].item(), i)
self.assertEqual(row["test_str"], f"str_{i}")


class TestToIterableDataset(testslide.TestCase):
def test_to_iterable_dataset(self):
n = 20
node = MockSource(n)
iterable_ds = ToIterableDataset(node)
self.assertIsInstance(iterable_ds, IterableDataset)
for epoch in range(2):
result = list(iterable_ds)
self.assertEqual(len(result), n)
for i, row in enumerate(result):
self.assertEqual(row["step"], i)
self.assertEqual(row["test_tensor"].item(), i)
self.assertEqual(row["test_str"], f"str_{i}")
20 changes: 20 additions & 0 deletions test/nodes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,23 @@ def __init__(self, msg: str = "Iter Init Error") -> None:

def iterator(self) -> Iterator[int]:
raise ValueError(self.msg)


class DummyIterableDataset(torch.utils.data.IterableDataset):
def __init__(self, num_samples: int) -> None:
self.num_samples = num_samples

def __iter__(self) -> Iterator[dict]:
for i in range(self.num_samples):
yield {"step": i, "test_tensor": torch.tensor([i]), "test_str": f"str_{i}"}


class DummyMapDataset(torch.utils.data.Dataset):
def __init__(self, num_samples: int) -> None:
self.num_samples = num_samples

def __len__(self) -> int:
return self.num_samples

def __getitem__(self, i: int) -> dict:
return {"step": i, "test_tensor": torch.tensor([i]), "test_str": f"str_{i}"}
1 change: 1 addition & 0 deletions torchdata/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .adapters import IterableWrapper, MapStyleWrapper
from .base_node import BaseNode, T
from .batch import Batcher
from .map import Mapper, ParallelMapper
Expand Down
64 changes: 64 additions & 0 deletions torchdata/nodes/adapters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from typing import Generic, Iterable, Iterator, Mapping, Optional, Sized, TypeVar

from torch.utils.data import IterableDataset, Sampler, SequentialSampler

from torchdata.nodes.base_node import BaseNode, T

K = TypeVar("K", covariant=True)


class IterableWrapper(BaseNode[T]):
"""Thin Wrapper that converts any Iterable (including
torch.utils.data.IterableDataset) in to a BaseNode.
:param iterable: Iterable to wrap. IterableWrapper calls iter() on it.
"""

iterable: Iterable[T]

def __init__(self, iterable: Iterable[T]):
self.iterable = iterable

def iterator(self) -> Iterator[T]:
return iter(self.iterable)


class MapStyleWrapper(BaseNode[T], Generic[K, T]):
"""Thin Wrapper that converts any Mapping[K, T] into a BaseNode[T].
If no sampler is provided, a SequentialSampler is used and requires dataset to be Sized.
Note that if your map_style lookup is expensive, you might want
to use __to_be_named_dataloader_drop_in__ instead which can take advantage
of process- or thread-based parallelism.
"""

dataset: Mapping[K, T]
sampler: Sampler[K]

def __init__(self, dataset: Mapping[K, T], sampler: Optional[Sampler[K]] = None):
self.dataset = dataset
if sampler is None:
if not isinstance(self.dataset, Sized):
raise ValueError("If dataset does not implement __len__, you must pass a sampler!")
self.sampler = SequentialSampler(self.dataset) # type: ignore
else:
self.sampler = sampler

def iterator(self) -> Iterator[T]:
for key in self.sampler:
yield self.dataset[key]


class ToIterableDataset(IterableDataset[T]):
def __init__(self, base_node: BaseNode[T]):
self.base_node = base_node

def __iter__(self) -> Iterator[T]:
return iter(self.base_node)
8 changes: 3 additions & 5 deletions torchdata/nodes/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,13 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Generic, Iterator, TypeVar
from typing import Iterable, Iterator, TypeVar

import torch.utils.data

T = TypeVar("T", covariant=True)

T = TypeVar("T")


class BaseNode(torch.utils.data.IterableDataset, Generic[T]):
class BaseNode(Iterable[T]):
def iterator(self) -> Iterator[T]:
"""Override this method to implement the iterator.
Iterators are expected to raise StopIteration to signal
Expand Down

0 comments on commit 7b0de83

Please sign in to comment.