diff --git a/CHANGES.md b/CHANGES.md index 05448f7..66910d6 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -4,6 +4,7 @@ - FIX: Manage case where we have a `pd.Dataframe` instead of a `gpd.GeoDataFrame` in `vectors.read` (reading a `.dbf` file for instance) - FIX: Simplify decorator function of `rasters.read`, to better check the input types and to have a clearer function name and clearer exceptions +- FIX: Simplify decorators `rasters.any_raster_to_xr_ds` and `rasters_rio.any_raster_to_rio_ds` to better check the input types and to have clearer exceptions ## 1.44.2 (2024-12-23) diff --git a/sertit/rasters.py b/sertit/rasters.py index 3401dde..516b566 100644 --- a/sertit/rasters.py +++ b/sertit/rasters.py @@ -193,15 +193,18 @@ def wrapper(any_raster_type: AnyRasterType, *args, **kwargs) -> Any: default_chunks = "auto" if dask.get_client() is not None else None masked = kwargs.get("masked", True) - # By default, try with the input fct + + # By default, try with the read fct: this fct returns the xr data structure as is and manages other input types such as tuple, rasterio datasets, paths... try: - out = function(any_raster_type, *args, **kwargs) - except Exception as ex: - if isinstance(any_raster_type, xr.DataArray): - # Should work with a DataArray - raise ex - elif isinstance(any_raster_type, xr.Dataset): - # Try on every DataArray of the Dataset + out = function( + read(any_raster_type, chunks=default_chunks, masked=masked), + *args, + **kwargs, + ) + except Exception: + # Try on every DataArray of the Dataset + # TODO: handle DataTrees? + if isinstance(any_raster_type, xr.Dataset): try: xds_dict = {} convert_to_xdataset = False @@ -215,13 +218,6 @@ def wrapper(any_raster_type: AnyRasterType, *args, **kwargs) -> Any: return xds except Exception as ex: raise TypeError("Function not available for xarray.Dataset") from ex - - else: - out = function( - read(any_raster_type, chunks=default_chunks, masked=masked), - *args, - **kwargs, - ) return out return wrapper @@ -829,6 +825,8 @@ def wrapper(any_raster_type: AnyRasterType, *args, **kwargs) -> Any: """ # Input is a path: open it with rasterio if path.is_path(any_raster_type): + # GOTCHA: rasterio and cloudpathlib are not really compatible, so passing a CloudPath directly to rasterio (without turning it into a string) with cache the file! + # This is really not ideal, so use the string conversion instead with rasterio.open(str(any_raster_type)) as ds: out = function(ds, *args, **kwargs) # Input is a tuple: we consider it's composed of an output of rasterio.read function, a numpy array and a metadata dict diff --git a/sertit/rasters_rio.py b/sertit/rasters_rio.py index bc0fcdf..4d1c0f7 100644 --- a/sertit/rasters_rio.py +++ b/sertit/rasters_rio.py @@ -28,6 +28,7 @@ import geopandas as gpd import numpy as np +import xarray as xr from shapely.geometry import Polygon try: @@ -196,55 +197,61 @@ def wrapper(any_raster_type: AnyRasterType, *args, **kwargs) -> Any: Returns: Any: regular output """ - try: - out = function(any_raster_type, *args, **kwargs) - except Exception as ex: - if path.is_path(any_raster_type): - with rasterio.open(str(any_raster_type)) as ds: - out = function(ds, *args, **kwargs) - elif isinstance(any_raster_type, tuple): + # Input is a path: open it with rasterio + if path.is_path(any_raster_type): + # GOTCHA: rasterio and cloudpathlib are not really compatible, so passing a CloudPath directly to rasterio (without turning it into a string) with cache the file! + # This is really not ideal, so use the string conversion instead + with rasterio.open(str(any_raster_type)) as ds: + out = function(ds, *args, **kwargs) + # Input is a tuple: we consider it's composed of an output of rasterio.read function, a numpy array and a metadata dict + elif isinstance(any_raster_type, tuple): + try: arr, meta = any_raster_type - with ( - MemoryFile() as memfile, - memfile.open(**meta, BIGTIFF=bigtiff_value(arr)) as ds, - ): - ds.write(arr) - out = function(ds, *args, **kwargs) - else: - try: - import xarray as xr - except ModuleNotFoundError: - raise ex # noqa: B904 + assert isinstance(arr, np.ndarray) + assert isinstance(meta, dict) + except (ValueError, AssertionError) as exc: + raise TypeError( + "Input tuple should be composed of a numpy array of your data and the corresponding metadata dictionary, is rasterio's sense." + ) from exc + + with ( + MemoryFile() as memfile, + memfile.open(**meta, BIGTIFF=bigtiff_value(arr)) as ds, + ): + ds.write(arr) + out = function(ds, *args, **kwargs) + + # Return given xarray object as is + elif isinstance(any_raster_type, (xr.DataArray, xr.Dataset)): + from sertit.rasters import get_nodata_value_from_xr + + nodata = get_nodata_value_from_xr(any_raster_type) + + meta = { + "driver": "GTiff", + "dtype": any_raster_type.dtype, + "nodata": nodata, + "width": any_raster_type.rio.width, + "height": any_raster_type.rio.height, + "count": any_raster_type.rio.count, + "crs": any_raster_type.rio.crs, + "transform": any_raster_type.rio.transform(), + } + with ( + MemoryFile() as memfile, + memfile.open(**meta, BIGTIFF=bigtiff_value(any_raster_type)) as ds, + ): + if nodata is not None: + arr = any_raster_type.fillna(nodata) else: - if isinstance(any_raster_type, (xr.DataArray, xr.Dataset)): - from sertit.rasters import get_nodata_value_from_xr - - nodata = get_nodata_value_from_xr(any_raster_type) - - meta = { - "driver": "GTiff", - "dtype": any_raster_type.dtype, - "nodata": nodata, - "width": any_raster_type.rio.width, - "height": any_raster_type.rio.height, - "count": any_raster_type.rio.count, - "crs": any_raster_type.rio.crs, - "transform": any_raster_type.rio.transform(), - } - with ( - MemoryFile() as memfile, - memfile.open( - **meta, BIGTIFF=bigtiff_value(any_raster_type) - ) as ds, - ): - if nodata is not None: - arr = any_raster_type.fillna(nodata) - else: - arr = any_raster_type - ds.write(arr.data) - out = function(ds, *args, **kwargs) - else: - raise ex + arr = any_raster_type + ds.write(arr.data) + out = function(ds, *args, **kwargs) + + # Run the fct directly on the input (which should be a rasterio Dataset). If not, this will fail and it's expected. + else: + out = function(any_raster_type, *args, **kwargs) + return out return wrapper