From 16047aa90c4f4d22a61df2aa3624a302c639473d Mon Sep 17 00:00:00 2001 From: BRAUN REMI Date: Mon, 6 Jan 2025 13:58:25 +0100 Subject: [PATCH] FIX: Simplify decorator function of `rasters.read`, to have clearer name and clearer exceptions --- CHANGES.md | 1 + ci/test_rasters.py | 2 +- sertit/rasters.py | 64 +++++++++++++++++++++++----------------------- 3 files changed, 34 insertions(+), 33 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index e06d41c..05448f7 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -3,6 +3,7 @@ ## 1.44.3 (2025-mm-dd) - FIX: Manage case where we have a `pd.Dataframe` instead of a `gpd.GeoDataFrame` in `vectors.read` (reading a `.dbf` file for instance) +- FIX: Simplify decorator function of `rasters.read`, to better check the input types and to have a clearer function name and clearer exceptions ## 1.44.2 (2024-12-23) diff --git a/ci/test_rasters.py b/ci/test_rasters.py index d09addd..08337d0 100644 --- a/ci/test_rasters.py +++ b/ci/test_rasters.py @@ -612,7 +612,7 @@ def test_merge_different_crs_gtiff(tmp_path): def _test_raster_after_write(test_path, dtype, nodata_val): - with rasterio.open(test_path) as ds: + with rasterio.open(str(test_path)) as ds: assert ds.meta["dtype"] == dtype or ds.meta["dtype"] == dtype.__name__ assert ds.meta["nodata"] == nodata_val assert ds.read()[:, 0, 0] == nodata_val # Check value diff --git a/sertit/rasters.py b/sertit/rasters.py index 09482e3..3401dde 100644 --- a/sertit/rasters.py +++ b/sertit/rasters.py @@ -781,7 +781,7 @@ def crop( ) -def _any_raster_to_rio_ds(function: Callable) -> Callable: +def __read__any_raster_to_rio_ds(function: Callable) -> Callable: """ Specific declination of rasters_rio.any_raster_to_rio_ds for this specific case, handling the xarray object differently. @@ -827,37 +827,37 @@ def wrapper(any_raster_type: AnyRasterType, *args, **kwargs) -> Any: Returns: Any: regular output """ - try: - out = function(any_raster_type, *args, **kwargs) - except Exception as ex: - if path.is_path(any_raster_type): - with rasterio.open(str(any_raster_type)) as ds: - out = function(ds, *args, **kwargs) - elif isinstance(any_raster_type, tuple): + # Input is a path: open it with rasterio + if path.is_path(any_raster_type): + with rasterio.open(str(any_raster_type)) as ds: + out = function(ds, *args, **kwargs) + # Input is a tuple: we consider it's composed of an output of rasterio.read function, a numpy array and a metadata dict + elif isinstance(any_raster_type, tuple): + try: arr, meta = any_raster_type - from rasterio import MemoryFile - - with ( - MemoryFile() as memfile, - memfile.open(**meta, BIGTIFF=rasters_rio.bigtiff_value(arr)) as ds, - ): - ds.write(arr) - out = function(ds, *args, **kwargs) - else: - # Try if xarray is importable - try: - if isinstance(any_raster_type, (xr.DataArray, xr.Dataset)): - try: - file_path = any_raster_type.encoding["source"] - with rasterio.open(file_path) as ds: - out = function(ds, *args, **kwargs) - except Exception: - # Return given xarray object - return any_raster_type - else: - raise ex - except Exception as exc: - raise ex from exc + assert isinstance(arr, np.ndarray) + assert isinstance(meta, dict) + except (ValueError, AssertionError) as exc: + raise TypeError( + "Input tuple should be composed of a numpy array of your data and the corresponding metadata dictionary, is rasterio's sense." + ) from exc + + from rasterio import MemoryFile + + with ( + MemoryFile() as memfile, + memfile.open(**meta, BIGTIFF=rasters_rio.bigtiff_value(arr)) as ds, + ): + ds.write(arr) + out = function(ds, *args, **kwargs) + + # Return given xarray object as is + elif isinstance(any_raster_type, (xr.DataArray, xr.Dataset)): + out = any_raster_type + + # Run the fct directly on the input (which should be a rasterio Dataset). If not, this will fail and it's expected. + else: + out = function(any_raster_type, *args, **kwargs) return out return wrapper @@ -883,7 +883,7 @@ def wrapper(xda: xr.DataArray, *_args, **_kwargs): return wrapper -@_any_raster_to_rio_ds +@__read__any_raster_to_rio_ds def read( ds: AnyRasterType, resolution: Union[tuple, list, float] = None,