Skip to content

Commit

Permalink
FIX: Simplify decorators rasters.any_raster_to_xr_ds and `rasters_r…
Browse files Browse the repository at this point in the history
…io.any_raster_to_rio_ds` to better check the input types and to have clearer exceptions
  • Loading branch information
remi-braun committed Jan 6, 2025
1 parent 16047aa commit 40c0875
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 62 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

- FIX: Manage case where we have a `pd.Dataframe` instead of a `gpd.GeoDataFrame` in `vectors.read` (reading a `.dbf` file for instance)
- FIX: Simplify decorator function of `rasters.read`, to better check the input types and to have a clearer function name and clearer exceptions
- FIX: Simplify decorators `rasters.any_raster_to_xr_ds` and `rasters_rio.any_raster_to_rio_ds` to better check the input types and to have clearer exceptions

## 1.44.2 (2024-12-23)

Expand Down
28 changes: 13 additions & 15 deletions sertit/rasters.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,15 +193,18 @@ def wrapper(any_raster_type: AnyRasterType, *args, **kwargs) -> Any:

default_chunks = "auto" if dask.get_client() is not None else None
masked = kwargs.get("masked", True)
# By default, try with the input fct

# By default, try with the read fct: this fct returns the xr data structure as is and manages other input types such as tuple, rasterio datasets, paths...
try:
out = function(any_raster_type, *args, **kwargs)
except Exception as ex:
if isinstance(any_raster_type, xr.DataArray):
# Should work with a DataArray
raise ex
elif isinstance(any_raster_type, xr.Dataset):
# Try on every DataArray of the Dataset
out = function(
read(any_raster_type, chunks=default_chunks, masked=masked),
*args,
**kwargs,
)
except Exception:
# Try on every DataArray of the Dataset
# TODO: handle DataTrees?
if isinstance(any_raster_type, xr.Dataset):
try:
xds_dict = {}
convert_to_xdataset = False
Expand All @@ -215,13 +218,6 @@ def wrapper(any_raster_type: AnyRasterType, *args, **kwargs) -> Any:
return xds
except Exception as ex:
raise TypeError("Function not available for xarray.Dataset") from ex

else:
out = function(
read(any_raster_type, chunks=default_chunks, masked=masked),
*args,
**kwargs,
)
return out

return wrapper
Expand Down Expand Up @@ -829,6 +825,8 @@ def wrapper(any_raster_type: AnyRasterType, *args, **kwargs) -> Any:
"""
# Input is a path: open it with rasterio
if path.is_path(any_raster_type):
# GOTCHA: rasterio and cloudpathlib are not really compatible, so passing a CloudPath directly to rasterio (without turning it into a string) with cache the file!
# This is really not ideal, so use the string conversion instead
with rasterio.open(str(any_raster_type)) as ds:
out = function(ds, *args, **kwargs)
# Input is a tuple: we consider it's composed of an output of rasterio.read function, a numpy array and a metadata dict
Expand Down
101 changes: 54 additions & 47 deletions sertit/rasters_rio.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import geopandas as gpd
import numpy as np
import xarray as xr
from shapely.geometry import Polygon

try:
Expand Down Expand Up @@ -196,55 +197,61 @@ def wrapper(any_raster_type: AnyRasterType, *args, **kwargs) -> Any:
Returns:
Any: regular output
"""
try:
out = function(any_raster_type, *args, **kwargs)
except Exception as ex:
if path.is_path(any_raster_type):
with rasterio.open(str(any_raster_type)) as ds:
out = function(ds, *args, **kwargs)
elif isinstance(any_raster_type, tuple):
# Input is a path: open it with rasterio
if path.is_path(any_raster_type):
# GOTCHA: rasterio and cloudpathlib are not really compatible, so passing a CloudPath directly to rasterio (without turning it into a string) with cache the file!
# This is really not ideal, so use the string conversion instead
with rasterio.open(str(any_raster_type)) as ds:
out = function(ds, *args, **kwargs)
# Input is a tuple: we consider it's composed of an output of rasterio.read function, a numpy array and a metadata dict
elif isinstance(any_raster_type, tuple):
try:
arr, meta = any_raster_type
with (
MemoryFile() as memfile,
memfile.open(**meta, BIGTIFF=bigtiff_value(arr)) as ds,
):
ds.write(arr)
out = function(ds, *args, **kwargs)
else:
try:
import xarray as xr
except ModuleNotFoundError:
raise ex # noqa: B904
assert isinstance(arr, np.ndarray)
assert isinstance(meta, dict)
except (ValueError, AssertionError) as exc:
raise TypeError(
"Input tuple should be composed of a numpy array of your data and the corresponding metadata dictionary, is rasterio's sense."
) from exc

with (
MemoryFile() as memfile,
memfile.open(**meta, BIGTIFF=bigtiff_value(arr)) as ds,
):
ds.write(arr)
out = function(ds, *args, **kwargs)

# Return given xarray object as is
elif isinstance(any_raster_type, (xr.DataArray, xr.Dataset)):
from sertit.rasters import get_nodata_value_from_xr

nodata = get_nodata_value_from_xr(any_raster_type)

meta = {
"driver": "GTiff",
"dtype": any_raster_type.dtype,
"nodata": nodata,
"width": any_raster_type.rio.width,
"height": any_raster_type.rio.height,
"count": any_raster_type.rio.count,
"crs": any_raster_type.rio.crs,
"transform": any_raster_type.rio.transform(),
}
with (
MemoryFile() as memfile,
memfile.open(**meta, BIGTIFF=bigtiff_value(any_raster_type)) as ds,
):
if nodata is not None:
arr = any_raster_type.fillna(nodata)
else:
if isinstance(any_raster_type, (xr.DataArray, xr.Dataset)):
from sertit.rasters import get_nodata_value_from_xr

nodata = get_nodata_value_from_xr(any_raster_type)

meta = {
"driver": "GTiff",
"dtype": any_raster_type.dtype,
"nodata": nodata,
"width": any_raster_type.rio.width,
"height": any_raster_type.rio.height,
"count": any_raster_type.rio.count,
"crs": any_raster_type.rio.crs,
"transform": any_raster_type.rio.transform(),
}
with (
MemoryFile() as memfile,
memfile.open(
**meta, BIGTIFF=bigtiff_value(any_raster_type)
) as ds,
):
if nodata is not None:
arr = any_raster_type.fillna(nodata)
else:
arr = any_raster_type
ds.write(arr.data)
out = function(ds, *args, **kwargs)
else:
raise ex
arr = any_raster_type
ds.write(arr.data)
out = function(ds, *args, **kwargs)

# Run the fct directly on the input (which should be a rasterio Dataset). If not, this will fail and it's expected.
else:
out = function(any_raster_type, *args, **kwargs)

return out

return wrapper
Expand Down

0 comments on commit 40c0875

Please sign in to comment.