Skip to content

Commit

Permalink
Merge branch 'main' of github.com:goat-community/goat-core
Browse files Browse the repository at this point in the history
  • Loading branch information
EPajares committed Sep 4, 2024
2 parents 638f22f + fb4d276 commit 559065b
Show file tree
Hide file tree
Showing 12 changed files with 284 additions and 68 deletions.
23 changes: 16 additions & 7 deletions src/core/config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Any, Dict, Optional
from uuid import UUID

import boto3
from pydantic import BaseSettings, HttpUrl, PostgresDsn, validator
from uuid import UUID


class AsyncPostgresDsn(PostgresDsn):
allowed_schemes = {"postgres+asyncpg", "postgresql+asyncpg"}
Expand All @@ -24,15 +25,17 @@ class Settings(BaseSettings):
CUSTOMER_SCHEMA: Optional[str] = "customer"
REGION_MAPPING_PT_TABLE: Optional[str] = "basic.region_mapping_pt"
BASE_STREET_NETWORK: Optional[UUID] = "903ecdca-b717-48db-bbce-0219e41439cf"
STREET_NETWORK_EDGE_DEFAULT_LAYER_PROJECT_ID = 36126
STREET_NETWORK_NODE_DEFAULT_LAYER_PROJECT_ID = 37319

ASYNC_CLIENT_DEFAULT_TIMEOUT: Optional[float] = (
10.0 # Default timeout for async http client
)
ASYNC_CLIENT_READ_TIMEOUT: Optional[float] = (
30.0 # Read timeout for async http client
)
CRUD_NUM_RETRIES: Optional[int] = 20 # Number of times to retry calling an endpoint
CRUD_RETRY_INTERVAL: Optional[int] = 2 # Number of seconds to wait between retries
CRUD_NUM_RETRIES: Optional[int] = 20 # Number of times to retry calling an endpoint
CRUD_RETRY_INTERVAL: Optional[int] = 2 # Number of seconds to wait between retries

HEATMAP_GRAVITY_MAX_SENSITIVITY: int = 1000000

Expand Down Expand Up @@ -138,9 +141,15 @@ def assemble_s3_client(cls, v: Optional[str], values: Dict[str, Any]) -> Any:
region_name=values.get("AWS_REGION"),
)

DEFAULT_PROJECT_THUMBNAIL: Optional[str] = "https://assets.plan4better.de/img/goat_new_project_artwork.png"
DEFAULT_LAYER_THUMBNAIL: Optional[str] = "https://assets.plan4better.de/img/goat_new_dataset_thumbnail.png"
DEFAULT_REPORT_THUMBNAIL: Optional[str] = "https://goat-app-assets.s3.eu-central-1.amazonaws.com/logos/goat_green.png"
DEFAULT_PROJECT_THUMBNAIL: Optional[str] = (
"https://assets.plan4better.de/img/goat_new_project_artwork.png"
)
DEFAULT_LAYER_THUMBNAIL: Optional[str] = (
"https://assets.plan4better.de/img/goat_new_dataset_thumbnail.png"
)
DEFAULT_REPORT_THUMBNAIL: Optional[str] = (
"https://goat-app-assets.s3.eu-central-1.amazonaws.com/logos/goat_green.png"
)
ASSETS_URL: Optional[str] = None
THUMBNAIL_DIR_LAYER: Optional[str] = None

Expand All @@ -162,7 +171,7 @@ def set_thumbnail_dir_project(cls, v: Optional[str], values: Dict[str, Any]) ->

MARKER_DIR: Optional[str] = "icons/maki"
MARKER_PREFIX: Optional[str] = "goat-marker-"

class Config:
case_sensitive = True

