From 88a1d342d626c5de7c034cb788b088534f4dd13e Mon Sep 17 00:00:00 2001 From: Gokul Gunasekaran Date: Tue, 26 Mar 2024 13:02:07 -0700 Subject: [PATCH] Address remaining mypy errors - mypy checks passing now (#1235) Summary: Address remaining mypy errors (8 that were left) ### Changes - Had to import Literal and two string literals in __init__.py.in which were used in https://github.com/pytorch/pytorch/blob/d1f58eaaf500da8040e65780b7db171dff953e8a/torch/utils/data/datapipes/iter/combining.py#L147. This addressed 3 out of the 8 remaining errors - For rest 5, I chose to add type ignore statements to skip mypy checks as they are not being actively developed Pull Request resolved: https://github.com/pytorch/data/pull/1235 Reviewed By: ejguan Differential Revision: D55374270 Pulled By: gokulavasan fbshipit-source-id: ce130c3f35e5218eddabf16a4dfee283ac0ae142 --- torchdata/dataloader2/reading_service.py | 2 +- torchdata/datapipes/iter/__init__.pyi.in | 3 ++- torchdata/datapipes/iter/util/combining.py | 4 ++-- torchdata/datapipes/iter/util/randomsplitter.py | 4 ++-- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/torchdata/dataloader2/reading_service.py b/torchdata/dataloader2/reading_service.py index adc94b78c..1af26f875 100644 --- a/torchdata/dataloader2/reading_service.py +++ b/torchdata/dataloader2/reading_service.py @@ -325,7 +325,7 @@ def initialize(self, datapipe: DataPipe) -> DataPipe: dispatching_dp = find_lca_round_robin_sharding_dp(graph) # TODO(ejguan): When the last DataPipe is round_robin_sharding, use InPrcoessReadingService if dispatching_dp is not None: - dummy_dp = _DummyIterDataPipe() + dummy_dp = _DummyIterDataPipe() # type: ignore graph = replace_dp(graph, dispatching_dp, dummy_dp) # type: ignore[arg-type] datapipe = list(graph.values())[0][0] # TODO(ejguan): Determine buffer_size at runtime or use unlimited buffer diff --git a/torchdata/datapipes/iter/__init__.pyi.in b/torchdata/datapipes/iter/__init__.pyi.in index 79f906a01..773ba85b1 100644 --- a/torchdata/datapipes/iter/__init__.pyi.in +++ b/torchdata/datapipes/iter/__init__.pyi.in @@ -15,7 +15,7 @@ from torch.utils.data import DataChunk, IterableDataset, default_collate from torch.utils.data.datapipes._typing import _DataPipeMeta from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES -from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union, Hashable +from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, TypeVar, Union, Hashable try: import torcharrow @@ -24,6 +24,7 @@ except ImportError: T = TypeVar("T") T_co = TypeVar("T_co", covariant=True) +ForkIterDataPipeCopyOptions = Literal["shallow", "deep"] class IterDataPipe(IterableDataset[T_co], metaclass=_DataPipeMeta): functions: Dict[str, Callable] = ... diff --git a/torchdata/datapipes/iter/util/combining.py b/torchdata/datapipes/iter/util/combining.py index aa8e2df56..ad98e4ff1 100644 --- a/torchdata/datapipes/iter/util/combining.py +++ b/torchdata/datapipes/iter/util/combining.py @@ -288,7 +288,7 @@ def __new__(cls, datapipe: IterDataPipe, num_instances: int, buffer_size: int = return [datapipe] datapipe = datapipe.enumerate() - container = _RoundRobinDemultiplexerIterDataPipe(datapipe, num_instances, buffer_size=buffer_size) + container = _RoundRobinDemultiplexerIterDataPipe(datapipe, num_instances, buffer_size=buffer_size) # type: ignore return [_ChildDataPipe(container, i).map(_drop_index) for i in range(num_instances)] @@ -357,7 +357,7 @@ def __new__( ) # The implementation basically uses Forker but only yields a specific element within the sequence - container = _UnZipperIterDataPipe(source_datapipe, instance_ids, buffer_size) # type: ignore[arg-type] + container = _UnZipperIterDataPipe(source_datapipe, instance_ids, buffer_size) # type: ignore return [_ChildDataPipe(container, i) for i in range(len(instance_ids))] diff --git a/torchdata/datapipes/iter/util/randomsplitter.py b/torchdata/datapipes/iter/util/randomsplitter.py index 2608f93c8..2972122f9 100644 --- a/torchdata/datapipes/iter/util/randomsplitter.py +++ b/torchdata/datapipes/iter/util/randomsplitter.py @@ -72,7 +72,7 @@ def __new__( "RandomSplitter needs `total_length`, but it is unable to infer it from " f"the `source_datapipe`: {source_datapipe}." ) - container = _RandomSplitterIterDataPipe(source_datapipe, total_length, weights, seed) + container = _RandomSplitterIterDataPipe(source_datapipe, total_length, weights, seed) # type: ignore if target is None: return [SplitterIterator(container, k) for k in list(weights.keys())] else: @@ -101,7 +101,7 @@ def __init__( self._rng = random.Random(self._seed) self._lengths: List[int] = [] - def draw(self) -> T: + def draw(self) -> T: # type: ignore selected_key = self._rng.choices(self.keys, self.weights)[0] index = self.key_to_index[selected_key] self.weights[index] -= 1