Skip to content

Commit

Permalink
fix: flask & quart resp validation (#345)
Browse files Browse the repository at this point in the history
* fix: flask response parser

Signed-off-by: Keming <kemingy94@gmail.com>

* apply to quart

Signed-off-by: Keming <kemingy94@gmail.com>

* format

Signed-off-by: Keming <kemingy94@gmail.com>

* fix lint

Signed-off-by: Keming <kemingy94@gmail.com>

---------

Signed-off-by: Keming <kemingy94@gmail.com>
  • Loading branch information
kemingy authored Sep 25, 2023
1 parent 9ea8007 commit eb6a554
Show file tree
Hide file tree
Showing 17 changed files with 203 additions and 511 deletions.
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ test: import_test
pip install --force-reinstall 'pydantic[email]<2'
pytest tests -vv -rs

update_snapshot:
@pytest --snapshot-update

doc:
cd docs && make html

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "spectree"
version = "1.2.3"
version = "1.2.4"
dynamic = []
description = "generate OpenAPI document and validate request&response with Python annotations."
readme = "README.md"
Expand Down
32 changes: 11 additions & 21 deletions spectree/plugins/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
Union,
)

from .._pydantic import ValidationError, is_root_model, serialize_model_instance
from .._pydantic import serialize_model_instance
from .._types import JsonType, ModelType, OptionalModelType
from ..config import Configuration
from ..response import Response
Expand Down Expand Up @@ -138,50 +138,40 @@ class ResponseValidationResult:


def validate_response(
skip_validation: bool,
validation_model: OptionalModelType,
response_payload: Any,
) -> ResponseValidationResult:
"""Validate a given ``response_payload`` against a ``validation_model``.
This does nothing if ``validation_model is None``.
:param skip_validation: When set to true, validation is not carried out
and the input ``response_payload`` is returned as-is. This is equivalent
to not providing a ``validation_model``.
:param validation_model: Pydantic model used to validate the provided
``response_payload``.
:param response_payload: Validated response payload. A :class:`RawResponsePayload`
should be provided when the plugin view function returned an already
JSON-serialized response payload.
"""
if not validation_model:
return ResponseValidationResult(payload=response_payload)

final_response_payload = None
skip_validation = False
if isinstance(response_payload, RawResponsePayload):
final_response_payload = response_payload.payload
elif skip_validation or validation_model is None:
final_response_payload = response_payload

if not skip_validation and validation_model and not final_response_payload:
else:
if isinstance(response_payload, validation_model):
skip_validation = True
final_response_payload = serialize_model_instance(response_payload)
elif is_root_model(validation_model):
# Make it possible to return an instance of the model __root__ type
# (i.e. not the root model itself).
try:
response_payload = validation_model(__root__=response_payload)
except ValidationError:
raise
else:
skip_validation = True
final_response_payload = serialize_model_instance(response_payload)
else:
final_response_payload = response_payload

if validation_model and not skip_validation:
if not skip_validation:
validator = (
validation_model.parse_raw
if isinstance(final_response_payload, bytes)
else validation_model.parse_obj
)
validator(final_response_payload)
final_response_payload = serialize_model_instance(
validator(final_response_payload)
)

