Skip to content

Commit

Permalink
Add converter validation (#166)
Browse files Browse the repository at this point in the history
  • Loading branch information
zmievsa authored Mar 24, 2024
1 parent 630e274 commit 7ee3871
Show file tree
Hide file tree
Showing 28 changed files with 308 additions and 317 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ repos:
- id: python-check-blanket-noqa

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.2
rev: v0.3.3
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand Down
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@ Please follow [the Keep a Changelog standard](https://keepachangelog.com/en/1.0.

## [Unreleased]

## [3.13.0]

### Added

* Validation for path converters to make sure that impossible HTTP methods cannot be used
* Validation for both path and schema converters to make sure that they are used at some point. Otherwise, router generation will raise an error

## [3.12.1]

### Fixed
Expand Down
5 changes: 3 additions & 2 deletions cadwyn/_asts.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
TYPE_CHECKING,
Any,
List,
cast,
get_args,
get_origin,
)
Expand All @@ -32,7 +33,7 @@


# A parent type of typing._GenericAlias
_BaseGenericAlias = type(List[int]).mro()[1] # noqa: UP006
_BaseGenericAlias = cast(type, type(List[int])).mro()[1] # noqa: UP006

# type(list[int]) and type(List[int]) are different which is why we have to do this.
# Please note that this problem is much wider than just lists which is why we use typing._BaseGenericAlias
Expand Down Expand Up @@ -134,7 +135,7 @@ def transform_auto(_: auto) -> Any:
return PlainRepr("auto()")


def transform_union(value: UnionType) -> Any: # pyright: ignore[reportInvalidTypeForm]
def transform_union(value: UnionType) -> Any:
return "typing.Union[" + (", ".join(get_fancy_repr(a) for a in get_args(value))) + "]"


Expand Down
11 changes: 9 additions & 2 deletions cadwyn/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,15 @@ class PydanticFieldWrapper:

annotation: Any

init_model_field: dataclasses.InitVar[ModelField] # pyright: ignore[reportInvalidTypeForm]
init_model_field: dataclasses.InitVar[ModelField]
field_info: FieldInfo = dataclasses.field(init=False)

annotation_ast: ast.expr | None = None
# In the expressions "foo: str | None = None" and "foo: str | None = Field(default=None)"
# the value_ast is "None" and "Field(default=None)" respectively
value_ast: ast.expr | None = None

def __post_init__(self, init_model_field: ModelField): # pyright: ignore[reportInvalidTypeForm]
def __post_init__(self, init_model_field: ModelField):
if isinstance(init_model_field, FieldInfo):
self.field_info = init_model_field
else:
Expand Down Expand Up @@ -111,6 +111,13 @@ def passed_field_attributes(self):
return attributes | extras


def get_annotation_from_model_field(model: ModelField) -> Any:
if PYDANTIC_V2:
return model.field_info.annotation
else:
return model.annotation


def model_fields(model: type[BaseModel]) -> dict[str, FieldInfo]:
if PYDANTIC_V2:
return model.model_fields
Expand Down
12 changes: 12 additions & 0 deletions cadwyn/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,18 @@ class RouterPathParamsModifiedError(RouterGenerationError):
pass


class RouteResponseBySchemaConverterDoesNotApplyToAnythingError(RouterGenerationError):
pass


class RouteRequestBySchemaConverterDoesNotApplyToAnythingError(RouterGenerationError):
pass


class RouteByPathConverterDoesNotApplyToAnythingError(RouterGenerationError):
pass


class RouteAlreadyExistsError(RouterGenerationError):
def __init__(self, *routes: APIRoute):
self.routes = routes
Expand Down
163 changes: 115 additions & 48 deletions cadwyn/route_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
Annotated,
Any,
Generic,
TypeAlias,
TypeVar,
_BaseGenericAlias, # pyright: ignore[reportAttributeAccessIssue]
cast,
Expand All @@ -29,6 +28,7 @@
import fastapi.routing
import fastapi.security.base
import fastapi.utils
from fastapi import APIRouter
from fastapi._compat import ModelField as FastAPIModelField
from fastapi._compat import create_body_model
from fastapi.dependencies.models import Dependant
Expand All @@ -39,19 +39,23 @@
)
from fastapi.params import Depends
from fastapi.routing import APIRoute
from issubclass import issubclass as lenient_issubclass
from pydantic import BaseModel
from starlette.routing import (
BaseRoute,
request_response,
)
from typing_extensions import Self, assert_never, deprecated

