Skip to content

Commit

Permalink
add polars and resolve some issues
Browse files Browse the repository at this point in the history
Signed-off-by: Matt Richards <mrichards7@outlook.com.au>
  • Loading branch information
m-richards committed Feb 15, 2025
1 parent 41df472 commit aeea31b
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 6 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ repos:
- types-pyyaml
- types-requests
- types-setuptools
- polars
args: ["pandera", "tests", "scripts"]
exclude: (^docs/|^tests/mypy/modules/)
pass_filenames: false
Expand Down
2 changes: 1 addition & 1 deletion pandera/api/polars/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class Column(ComponentSchema[PolarsCheckObjects]):

def __init__(
self,
dtype: PolarsDtypeInputTypes = None,
dtype: Optional[PolarsDtypeInputTypes] = None,
checks: Optional[CheckList] = None,
nullable: bool = False,
unique: bool = False,
Expand Down
5 changes: 3 additions & 2 deletions pandera/backends/polars/builtin_checks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Built-in checks for polars."""

import re
from collections.abc import Collection
from typing import Any, Iterable, Optional, TypeVar, Union

import polars as pl
Expand Down Expand Up @@ -140,7 +141,7 @@ def in_range(
@register_builtin_check(
error="isin({allowed_values})",
)
def isin(data: PolarsData, allowed_values: Iterable) -> pl.LazyFrame:
def isin(data: PolarsData, allowed_values: Collection) -> pl.LazyFrame:
"""Ensure only allowed values occur within a series.
This checks whether all elements of a :class:`polars.Series`
Expand All @@ -160,7 +161,7 @@ def isin(data: PolarsData, allowed_values: Iterable) -> pl.LazyFrame:
@register_builtin_check(
error="notin({forbidden_values})",
)
def notin(data: PolarsData, forbidden_values: Iterable) -> pl.LazyFrame:
def notin(data: PolarsData, forbidden_values: Collection) -> pl.LazyFrame:
"""Ensure some defined values don't occur within a series.
Like :meth:`Check.isin` this check operates on single characters if
Expand Down
4 changes: 2 additions & 2 deletions pandera/engines/polars_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def polars_coerce_failure_cases(
class DataType(dtypes.DataType):
"""Base `DataType` for boxing Polars data types."""

type: pl.DataType = dataclasses.field(repr=False, init=False)
type: Type[pl.DataType] = dataclasses.field(repr=False, init=False)

def __init__(self, dtype: Optional[Any] = None):
super().__init__()
Expand Down Expand Up @@ -673,7 +673,7 @@ class Categorical(DataType):

def __init__( # pylint:disable=super-init-not-called
self,
ordering: str = "physical",
ordering: Literal["physical", "lexical"] = "physical",
) -> None:
object.__setattr__(self, "ordering", ordering)
object.__setattr__(self, "type", pl.Categorical(ordering=ordering))
Expand Down
6 changes: 5 additions & 1 deletion tests/polars/test_polars_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
except ImportError:
from typing_extensions import Annotated # type: ignore

from typing import (
Type,
) # when python 3.9 is minimum version, use `type` instead


@pytest.fixture
def ldf_basic():
Expand Down Expand Up @@ -406,7 +410,7 @@ def test_set_defaults(ldf_basic, ldf_schema_basic):
assert validated_data.equals(expected_data.collect())


def _failure_value(column: str, dtype: Optional[pl.DataType] = None):
def _failure_value(column: str, dtype: Optional[Type[pl.DataType]] = None):
if column.startswith("string"):
return pl.lit("9", dtype=dtype or pl.Utf8)
elif column.startswith("int"):
Expand Down

0 comments on commit aeea31b

Please sign in to comment.