Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix/763 improve type annotations for DataFrameModel.validate #1905

Merged
merged 10 commits into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/ci-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ jobs:
types-pytz \
types-pyyaml \
types-requests \
types-setuptools
types-setuptools \
polars
- name: Pip info
run: python -m pip list

Expand Down
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ repos:
- types-pyyaml
- types-requests
- types-setuptools
- polars
args: ["pandera", "tests", "scripts"]
exclude: (^docs/|^tests/mypy/modules/)
pass_filenames: false
Expand Down
16 changes: 15 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[mypy]
ignore_missing_imports = True
follow_imports = skip
follow_imports = normal
allow_redefinition = True
warn_return_any = False
warn_unused_configs = True
Expand All @@ -12,3 +12,17 @@ exclude=(?x)(
| ^pandera/backends/pyspark
| ^tests/pyspark
)
[mypy-pandera.api.pyspark.*]
follow_imports = skip

[mypy-docs.*]
follow_imports = skip

[mypy-pandera.engines.polars_engine]
ignore_errors = True

[mypy-pandera.backends.polars.builtin_checks]
ignore_errors = True

[mypy-tests.polars.*]
ignore_errors = True
24 changes: 23 additions & 1 deletion pandera/api/pandas/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import copy
import sys
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast

import pandas as pd

from pandera.api.base.schema import BaseSchema
from pandera.api.checks import Check
from pandera.api.dataframe.model import DataFrameModel as _DataFrameModel
from pandera.api.dataframe.model import get_dtype_kwargs
Expand All @@ -22,6 +23,7 @@
AnnotationInfo,
DataFrame,
)
from pandera.utils import docstring_substitution