from cadwyn._compat import model_fields, rebuild_fastapi_body_param
from cadwyn._compat import get_annotation_from_model_field, model_fields, rebuild_fastapi_body_param
from cadwyn._package_utils import get_version_dir_path
from cadwyn._utils import Sentinel, UnionType, get_another_version_of_cls
from cadwyn.exceptions import (
CadwynError,
RouteAlreadyExistsError,
RouteByPathConverterDoesNotApplyToAnythingError,
RouteRequestBySchemaConverterDoesNotApplyToAnythingError,
RouteResponseBySchemaConverterDoesNotApplyToAnythingError,
RouterGenerationError,
RouterPathParamsModifiedError,
)
Expand All @@ -68,8 +72,6 @@
_R = TypeVar("_R", bound=fastapi.routing.APIRouter)
# This is a hack we do because we can't guarantee how the user will use the router.
_DELETED_ROUTE_TAG = "_CADWYN_DELETED_ROUTE"
_EndpointPath: TypeAlias = str
_EndpointMethod: TypeAlias = str


@dataclass(slots=True, frozen=True, eq=True)
Expand All @@ -78,13 +80,6 @@ class _EndpointInfo:
endpoint_methods: frozenset[str]


@dataclass(slots=True)
class _RouterInfo(Generic[_R]):
router: _R
routes_with_migrated_requests: dict[_EndpointPath, set[_EndpointMethod]]
route_bodies_with_migrated_requests: set[type[BaseModel]]


@deprecated("It will soon be deleted. Use HeadVersion version changes instead.")
class InternalRepresentationOf:
def __class_getitem__(cls, original_schema: type, /) -> type[Self]:
Expand Down Expand Up @@ -161,22 +156,15 @@ def transform(self) -> dict[VersionDate, _R]:
self.parent_router
)
router = deepcopy(self.parent_router)
router_infos: dict[VersionDate, _RouterInfo] = {}
routes_with_migrated_requests = {}
route_bodies_with_migrated_requests: set[type[BaseModel]] = set()
routers: dict[VersionDate, _R] = {}

for version in self.versions:
self.annotation_transformer.migrate_router_to_version(router, version)

router_infos[version.value] = _RouterInfo(
router,
routes_with_migrated_requests,
route_bodies_with_migrated_requests,
)
self._validate_all_data_converters_are_applied(router, version)

routers[version.value] = router
# Applying changes for the next version
routes_with_migrated_requests = _get_migrated_routes_by_path(version)
route_bodies_with_migrated_requests = {
schema for change in version.version_changes for schema in change.alter_request_by_schema_instructions
}
router = deepcopy(router)
self._apply_endpoint_changes_to_router(router, version)

Expand All @@ -194,21 +182,21 @@ def transform(self) -> dict[VersionDate, _R]:
continue
_add_request_and_response_params(head_route)
copy_of_dependant = deepcopy(head_route.dependant)
# Remember this: if len(body_params) == 1, then route.body_schema == route.dependant.body_params[0]
if len(copy_of_dependant.body_params) == 1:

if _route_has_a_simple_body_schema(head_route):
self._replace_internal_representation_with_the_versioned_schema(
copy_of_dependant,
schema_to_internal_request_body_representation,
)

