diff --git a/CHANGELOG.md b/CHANGELOG.md index ccd85a45..c0e75a32 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,13 @@ Please follow [the Keep a Changelog standard](https://keepachangelog.com/en/1.0. ## [Unreleased] +## [3.15.7] + +### Fixed + +* Wrong globals being used for wrapped endpoints in older versions, sometimes causing FastAPI to fail to resolve forward references on endpoint generation (see #192 for more details) +* dependency_overrides not working for old versions of the endpoints because they were wrapped and wraps did not have the same `__hash__` and `__eq__` as the original dependencies + ## [3.15.6] ### Added diff --git a/cadwyn/route_generation.py b/cadwyn/route_generation.py index 43225820..8a1ead0f 100644 --- a/cadwyn/route_generation.py +++ b/cadwyn/route_generation.py @@ -29,11 +29,7 @@ import fastapi.routing import fastapi.security.base import fastapi.utils -from fastapi import ( - APIRouter, - Request, # noqa: F401 # We import Request for libraries like svcs that expect it to be in globals - Response, # noqa: F401 # We import Request for libraries like svcs that expect it to be in globals -) +from fastapi import APIRouter from fastapi._compat import ModelField as FastAPIModelField from fastapi._compat import create_body_model from fastapi.params import Depends @@ -67,7 +63,7 @@ if TYPE_CHECKING: from fastapi.dependencies.models import Dependant -_T = TypeVar("_T", bound=Callable[..., Any]) +_Call = TypeVar("_Call", bound=Callable[..., Any]) _R = TypeVar("_R", bound=fastapi.routing.APIRouter) # This is a hack we do because we can't guarantee how the user will use the router. _DELETED_ROUTE_TAG = "_CADWYN_DELETED_ROUTE" @@ -121,7 +117,7 @@ def generate_versioned_routers( class VersionedAPIRouter(fastapi.routing.APIRouter): - def only_exists_in_older_versions(self, endpoint: _T) -> _T: + def only_exists_in_older_versions(self, endpoint: _Call) -> _Call: route = _get_route_from_func(self.routes, endpoint) if route is None: raise LookupError( @@ -466,9 +462,10 @@ def _extract_internal_request_schemas_from_annotations(annotations: dict[str, An for route in router.routes: if isinstance(route, APIRoute): # pragma: no branch - route.endpoint = _modify_callable( + route.endpoint = _modify_callable_annotations( route.endpoint, modify_annotations=_extract_internal_request_schemas_from_annotations, + annotation_modifying_wrapper_factory=_copy_endpoint, ) _remake_endpoint_dependencies(route) return schema_to_internal_request_body_representation @@ -575,7 +572,12 @@ def _change_versions_of_a_non_container_annotation(self, annotation: Any, versio def modifier(annotation: Any): return self._change_version_of_annotations(annotation, version_dir) - return _modify_callable(annotation, modifier, modifier) + return _modify_callable_annotations( + annotation, + modifier, + modifier, + annotation_modifying_wrapper_factory=_copy_function_through_class_based_wrapper, + ) else: return annotation @@ -636,12 +638,14 @@ def _validate_source_file_is_located_in_template_dir(self, annotation: type, sou ) -def _modify_callable( - call: Callable, +def _modify_callable_annotations( + call: _Call, modify_annotations: Callable[[dict[str, Any]], dict[str, Any]] = lambda a: a, modify_defaults: Callable[[tuple[Any, ...]], tuple[Any, ...]] = lambda a: a, -): - annotation_modifying_wrapper = _copy_function(call) + *, + annotation_modifying_wrapper_factory: Callable[[_Call], _Call], +) -> _Call: + annotation_modifying_wrapper = annotation_modifying_wrapper_factory(call) old_params = inspect.signature(call).parameters callable_annotations = annotation_modifying_wrapper.__annotations__ annotation_modifying_wrapper.__annotations__ = modify_annotations(callable_annotations) @@ -798,37 +802,78 @@ def _get_route_from_func( return None -def _copy_function(function: _T) -> _T: - while hasattr(function, "__alt_wrapped__"): - function = function.__alt_wrapped__ - if not isinstance(function, types.FunctionType | types.MethodType): - # This means that the callable is actually an instance of a regular class - function = function.__call__ - if inspect.iscoroutinefunction(function): +def _copy_endpoint(function: Any) -> Any: + function = _unwrap_callable(function) + function_copy: Any = types.FunctionType( + function.__code__, + function.__globals__, + name=function.__name__, + argdefs=function.__defaults__, + closure=function.__closure__, + ) + function_copy = functools.update_wrapper(function_copy, function) + # Otherwise it will have the same signature as __wrapped__ due to how inspect module works + del function_copy.__wrapped__ + + function_copy._original_callable = function + function.__kwdefaults__ = function.__kwdefaults__.copy() if function.__kwdefaults__ is not None else {} + + return function_copy - @functools.wraps(function) - async def annotation_modifying_wrapper( # pyright: ignore[reportRedeclaration] - *args: Any, - **kwargs: Any, - ) -> Any: - return await function(*args, **kwargs) +class _CallableWrapper: + """__eq__ and __hash__ are needed to make sure that dependency overrides work correctly. + They are based on putting dependencies (functions) as keys for the dictionary so if we want to be able to + override the wrapper, we need to make sure that it is equivalent to the original in __hash__ and __eq__ + """ + + def __init__(self, original_callable: Callable) -> None: + super().__init__() + self._original_callable = original_callable + functools.update_wrapper(self, original_callable) + + @property + def __globals__(self): + """FastAPI uses __globals__ to resolve forward references in type hints + It's supposed to be an attribute on the function but we use it as property to prevent python + from trying to pickle globals when we deepcopy this wrapper + """ + # + return self._original_callable.__globals__ + + def __call__(self, *args: Any, **kwargs: Any): + return self._original_callable(*args, **kwargs) + + def __hash__(self): + return hash(self._original_callable) + + def __eq__(self, value: object) -> bool: + return self._original_callable == value # pyright: ignore[reportUnnecessaryComparison] + + +class _AsyncCallableWrapper(_CallableWrapper): + async def __call__(self, *args: Any, **kwargs: Any): + return await self._original_callable(*args, **kwargs) + + +def _copy_function_through_class_based_wrapper(call: Any): + """Separate from copy_endpoint because endpoints MUST be functions in FastAPI, they cannot be cls instances""" + call = _unwrap_callable(call) + + if inspect.iscoroutinefunction(call): + return _AsyncCallableWrapper(call) else: + return _CallableWrapper(call) - @functools.wraps(function) - def annotation_modifying_wrapper( - *args: Any, - **kwargs: Any, - ) -> Any: - return function(*args, **kwargs) - # Otherwise it will have the same signature as __wrapped__ due to how inspect module works - annotation_modifying_wrapper.__alt_wrapped__ = ( # pyright: ignore[reportAttributeAccessIssue] - annotation_modifying_wrapper.__wrapped__ - ) - del annotation_modifying_wrapper.__wrapped__ +def _unwrap_callable(call: Any) -> Any: + while hasattr(call, "_original_callable"): + call = call._original_callable + if not isinstance(call, types.FunctionType | types.MethodType): + # This means that the callable is actually an instance of a regular class + call = call.__call__ - return cast(_T, annotation_modifying_wrapper) + return call def _route_has_a_simple_body_schema(route: APIRoute) -> bool: diff --git a/pyproject.toml b/pyproject.toml index 5f89967e..aa03c48a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "cadwyn" -version = "3.15.6" +version = "3.15.7" description = "Production-ready community-driven modern Stripe-like API versioning in FastAPI" authors = ["Stanislav Zmiev "] license = "MIT" diff --git a/tests/test_applications.py b/tests/test_applications.py index 5a300157..dcf78f29 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -106,7 +106,7 @@ def test__cadwyn__with_dependency_overrides__overrides_should_be_applied( run_schema_codegen(app.versions) async def old_dependency(): - return "old" + raise NotImplementedError async def new_dependency(): return "new" diff --git a/tests/test_router_generation.py b/tests/test_router_generation.py index 8c4749d1..65151d81 100644 --- a/tests/test_router_generation.py +++ b/tests/test_router_generation.py @@ -51,8 +51,6 @@ def get_wrapped_endpoint(endpoint: Endpoint) -> Endpoint: while hasattr(endpoint, "__wrapped__"): endpoint = endpoint.__wrapped__ - while hasattr(endpoint, "__alt_wrapped__"): - endpoint = endpoint.__alt_wrapped__ return endpoint