Skip to content

Commit

Permalink
Cleaner solution with a skip_serialization decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
Zaczero committed Feb 21, 2024
1 parent 08be876 commit 5767c7e
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 63 deletions.
82 changes: 40 additions & 42 deletions api/v1/countries.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

from anyio import create_task_group
from fastapi import APIRouter, Path
from fastapi.responses import ORJSONResponse
from sentry_sdk import start_span
from shapely.geometry import mapping

from middlewares.cache_middleware import configure_cache
from middlewares.skip_serialization import skip_serialization
from models.country import Country
from states.aed_state import AEDState
from states.country_state import CountryState
Expand All @@ -17,6 +17,7 @@

@router.get('/names')
@configure_cache(timedelta(hours=1), stale=timedelta(days=7))
@skip_serialization()
async def get_names(language: str | None = None):
countries = await CountryState.get_all_countries()
country_count_map: dict[str, int] = {}
Expand All @@ -36,54 +37,51 @@ def limit_country_names(names: dict[str, str]):
return {language: name}
return names

return ORJSONResponse(
[
{
'country_code': country.code,
'country_names': limit_country_names(country.names),
'feature_count': country_count_map[country.name],
'data_path': f'/api/v1/countries/{country.code}.geojson',
}
for country in countries
]
+ [
{
'country_code': 'WORLD',
'country_names': {'default': 'World'},
'feature_count': sum(country_count_map.values()),
'data_path': '/api/v1/countries/WORLD.geojson',
}
]
)
return [
{
'country_code': country.code,
'country_names': limit_country_names(country.names),
'feature_count': country_count_map[country.name],
'data_path': f'/api/v1/countries/{country.code}.geojson',
}
for country in countries
] + [
{
'country_code': 'WORLD',
'country_names': {'default': 'World'},
'feature_count': sum(country_count_map.values()),
'data_path': '/api/v1/countries/WORLD.geojson',
}
]


@router.get('/{country_code}.geojson')
@configure_cache(timedelta(hours=1), stale=timedelta(seconds=0))
@skip_serialization(
{
'Content-Disposition': 'attachment',
'Content-Type': 'application/geo+json; charset=utf-8',
}
)
async def get_geojson(country_code: Annotated[str, Path(min_length=2, max_length=5)]):
if country_code == 'WORLD':
aeds = await AEDState.get_all_aeds()
else:
aeds = await AEDState.get_aeds_by_country_code(country_code)

return ORJSONResponse(
{
'type': 'FeatureCollection',
'features': [
{
'type': 'Feature',
'geometry': mapping(aed.position),
'properties': {
'@osm_type': 'node',
'@osm_id': aed.id,
'@osm_version': aed.version,
**aed.tags,
},
}
for aed in aeds
],
},
headers={
'Content-Disposition': 'attachment',
'Content-Type': 'application/geo+json; charset=utf-8',
},
)
return {
'type': 'FeatureCollection',
'features': [
{
'type': 'Feature',
'geometry': mapping(aed.position),
'properties': {
'@osm_type': 'node',
'@osm_id': aed.id,
'@osm_version': aed.version,
**aed.tags,
},
}
for aed in aeds
],
}
41 changes: 20 additions & 21 deletions api/v1/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from urllib.parse import quote_plus

from fastapi import APIRouter, HTTPException
from fastapi.responses import ORJSONResponse
from pytz import timezone
from shapely import get_coordinates
from tzfpy import get_tz

from middlewares.cache_middleware import configure_cache
from middlewares.skip_serialization import skip_serialization
from states.aed_state import AEDState
from states.photo_state import PhotoState
from utils import get_wikimedia_commons_url
Expand Down Expand Up @@ -73,6 +73,7 @@ async def _get_image_data(tags: dict[str, str]) -> dict:

@router.get('/node/{node_id}')
@configure_cache(timedelta(minutes=1), stale=timedelta(minutes=5))
@skip_serialization()
async def get_node(node_id: int):
aed = await AEDState.get_aed_by_id(node_id)

Expand All @@ -89,23 +90,21 @@ async def get_node(node_id: int):
'@timezone_offset': timezone_offset,
}

return ORJSONResponse(
{
'version': 0.6,
'copyright': 'OpenStreetMap and contributors',
'attribution': 'https://www.openstreetmap.org/copyright',
'license': 'https://opendatacommons.org/licenses/odbl/1-0/',
'elements': [
{
**photo_dict,
**timezone_dict,
'type': 'node',
'id': aed.id,
'lat': y,
'lon': x,
'tags': aed.tags,
'version': aed.version,
}
],
}
)
return {
'version': 0.6,
'copyright': 'OpenStreetMap and contributors',
'attribution': 'https://www.openstreetmap.org/copyright',
'license': 'https://opendatacommons.org/licenses/odbl/1-0/',
'elements': [
{
**photo_dict,
**timezone_dict,
'type': 'node',
'id': aed.id,
'lat': y,
'lon': x,
'tags': aed.tags,
'version': aed.version,
}
],
}
20 changes: 20 additions & 0 deletions middlewares/skip_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import functools
from collections.abc import Mapping

from fastapi import Response

from orjson_response import CustomORJSONResponse


def skip_serialization(headers: Mapping[str, str] | None = None):
def decorator(func):
@functools.wraps(func)
async def wrapper(*args, **kwargs):
raw_response = await func(*args, **kwargs)
if isinstance(raw_response, Response):
return raw_response
return CustomORJSONResponse(raw_response, headers=headers)

return wrapper

return decorator

0 comments on commit 5767c7e

Please sign in to comment.