Skip to content

Commit

Permalink
Don't eval types, implement classmethod transform
Browse files Browse the repository at this point in the history
  • Loading branch information
bswck committed Nov 25, 2023
1 parent 077b4d7 commit 8820a00
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 55 deletions.
76 changes: 26 additions & 50 deletions runtime_generics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,9 @@

from __future__ import annotations

import inspect
from itertools import chain
from typing import TYPE_CHECKING, Any, ForwardRef, Generic, Protocol, TypeVar, cast
from typing import _eval_type as _typing_eval_type # type: ignore[attr-defined]
from types import MethodType
from typing import TYPE_CHECKING, Any, ForwardRef, Generic, Protocol, TypeVar
from typing import _GenericAlias as _typing_GenericAlias # type: ignore[attr-defined]
from typing import get_args as _typing_get_args

Expand Down Expand Up @@ -119,6 +118,20 @@ def _try_forward_ref(obj: str) -> str | ForwardRef:
return obj


class _ClassMethodProxy:
def __init__(
self,
alias_proxy: _AliasProxy,
cls_method: classmethod[Any, Any, Any],
) -> None:
self.alias_proxy = alias_proxy
self.cls_method = cls_method
self.__func__ = self.cls_method.__func__

def __get__(self, instance: object, owner: type[Any] | None = None) -> MethodType:
return MethodType(self.cls_method.__func__, self.alias_proxy)


class _AliasProxy(
_typing_GenericAlias, # type: ignore[misc,call-arg]
_root=True,
Expand All @@ -134,6 +147,14 @@ def __init__(
for param in (params if isinstance(params, tuple) else (params,))
)
super().__init__(origin, patched_params, **kwds)
cls_dict = vars(origin)
for cls_method_name, cls_method in cls_dict.items():
if isinstance(cls_method, classmethod):
setattr(
origin,
cls_method_name,
_ClassMethodProxy(self, cls_method),
)

def __get_arguments__(
self,
Expand All @@ -145,9 +166,6 @@ def __get_arguments__(
_origin or self.__origin__,
"__get_arguments__",
)
if not isinstance(method, classmethod): # pragma: no cover
msg = f"Expected {method} to be a classmethod, got {type(method)}"
raise TypeError(msg)
return method.__func__(self, instance)

# @override?
Expand Down Expand Up @@ -238,15 +256,6 @@ def get_all_arguments(instance: object) -> tuple[Any, ...]:
get_all_args = get_all_arguments


class _SelectorProtocol(Protocol):
@classmethod
def __get_arguments__(
cls,
instance: GenericClass,
) -> tuple[Any, ...]: # pragma: no cover
...


class FunctionalSelectorMixin:
"""
Mixin for functional selectors.
Expand All @@ -263,7 +272,7 @@ class Select(Generic[Unpack[GenericArguments]]):
def __get_arguments__(
cls,
instance: GenericClass,
) -> Any:
) -> tuple[Any, ...]:
"""Return the selected type arguments."""
arguments = get_all_arguments(instance)
tvars = _typing_get_args(cls)
Expand Down Expand Up @@ -368,39 +377,12 @@ class _FunctionalIndex(
index: Any = _FunctionalIndex


def _eval_generic_type(
obj: object,
*,
local_ns: dict[str, object] | None = None,
global_ns: dict[str, object] | None = None,
stack_offset: int = 1,
) -> _SelectorProtocol:
if local_ns is None or global_ns is None:
frame = inspect.stack()[stack_offset].frame
global_ns = global_ns or frame.f_globals or {}
local_ns = local_ns or frame.f_locals or {}
return cast(
_SelectorProtocol,
_typing_eval_type(
_try_forward_ref(obj) if isinstance(obj, str) else obj,
globalns=global_ns,
localns=local_ns,
),
)


def get_arguments(
instance: object,
argument_type: type[
Select[Unpack[GenericArguments]] | Index[Unpack[GenericArguments]]
]
| str
| ForwardRef
| None = None,
*,
local_ns: dict[str, object] | None = None,
global_ns: dict[str, object] | None = None,
stack_offset: int = 1,
) -> tuple[Any, ...]:
"""
Get the single type argument of a runtime generic instance.
Expand Down Expand Up @@ -449,13 +431,7 @@ def get_arguments(
if argument_type is None:
return arguments

impl = _eval_generic_type(
argument_type,
global_ns=global_ns,
local_ns=local_ns,
stack_offset=stack_offset + 1,
)
return impl.__get_arguments__(instance)
return argument_type.__get_arguments__(instance)


get_args = get_arguments
Expand Down
14 changes: 9 additions & 5 deletions tests/test_runtime_generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,18 +62,13 @@ def test_get_arguments() -> None:
assert get_arguments(SingleArgGeneric[complex]()) == (complex,)
assert get_arguments(TwoArgGeneric[int, str]()) == (int, str)
assert get_arguments(TwoArgGeneric[int, bytearray](), Select[T2]) == (bytearray,) # type: ignore[valid-type]
assert get_arguments(TwoArgGeneric[float, int](), "Select[T]") == (float,)
assert get_arguments(TwoArgGeneric[str, float](), Select["T2"]) == (float,) # type: ignore[valid-type]
assert get_arguments(
VariadicGeneric[str, float, bytearray, bytes](), Index[1:3]
) == (
float,
bytearray,
bytes,
) # right inclusive
assert get_arguments(
VariadicGeneric[int, bytearray, bytes, complex](), "Select[T]"
) == (int,)


def test_dunder_args_two() -> None:
Expand Down Expand Up @@ -107,3 +102,12 @@ def test_index() -> None:
index["?"](TwoArgGeneric[int, str]())
with raises(TypeError):
index[None](TwoArgGeneric[int, str]())

def test_classmethod_transform() -> None:
@runtime_generic
class C(Generic[T]):
@classmethod
def foo(cls) -> type[C[T]]:
return cls

assert C[int].foo() == C[int]

0 comments on commit 8820a00

Please sign in to comment.