return ResponseValidationResult(payload=final_response_payload)
8 changes: 3 additions & 5 deletions spectree/plugins/falcon_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,12 +226,11 @@ def validate(

func(*args, **kwargs)

if not self._data_set_manually(_resp):
if not self._data_set_manually(_resp) and not skip_validation and resp:
try:
status = int(_resp.status[:3])
response_validation_result = validate_response(
skip_validation=skip_validation,
validation_model=resp.find_model(status) if resp else None,
validation_model=resp.find_model(status),
response_payload=_resp.media,
)
except ValidationError as err:
Expand Down Expand Up @@ -327,11 +326,10 @@ async def validate(

await func(*args, **kwargs)

if not self._data_set_manually(_resp):
if not self._data_set_manually(_resp) and not skip_validation and resp:
try:
status = int(_resp.status[:3])
response_validation_result = validate_response(
skip_validation=skip_validation,
validation_model=resp.find_model(status) if resp else None,
response_payload=_resp.media,
)
Expand Down
67 changes: 30 additions & 37 deletions spectree/plugins/flask_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from .._pydantic import ValidationError
from .._types import ModelType
from ..response import Response
from ..utils import get_multidict_items, werkzeug_parse_rule
from .base import BasePlugin, Context, RawResponsePayload, validate_response
from ..utils import flask_response_unpack, get_multidict_items, werkzeug_parse_rule
from .base import BasePlugin, Context, validate_response


class FlaskPlugin(BasePlugin):
Expand Down Expand Up @@ -199,44 +199,37 @@ def validate(

result = func(*args, **kwargs)

status = 200
rest = []
if resp and isinstance(result, tuple):
if len(result) > 1:
response_payload, status, *rest = result
else:
response_payload = result[0]
elif isinstance(result, flask.Response):
response_payload, status = result, result.status_code
rest.append(result.headers)
else:
response_payload = result
payload, status, additional_headers = flask_response_unpack(result)
if isinstance(payload, flask.Response):
payload, resp_status, resp_headers = (
payload.get_json(),
payload.status_code,
payload.headers,
)
# the inner flask.Response.status_code only takes effect when there is
# no other status code
if status == 200:
status = resp_status
additional_headers.update(resp_headers)

try:
response_validation_result = validate_response(
skip_validation=skip_validation,
validation_model=resp.find_model(status) if resp else None,
response_payload=(
RawResponsePayload(payload=response_payload.get_json())
if (
isinstance(response_payload, flask.Response)
and not skip_validation
if not skip_validation and resp:
try:
response_validation_result = validate_response(
validation_model=resp.find_model(status),
response_payload=payload,
)
except ValidationError as err:
response = make_response(err.errors(), 500)
else:
response = make_response(
(
response_validation_result.payload,
status,
additional_headers,
)
else response_payload
),
)
except ValidationError:
response = make_response(
jsonify({"message": "response validation error"}), 500
)
else:
response = make_response(
(
response_validation_result.payload,
status,
*rest,
)
)
else:
response = make_response(result)

after(request, response, resp_validation_error, None)

Expand Down
66 changes: 30 additions & 36 deletions spectree/plugins/quart_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from .._pydantic import ValidationError
from .._types import ModelType
from ..response import Response
from ..utils import get_multidict_items, werkzeug_parse_rule
from .base import BasePlugin, Context, RawResponsePayload, validate_response
from ..utils import flask_response_unpack, get_multidict_items, werkzeug_parse_rule
from .base import BasePlugin, Context, validate_response


class QuartPlugin(BasePlugin):
Expand Down Expand Up @@ -211,43 +211,37 @@ async def validate(
else func(*args, **kwargs)
)

status = 200
rest = []
if resp and isinstance(result, tuple):
if len(result) > 1:
response_payload, status, *rest = result
else:
response_payload = result[0]
else:
response_payload = result
payload, status, additional_headers = flask_response_unpack(result)
if isinstance(payload, quart.Response):
payload, resp_status, resp_headers = (
await payload.get_json(),
payload.status_code,
payload.headers,
)
# the inner quart.Response.status_code only takes effect when there is
# no other status code
if status == 200:
status = resp_status
additional_headers.append(resp_headers)

try:
response_validation_result = validate_response(
skip_validation=skip_validation,
validation_model=resp.find_model(status) if resp else None,
response_payload=(
RawResponsePayload(
payload=(await response.get_json()) if response else None
)
if (
isinstance(response_payload, quart.Response)
and not skip_validation
if not skip_validation and resp:
try:
response_validation_result = validate_response(
validation_model=resp.find_model(status),
response_payload=payload,
)
except ValidationError as err:
response = await make_response(err.errors(), 500)
else:
response = await make_response(
(
response_validation_result.payload,
status,
additional_headers,
)
else response_payload
),
)
except ValidationError:
response = await make_response(
jsonify({"message": "response validation error"}), 500
)
else:
response = await make_response(
(
response_validation_result.payload,
status,
*rest,
)
)
else:
response = await make_response(result)

after(request, response, resp_validation_error, None)

Expand Down
19 changes: 8 additions & 11 deletions spectree/plugins/starlette_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,21 +129,18 @@ async def validate(
response = func(*args, **kwargs)

if not skip_validation and resp and response:
if (
if not (
isinstance(response, JSONResponse)
and hasattr(response, "_model_class")
and response._model_class == resp.find_model(response.status_code)
):
skip_validation = True

try:
validate_response(
skip_validation=skip_validation,
validation_model=resp.find_model(response.status_code),
response_payload=RawResponsePayload(payload=response.body),
)
except ValidationError as err:
response = JSONResponse(err.errors(), 500)
try:
validate_response(
validation_model=resp.find_model(response.status_code),
response_payload=RawResponsePayload(payload=response.body),
)
except ValidationError as err:
response = JSONResponse(err.errors(), 500)

after(request, response, resp_validation_error, instance)

Expand Down
25 changes: 25 additions & 0 deletions spectree/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,31 @@ def werkzeug_parse_rule(
yield None, None, remaining


def flask_response_unpack(resp: Any) -> Tuple[Any, int, Dict[str, Any]]:
"""Parse Flask response object into a tuple of (payload, status_code, headers)."""
status = 200
headers: Dict[str, str] = {}
payload = None
if not isinstance(resp, tuple):
return resp, status, headers
if len(resp) == 1:
payload = resp[0]
elif len(resp) == 2:
payload = resp[0]
if isinstance(resp[1], int):
status = resp[1]
else:
headers = resp[1]
elif len(resp) == 3:
payload, status, headers = resp
else:
raise ValueError(
f"Invalid return tuple: {resp}, expect (body,), (body, status), "
"(body, headers), or (body, status, headers)."
)
return payload, status, headers


def parse_resp(func: Any, naming_strategy: NamingStrategy = get_model_key):
"""
get the response spec
Expand Down
Loading

0 comments on commit eb6a554

Please sign in to comment.