From 2b831b38939705c239dfcb064c6c61c8443ed3e9 Mon Sep 17 00:00:00 2001 From: Gokul Gunasekaran Date: Mon, 25 Mar 2024 22:34:13 -0700 Subject: [PATCH] Fix two types of issues raised by mypy - implicit optional issues and arg-type mismatch (#1234) Summary: Title. Here is the mypy run failing with 24 errors - https://github.com/pytorch/data/actions/runs/8428471909/job/23081016202. With these two changes, the mypy errors are now down to 8 - https://github.com/pytorch/data/actions/runs/8428898766/job/23082258705?pr=1234 ### Changes - Add optional where it was implicit earlier - Change ignore type to arg-type Pull Request resolved: https://github.com/pytorch/data/pull/1234 Reviewed By: ejguan Differential Revision: D55350451 Pulled By: gokulavasan fbshipit-source-id: 54a99c8879eaf84aea0fb833cbbd9f37bc504db9 --- torchdata/datapipes/iter/load/online.py | 6 +++--- torchdata/datapipes/iter/util/cacheholder.py | 2 +- torchdata/datapipes/iter/util/rows2columnar.py | 6 ++++-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/torchdata/datapipes/iter/load/online.py b/torchdata/datapipes/iter/load/online.py index 2b34fafeb..7b9d8d9dc 100644 --- a/torchdata/datapipes/iter/load/online.py +++ b/torchdata/datapipes/iter/load/online.py @@ -36,7 +36,7 @@ def _get_response_from_http( ) -> Tuple[str, StreamWrapper]: with requests.Session() as session: proxies = _get_proxies() - r = session.get(url, timeout=timeout, proxies=proxies, stream=True, **query_params) # type: ignore[attr-defined] + r = session.get(url, timeout=timeout, proxies=proxies, stream=True, **query_params) # type: ignore[arg-type] r.raise_for_status() return url, StreamWrapper(r.raw) @@ -112,7 +112,7 @@ def _get_response_from_google_drive( confirm_token = None with requests.Session() as session: - response = session.get(url, timeout=timeout, stream=True, **query_params) # type: ignore[attr-defined] + response = session.get(url, timeout=timeout, stream=True, **query_params) # type: ignore[arg-type] response.raise_for_status() for k, v in response.cookies.items(): @@ -129,7 +129,7 @@ def _get_response_from_google_drive( if confirm_token: url = url + "&confirm=" + confirm_token - response = session.get(url, timeout=timeout, stream=True, **query_params) # type: ignore[attr-defined] + response = session.get(url, timeout=timeout, stream=True, **query_params) # type: ignore[arg-type] response.raise_for_status() if "content-disposition" not in response.headers: diff --git a/torchdata/datapipes/iter/util/cacheholder.py b/torchdata/datapipes/iter/util/cacheholder.py index ae2c02e79..ca4c705b5 100644 --- a/torchdata/datapipes/iter/util/cacheholder.py +++ b/torchdata/datapipes/iter/util/cacheholder.py @@ -202,7 +202,7 @@ def __init__( self, source_datapipe: IterDataPipe, filepath_fn: Optional[Callable] = None, - hash_dict: Dict[str, str] = None, + hash_dict: Optional[Dict[str, str]] = None, hash_type: str = "sha256", extra_check_fn: Optional[Callable[[str], bool]] = None, ): diff --git a/torchdata/datapipes/iter/util/rows2columnar.py b/torchdata/datapipes/iter/util/rows2columnar.py index 14ffdf93c..5764a3d60 100644 --- a/torchdata/datapipes/iter/util/rows2columnar.py +++ b/torchdata/datapipes/iter/util/rows2columnar.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from collections import defaultdict -from typing import Dict, Iterator, List, Union +from typing import Dict, Iterator, List, Optional, Union from torchdata.datapipes import functional_datapipe from torchdata.datapipes.iter import IterDataPipe @@ -50,7 +50,9 @@ class Rows2ColumnarIterDataPipe(IterDataPipe[Dict]): """ column_names: List[str] - def __init__(self, source_datapipe: IterDataPipe[List[Union[Dict, List]]], column_names: List[str] = None) -> None: + def __init__( + self, source_datapipe: IterDataPipe[List[Union[Dict, List]]], column_names: Optional[List[str]] = None + ) -> None: self.source_datapipe: IterDataPipe[List[Union[Dict, List]]] = source_datapipe self.column_names: List[str] = [] if column_names is None else column_names