Skip to content

Commit

Permalink
Fix dependency overrides not working in old versions (#195)
Browse files Browse the repository at this point in the history
  • Loading branch information
zmievsa authored Jun 14, 2024
1 parent 9d564ed commit 18f2606
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 42 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@ Please follow [the Keep a Changelog standard](https://keepachangelog.com/en/1.0.

## [Unreleased]

## [3.15.7]

### Fixed

* Wrong globals being used for wrapped endpoints in older versions, sometimes causing FastAPI to fail to resolve forward references on endpoint generation (see #192 for more details)
* dependency_overrides not working for old versions of the endpoints because they were wrapped and wraps did not have the same `__hash__` and `__eq__` as the original dependencies

## [3.15.6]

### Added
Expand Down
121 changes: 83 additions & 38 deletions cadwyn/route_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,7 @@
import fastapi.routing
import fastapi.security.base
import fastapi.utils
from fastapi import (
APIRouter,
Request, # noqa: F401 # We import Request for libraries like svcs that expect it to be in globals
Response, # noqa: F401 # We import Request for libraries like svcs that expect it to be in globals
)
from fastapi import APIRouter
from fastapi._compat import ModelField as FastAPIModelField
from fastapi._compat import create_body_model
from fastapi.params import Depends
Expand Down Expand Up @@ -67,7 +63,7 @@
if TYPE_CHECKING:
from fastapi.dependencies.models import Dependant

_T = TypeVar("_T", bound=Callable[..., Any])
_Call = TypeVar("_Call", bound=Callable[..., Any])
_R = TypeVar("_R", bound=fastapi.routing.APIRouter)
# This is a hack we do because we can't guarantee how the user will use the router.
_DELETED_ROUTE_TAG = "_CADWYN_DELETED_ROUTE"
Expand Down Expand Up @@ -121,7 +117,7 @@ def generate_versioned_routers(


class VersionedAPIRouter(fastapi.routing.APIRouter):
def only_exists_in_older_versions(self, endpoint: _T) -> _T:
def only_exists_in_older_versions(self, endpoint: _Call) -> _Call:
route = _get_route_from_func(self.routes, endpoint)
if route is None:
raise LookupError(
Expand Down Expand Up @@ -466,9 +462,10 @@ def _extract_internal_request_schemas_from_annotations(annotations: dict[str, An

for route in router.routes:
if isinstance(route, APIRoute): # pragma: no branch
route.endpoint = _modify_callable(
route.endpoint = _modify_callable_annotations(
route.endpoint,
modify_annotations=_extract_internal_request_schemas_from_annotations,
annotation_modifying_wrapper_factory=_copy_endpoint,
)
_remake_endpoint_dependencies(route)
return schema_to_internal_request_body_representation
Expand Down Expand Up @@ -575,7 +572,12 @@ def _change_versions_of_a_non_container_annotation(self, annotation: Any, versio
def modifier(annotation: Any):
return self._change_version_of_annotations(annotation, version_dir)

return _modify_callable(annotation, modifier, modifier)
return _modify_callable_annotations(
annotation,
modifier,
modifier,
annotation_modifying_wrapper_factory=_copy_function_through_class_based_wrapper,
)
else:
return annotation

Expand Down Expand Up @@ -636,12 +638,14 @@ def _validate_source_file_is_located_in_template_dir(self, annotation: type, sou
)


def _modify_callable(
call: Callable,
def _modify_callable_annotations(
call: _Call,
modify_annotations: Callable[[dict[str, Any]], dict[str, Any]] = lambda a: a,
modify_defaults: Callable[[tuple[Any, ...]], tuple[Any, ...]] = lambda a: a,
):
annotation_modifying_wrapper = _copy_function(call)
*,
annotation_modifying_wrapper_factory: Callable[[_Call], _Call],
) -> _Call:
annotation_modifying_wrapper = annotation_modifying_wrapper_factory(call)
old_params = inspect.signature(call).parameters
callable_annotations = annotation_modifying_wrapper.__annotations__
annotation_modifying_wrapper.__annotations__ = modify_annotations(callable_annotations)
Expand Down Expand Up @@ -798,37 +802,78 @@ def _get_route_from_func(
return None


def _copy_function(function: _T) -> _T:
while hasattr(function, "__alt_wrapped__"):
function = function.__alt_wrapped__
if not isinstance(function, types.FunctionType | types.MethodType):
# This means that the callable is actually an instance of a regular class
function = function.__call__
if inspect.iscoroutinefunction(function):
def _copy_endpoint(function: Any) -> Any:
function = _unwrap_callable(function)
function_copy: Any = types.FunctionType(
function.__code__,
function.__globals__,
name=function.__name__,
argdefs=function.__defaults__,
closure=function.__closure__,
)
function_copy = functools.update_wrapper(function_copy, function)
# Otherwise it will have the same signature as __wrapped__ due to how inspect module works
del function_copy.__wrapped__

function_copy._original_callable = function
function.__kwdefaults__ = function.__kwdefaults__.copy() if function.__kwdefaults__ is not None else {}

return function_copy

@functools.wraps(function)
async def annotation_modifying_wrapper( # pyright: ignore[reportRedeclaration]
*args: Any,
**kwargs: Any,
) -> Any:
return await function(*args, **kwargs)

class _CallableWrapper:
"""__eq__ and __hash__ are needed to make sure that dependency overrides work correctly.
They are based on putting dependencies (functions) as keys for the dictionary so if we want to be able to
override the wrapper, we need to make sure that it is equivalent to the original in __hash__ and __eq__
"""

def __init__(self, original_callable: Callable) -> None:
super().__init__()
self._original_callable = original_callable
functools.update_wrapper(self, original_callable)

@property
def __globals__(self):
"""FastAPI uses __globals__ to resolve forward references in type hints
It's supposed to be an attribute on the function but we use it as property to prevent python
from trying to pickle globals when we deepcopy this wrapper
"""
#
return self._original_callable.__globals__

def __call__(self, *args: Any, **kwargs: Any):
return self._original_callable(*args, **kwargs)

def __hash__(self):
return hash(self._original_callable)

def __eq__(self, value: object) -> bool:
return self._original_callable == value # pyright: ignore[reportUnnecessaryComparison]


class _AsyncCallableWrapper(_CallableWrapper):
async def __call__(self, *args: Any, **kwargs: Any):
return await self._original_callable(*args, **kwargs)


def _copy_function_through_class_based_wrapper(call: Any):
"""Separate from copy_endpoint because endpoints MUST be functions in FastAPI, they cannot be cls instances"""
call = _unwrap_callable(call)

if inspect.iscoroutinefunction(call):
return _AsyncCallableWrapper(call)
else:
return _CallableWrapper(call)

@functools.wraps(function)
def annotation_modifying_wrapper(
*args: Any,
**kwargs: Any,
) -> Any:
return function(*args, **kwargs)

# Otherwise it will have the same signature as __wrapped__ due to how inspect module works
annotation_modifying_wrapper.__alt_wrapped__ = ( # pyright: ignore[reportAttributeAccessIssue]
annotation_modifying_wrapper.__wrapped__
)
del annotation_modifying_wrapper.__wrapped__
def _unwrap_callable(call: Any) -> Any:
while hasattr(call, "_original_callable"):
call = call._original_callable
if not isinstance(call, types.FunctionType | types.MethodType):
# This means that the callable is actually an instance of a regular class
call = call.__call__

return cast(_T, annotation_modifying_wrapper)
return call


def _route_has_a_simple_body_schema(route: APIRoute) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "cadwyn"
version = "3.15.6"
version = "3.15.7"
description = "Production-ready community-driven modern Stripe-like API versioning in FastAPI"
authors = ["Stanislav Zmiev <zmievsa@gmail.com>"]
license = "MIT"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def test__cadwyn__with_dependency_overrides__overrides_should_be_applied(
run_schema_codegen(app.versions)

async def old_dependency():
return "old"
raise NotImplementedError

async def new_dependency():
return "new"
Expand Down
2 changes: 0 additions & 2 deletions tests/test_router_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@
def get_wrapped_endpoint(endpoint: Endpoint) -> Endpoint:
while hasattr(endpoint, "__wrapped__"):
endpoint = endpoint.__wrapped__
while hasattr(endpoint, "__alt_wrapped__"):
endpoint = endpoint.__alt_wrapped__
return endpoint


Expand Down

0 comments on commit 18f2606

Please sign in to comment.