diff --git a/CHANGELOG.md b/CHANGELOG.md index d5b1fd79..5a87ca0f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,13 @@ Please follow [the Keep a Changelog standard](https://keepachangelog.com/en/1.0. ## [Unreleased] +## [3.4.5] + +### Fixed + +* Previously, Cadwyn was unable to handle HTTP status errors in response converters +* Previously, Cadwyn did not set the default status code for ResponseInfo + ## [3.4.4] ### Fixed diff --git a/cadwyn/structure/versions.py b/cadwyn/structure/versions.py index f68ddbfc..6686b946 100644 --- a/cadwyn/structure/versions.py +++ b/cadwyn/structure/versions.py @@ -423,11 +423,12 @@ async def decorator(*args: Any, **kwargs: Any) -> _R: if response_param_name == _CADWYN_RESPONSE_PARAM_NAME: _add_keyword_only_parameter(decorator, _CADWYN_RESPONSE_PARAM_NAME, FastapiResponse) - return decorator # pyright: ignore[reportGeneralTypeIssues] + return decorator # pyright: ignore[reportReturnType] return wrapper - async def _convert_endpoint_response_to_version( + # TODO: Simplify it + async def _convert_endpoint_response_to_version( # noqa: C901 self, func_to_get_response_from: Endpoint, latest_route: APIRoute, @@ -437,14 +438,23 @@ async def _convert_endpoint_response_to_version( kwargs: dict[str, Any], fastapi_response_dependency: FastapiResponse, ) -> Any: + raised_exception = None if response_param_name == _CADWYN_RESPONSE_PARAM_NAME: kwargs.pop(response_param_name) - if is_async_callable(func_to_get_response_from): - response_or_response_body: FastapiResponse | object = await func_to_get_response_from(**kwargs) - else: - response_or_response_body: FastapiResponse | object = await run_in_threadpool( - func_to_get_response_from, - **kwargs, + try: + if is_async_callable(func_to_get_response_from): + response_or_response_body: FastapiResponse | object = await func_to_get_response_from(**kwargs) + else: + response_or_response_body: FastapiResponse | object = await run_in_threadpool( + func_to_get_response_from, + **kwargs, + ) + except HTTPException as exc: + raised_exception = exc + response_or_response_body = FastapiResponse( + content=json.dumps({"detail": raised_exception.detail}), + status_code=raised_exception.status_code, + headers=raised_exception.headers, ) api_version = self.api_version_var.get() if api_version is None: @@ -462,8 +472,18 @@ async def _convert_endpoint_response_to_version( else: body = None # TODO (https://github.com/zmievsa/cadwyn/issues/51): Only do this if there are migrations + response_info = ResponseInfo(response_or_response_body, body) else: + if fastapi_response_dependency.status_code is not None: + status_code = fastapi_response_dependency.status_code + elif route.status_code is not None: + status_code = route.status_code + elif raised_exception is not None: + raise NotImplementedError + else: + status_code = 200 + fastapi_response_dependency.status_code = status_code response_info = ResponseInfo( fastapi_response_dependency, _prepare_response_content( @@ -495,6 +515,17 @@ async def _convert_endpoint_response_to_version( if response_info.body is not None and hasattr(response_info._response, "body"): # TODO (https://github.com/zmievsa/cadwyn/issues/51): Only do this if there are migrations response_info._response.body = json.dumps(response_info.body).encode() + + if raised_exception is not None and response_info.status_code >= 400: + if isinstance(response_info.body, dict) and "detail" in response_info.body: + detail = response_info.body["detail"] + else: + detail = response_info.body + raise HTTPException( + status_code=response_info.status_code, + detail=detail, + headers=dict(response_info.headers), + ) return response_info._response return response_info.body diff --git a/pyproject.toml b/pyproject.toml index ec2396e7..3140845f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "cadwyn" -version = "3.4.4" +version = "3.4.5" description = "Production-ready community-driven modern Stripe-like API versioning in FastAPI" authors = ["Stanislav Zmiev "] license = "MIT" diff --git a/tests/test_data_migrations.py b/tests/test_data_migrations.py index c53ba90b..f7825d7e 100644 --- a/tests/test_data_migrations.py +++ b/tests/test_data_migrations.py @@ -10,7 +10,7 @@ import fastapi import pytest from dirty_equals import IsPartialDict, IsStr -from fastapi import APIRouter, Body, Cookie, File, Header, Query, Request, Response, UploadFile +from fastapi import APIRouter, Body, Cookie, File, Header, HTTPException, Query, Request, Response, UploadFile from fastapi.responses import JSONResponse from fastapi.routing import APIRoute from starlette.responses import StreamingResponse @@ -490,7 +490,7 @@ def test__all_response_components_migration__post_endpoint__migration_filled_res @convert_response_to_previous_version_for(latest_module.AnyResponseSchema) def migrator(response: ResponseInfo): response.body["body_key"] = "body_val" - assert response.status_code is None + assert response.status_code == 200 response.status_code = 300 response.headers["header"] = "header_val" response.set_cookie("cookie_key", "cookie_val", max_age=83) @@ -1057,3 +1057,127 @@ def response_converter(response: ResponseInfo): clients = create_versioned_clients(version_change(req=request_converter, resp=response_converter)) assert clients[date(2000, 1, 1)].post("/test/83").json() == [83, "Hewwo", "World"] assert clients[date(2001, 1, 1)].post("/test/83").json() == [83, "wow"] + + +def test__request_and_response_migrations__for_endpoint_with_http_exception__can_migrate_to_200( + create_versioned_clients: CreateVersionedClients, + latest_module, + router: VersionedAPIRouter, +): + @router.post("/test") + async def endpoint(): + raise HTTPException(status_code=404) + + @convert_response_to_previous_version_for("/test", ["POST"]) + def response_converter(response: ResponseInfo): + response.status_code = 200 + response.body = {"hello": "darkness"} + response.headers["hewwo"] = "dawkness" + + clients = create_versioned_clients(version_change(resp=response_converter)) + resp_2000 = clients[date(2000, 1, 1)].post("/test") + assert resp_2000.status_code == 200 + assert resp_2000.json() == {"hello": "darkness"} + assert resp_2000.headers["hewwo"] == "dawkness" + + resp_2001 = clients[date(2001, 1, 1)].post("/test") + assert resp_2001.status_code == 404 + assert resp_2001.json() == {"detail": "Not Found"} + assert "hewwo" not in resp_2001.headers + + +def test__request_and_response_migrations__for_endpoint_with_http_exception__can_migrate_to_another_error( + create_versioned_clients: CreateVersionedClients, + latest_module, + router: VersionedAPIRouter, +): + @router.post("/test") + async def endpoint(): + raise HTTPException(status_code=404) + + @convert_response_to_previous_version_for("/test", ["POST"]) + def response_converter(response: ResponseInfo): + response.status_code = 401 + response.body = None + + clients = create_versioned_clients(version_change(resp=response_converter)) + resp_2000 = clients[date(2000, 1, 1)].post("/test") + assert resp_2000.status_code == 401 + assert resp_2000.json() == {"detail": "Unauthorized"} + + resp_2001 = clients[date(2001, 1, 1)].post("/test") + assert resp_2001.status_code == 404 + assert resp_2001.json() == {"detail": "Not Found"} + + +def test__request_and_response_migrations__for_endpoint_with_no_default_status_code__response_should_contain_default( + create_versioned_clients: CreateVersionedClients, + latest_module, + router: VersionedAPIRouter, +): + @router.post("/test") + async def endpoint(): + return 83 + + @convert_response_to_previous_version_for("/test", ["POST"]) + def response_converter(response: ResponseInfo): + assert response.status_code == 200 + + clients = create_versioned_clients(version_change(resp=response_converter)) + + resp_2000 = clients[date(2000, 1, 1)].post("/test") + assert resp_2000.status_code == 200 + assert resp_2000.json() == 83 + + resp_2001 = clients[date(2001, 1, 1)].post("/test") + assert resp_2001.status_code == 200 + assert resp_2001.json() == 83 + + +def test__request_and_response_migrations__for_endpoint_with_custom_status_code__response_should_contain_default( + create_versioned_clients: CreateVersionedClients, + latest_module, + router: VersionedAPIRouter, +): + @router.post("/test", status_code=201) + async def endpoint(): + return 83 + + @convert_response_to_previous_version_for("/test", ["POST"]) + def response_converter(response: ResponseInfo): + assert response.status_code == 201 + + clients = create_versioned_clients(version_change(resp=response_converter)) + + resp_2000 = clients[date(2000, 1, 1)].post("/test") + assert resp_2000.status_code == 201 + assert resp_2000.json() == 83 + + resp_2001 = clients[date(2001, 1, 1)].post("/test") + assert resp_2001.status_code == 201 + assert resp_2001.json() == 83 + + +def test__request_and_response_migrations__for_endpoint_with_modified_status_code__response_should_not_change( + create_versioned_clients: CreateVersionedClients, + latest_module, + router: VersionedAPIRouter, +): + @router.post("/test") + async def endpoint(response: Response): + response.status_code = 201 + return 83 + + @convert_response_to_previous_version_for("/test", ["POST"]) + def response_converter(response: ResponseInfo): + assert response.status_code == 201 + + clients = create_versioned_clients(version_change(resp=response_converter)) + + resp_2000 = clients[date(2000, 1, 1)].post("/test") + assert resp_2000.status_code == 201 + assert resp_2000.json() == 83 + + resp_2001 = clients[date(2001, 1, 1)].post("/test") + assert resp_2001.status_code == 201 + assert resp_2001.json() == 83