# if python version is < 3.11, import Self from typing_extensions
if sys.version_info < (3, 11):
Expand Down Expand Up @@ -171,6 +173,26 @@ def _build_columns_index( # pylint:disable=too-many-locals,too-many-branches

return columns, _build_schema_index(indices, **multiindex_kwargs)

@classmethod
@docstring_substitution(validate_doc=BaseSchema.validate.__doc__)
def validate(
cls: Type[Self],
check_obj: pd.DataFrame,
head: Optional[int] = None,
tail: Optional[int] = None,
sample: Optional[int] = None,
random_state: Optional[int] = None,
lazy: bool = False,
inplace: bool = False,
) -> DataFrame[Self]:
"""%(validate_doc)s"""
return cast(
Copy link
Collaborator Author

@m-richards m-richards Feb 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could potentially also all be in an if typing.TYPE_CHECKING block if that's cleaner than providing an implementation the same as the parent. I also don't know if it's preferable to use super() instead?

Edit: I also briefly looked at seeing if an overload on recieving a GeoDataFrame is possible. I had a solution working in my standalone test case, but it was pretty ugly with conditional overloads, and pre-commit mypy was not very happy with that.

DataFrame[Self],
cls.to_schema().validate(
check_obj, head, tail, sample, random_state, lazy, inplace
),
)

@classmethod
def to_json_schema(cls):
"""Serialize schema metadata into json-schema format.
Expand Down
2 changes: 1 addition & 1 deletion pandera/api/polars/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class Column(ComponentSchema[PolarsCheckObjects]):

def __init__(
self,
dtype: PolarsDtypeInputTypes = None,
dtype: Optional[PolarsDtypeInputTypes] = None,
checks: Optional[CheckList] = None,
nullable: bool = False,
unique: bool = False,
Expand Down
54 changes: 52 additions & 2 deletions pandera/api/polars/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Class-based api for polars models."""

import inspect
from typing import Dict, List, Tuple, Type
from typing import Dict, List, Tuple, Type, cast, Optional, overload, Union
from typing_extensions import Self

import pandas as pd
import polars as pl

from pandera.api.base.schema import BaseSchema
from pandera.api.checks import Check
from pandera.api.dataframe.model import DataFrameModel as _DataFrameModel
from pandera.api.dataframe.model import get_dtype_kwargs
Expand All @@ -16,7 +18,8 @@
from pandera.engines import polars_engine as pe
from pandera.errors import SchemaInitError
from pandera.typing import AnnotationInfo
from pandera.typing.polars import Series
from pandera.typing.polars import Series, LazyFrame, DataFrame
from pandera.utils import docstring_substitution


class DataFrameModel(_DataFrameModel[pl.LazyFrame, DataFrameSchema]):
Expand Down Expand Up @@ -109,6 +112,53 @@

return columns

@classmethod
@overload
def validate(
cls: Type[Self],
check_obj: pl.DataFrame,
head: Optional[int] = None,
tail: Optional[int] = None,
sample: Optional[int] = None,
random_state: Optional[int] = None,
lazy: bool = False,
inplace: bool = False,
) -> DataFrame[Self]: ...

@classmethod
@overload
def validate(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With mypy in pre-commit, this gives:
pandera\api\polars\model.py:130: error: Overloaded function signature 2 will never be matched: signature 1's parameter type(s) are the same or broader [misc]

but that doesn't seem to be the case. I have this as a test case:

import typing
from typing import reveal_type

import pandera as pda
import pandera.polars as pla

import pandas as pd
import polars as pl

from pandera.typing import Series
from pandera.typing.geopandas import GeoSeries
import geopandas as gpd


class SchemaPandas(pda.DataFrameModel):
    col1: Series[int]
    col2: Series[int]

class SchemaGeoPandas(pda.DataFrameModel):
    col1: Series[int]
    col2: GeoSeries


class SchemaPolars(pla.DataFrameModel):
    col1: Series[int]
    col2: Series[int]

pandas_df = pd.DataFrame({"col1": [1, 2, 3], "col2": [1, 2, 3]})
gdf = gpd.GeoDataFrame(pandas_df.assign(col2=gpd.GeoSeries.from_xy(x=[1,1,1], y=[2,2,2]))).set_geometry("col2")

polars_df = pl.from_pandas(pandas_df)
lazyframe = polars_df.lazy()

reveal_type(pandas_df)

print("pd.DataFrame")
result = SchemaPandas.validate(pandas_df)
# test operations for methods accepting pandas dataframes
result.to_csv("test")
pd.concat([result, result])
reveal_type(result)
print("gpd.GeoDataFrame")
reveal_type(SchemaGeoPandas.validate(gdf))

print("pl.DataFrame")
result2 = SchemaPolars.validate(polars_df)

reveal_type(result2)
print("pl.LazyFrame")
reveal_type(SchemaPolars.validate(lazyframe))
if typing.TYPE_CHECKING:
    # this should fail mypy, I don't want it crashing my test script though
    SchemaPolars.validate(pandas_df) # should fail

when then shows

fork\tester.py:34: note: Revealed type is "pandas.core.frame.DataFrame"
fork\tester.py:41: note: Revealed type is "pandera.typing.pandas.DataFrame[tester.SchemaPandas]"
fork\tester.py:43: note: Revealed type is "pandera.typing.pandas.DataFrame[tester.SchemaGeoPandas]"
fork\tester.py:48: note: Revealed type is "pandera.typing.polars.DataFrame[tester.SchemaPolars]"
fork\tester.py:50: note: Revealed type is "pandera.typing.polars.LazyFrame[tester.SchemaPolars]"
fork\tester.py:53: error: No overload variant of "validate" of "DataFrameModel" matches argument type "DataFrame"  [call-overload]

under mypy. I thought at first this might be because pl.LazyFrame and pl.DataFrame are resolving to Any in the pre-commit environment, but I've tried adding polars there explicitly and it still shows the same.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @cosmicBboy I had another look at this and I think it's ready for another look if you have a chance. I worked out how to resolve the

pandera\api\polars\model.py:130: error: Overloaded function signature 2 will never be matched: signature 1's parameter type(s) are the same or broader [misc]
I had above. This was happening for two reasons. Polars wasn't installed in the mypy environment, so pl.DataFrame and pl.LazyFrame were resolving to Any. The second reason was the global mypi config setting follow_imports=skip. I swapped this to be a per module follow_imports = skip, and then relaxed the required cases so that this import would get treated properly.

I did however notice that adding polars to the mypy environment resulted in a lot of new errors so I've temporarily supressed some errors with more rules in the mypy.ini. I was hoping I would just be able to fix them, but they probably need a bit of input (particularly around how the PolarsData namedtuple works), so I've left that over in #1911 and kept only the minimal mypy changes here.

I guess question, are you happy with mypy changes in this PR, or would be better to revert to the previous state and I just type:ignore the error I quoted up above?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm okay with this approach @m-richards. It does add a boilerplate method for each dataframe model type, but I think it's acceptable in order to make the type linters happy.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did however notice that adding polars to the mypy environment resulted in a lot of new errors so I've temporarily supressed some errors with more rules in the mypy.ini. I was hoping I would just be able to fix them, but they probably need a bit of input (particularly around how the PolarsData namedtuple works), so I've left that over in #1911 and kept only the minimal mypy changes here.

Thanks for opening up #1911! Let me know if you need any additional help address the mypy errors: I'm okay with type: ignore'ing them and slowly chipping away at the errors.

cls: Type[Self],
check_obj: pl.LazyFrame,
head: Optional[int] = None,
tail: Optional[int] = None,
sample: Optional[int] = None,
random_state: Optional[int] = None,
lazy: bool = False,
inplace: bool = False,
) -> LazyFrame[Self]: ...

@classmethod
@docstring_substitution(validate_doc=BaseSchema.validate.__doc__)
def validate(
cls: Type[Self],
check_obj: Union[pl.LazyFrame, pl.DataFrame],
head: Optional[int] = None,
tail: Optional[int] = None,
sample: Optional[int] = None,
random_state: Optional[int] = None,
lazy: bool = False,
inplace: bool = False,
) -> Union[LazyFrame[Self], DataFrame[Self]]:
"""%(validate_doc)s"""
result = cls.to_schema().validate(
check_obj, head, tail, sample, random_state, lazy, inplace
)
if isinstance(check_obj, pl.LazyFrame):
return cast(LazyFrame[Self], result)
else:
return cast(DataFrame[Self], result)

Check warning on line 160 in pandera/api/polars/model.py

View check run for this annotation

Codecov / codecov/patch

pandera/api/polars/model.py#L160

Added line #L160 was not covered by tests

@classmethod
def to_json_schema(cls):
"""Serialize schema metadata into json-schema format.
Expand Down
5 changes: 3 additions & 2 deletions pandera/api/pyspark/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from pandera.errors import SchemaInitError
from pandera.typing import AnnotationInfo
from pandera.typing.common import DataFrameBase
from pandera.typing.pyspark import DataFrame

try:
from typing_extensions import get_type_hints
Expand Down Expand Up @@ -300,10 +301,10 @@ def validate(
random_state: Optional[int] = None,
lazy: bool = True,
inplace: bool = False,
) -> Optional[DataFrameBase[TDataFrameModel]]:
) -> DataFrame[TDataFrameModel]:
"""%(validate_doc)s"""
return cast(
DataFrameBase[TDataFrameModel],
DataFrame[TDataFrameModel],
cls.to_schema().validate(
check_obj, head, tail, sample, random_state, lazy, inplace
),
Expand Down
5 changes: 4 additions & 1 deletion pandera/backends/polars/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def subsample(
obj_subsample.append(check_obj.tail(tail))
if sample is not None:
obj_subsample.append(
check_obj.sample(sample, random_state=random_state)
# mypy is detecting a bug https://github.com/unionai-oss/pandera/issues/1912
check_obj.sample( # type:ignore [attr-defined]
sample, random_state=random_state
)
)
return (
check_obj
Expand Down