Skip to content

Commit

Permalink
Add support for newer fastapi (#224)
Browse files Browse the repository at this point in the history
  • Loading branch information
zmievsa authored Oct 17, 2024
1 parent 836fc87 commit 12a3877
Show file tree
Hide file tree
Showing 10 changed files with 34 additions and 524 deletions.
4 changes: 2 additions & 2 deletions cadwyn/_asts.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def transform_generic_alias(value: GenericAliasUnion) -> Any:
return f"{get_fancy_repr(get_origin(value))}[{', '.join(get_fancy_repr(a) for a in get_args(value))}]"


def transform_none(_: NoneType) -> Any:
def transform_none(_: Any) -> Any:
return "None"


Expand Down Expand Up @@ -231,7 +231,7 @@ def delete_keyword_from_call(attr_name: str, call: ast.Call):
def get_ast_keyword_from_argument_name_and_value(name: str, value: Any):
if not isinstance(value, ast.AST):
value = ast.parse(get_fancy_repr(value), mode="eval").body
return ast.keyword(arg=name, value=value)
return ast.keyword(arg=name, value=value) # pyright: ignore[reportArgumentType, reportCallIssue]


def pop_docstring_from_cls_body(cls_body: list[ast.stmt]) -> list[ast.stmt]:
Expand Down
4 changes: 2 additions & 2 deletions cadwyn/codegen/_plugins/class_rebuilding.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ def _modify_schema_cls(
cls_node.name = model_info.name

field_definitions = [
ast.AnnAssign(
ast.AnnAssign( # pyright: ignore[reportCallIssue]
target=ast.Name(name, ctx=ast.Store()),
annotation=copy.deepcopy(field.annotation_ast),
annotation=copy.deepcopy(field.annotation_ast), # pyright: ignore[reportArgumentType]
# We do this because next plugins **might** use a transformer which will edit the ast within the field
# and break rendering
value=copy.deepcopy(field.value_ast),
Expand Down
8 changes: 4 additions & 4 deletions cadwyn/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ async def dispatch(
request=request,
dependant=self.version_header_validation_dependant,
async_exit_stack=async_exit_stack,
embed_body_fields=False,
)
values, errors, *_ = solved_result
if errors:
return self.default_response_class(status_code=422, content=_normalize_errors(errors))
api_version = cast(date, values[self.api_version_header_name.replace("-", "_")])
if solved_result.errors:
return self.default_response_class(status_code=422, content=_normalize_errors(solved_result.errors))
api_version = cast(date, solved_result.values[self.api_version_header_name.replace("-", "_")])
self.api_version_var.set(api_version)

response = await call_next(request)
Expand Down
4 changes: 2 additions & 2 deletions cadwyn/route_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def _extract_internal_request_schemas_from_router(

def _extract_internal_request_schemas_from_annotations(annotations: dict[str, Any]):
for key, annotation in annotations.items():
if isinstance(annotation, type(Annotated[int, int])):
if isinstance(annotation, type(Annotated[int, int])): # pyright: ignore[reportArgumentType]
args = get_args(annotation)
if isinstance(args[1], type) and issubclass( # pragma: no branch
args[1],
Expand Down Expand Up @@ -525,7 +525,7 @@ def migrate_route_to_version(
):
if route.response_model is not None and not ignore_response_model:
route.response_model = self._change_version_of_annotations(route.response_model, version_dir)
route.response_field = fastapi.utils.create_response_field(
route.response_field = fastapi.utils.create_model_field(
name="Response_" + route.unique_id,
type_=route.response_model,
mode="serialization",
Expand Down
12 changes: 8 additions & 4 deletions cadwyn/structure/versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,8 @@ async def _migrate_request(
current_version: VersionDate,
head_route: APIRoute,
exit_stack: AsyncExitStack,
*,
embed_body_fields: bool,
) -> dict[str, Any]:
method = request.method
for v in reversed(self.versions):
Expand All @@ -465,19 +467,20 @@ async def _migrate_request(
request.scope["headers"] = tuple((key.encode(), value.encode()) for key, value in request_info.headers.items())
del request._headers
# Remember this: if len(body_params) == 1, then route.body_schema == route.dependant.body_params[0]
dependencies, errors, _, _, _ = await solve_dependencies(
result = await solve_dependencies(
request=request,
response=response,
dependant=head_dependant,
body=request_info.body,
dependency_overrides_provider=head_route.dependency_overrides_provider,
async_exit_stack=exit_stack,
embed_body_fields=embed_body_fields,
)
if errors:
if result.errors:
raise CadwynHeadRequestValidationError(
_normalize_errors(errors), body=request_info.body, version=current_version
_normalize_errors(result.errors), body=request_info.body, version=current_version
)
return dependencies
return result.values

def _migrate_response(
self,
Expand Down Expand Up @@ -755,6 +758,7 @@ async def _convert_endpoint_kwargs_to_version(
api_version,
head_route,
exit_stack=exit_stack,
embed_body_fields=route._embed_body_fields,
)
# Because we re-added it into our kwargs when we did solve_dependencies
if _CADWYN_REQUEST_PARAM_NAME in new_kwargs:
Expand Down
Loading

0 comments on commit 12a3877

Please sign in to comment.