Skip to content

Commit

Permalink
Add basic request modification (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
zmievsa authored Sep 15, 2023
1 parent d6a5bc4 commit 9f62195
Show file tree
Hide file tree
Showing 11 changed files with 134 additions and 40 deletions.
31 changes: 28 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,32 @@ class MyChange(VersionChange):

```

#### Changing endpoint logic (Experimental)

Oftentimes you change some of the logic of your endpoint in a way that is incompatible with or not yet supported by Universi's migrations. In order to combat this, we have come up with an ugly hack that allows you to change any detail about your endpoint's arguments or logic:

```python
from fastapi.params import Param
from fastapi import Header


class MyVersionChange(VersionChange):
description = "..."
instructions_to_migrate_to_previous_version = ()

@endpoint("/users", ["GET"]).was
def get_old_endpoint():
from some_business_logic import SomeController


async def get_users(some_old_parameter: str = Param(), some_new_required_header: str = Header()):
return SomeController(some_old_parameter, some_new_required_header)

return get_users
```

As you see, it's hacky in more ways than one. Any imports to your business logic must happen within the function to prevent circular dependencies and you have to have a function within a function as a result. It is therefore not advised to use this functionality unlesss absolutely required. I recommend to instead add an issue on our github. However, if Universi definitely cannot solve your problem -- this should be your "get out of jail free" card.

#### Dealing with endpoint duplicates

Sometimes, when you're doing some advanced changes in between versions, you will need to rewrite your endpoint function entirely. So essentially you'd have the following structure:
Expand Down Expand Up @@ -497,7 +523,7 @@ class MyChange(VersionChange):

```

#### Rename a schema
#### Rename a schema (Experimental)

If you wish to rename your schema to make sure that its name is different in openapi.json:

Expand All @@ -512,8 +538,7 @@ class MyChange(VersionChange):

```

which will replace all references to this schema with the new name. Note that this functionality is still experimental
so minor issues can be expected. If you find any -- feel free to report it in issues.
which will replace all references to this schema with the new name.

Note also that renaming a schema should not technically be a breaking change.

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "universi"
version = "1.9.1"
version = "1.10.0.rc0"
description = "Modern Stripe-like API versioning in FastAPI"
authors = ["Stanislav Zmiev <zmievsa@gmail.com>"]
license = "MIT"
Expand Down
2 changes: 2 additions & 0 deletions tests/_data/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def some_fastapi_dependency(hi: str = "hellow"):
return hi
67 changes: 52 additions & 15 deletions tests/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Annotated, Any, NewType, TypeAlias, cast, get_args

import pytest
from fastapi import APIRouter, Body, Depends, FastAPI
from fastapi import APIRouter, Body, Depends, FastAPI, Header
from fastapi.routing import APIRoute
from fastapi.testclient import TestClient
from pydantic import BaseModel
Expand Down Expand Up @@ -38,14 +38,14 @@ def router() -> VersionedAPIRouter:

@pytest.fixture()
def test_path() -> str:
return "/test/{hewoo}"
return "/test/{hewwo}"


@pytest.fixture()
def test_endpoint(router: VersionedAPIRouter, test_path: str) -> Endpoint:
@router.get(test_path)
async def test(hewwo: int):
raise NotImplementedError
return hewwo

return test

Expand Down Expand Up @@ -94,9 +94,6 @@ def __call__(
latest_schemas_module: Any = Default,
) -> tuple[list[APIRoute], list[APIRoute]]:
routers = self.create_versioned_copies(*instructions, latest_schemas_module=latest_schemas_module)
for router in routers.values():
for route in router.routes:
assert isinstance(route, APIRoute)
return cast(
tuple[list[APIRoute], list[APIRoute]],
(routers[date(2000, 1, 1)].routes, routers[date(2001, 1, 1)].routes),
Expand Down Expand Up @@ -310,7 +307,7 @@ def test__router_generation__re_creating_an_existing_endpoint__error(
with pytest.raises(
RouterGenerationError,
match=re.escape(
"Endpoint \"['GET'] /test/{hewoo}\" you tried to restore in "
"Endpoint \"['GET'] /test/{hewwo}\" you tried to restore in "
'"MyVersionChange" already existed in a newer version',
),
):
Expand All @@ -324,7 +321,7 @@ def test__router_generation__editing_an_endpoint_with_wrong_method__should_raise
):
with pytest.raises(
RouterGenerationError,
match=re.escape('Endpoint "[\'POST\'] /test/{hewoo}" you tried to change in "MyVersionChange" doesn\'t exist'),
match=re.escape('Endpoint "[\'POST\'] /test/{hewwo}" you tried to change in "MyVersionChange" doesn\'t exist'),
):
create_versioned_copies(endpoint(test_path, ["POST"]).had(description="Hewwo"))

Expand All @@ -333,27 +330,27 @@ def test__router_generation__editing_an_endpoint_with_a_less_general_method__sho
router: VersionedAPIRouter,
create_versioned_copies: CreateVersionedCopies,
):
@router.route("/test/{hewoo}", methods=["GET", "POST"])
@router.route("/test/{hewwo}", methods=["GET", "POST"])
async def test(hewwo: int):
raise NotImplementedError

with pytest.raises(
RouterGenerationError,
match=re.escape('Endpoint "[\'GET\'] /test/{hewoo}" you tried to change in "MyVersionChange" doesn\'t exist'),
match=re.escape('Endpoint "[\'GET\'] /test/{hewwo}" you tried to change in "MyVersionChange" doesn\'t exist'),
):
create_versioned_copies(endpoint("/test/{hewoo}", ["GET"]).had(description="Hewwo"))
create_versioned_copies(endpoint("/test/{hewwo}", ["GET"]).had(description="Hewwo"))


def test__router_generation__editing_multiple_endpoints_with_same_route(
router: VersionedAPIRouter,
create_versioned_api_routes: CreateVersionedAPIRoutes,
):
@router.api_route("/test/{hewoo}", methods=["GET", "POST"])
@router.api_route("/test/{hewwo}", methods=["GET", "POST"])
async def test(hewwo: int):
raise NotImplementedError

routes_2000, routes_2001 = create_versioned_api_routes(
endpoint("/test/{hewoo}", ["GET", "POST"]).had(description="Meaw"),
endpoint("/test/{hewwo}", ["GET", "POST"]).had(description="Meaw"),
)
assert len(routes_2000) == len(routes_2001) == 1
assert routes_2000[0].description == "Meaw"
Expand All @@ -367,7 +364,7 @@ def test__router_generation__editing_an_endpoint_with_a_more_general_method__sho
):
with pytest.raises(
RouterGenerationError,
match=re.escape('Endpoint "[\'POST\'] /test/{hewoo}" you tried to change in "MyVersionChange" doesn\'t exist'),
match=re.escape('Endpoint "[\'POST\'] /test/{hewwo}" you tried to change in "MyVersionChange" doesn\'t exist'),
):
create_versioned_copies(endpoint(test_path, ["GET", "POST"]).had(description="Hewwo"))

Expand Down Expand Up @@ -582,7 +579,7 @@ def test__router_generation__changing_attribute_to_the_same_value__error(
with pytest.raises(
RouterGenerationError,
match=re.escape(
'Expected attribute "path" of endpoint "[\'GET\'] /test/{hewoo}" to be different in "MyVersionChange", but '
'Expected attribute "path" of endpoint "[\'GET\'] /test/{hewwo}" to be different in "MyVersionChange", but '
"it was the same. It means that your version change has no effect on the attribute and can be removed.",
),
):
Expand Down Expand Up @@ -993,3 +990,43 @@ class V2001(VersionChange):
}
assert routers[date(2000, 1, 1)].routes[0].endpoint.func == test_endpoint2
assert routers[date(2000, 1, 1)].routes[0].endpoint.func == test_endpoint2


def test__router_generation__adding_request_header_using_hacks(
router: VersionedAPIRouter,
test_endpoint: Endpoint,
test_path: str,
api_version_var: ContextVar[date | None],
):
class MyVersionChange(VersionChange):
description = "..."
instructions_to_migrate_to_previous_version = ()

@endpoint(test_path, ["GET"]).was
def get_old_endpoint():
from tests._data.utils import some_fastapi_dependency

async def test_endpoint(hewwo: int, header: str = Header(), dep: str = Depends(some_fastapi_dependency)):
return {"hewwo": hewwo, "header": header, "dep": dep}

return test_endpoint

versions = generate_all_router_versions(
router,
versions=VersionBundle(
Version(date(2001, 1, 1), MyVersionChange),
Version(date(2000, 1, 1)),
api_version_var=api_version_var,
),
latest_schemas_module=None,
)

client_2000 = client(versions[date(2000, 1, 1)])
client_2001 = client(versions[date(2001, 1, 1)])

assert client_2000.get("/test/83", headers={"header": "11"}, params={"hi": "Mark"}).json() == {
"hewwo": 83,
"header": "11",
"dep": "Mark",
}
assert client_2001.get("/test/83", params={"header": 11, "hi": "Mark"}).json() == 83
4 changes: 1 addition & 3 deletions tests/test_tutorial/test_users_example001/run.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from universi.routing import generate_all_router_versions


if __name__ == "__main__":
from universi.routing import generate_all_router_versions
from datetime import date
from pathlib import Path

Expand Down
4 changes: 1 addition & 3 deletions tests/test_tutorial/test_users_example002/run.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from universi.routing import generate_all_router_versions


if __name__ == "__main__":
from universi.routing import generate_all_router_versions
from datetime import date
from pathlib import Path

Expand Down
Empty file removed tests/tools.py
Empty file.
4 changes: 2 additions & 2 deletions universi/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def _get_union_of_versioned_name(
if isinstance(node, ast.ClassDef):
# We add [schemas_per_version[0]] because imported_modules include "latest" and schemas_per_version do not
union = " | ".join(
f"{module.alias}.{_get_mod_name(node, module, schemas)}"
f"{module.alias}.{_get_modified_name_of_ast_node(node, module, schemas)}"
for module, schemas in zip(imported_modules, [schemas_per_version[0], *schemas_per_version])
)
return ast.Name(
Expand All @@ -196,7 +196,7 @@ def _get_union_of_versioned_name(
return node


def _get_mod_name(node: ast.ClassDef, module: ImportedModule, schemas: dict[str, ModelInfo]):
def _get_modified_name_of_ast_node(node: ast.ClassDef, module: ImportedModule, schemas: dict[str, ModelInfo]):
node_python_path = f"{module.absolute_python_path_to_origin}.{node.name}"
if node_python_path in schemas:
return schemas[node_python_path].name
Expand Down
31 changes: 22 additions & 9 deletions universi/routing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import datetime
import functools
import inspect
import itertools
import typing
import warnings
from collections.abc import Callable, Sequence
Expand Down Expand Up @@ -59,6 +58,7 @@
EndpointDidntExistInstruction,
EndpointExistedInstruction,
EndpointHadInstruction,
EndpointWasInstruction,
)
from universi.structure.versions import VersionChange

Expand Down Expand Up @@ -243,6 +243,15 @@ def _apply_endpoint_changes_to_router(self, router: fastapi.routing.APIRouter, v
'Endpoint "{endpoint_methods} {endpoint_path}" you tried to change in'
' "{version_change_name}" doesn\'t exist'
)
elif isinstance(instruction, EndpointWasInstruction):
for original_route in original_routes:
methods_to_which_we_applied_changes |= original_route.methods
original_route.endpoint = instruction.get_old_endpoint()
_remake_endpoint_dependencies(original_route)
err = (
'Endpoint "{endpoint_methods} {endpoint_path}" whose handler you tried to change in'
' "{version_change_name}" doesn\'t exist'
)
else:
assert_never(instruction)
method_diff = methods_we_should_have_applied_changes_to - methods_to_which_we_applied_changes
Expand Down Expand Up @@ -317,14 +326,7 @@ def migrate_router_to_version(self, router: fastapi.routing.APIRouter, version:
route.response_model = self._change_version_of_annotations(route.response_model, version_dir)
route.dependencies = self._change_version_of_annotations(route.dependencies, version_dir)
route.endpoint = self._change_version_of_annotations(route.endpoint, version_dir)
route.dependant = get_dependant(path=route.path_format, call=route.endpoint)
route.body_field = get_body_field(dependant=route.dependant, name=route.unique_id)
for depends in route.dependencies[::-1]:
route.dependant.dependencies.insert(
0,
get_parameterless_sub_dependant(depends=depends, path=route.path_format),
)
route.app = request_response(route.get_route_handler())
_remake_endpoint_dependencies(route)

def _change_versions_of_a_non_container_annotation(self, annotation: Any, version_dir: Path) -> Any:
if isinstance(annotation, _BaseGenericAlias | GenericAlias):
Expand Down Expand Up @@ -438,6 +440,17 @@ def _change_version_of_type(self, annotation: type, version_dir: Path):
return annotation


def _remake_endpoint_dependencies(route: fastapi.routing.APIRoute):
route.dependant = get_dependant(path=route.path_format, call=route.endpoint)
route.body_field = get_body_field(dependant=route.dependant, name=route.unique_id)
for depends in route.dependencies[::-1]:
route.dependant.dependencies.insert(
0,
get_parameterless_sub_dependant(depends=depends, path=route.path_format),
)
route.app = request_response(route.get_route_handler())


def _add_data_migrations_to_all_routes(router: fastapi.routing.APIRouter, versions: VersionBundle):
for route in router.routes:
if isinstance(route, fastapi.routing.APIRoute):
Expand Down
21 changes: 20 additions & 1 deletion universi/structure/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from enum import Enum
from typing import Any
from typing import Any, cast

from fastapi import Response
from fastapi.params import Depends
Expand Down Expand Up @@ -71,6 +71,14 @@ class EndpointDidntExistInstruction:
endpoint_func_name: str | None


@dataclass(slots=True)
class EndpointWasInstruction:
endpoint_path: str
endpoint_methods: Sequence[str]
endpoint_func_name: str | None
get_old_endpoint: Callable[..., Any]


@dataclass(slots=True)
class EndpointInstructionFactory:
endpoint_path: str
Expand Down Expand Up @@ -138,6 +146,17 @@ def had(
),
)

def was(self, get_old_endpoint: Callable[[], Any]) -> type[staticmethod]:
return cast(
type[staticmethod],
EndpointWasInstruction(
endpoint_path=self.endpoint_path,
endpoint_methods=self.endpoint_methods,
endpoint_func_name=self.endpoint_func_name,
get_old_endpoint=get_old_endpoint,
),
)


def endpoint(path: str, methods: list[str], /, *, func_name: str | None = None) -> EndpointInstructionFactory:
return EndpointInstructionFactory(path, methods, func_name)
Expand Down
8 changes: 5 additions & 3 deletions universi/structure/versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing_extensions import assert_never

from universi.exceptions import UniversiError, UniversiStructureError
from universi.structure.endpoints import AlterEndpointSubInstruction
from universi.structure.endpoints import AlterEndpointSubInstruction, EndpointWasInstruction
from universi.structure.enums import AlterEnumSubInstruction

from .._utils import Sentinel
Expand All @@ -30,7 +30,7 @@ class VersionChange:
instructions_to_migrate_to_previous_version: ClassVar[Sequence[PossibleInstructions]] = Sentinel
alter_schema_instructions: ClassVar[Sequence[AlterSchemaSubInstruction | AlterSchemaInstruction]] = Sentinel
alter_enum_instructions: ClassVar[Sequence[AlterEnumSubInstruction]] = Sentinel
alter_endpoint_instructions: ClassVar[Sequence[AlterEndpointSubInstruction]] = Sentinel
alter_endpoint_instructions: ClassVar[Sequence[AlterEndpointSubInstruction | EndpointWasInstruction]] = Sentinel
alter_response_instructions: ClassVar[dict[Any, AlterResponseInstruction]] = Sentinel
alter_request_instructions: ClassVar[dict[Any, list[AlterRequestInstruction]]] = Sentinel
_bound_versions: "VersionBundle | None"
Expand Down Expand Up @@ -59,6 +59,8 @@ def __init_subclass__(cls, _abstract: bool = False) -> None:
cls.alter_schema_instructions.append(instruction)
elif isinstance(instruction, AlterResponseInstruction):
cls.alter_response_instructions[instruction.schema] = instruction
elif isinstance(instruction, EndpointWasInstruction):
cls.alter_endpoint_instructions.append(instruction)

cls._check_no_subclassing()
cls._bound_versions = None
Expand Down Expand Up @@ -86,7 +88,7 @@ def _validate_subclass(cls):
for attr_name, attr_value in cls.__dict__.items():
if not isinstance(
attr_value,
AlterResponseInstruction | SchemaPropertyDefinitionInstruction,
AlterResponseInstruction | SchemaPropertyDefinitionInstruction | EndpointWasInstruction,
) and attr_name not in {
"description",
"side_effects",
Expand Down

0 comments on commit 9f62195

Please sign in to comment.