for older_router_info in list(router_infos.values()):
older_route = older_router_info.router.routes[route_index]
for older_router in list(routers.values()):
older_route = older_router.routes[route_index]

# We know they are APIRoutes because of the check at the very beginning of the top loop.
# I.e. Because head_route is an APIRoute, both routes are APIRoutes too
older_route = cast(APIRoute, older_route)
# Wait.. Why do we need this code again?
if older_route.body_field is not None and len(older_route.dependant.body_params) == 1:
if older_route.body_field is not None and _route_has_a_simple_body_schema(older_route):
template_older_body_model = self.annotation_transformer._change_version_of_annotations(
older_route.body_field.type_,
self.annotation_transformer.head_version_dir,
Expand All @@ -224,13 +212,99 @@ def transform(self) -> dict[VersionDate, _R]:
copy_of_dependant,
self.versions,
)
for _, router_info in router_infos.items():
router_info.router.routes = [
for _, router in routers.items():
router.routes = [
route
for route in router_info.router.routes
for route in router.routes
if not (isinstance(route, fastapi.routing.APIRoute) and _DELETED_ROUTE_TAG in route.tags)
]
return {version: router_info.router for version, router_info in router_infos.items()}
return routers

def _validate_all_data_converters_are_applied(self, router: APIRouter, version: Version):
path_to_route_methods_mapping, head_response_models, head_request_bodies = self._extract_all_routes_identifiers(
router
)

for version_change in version.version_changes:
for by_path_converters in [
*version_change.alter_response_by_path_instructions.values(),
*version_change.alter_request_by_path_instructions.values(),
]:
for by_path_converter in by_path_converters:
missing_methods = by_path_converter.methods.difference(
path_to_route_methods_mapping[by_path_converter.path]
)

if missing_methods:
raise RouteByPathConverterDoesNotApplyToAnythingError(
f"{by_path_converter.repr_name} "
f'"{version_change.__name__}.{by_path_converter.transformer.__name__}" '
f"failed to find routes with the following methods: {list(missing_methods)}. "
f"This means that you are trying to apply this converter to non-existing endpoint(s). "
"Please, check whether the path and methods are correct. (hint: path must include "
"all path variables and have a name that was used in the version that this "
"VersionChange resides in)"
)

for by_schema_converters in version_change.alter_request_by_schema_instructions.values():
for by_schema_converter in by_schema_converters:
missing_models = set(by_schema_converter.schemas) - head_request_bodies
if missing_models:
raise RouteRequestBySchemaConverterDoesNotApplyToAnythingError(
f"Request by body schema converter "
f'"{version_change.__name__}.{by_schema_converter.transformer.__name__}" '
f"failed to find routes with the following body schemas: "
f"{[m.__name__ for m in missing_models]}. "
f"This means that you are trying to apply this converter to non-existing endpoint(s). "
)
for by_schema_converters in version_change.alter_response_by_schema_instructions.values():
for by_schema_converter in by_schema_converters:
missing_models = set(by_schema_converter.schemas) - head_response_models
if missing_models:
raise RouteResponseBySchemaConverterDoesNotApplyToAnythingError(
f"Response by response model converter "
f'"{version_change.__name__}.{by_schema_converter.transformer.__name__}" '
f"failed to find routes with the following response models: "
f"{[m.__name__ for m in missing_models]}. "
f"This means that you are trying to apply this converter to non-existing endpoint(s). "
)

def _extract_all_routes_identifiers(
self, router: APIRouter
) -> tuple[defaultdict[str, set[str]], set[Any], set[Any]]:
response_models: set[Any] = set()
request_bodies: set[Any] = set()
path_to_route_methods_mapping: dict[str, set[str]] = defaultdict(set)

for route in router.routes:
if isinstance(route, APIRoute):
if route.response_model is not None and lenient_issubclass(route.response_model, BaseModel):
# FIXME: This is going to fail on Pydantic 1
response_models.add(route.response_model)
# Not sure if it can ever be None when it's a simple schema. Eh, I would rather be safe than sorry
if _route_has_a_simple_body_schema(route) and route.body_field is not None:
annotation = get_annotation_from_model_field(route.body_field)
if lenient_issubclass(annotation, BaseModel):
# FIXME: This is going to fail on Pydantic 1
request_bodies.add(annotation)
path_to_route_methods_mapping[route.path] |= route.methods

head_response_models = {
self.annotation_transformer._change_version_of_annotations(
model,
self.versions.versioned_directories_with_head[0],
)
for model in response_models
}
head_request_bodies = {
self.annotation_transformer._change_version_of_annotations(
body,
self.versions.versioned_directories_with_head[0],
)
for body in request_bodies
}

return path_to_route_methods_mapping, head_response_models, head_request_bodies

def _replace_internal_representation_with_the_versioned_schema(
self,
Expand Down Expand Up @@ -422,8 +496,8 @@ def __init__(self, head_schemas_package: ModuleType, versions: VersionBundle) ->
self.versions = versions
self.versions.head_schemas_package = head_schemas_package
self.head_schemas_package = head_schemas_package
self.head_version_dir = min(versions.versioned_directories) # "head" < "v0000_00_00"
self.latest_version_dir = max(versions.versioned_directories) # "v2005_11_11" > "v2000_11_11"
self.head_version_dir = min(versions.versioned_directories_with_head) # "head" < "v0000_00_00"
self.latest_version_dir = max(versions.versioned_directories_with_head) # "v2005_11_11" > "v2000_11_11"

# This cache is not here for speeding things up. It's for preventing the creation of copies of the same object
# because such copies could produce weird behaviors at runtime, especially if you/fastapi do any comparisons.
Expand Down Expand Up @@ -537,7 +611,7 @@ def _change_version_of_type(self, annotation: type, version_dir: Path):
)
else:
self._validate_source_file_is_located_in_template_dir(annotation, source_file)
return get_another_version_of_cls(annotation, version_dir, self.versions.versioned_directories)
return get_another_version_of_cls(annotation, version_dir, self.versions.versioned_directories_with_head)
else:
return annotation

Expand All @@ -550,7 +624,7 @@ def _validate_source_file_is_located_in_template_dir(self, annotation: type, sou
if (
source_file.startswith(dir_with_versions)
and not source_file.startswith(template_dir)
and any(source_file.startswith(str(d)) for d in self.versions.versioned_directories)
and any(source_file.startswith(str(d)) for d in self.versions.versioned_directories_with_head)
):
raise RouterGenerationError(
f'"{annotation}" is not defined in "{self.head_version_dir}" even though it must be. '
Expand Down Expand Up @@ -725,18 +799,6 @@ def _get_route_from_func(
return None


def _get_migrated_routes_by_path(version: Version) -> dict[_EndpointPath, set[_EndpointMethod]]:
request_by_path_migration_instructions = [
version_change.alter_request_by_path_instructions for version_change in version.version_changes
]
migrated_routes = defaultdict(set)
for instruction_dict in request_by_path_migration_instructions:
for path, instruction_list in instruction_dict.items():
for instruction in instruction_list:
migrated_routes[path] |= instruction.methods
return migrated_routes


def _copy_function(function: _T) -> _T:
while hasattr(function, "__alt_wrapped__"):
function = function.__alt_wrapped__
Expand Down Expand Up @@ -768,3 +830,8 @@ def annotation_modifying_wrapper(
del annotation_modifying_wrapper.__wrapped__

return cast(_T, annotation_modifying_wrapper)


def _route_has_a_simple_body_schema(route: APIRoute) -> bool:
# Remember this: if len(body_params) == 1, then route.body_schema == route.dependant.body_params[0]
return len(route.dependant.body_params) == 1
Loading

0 comments on commit 7ee3871

Please sign in to comment.