Expand Down
21 changes: 10 additions & 11 deletions src/core/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def __init__(
self.file_path = os.path.join(self.folder_path, "file." + self.file_ending)
else:
self.file_path = os.path.join(
self.folder_path, "file." + FileUploadType.gpkg.value
self.folder_path, "file." + FileUploadType.geojson.value
)

async def _fetch_and_write(self):
Expand Down Expand Up @@ -247,10 +247,11 @@ class OGRExternalServiceFetching:
def __init__(self, url: HttpUrl, output_file: str):
self.url = url

# Initialize output GeoPackage data source
output_driver = ogr.GetDriverByName(OgrDriverType.gpkg.value)
# Initialize output GeoJSON data source
driver_type = OgrDriverType.geojson
output_driver = ogr.GetDriverByName(driver_type)
if output_driver is None:
raise Exception(f"{OgrDriverType.gpkg.value} driver is not available.")
raise Exception(f"{driver_type} driver is not available.")

self.output_data_source = output_driver.CreateDataSource(output_file)
if self.output_data_source is None:
Expand All @@ -262,7 +263,7 @@ def fetch_wfs(self, layer_name: str):
ogr.UseExceptions()

# Initialize WFS data source
wfs_data_source = ogr.Open(str(self.url))
wfs_data_source = ogr.Open(f"WFS:{str(self.url)}")
if wfs_data_source is None:
raise Exception(f"Could not open WFS service at {self.url}")

Expand All @@ -271,11 +272,6 @@ def fetch_wfs(self, layer_name: str):
if input_layer is None:
raise Exception(f"Could not find layer {layer_name} in WFS service.")

# Get geometry column name and type
geom_column = input_layer.GetGeometryColumn()
if not geom_column:
raise Exception("Could not determine geometry column for WFS layer.")

geom_type = input_layer.GetGeomType()
if geom_type == ogr.wkbUnknown:
first_feature = input_layer.GetNextFeature()
Expand All @@ -291,7 +287,6 @@ def fetch_wfs(self, layer_name: str):
layer_name,
srs=input_layer.GetSpatialRef(),
geom_type=geom_type,
options=[f"GEOMETRY_NAME={geom_column}"],
)
if output_layer is None:
raise Exception(
Expand All @@ -308,6 +303,10 @@ def fetch_wfs(self, layer_name: str):
for feature in input_layer:
output_layer.CreateFeature(feature)

# Cleanup
self.output_data_source = None
wfs_data_source = None

ogr.DontUseExceptions()

async def fetch_mvt(self):
Expand Down
4 changes: 4 additions & 0 deletions src/crud/crud_catchment_area.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,8 @@ async def catchment_area(
),
"layer_id": str(layer_id),
"scenario_id": str(params.scenario_id) if params.scenario_id else None,
"street_network_edge_layer_project_id": params.layer_project_id_street_network_edge,
"street_network_node_layer_project_id": params.layer_project_id_street_network_node,
}

await call_routing_endpoint(
Expand Down Expand Up @@ -658,6 +660,8 @@ async def catchment_area(
),
"layer_id": str(layer_id),
"scenario_id": str(params.scenario_id) if params.scenario_id else None,
"street_network_edge_layer_project_id": params.layer_project_id_street_network_edge,
"street_network_node_layer_project_id": params.layer_project_id_street_network_node,
}

await call_routing_endpoint(
Expand Down
81 changes: 70 additions & 11 deletions src/crud/crud_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,6 @@ def __init__(self, job_id, background_tasks, async_session, user_id):
f'{settings.USER_DATA_SCHEMA}."{str(self.job_id).replace("-", "")}"'
)

