From aeea31badf0db8affbbdf3088cefc13f1103de81 Mon Sep 17 00:00:00 2001 From: Matt Richards Date: Sat, 15 Feb 2025 13:07:11 +1100 Subject: [PATCH] add polars and resolve some issues Signed-off-by: Matt Richards --- .pre-commit-config.yaml | 1 + pandera/api/polars/components.py | 2 +- pandera/backends/polars/builtin_checks.py | 5 +++-- pandera/engines/polars_engine.py | 4 ++-- tests/polars/test_polars_container.py | 6 +++++- 5 files changed, 12 insertions(+), 6 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5d6520d23..77fd1c5f9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -54,6 +54,7 @@ repos: - types-pyyaml - types-requests - types-setuptools + - polars args: ["pandera", "tests", "scripts"] exclude: (^docs/|^tests/mypy/modules/) pass_filenames: false diff --git a/pandera/api/polars/components.py b/pandera/api/polars/components.py index 823912d1d..e697c10be 100644 --- a/pandera/api/polars/components.py +++ b/pandera/api/polars/components.py @@ -23,7 +23,7 @@ class Column(ComponentSchema[PolarsCheckObjects]): def __init__( self, - dtype: PolarsDtypeInputTypes = None, + dtype: Optional[PolarsDtypeInputTypes] = None, checks: Optional[CheckList] = None, nullable: bool = False, unique: bool = False, diff --git a/pandera/backends/polars/builtin_checks.py b/pandera/backends/polars/builtin_checks.py index 19a736010..3da2ddfde 100644 --- a/pandera/backends/polars/builtin_checks.py +++ b/pandera/backends/polars/builtin_checks.py @@ -1,6 +1,7 @@ """Built-in checks for polars.""" import re +from collections.abc import Collection from typing import Any, Iterable, Optional, TypeVar, Union import polars as pl @@ -140,7 +141,7 @@ def in_range( @register_builtin_check( error="isin({allowed_values})", ) -def isin(data: PolarsData, allowed_values: Iterable) -> pl.LazyFrame: +def isin(data: PolarsData, allowed_values: Collection) -> pl.LazyFrame: """Ensure only allowed values occur within a series. This checks whether all elements of a :class:`polars.Series` @@ -160,7 +161,7 @@ def isin(data: PolarsData, allowed_values: Iterable) -> pl.LazyFrame: @register_builtin_check( error="notin({forbidden_values})", ) -def notin(data: PolarsData, forbidden_values: Iterable) -> pl.LazyFrame: +def notin(data: PolarsData, forbidden_values: Collection) -> pl.LazyFrame: """Ensure some defined values don't occur within a series. Like :meth:`Check.isin` this check operates on single characters if diff --git a/pandera/engines/polars_engine.py b/pandera/engines/polars_engine.py index 9ba452c0e..8f073de7e 100644 --- a/pandera/engines/polars_engine.py +++ b/pandera/engines/polars_engine.py @@ -124,7 +124,7 @@ def polars_coerce_failure_cases( class DataType(dtypes.DataType): """Base `DataType` for boxing Polars data types.""" - type: pl.DataType = dataclasses.field(repr=False, init=False) + type: Type[pl.DataType] = dataclasses.field(repr=False, init=False) def __init__(self, dtype: Optional[Any] = None): super().__init__() @@ -673,7 +673,7 @@ class Categorical(DataType): def __init__( # pylint:disable=super-init-not-called self, - ordering: str = "physical", + ordering: Literal["physical", "lexical"] = "physical", ) -> None: object.__setattr__(self, "ordering", ordering) object.__setattr__(self, "type", pl.Categorical(ordering=ordering)) diff --git a/tests/polars/test_polars_container.py b/tests/polars/test_polars_container.py index edcaf10d7..19280c949 100644 --- a/tests/polars/test_polars_container.py +++ b/tests/polars/test_polars_container.py @@ -22,6 +22,10 @@ except ImportError: from typing_extensions import Annotated # type: ignore +from typing import ( + Type, +) # when python 3.9 is minimum version, use `type` instead + @pytest.fixture def ldf_basic(): @@ -406,7 +410,7 @@ def test_set_defaults(ldf_basic, ldf_schema_basic): assert validated_data.equals(expected_data.collect()) -def _failure_value(column: str, dtype: Optional[pl.DataType] = None): +def _failure_value(column: str, dtype: Optional[Type[pl.DataType]] = None): if column.startswith("string"): return pl.lit("9", dtype=dtype or pl.Utf8) elif column.startswith("int"):