Skip to content

Commit

Permalink
Merge pull request #603 from lenskit/tweak/component-config-type-hint
Browse files Browse the repository at this point in the history
Update component configuration type warnings
  • Loading branch information
mdekstrand authored Jan 14, 2025
2 parents f5a0422 + 90eba72 commit 07f70c3
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 69 deletions.
82 changes: 45 additions & 37 deletions lenskit/lenskit/pipeline/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
# pyright: strict
from __future__ import annotations

import inspect
import json
import warnings
from abc import abstractmethod
from abc import ABC, abstractmethod
from importlib import import_module
from types import FunctionType
from typing import (
from inspect import isabstract
from types import FunctionType, NoneType

from pydantic import JsonValue, TypeAdapter
from typing_extensions import (
Any,
Callable,
Generic,
Expand All @@ -25,17 +27,26 @@
TypeAlias,
TypeVar,
get_origin,
get_type_hints,
runtime_checkable,
)

from pydantic import JsonValue, TypeAdapter

from .types import Lazy

P = ParamSpec("P")
T = TypeVar("T")
CArgs = ParamSpec("CArgs", default=...)
"""
Argument type for a component. It is difficult to actually specify this, but
using this default parameter spec allows :class:`Component` subclasses to
typecheck by declaring the base class :meth:`~Component.__call__` to have
unknown parameters.
"""
# COut is only return, so Component[U] can be assigned to Component[T] if U ≼ T.
COut = TypeVar("COut", covariant=True)
COut = TypeVar("COut", covariant=True, default=Any)
"""
Return type for a component.
"""
PipelineFunction: TypeAlias = Callable[..., COut]


Expand Down Expand Up @@ -85,7 +96,7 @@ def load_params(self, params: dict[str, object]) -> None:
raise NotImplementedError()


class Component(Generic[COut]):
class Component(ABC, Generic[COut, CArgs]):
"""
Base class for pipeline component objects. Any component that is not just a
function should extend this class.
Expand Down Expand Up @@ -132,12 +143,15 @@ class MyComponent(Component):

def __init_subclass__(cls, **kwargs: Any):
super().__init_subclass__(**kwargs)
annots = inspect.get_annotations(cls)
if annots.get("config", None) == Any:
warnings.warn(
"component class {} does not define a config attribute".format(cls.__qualname__),
stacklevel=2,
)
if not isabstract(cls):
ct = cls._config_class(return_any=True)
if ct == Any:
warnings.warn(
"component class {} does not define a config attribute".format(
cls.__qualname__
),
stacklevel=2,
)

def __init__(self, config: object | None = None, **kwargs: Any):
if config is None:
Expand All @@ -152,26 +166,29 @@ def __init__(self, config: object | None = None, **kwargs: Any):
self.config = config

@classmethod
def _config_class(cls) -> type | None:
for base in cls.__mro__:
annots = inspect.get_annotations(base, eval_str=True)
ct = annots.get("config", None)
if ct == Any:
return None

if isinstance(ct, type):
def _config_class(cls, return_any: bool = False) -> type | None:
hints = get_type_hints(cls)
ct = hints.get("config", None)
if ct == NoneType:
return None
elif ct is None or ct == Any:
if return_any:
return ct
elif ct is not None: # pragma: nocover
warnings.warn("config attribute is not annotated with a plain type")
return get_origin(ct)
else:
return None
elif isinstance(ct, type):
return ct
else:
warnings.warn("config attribute is not annotated with a plain type")
return get_origin(ct)

def dump_config(self) -> dict[str, JsonValue]:
"""
Dump the configuration to JSON-serializable format.
"""
cfg_cls = self._config_class()
if cfg_cls:
return TypeAdapter(cfg_cls).dump_python(self.config, mode="json")
return TypeAdapter(cfg_cls).dump_python(self.config, mode="json") # type: ignore
else:
return {}

Expand All @@ -184,7 +201,7 @@ def validate_config(cls, data: Mapping[str, JsonValue] | None = None) -> object
data = {}
cfg_cls = cls._config_class()
if cfg_cls:
return TypeAdapter(cfg_cls).validate_python(data)
return TypeAdapter(cfg_cls).validate_python(data) # type: ignore
elif data: # pragma: nocover
raise RuntimeError(
"supplied configuration options but {} has no config class".format(cls.__name__)
Expand All @@ -193,19 +210,10 @@ def validate_config(cls, data: Mapping[str, JsonValue] | None = None) -> object
return None

@abstractmethod
def __call__(self, **kwargs: Any) -> Any: # pragma: nocover
def __call__(self, *args: CArgs.args, **kwargs: CArgs.kwargs) -> COut: # pragma: nocover
"""
Run the pipeline's operation and produce a result. This is the key
method for components to implement.
.. note::
Due to limitations of Python's type model, derived classes will have
a type error overriding this method when using strict type checking,
because it is very cumbersome (if not impossible) to propagate
parameter names and types through to a base class. The solution is
to either use basic type checking for implementations, or to disable
the typechecker on the ``__call__`` signature definition.
"""
...

Expand Down
1 change: 1 addition & 0 deletions lenskit/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ readme = "README.md"
license = { file = "LICENSE.md" }
dynamic = ["version"]
dependencies = [
"typing-extensions ~=4.12",
"pandas ~=2.0",
"pyarrow >=15",
"numpy >=1.25",
Expand Down
39 changes: 7 additions & 32 deletions pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ channels = ["conda-forge", "pytorch", "nodefaults"]
platforms = ["linux-64", "win-64", "osx-arm64"]

[dependencies]
typing-extensions = "~=4.12"
pandas = "~=2.0"
pyarrow = ">=15"
numpy = ">=1.25"
Expand Down

0 comments on commit 07f70c3

Please sign in to comment.