-
Notifications
You must be signed in to change notification settings - Fork 161
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Don't subclass IterableDataset, and add two-directional adapters inst…
…ead (#1349) * add basic iterable and map style adapters * Add ToIterableDataset adapter and drop that inheritance * fix test
- Loading branch information
Showing
5 changed files
with
193 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import testslide | ||
from torch.utils.data import IterableDataset, RandomSampler | ||
from torchdata.nodes.adapters import IterableWrapper, MapStyleWrapper, ToIterableDataset | ||
|
||
from .utils import DummyIterableDataset, DummyMapDataset, MockSource | ||
|
||
|
||
class TestIterableWrapper(testslide.TestCase): | ||
def test_iterable(self): | ||
n = 20 | ||
node = IterableWrapper(range(n)) | ||
for epoch in range(2): | ||
result = list(node) | ||
self.assertEqual(len(result), n) | ||
for i, j in enumerate(result): | ||
self.assertEqual(j, i) | ||
|
||
def test_generator(self): | ||
n = 20 | ||
node = IterableWrapper(f"str_{i}" for i in range(n)) | ||
result = list(node) | ||
self.assertEqual(len(result), n) | ||
for i, j in enumerate(result): | ||
self.assertEqual(j, f"str_{i}") | ||
|
||
# Second time iter is called on generator will raise StopIteration | ||
result = list(node) | ||
self.assertEqual(len(result), 0) | ||
|
||
def test_iterable_dataset(self): | ||
n = 20 | ||
node = IterableWrapper(DummyIterableDataset(n)) | ||
for epoch in range(2): | ||
result = list(node) | ||
self.assertEqual(len(result), n) | ||
for i, row in enumerate(result): | ||
self.assertEqual(row["step"], i) | ||
self.assertEqual(row["test_tensor"].item(), i) | ||
self.assertEqual(row["test_str"], f"str_{i}") | ||
|
||
|
||
class TestMapStyle(testslide.TestCase): | ||
def test_default_sampler(self): | ||
n = 20 | ||
node = MapStyleWrapper(DummyMapDataset(n)) | ||
for epoch in range(2): | ||
result = list(node) | ||
self.assertEqual(len(result), n) | ||
for i, row in enumerate(result): | ||
self.assertEqual(row["step"], i) | ||
self.assertEqual(row["test_tensor"].item(), i) | ||
self.assertEqual(row["test_str"], f"str_{i}") | ||
|
||
def test_random_sampler(self): | ||
n = 20 | ||
ds = DummyMapDataset(n) | ||
node = MapStyleWrapper(ds, sampler=RandomSampler(ds)) | ||
results = [] | ||
for epoch in range(2): | ||
result = list(node) | ||
results.append(result) | ||
self.assertEqual(len(result), n) | ||
self.assertEqual({row["step"] for row in result}, set(range(n))) | ||
self.assertEqual({row["test_tensor"].item() for row in result}, set(range(n))) | ||
self.assertEqual( | ||
{row["test_str"] for row in result}, | ||
{f"str_{i}" for i in range(n)}, | ||
) | ||
|
||
self.assertNotEqual(results[0], results[1]) # Should have different values per epoch | ||
|
||
def test_dict(self): | ||
n = 20 | ||
orig_ds = DummyMapDataset(n) | ||
d = {f"i{i}": orig_ds[i] for i in range(n)} | ||
sampler = list(d.keys()) | ||
node = MapStyleWrapper(d, sampler=sampler) | ||
for epoch in range(2): | ||
result = list(node) | ||
self.assertEqual(len(result), n) | ||
for i, row in enumerate(result): | ||
self.assertEqual(row["step"], i) | ||
self.assertEqual(row["test_tensor"].item(), i) | ||
self.assertEqual(row["test_str"], f"str_{i}") | ||
|
||
|
||
class TestToIterableDataset(testslide.TestCase): | ||
def test_to_iterable_dataset(self): | ||
n = 20 | ||
node = MockSource(n) | ||
iterable_ds = ToIterableDataset(node) | ||
self.assertIsInstance(iterable_ds, IterableDataset) | ||
for epoch in range(2): | ||
result = list(iterable_ds) | ||
self.assertEqual(len(result), n) | ||
for i, row in enumerate(result): | ||
self.assertEqual(row["step"], i) | ||
self.assertEqual(row["test_tensor"].item(), i) | ||
self.assertEqual(row["test_str"], f"str_{i}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
from typing import Generic, Iterable, Iterator, Mapping, Optional, Sized, TypeVar | ||
|
||
from torch.utils.data import IterableDataset, Sampler, SequentialSampler | ||
|
||
from torchdata.nodes.base_node import BaseNode, T | ||
|
||
K = TypeVar("K", covariant=True) | ||
|
||
|
||
class IterableWrapper(BaseNode[T]): | ||
"""Thin Wrapper that converts any Iterable (including | ||
torch.utils.data.IterableDataset) in to a BaseNode. | ||
:param iterable: Iterable to wrap. IterableWrapper calls iter() on it. | ||
""" | ||
|
||
iterable: Iterable[T] | ||
|
||
def __init__(self, iterable: Iterable[T]): | ||
self.iterable = iterable | ||
|
||
def iterator(self) -> Iterator[T]: | ||
return iter(self.iterable) | ||
|
||
|
||
class MapStyleWrapper(BaseNode[T], Generic[K, T]): | ||
"""Thin Wrapper that converts any Mapping[K, T] into a BaseNode[T]. | ||
If no sampler is provided, a SequentialSampler is used and requires dataset to be Sized. | ||
Note that if your map_style lookup is expensive, you might want | ||
to use __to_be_named_dataloader_drop_in__ instead which can take advantage | ||
of process- or thread-based parallelism. | ||
""" | ||
|
||
dataset: Mapping[K, T] | ||
sampler: Sampler[K] | ||
|
||
def __init__(self, dataset: Mapping[K, T], sampler: Optional[Sampler[K]] = None): | ||
self.dataset = dataset | ||
if sampler is None: | ||
if not isinstance(self.dataset, Sized): | ||
raise ValueError("If dataset does not implement __len__, you must pass a sampler!") | ||
self.sampler = SequentialSampler(self.dataset) # type: ignore | ||
else: | ||
self.sampler = sampler | ||
|
||
def iterator(self) -> Iterator[T]: | ||
for key in self.sampler: | ||
yield self.dataset[key] | ||
|
||
|
||
class ToIterableDataset(IterableDataset[T]): | ||
def __init__(self, base_node: BaseNode[T]): | ||
self.base_node = base_node | ||
|
||
def __iter__(self) -> Iterator[T]: | ||
return iter(self.base_node) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters