Skip to content

Commit

Permalink
Fix status code handling in migrations
Browse files Browse the repository at this point in the history
  • Loading branch information
zmievsa committed Feb 16, 2024
1 parent 8203578 commit 03bb738
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 11 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@ Please follow [the Keep a Changelog standard](https://keepachangelog.com/en/1.0.

## [Unreleased]

## [3.4.5]

### Fixed

* Previously, Cadwyn was unable to handle HTTP status errors in response converters
* Previously, Cadwyn did not set the default status code for ResponseInfo

## [3.4.4]

### Fixed
Expand Down
47 changes: 39 additions & 8 deletions cadwyn/structure/versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,11 +423,12 @@ async def decorator(*args: Any, **kwargs: Any) -> _R:
if response_param_name == _CADWYN_RESPONSE_PARAM_NAME:
_add_keyword_only_parameter(decorator, _CADWYN_RESPONSE_PARAM_NAME, FastapiResponse)

return decorator # pyright: ignore[reportGeneralTypeIssues]
return decorator # pyright: ignore[reportReturnType]

return wrapper

async def _convert_endpoint_response_to_version(
# TODO: Simplify it
async def _convert_endpoint_response_to_version( # noqa: C901
self,
func_to_get_response_from: Endpoint,
latest_route: APIRoute,
Expand All @@ -437,14 +438,23 @@ async def _convert_endpoint_response_to_version(
kwargs: dict[str, Any],
fastapi_response_dependency: FastapiResponse,
) -> Any:
raised_exception = None
if response_param_name == _CADWYN_RESPONSE_PARAM_NAME:
kwargs.pop(response_param_name)
if is_async_callable(func_to_get_response_from):
response_or_response_body: FastapiResponse | object = await func_to_get_response_from(**kwargs)
else:
response_or_response_body: FastapiResponse | object = await run_in_threadpool(
func_to_get_response_from,
**kwargs,
try:
if is_async_callable(func_to_get_response_from):
response_or_response_body: FastapiResponse | object = await func_to_get_response_from(**kwargs)
else:
response_or_response_body: FastapiResponse | object = await run_in_threadpool(
func_to_get_response_from,
**kwargs,
)
except HTTPException as exc:
raised_exception = exc
response_or_response_body = FastapiResponse(
content=json.dumps({"detail": raised_exception.detail}),
status_code=raised_exception.status_code,
headers=raised_exception.headers,
)
api_version = self.api_version_var.get()
if api_version is None:
Expand All @@ -462,8 +472,18 @@ async def _convert_endpoint_response_to_version(
else:
body = None
# TODO (https://github.com/zmievsa/cadwyn/issues/51): Only do this if there are migrations

response_info = ResponseInfo(response_or_response_body, body)
else:
if fastapi_response_dependency.status_code is not None:
status_code = fastapi_response_dependency.status_code
elif route.status_code is not None:
status_code = route.status_code
elif raised_exception is not None:
raise NotImplementedError
else:
status_code = 200
fastapi_response_dependency.status_code = status_code
response_info = ResponseInfo(
fastapi_response_dependency,
_prepare_response_content(
Expand Down Expand Up @@ -495,6 +515,17 @@ async def _convert_endpoint_response_to_version(
if response_info.body is not None and hasattr(response_info._response, "body"):
# TODO (https://github.com/zmievsa/cadwyn/issues/51): Only do this if there are migrations
response_info._response.body = json.dumps(response_info.body).encode()

if raised_exception is not None and response_info.status_code >= 400:
if isinstance(response_info.body, dict) and "detail" in response_info.body:
detail = response_info.body["detail"]
else:
detail = response_info.body
raise HTTPException(
status_code=response_info.status_code,
detail=detail,
headers=dict(response_info.headers),
)
return response_info._response
return response_info.body

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "cadwyn"
version = "3.4.4"
version = "3.4.5"
description = "Production-ready community-driven modern Stripe-like API versioning in FastAPI"
authors = ["Stanislav Zmiev <zmievsa@gmail.com>"]
license = "MIT"
Expand Down
128 changes: 126 additions & 2 deletions tests/test_data_migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import fastapi
import pytest
from dirty_equals import IsPartialDict, IsStr
from fastapi import APIRouter, Body, Cookie, File, Header, Query, Request, Response, UploadFile
from fastapi import APIRouter, Body, Cookie, File, Header, HTTPException, Query, Request, Response, UploadFile
from fastapi.responses import JSONResponse
from fastapi.routing import APIRoute
from starlette.responses import StreamingResponse
Expand Down Expand Up @@ -490,7 +490,7 @@ def test__all_response_components_migration__post_endpoint__migration_filled_res
@convert_response_to_previous_version_for(latest_module.AnyResponseSchema)
def migrator(response: ResponseInfo):
response.body["body_key"] = "body_val"
assert response.status_code is None
assert response.status_code == 200
response.status_code = 300
response.headers["header"] = "header_val"
response.set_cookie("cookie_key", "cookie_val", max_age=83)
Expand Down Expand Up @@ -1057,3 +1057,127 @@ def response_converter(response: ResponseInfo):
clients = create_versioned_clients(version_change(req=request_converter, resp=response_converter))
assert clients[date(2000, 1, 1)].post("/test/83").json() == [83, "Hewwo", "World"]
assert clients[date(2001, 1, 1)].post("/test/83").json() == [83, "wow"]


def test__request_and_response_migrations__for_endpoint_with_http_exception__can_migrate_to_200(
create_versioned_clients: CreateVersionedClients,
latest_module,
router: VersionedAPIRouter,
):
@router.post("/test")
async def endpoint():
raise HTTPException(status_code=404)

@convert_response_to_previous_version_for("/test", ["POST"])
def response_converter(response: ResponseInfo):
response.status_code = 200
response.body = {"hello": "darkness"}
response.headers["hewwo"] = "dawkness"

clients = create_versioned_clients(version_change(resp=response_converter))
resp_2000 = clients[date(2000, 1, 1)].post("/test")
assert resp_2000.status_code == 200
assert resp_2000.json() == {"hello": "darkness"}
assert resp_2000.headers["hewwo"] == "dawkness"

resp_2001 = clients[date(2001, 1, 1)].post("/test")
assert resp_2001.status_code == 404
assert resp_2001.json() == {"detail": "Not Found"}
assert "hewwo" not in resp_2001.headers


def test__request_and_response_migrations__for_endpoint_with_http_exception__can_migrate_to_another_error(
create_versioned_clients: CreateVersionedClients,
latest_module,
router: VersionedAPIRouter,
):
@router.post("/test")
async def endpoint():
raise HTTPException(status_code=404)

@convert_response_to_previous_version_for("/test", ["POST"])
def response_converter(response: ResponseInfo):
response.status_code = 401
response.body = None

clients = create_versioned_clients(version_change(resp=response_converter))
resp_2000 = clients[date(2000, 1, 1)].post("/test")
assert resp_2000.status_code == 401
assert resp_2000.json() == {"detail": "Unauthorized"}

resp_2001 = clients[date(2001, 1, 1)].post("/test")
assert resp_2001.status_code == 404
assert resp_2001.json() == {"detail": "Not Found"}


def test__request_and_response_migrations__for_endpoint_with_no_default_status_code__response_should_contain_default(
create_versioned_clients: CreateVersionedClients,
latest_module,
router: VersionedAPIRouter,
):
@router.post("/test")
async def endpoint():
return 83

@convert_response_to_previous_version_for("/test", ["POST"])
def response_converter(response: ResponseInfo):
assert response.status_code == 200

clients = create_versioned_clients(version_change(resp=response_converter))

resp_2000 = clients[date(2000, 1, 1)].post("/test")
assert resp_2000.status_code == 200
assert resp_2000.json() == 83

resp_2001 = clients[date(2001, 1, 1)].post("/test")
assert resp_2001.status_code == 200
assert resp_2001.json() == 83


def test__request_and_response_migrations__for_endpoint_with_custom_status_code__response_should_contain_default(
create_versioned_clients: CreateVersionedClients,
latest_module,
router: VersionedAPIRouter,
):
@router.post("/test", status_code=201)
async def endpoint():
return 83

@convert_response_to_previous_version_for("/test", ["POST"])
def response_converter(response: ResponseInfo):
assert response.status_code == 201

clients = create_versioned_clients(version_change(resp=response_converter))

resp_2000 = clients[date(2000, 1, 1)].post("/test")
assert resp_2000.status_code == 201
assert resp_2000.json() == 83

resp_2001 = clients[date(2001, 1, 1)].post("/test")
assert resp_2001.status_code == 201
assert resp_2001.json() == 83


def test__request_and_response_migrations__for_endpoint_with_modified_status_code__response_should_not_change(
create_versioned_clients: CreateVersionedClients,
latest_module,
router: VersionedAPIRouter,
):
@router.post("/test")
async def endpoint(response: Response):
response.status_code = 201
return 83

@convert_response_to_previous_version_for("/test", ["POST"])
def response_converter(response: ResponseInfo):
assert response.status_code == 201

clients = create_versioned_clients(version_change(resp=response_converter))

resp_2000 = clients[date(2000, 1, 1)].post("/test")
assert resp_2000.status_code == 201
assert resp_2000.json() == 83

resp_2001 = clients[date(2001, 1, 1)].post("/test")
assert resp_2001.status_code == 201
assert resp_2001.json() == 83

0 comments on commit 03bb738

Please sign in to comment.