diff --git a/connexion/middleware/main.py b/connexion/middleware/main.py index 317da6f3e..0e9b94e2c 100644 --- a/connexion/middleware/main.py +++ b/connexion/middleware/main.py @@ -1,3 +1,4 @@ +import copy import dataclasses import enum import logging @@ -180,7 +181,9 @@ def __init__( self.app = app self.lifespan = lifespan self.middlewares = ( - middlewares if middlewares is not None else self.default_middlewares + middlewares + if middlewares is not None + else copy.copy(self.default_middlewares) ) self.middleware_stack: t.Optional[t.Iterable[ASGIApp]] = None self.apis: t.List[API] = [] @@ -223,11 +226,16 @@ def add_middleware( if isinstance(middleware, partial): middleware = middleware.func - if middleware == position: + if middleware == position.value: self.middlewares.insert( m, t.cast(ASGIApp, partial(middleware_class, **options)) ) break + else: + raise ValueError( + f"Could not insert middleware at position {position.name}. " + f"Please make sure you have a {position.value} in your stack." + ) def _build_middleware_stack(self) -> t.Tuple[ASGIApp, t.Iterable[ASGIApp]]: """Apply all middlewares to the provided app. diff --git a/tests/api/test_errors.py b/tests/api/test_errors.py index 0987ccf66..00445f280 100644 --- a/tests/api/test_errors.py +++ b/tests/api/test_errors.py @@ -1,8 +1,3 @@ -import json - -import flask - - def fix_data(data): return data.replace(b'\\"', b'"') diff --git a/tests/test_middleware.py b/tests/test_middleware.py index 512c59c68..f71ed6844 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -1,8 +1,6 @@ -import sys -from unittest import mock - import pytest -from connexion.middleware import ConnexionMiddleware +from connexion.middleware import ConnexionMiddleware, MiddlewarePosition +from connexion.middleware.swagger_ui import SwaggerUIMiddleware from starlette.datastructures import MutableHeaders from conftest import build_app_from_fixture @@ -49,3 +47,37 @@ def test_routing_middleware(middleware_app): assert ( response.headers.get("operation_id") == "fakeapi.hello.post_greeting" ), response.status_code + + +def test_add_middleware(spec, app_class): + """Test adding middleware via the `add_middleware` method.""" + app = build_app_from_fixture("simple", app_class=app_class, spec_file=spec) + app.add_middleware(TestMiddleware) + + app_client = app.test_client() + response = app_client.post("/v1.0/greeting/robbe") + + assert ( + response.headers.get("operation_id") == "fakeapi.hello.post_greeting" + ), response.status_code + + +def test_position(spec, app_class): + """Test adding middleware via the `add_middleware` method.""" + middlewares = [ + middleware + for middleware in ConnexionMiddleware.default_middlewares + if middleware != SwaggerUIMiddleware + ] + app = build_app_from_fixture( + "simple", app_class=app_class, spec_file=spec, middlewares=middlewares + ) + + with pytest.raises(ValueError) as exc_info: + app.add_middleware(TestMiddleware, position=MiddlewarePosition.BEFORE_SWAGGER) + + assert ( + exc_info.value.args[0] + == f"Could not insert middleware at position BEFORE_SWAGGER. " + f"Please make sure you have a {SwaggerUIMiddleware} in your stack." + )