Skip to content

Commit

Permalink
Use ContextVar to remove request kwarg requirement
Browse files Browse the repository at this point in the history
  • Loading branch information
Zaczero committed Feb 10, 2024
1 parent ab8bbb3 commit 19dc8e9
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 8 deletions.
4 changes: 1 addition & 3 deletions api/v1/countries.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import anyio
from anyio.streams.memory import MemoryObjectSendStream
from fastapi import APIRouter, Path, Request, Response
from fastapi import APIRouter, Path, Response
from shapely.geometry import mapping

from middlewares.cache_middleware import configure_cache
Expand All @@ -22,7 +22,6 @@ async def _count_aed_in_country(country: Country, aed_state: AEDState, send_stre
@router.get('/names')
@configure_cache(timedelta(hours=1), stale=timedelta(days=7))
async def get_names(
request: Request,
country_state: CountryStateDep,
aed_state: AEDStateDep,
language: str | None = None,
Expand Down Expand Up @@ -66,7 +65,6 @@ def limit_country_names(names: dict[str, str]):
@router.get('/{country_code}.geojson')
@configure_cache(timedelta(hours=1), stale=timedelta(seconds=0))
async def get_geojson(
request: Request,
response: Response,
country_code: Annotated[str, Path(min_length=2, max_length=5)],
country_state: CountryStateDep,
Expand Down
6 changes: 3 additions & 3 deletions api/v1/photos.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async def _fetch_image(url: str) -> tuple[bytes, str]:

@router.get('/view/{id}.webp')
@configure_cache(timedelta(days=365), stale=timedelta(days=365))
async def view(request: Request, id: str, photo_state: PhotoStateDep) -> FileResponse:
async def view(id: str, photo_state: PhotoStateDep) -> FileResponse:
info = await photo_state.get_photo_by_id(id)

if info is None:
Expand All @@ -62,14 +62,14 @@ async def view(request: Request, id: str, photo_state: PhotoStateDep) -> FileRes

@router.get('/proxy/direct/{url_encoded:path}')
@configure_cache(timedelta(days=7), stale=timedelta(days=7))
async def proxy_direct(request: Request, url_encoded: str) -> FileResponse:
async def proxy_direct(url_encoded: str) -> FileResponse:
file, content_type = await _fetch_image(unquote_plus(url_encoded))
return Response(file, media_type=content_type)


@router.get('/proxy/wikimedia-commons/{path_encoded:path}')
@configure_cache(timedelta(days=7), stale=timedelta(days=7))
async def proxy_wikimedia_commons(request: Request, path_encoded: str) -> FileResponse:
async def proxy_wikimedia_commons(path_encoded: str) -> FileResponse:
async with get_http_client() as http:
url = f'https://commons.wikimedia.org/wiki/{unquote_plus(path_encoded)}'
r = await http.get(url)
Expand Down
11 changes: 9 additions & 2 deletions middlewares/cache_middleware.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import functools
from contextvars import ContextVar
from datetime import timedelta

from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp

_request_context = ContextVar('Request_context')


def make_cache_control(max_age: timedelta, stale: timedelta):
return f'public, max-age={int(max_age.total_seconds())}, stale-while-revalidate={int(stale.total_seconds())}'
Expand All @@ -17,7 +20,11 @@ def __init__(self, app: ASGIApp, max_age: timedelta, stale: timedelta):
self.stale = stale

async def dispatch(self, request: Request, call_next):
response = await call_next(request)
token = _request_context.set(request)
try:
response = await call_next(request)
finally:
_request_context.reset(token)

if request.method in ('GET', 'HEAD') and 200 <= response.status_code < 300:
try:
Expand All @@ -40,7 +47,7 @@ def configure_cache(max_age: timedelta, stale: timedelta):
def decorator(func):
@functools.wraps(func)
async def wrapper(*args, **kwargs):
request: Request = kwargs['request']
request: Request = _request_context.get()
request.state.max_age = max_age
request.state.stale = stale
return await func(*args, **kwargs)
Expand Down

0 comments on commit 19dc8e9

Please sign in to comment.