Skip to content

Commit

Permalink
Split DEM tests into 4
Browse files Browse the repository at this point in the history
  • Loading branch information
remi-braun committed Dec 6, 2024
1 parent 70d2c76 commit 56aa0ef
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 16 deletions.
57 changes: 44 additions & 13 deletions CI/SCRIPTS/test_rasters.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ def raster_path():
return rasters_path().joinpath("raster.tif")


@pytest.fixture
def dem_path():
return rasters_path().joinpath("dem.tif")


@pytest.fixture
def xda(raster_path):
return rasters.read(raster_path)
Expand Down Expand Up @@ -336,7 +341,7 @@ def test_sieve(tmp_path, xda, xds, xda_dask):

# With dask
sieve_xda_dask = rasters.sieve(xda_dask, sieve_thresh=20, connectivity=4)
# assert sieve_xda_dask.chunks is not None TODO
assert sieve_xda_dask.chunks is not None
np.testing.assert_array_equal(sieve_xda, sieve_xda_dask)
ci.assert_xr_encoding_attrs(xda_dask, sieve_xda_dask)

Expand Down Expand Up @@ -787,45 +792,71 @@ def test_where():

@s3_env
@dask_env
def test_dem_fct(tmp_path):
"""Test DEM fct, i.e. slope and hillshade"""
# Paths IN
dem_path = rasters_path().joinpath("dem.tif")
def test_aspect(tmp_path, dem_path):
"""Test aspect function"""
aspect_path = rasters_path().joinpath("aspect.tif")
hlsd_path = rasters_path().joinpath("hillshade.tif")
slope_path = rasters_path().joinpath("slope.tif")
slope_r_path = rasters_path().joinpath("slope_r.tif")
slope_p_path = rasters_path().joinpath("slope_p.tif")

# Path OUT
aspect_path_out = os.path.join(tmp_path, "aspect_out.tif")
hlsd_path_out = os.path.join(tmp_path, "hillshade_out.tif")
slope_path_out = os.path.join(tmp_path, "slope_out.tif")
slope_r_path_out = os.path.join(tmp_path, "slope_r_out.tif")
slope_p_path_out = os.path.join(tmp_path, "slope_p_out.tif")

# Aspect
aspect = rasters.aspect(dem_path)
# assert aspect.chunks is not None
rasters.write(aspect, aspect_path_out, dtype="float32")
ci.assert_raster_almost_equal(aspect_path, aspect_path_out, decimal=4)


@s3_env
@dask_env
def test_hillshade(tmp_path, dem_path):
"""Test hillshade function"""
hlsd_path = rasters_path().joinpath("hillshade.tif")
hlsd_path_out = os.path.join(tmp_path, "hillshade_out.tif")

# Hillshade
hlsd = rasters.hillshade(dem_path, 34.0, 45.2)
# assert hlsd.chunks is not None
rasters.write(hlsd, hlsd_path_out, dtype="float32")
ci.assert_raster_almost_equal(hlsd_path, hlsd_path_out, decimal=4)


@s3_env
@dask_env
def test_slope(tmp_path, dem_path):
"""Test slope function"""
slope_path = rasters_path().joinpath("slope.tif")
slope_path_out = os.path.join(tmp_path, "slope_out.tif")
# Slope
slp = rasters.slope(dem_path)
# assert slp.chunks is not None
rasters.write(slp, slope_path_out, dtype="float32")
ci.assert_raster_almost_equal(slope_path, slope_path_out, decimal=4)


@s3_env
@dask_env
def test_slope_rad(tmp_path, dem_path):
"""Test slope (radian) function"""
slope_r_path = rasters_path().joinpath("slope_r.tif")
slope_r_path_out = os.path.join(tmp_path, "slope_r_out.tif")

# Slope rad
slp_r = rasters.slope(dem_path, in_pct=False, in_rad=True)
# assert slp_r.chunks is not None
rasters.write(slp_r, slope_r_path_out, dtype="float32")
ci.assert_raster_almost_equal(slope_r_path, slope_r_path_out, decimal=4)


@s3_env
@dask_env
def test_slope_pct(tmp_path, dem_path):
"""Test slope (pct) function"""
slope_p_path = rasters_path().joinpath("slope_p.tif")
slope_p_path_out = os.path.join(tmp_path, "slope_p_out.tif")

# Slope pct
slp_p = rasters.slope(dem_path, in_pct=True)
# assert slp_p.chunks is not None
rasters.write(slp_p, slope_p_path_out, dtype="float32")
ci.assert_raster_almost_equal(slope_p_path, slope_p_path_out, decimal=4)

Expand Down
14 changes: 11 additions & 3 deletions sertit/rasters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1926,7 +1926,7 @@ def hillshade(
res=np.abs(xds.rio.resolution()),
)

xds = xds.copy(data=out).rename(kwargs.get("name", "hillshade"))
xds = xds.copy(data=out)

except ImportError:
LOGGER.debug(
Expand All @@ -1937,6 +1937,9 @@ def hillshade(

xds = xds.copy(data=arr)

xds = xds.rename(kwargs.get("name", "hillshade"))
xds.attrs["long_name"] = "hillshade"

return xds


Expand All @@ -1962,7 +1965,7 @@ def slope(
from xarray.ufuncs import tan
from xrspatial import slope

xds = slope(xds, name=kwargs.get("name", "slope"))
xds = slope(xds)

if in_pct:
xds = 100 * tan(xds * DEG_2_RAD)
Expand All @@ -1978,6 +1981,9 @@ def slope(

xds = xds.copy(data=arr)

xds = xds.rename(kwargs.get("name", "slope"))
xds.attrs["long_name"] = "slope"

return xds


Expand All @@ -1996,7 +2002,9 @@ def aspect(xds: AnyRasterType, **kwargs) -> AnyXrDataStructure:
try:
from xrspatial import aspect

return aspect(xds, name=kwargs.get("name", "aspect"))
xds = aspect(xds, name=kwargs.get("name", "aspect"))
xds.attrs["long_name"] = "aspect"
return xds
except ImportError:
raise NotImplementedError(
"'Aspect' cannot be computed when 'xarray-spatial' is not installed."
Expand Down

0 comments on commit 56aa0ef

Please sign in to comment.