Skip to content

Commit

Permalink
FIX: Don't take nodata value into account in `ci.assert_raster_almost…
Browse files Browse the repository at this point in the history
…_equal_magnitude`
  • Loading branch information
remi-braun committed Feb 19, 2025
1 parent d9563a3 commit d26fbbf
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 6 deletions.
4 changes: 4 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Release History

## 1.45.3 (2025-mm-dd)

- FIX: Don't take nodata value into account in `ci.assert_raster_almost_equal_magnitude`

## 1.45.2 (2025-02-17)

- FIX: Fix regression in `geometry.simplify_footprint`
Expand Down
24 changes: 18 additions & 6 deletions sertit/ci.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,12 @@ def assert_raster_almost_equal(
)
LOGGER.info(f"Checking Band {i + 1}{desc}")
try:
marr_1 = ds_1.read(i + 1)
marr_2 = ds_2.read(i + 1)
marr_1 = ds_1.read(i + 1, masked=True)
marr_2 = ds_2.read(i + 1, masked=True)

# Set nodata
marr_1 = np.where(marr_1.mask, np.nan, marr_1)
marr_2 = np.where(marr_2.mask, np.nan, marr_2)
np.testing.assert_array_almost_equal(marr_1, marr_2, decimal=decimal)
except AssertionError:
text = f"Band {i + 1}{desc} failed"
Expand Down Expand Up @@ -293,8 +297,12 @@ def assert_raster_almost_equal_magnitude(
)
LOGGER.info(f"Checking Band {i + 1}{desc}")
try:
marr_1 = ds_1.read(i + 1)
marr_2 = ds_2.read(i + 1)
marr_1 = ds_1.read(i + 1, masked=True)
marr_2 = ds_2.read(i + 1, masked=True)

# Set nodata
marr_1 = np.where(marr_1.mask, np.nan, marr_1)
marr_2 = np.where(marr_2.mask, np.nan, marr_2)

# Manage better the number of (decimals are for a magnitude of 0)
magnitude = np.floor(np.log10(abs(np.nanmedian(marr_1))))
Expand Down Expand Up @@ -352,8 +360,12 @@ def assert_raster_max_mismatch(
assert_meta(ds_1.meta, ds_2.meta)

# Compute the number of mismatch
arr_1 = ds_1.read()
arr_2 = ds_2.read()
arr_1 = ds_1.read(masked=True)
arr_2 = ds_2.read(masked=True)

# Set nodata (to 0 as we will make a diff)
arr_1 = np.where(arr_1.mask, 0, arr_1)
arr_2 = np.where(arr_2.mask, 0, arr_2)

if decimal >= 0:
arr_1 = np.round(arr_1, decimal)
Expand Down

0 comments on commit d26fbbf

Please sign in to comment.