@job_log(job_step_name="internal_layer_create")
async def create_internal(
self,
layer_in: ILayerFromDatasetCreate,
Expand All @@ -695,9 +694,8 @@ async def create_internal(
geom_type = SupportedOgrGeomType[
file_metadata["data_types"]["geometry"]["type"]
].value
additional_attributes["properties"] = get_base_style(
feature_geometry_type=geom_type
)
if not layer_in.properties:
layer_in.properties = get_base_style(feature_geometry_type=geom_type)
additional_attributes["type"] = LayerType.feature
additional_attributes["feature_layer_type"] = FeatureType.standard
additional_attributes["feature_layer_geometry_type"] = geom_type
Expand Down Expand Up @@ -752,13 +750,8 @@ async def create_internal(
project_id=project_id,
)

return {
"msg": "Layer successfully created.",
"status": JobStatusType.finished.value,
}
return layer.id

@run_background_or_immediately(settings)
@job_init()
async def import_file(
self,
file_metadata: dict,
Expand Down Expand Up @@ -798,13 +791,30 @@ async def import_file(
job_id=self.job_id,
)
# Create layer metadata and thumbnail
result = await self.create_internal(
layer_id = await self.create_internal(
layer_in=layer_in,
file_metadata=file_metadata,
attribute_mapping=attribute_mapping,
project_id=project_id,
)

return result, layer_id

@run_background_or_immediately(settings)
@job_init()
async def import_file_job(
self,
file_metadata: dict,
layer_in: ILayerFromDatasetCreate,
project_id: UUID = None,
):
"""Create a layer from a dataset file."""

result, _ = await self.import_file(
file_metadata=file_metadata,
layer_in=layer_in,
project_id=project_id,
)
return result


Expand Down Expand Up @@ -980,3 +990,52 @@ async def delete_multi_run(
layers: list[Layer],
):
return await self.delete_multi(async_session=async_session, layers=layers)


class CRUDLayerDatasetUpdate(CRUDFailedJob):
"""CRUD class for updating the dataset of an existing layer and updating all layer project references."""

def __init__(self, job_id, background_tasks, async_session, user_id):
super().__init__(job_id, background_tasks, async_session, user_id)

@run_background_or_immediately(settings)
@job_init()
async def update(
self,
existing_layer_id: UUID,
file_metadata: dict,
layer_in: ILayerFromDatasetCreate,
):
"""Update layer dataset."""

original_name = layer_in.name

# Create a new layer with the updated dataset while transferring existing layer properties
result, layer_id = await CRUDLayerImport(
background_tasks=self.background_tasks,
async_session=self.async_session,
user_id=self.user_id,
job_id=self.job_id,
).import_file(
file_metadata=file_metadata,
layer_in=layer_in,
)

# Update all layer project references with the new layer id
await crud_layer_project.update_layer_id(
async_session=self.async_session,
layer_id=existing_layer_id,
new_layer_id=layer_id,
)

# Delete the old layer
await layer.delete(async_session=self.async_session, id=existing_layer_id)

# Rename the new layer
await layer.update(
async_session=self.async_session,
id=layer_id,
layer_in={"name": original_name},
)

return result
46 changes: 35 additions & 11 deletions src/crud/crud_layer_project.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,29 @@
# Standard library imports
import re
from typing import List, Union
from uuid import UUID

# Third party imports
from fastapi import HTTPException, status
from pydantic import ValidationError, parse_obj_as, BaseModel
from sqlalchemy import select, text
from pydantic import BaseModel, ValidationError, parse_obj_as
from sqlalchemy import select, text, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import SQLModel
from typing import Union, List

# Local application imports
from .base import CRUDBase
from src.schemas.error import UnsupportedLayerTypeError, LayerNotFoundError
from src.utils import build_where_clause
from src.core.layer import CRUDLayerBase
from src.db.models._link_model import LayerProjectLink
from src.db.models.layer import Layer
from src.db.models.project import Project
from src.schemas.layer import LayerType, FeatureGeometryType
from src.schemas.error import LayerNotFoundError, UnsupportedLayerTypeError
from src.schemas.layer import FeatureGeometryType, LayerType
from src.schemas.project import (
layer_type_mapping_read,
layer_type_mapping_update,
)
from src.core.layer import CRUDLayerBase
from src.utils import build_where_clause

# Local application imports
from .base import CRUDBase


class CRUDLayerProject(CRUDLayerBase):
async def layer_projects_to_schemas(
Expand Down Expand Up @@ -143,7 +144,10 @@ async def get_internal(
# Check if geometry type is correct
if layer_project.type == LayerType.feature.value:
if expected_geometry_types is not None:
if layer_project.feature_layer_geometry_type not in expected_geometry_types:
if (
layer_project.feature_layer_geometry_type
not in expected_geometry_types
):
raise UnsupportedLayerTypeError(
f"Layer {layer_project.name} is not a {[geom_type.value for geom_type in expected_geometry_types]} layer"
)
Expand Down Expand Up @@ -352,4 +356,24 @@ async def check_exceed_feature_cnt(
)
return feature_cnt

async def update_layer_id(
self,
async_session: AsyncSession,
layer_id: UUID,
new_layer_id: UUID,
):
"""Update layer id in layer project link."""

# Update all layers from project by id
query = (
update(LayerProjectLink)
.where(LayerProjectLink.layer_id == layer_id)
.values(layer_id=new_layer_id)
)

async with async_session.begin():
await async_session.execute(query)
await async_session.commit()


layer_project = CRUDLayerProject(LayerProjectLink)
2 changes: 2 additions & 0 deletions src/crud/crud_nearby_station_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ async def nearby_station_access(self, params: INearbyStationAccess):
catchment_area_type=CatchmentAreaTypeActiveMobility.polygon,
polygon_difference=True,
scenario_id=params.scenario_id,
layer_project_id_street_network_edge=params.layer_project_id_street_network_edge,
layer_project_id_street_network_node=params.layer_project_id_street_network_node,
)

# Compute catchment area
Expand Down
Loading

0 comments on commit 559065b

Please sign in to comment.