Skip to content

Commit

Permalink
final fix
Browse files Browse the repository at this point in the history
ThomasFaria committed Oct 10, 2024
1 parent 8b90d02 commit a44cb43
Showing 2 changed files with 13 additions and 13 deletions.
22 changes: 11 additions & 11 deletions app/main.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,6 @@
"""

import gc
import json
import os
from contextlib import asynccontextmanager
from typing import Dict
@@ -12,7 +11,8 @@
import mlflow
import numpy as np
import pyarrow.parquet as pq
from fastapi import FastAPI, Query, Response
from fastapi import FastAPI, Query
from fastapi.responses import JSONResponse
from osgeo import gdal
from shapely.geometry import box

@@ -126,7 +126,7 @@ async def predict_image(image: str, polygons: bool = False) -> Dict:
lsi = load_from_cache(image, n_bands)

if polygons:
return Response(content=create_geojson_from_mask(lsi).to_json(), media_type="text/plain")
return JSONResponse(content=create_geojson_from_mask(lsi).to_json())
else:
return {"mask": lsi.label.tolist()}

@@ -204,19 +204,19 @@ def predict_cluster(
f"""Loading predictions from cache for images: {", ".join(images_from_cache)}"""
)
# Load from cache
predictions += [load_from_cache(im, n_bands) for im in images_from_cache]
predictions += [load_from_cache(im, n_bands, fs) for im in images_from_cache]

# Restrict predictions to the selected cluster
preds_cluster = subset_predictions(predictions, selected_cluster)

stats_cluster = compute_roi_statistics(predictions, selected_cluster)

response_data = {
"predictions": preds_cluster.loc[:, "geometry"].to_json(),
"statistics": stats_cluster,
"predictions": preds_cluster.to_json(),
"statistics": stats_cluster.to_json(),
}

Response(content=json.dumps(response_data), media_type="text/plain")
return JSONResponse(content=response_data)


@app.get("/predict_bbox", tags=["Predict Bounding Box"])
@@ -286,15 +286,15 @@ def predict_bbox(
if images_from_cache:
logger.info(f"Loading predictions from cache for images: {", ".join(images_from_cache)}")
# Load from cache
predictions += [load_from_cache(im, n_bands) for im in images_from_cache]
predictions += [load_from_cache(im, n_bands, fs) for im in images_from_cache]

preds_bbox = subset_predictions(predictions, bbox_geo)

stats_bbox = compute_roi_statistics(predictions, bbox_geo)

response_data = {
"predictions": preds_bbox.loc[:, "geometry"].to_json(),
"statistics": stats_bbox,
"predictions": preds_bbox.to_json(),
"statistics": stats_bbox.to_json(),
}

Response(content=json.dumps(response_data), media_type="text/plain")
return JSONResponse(content=response_data)
4 changes: 2 additions & 2 deletions app/utils.py
Original file line number Diff line number Diff line change
@@ -429,7 +429,7 @@ def subset_predictions(
geometry=[unary_union(roi.geometry).intersection(unary_union(preds.geometry))],
crs=roi.crs,
)
return preds_roi
return preds_roi.reset_index(drop=True)


def get_filename_to_polygons(dep: str, year: int, fs: S3FileSystem) -> gpd.GeoDataFrame:
@@ -494,7 +494,7 @@ def compute_roi_statistics(predictions: list, roi: gpd.GeoDataFrame) -> Dict[str
area_cluster=area_cluster, area_building=area_building, pct_building=pct_building
)

return roi.to_json()
return roi.reset_index(drop=True)


def get_cache_path(image: str) -> str:

0 comments on commit a44cb43

Please sign in to comment.