Skip to content

Commit

Permalink
Merge pull request #178 from fmi-faim/distance-map-2d
Browse files Browse the repository at this point in the history
Improve stitching performance
  • Loading branch information
imagejan authored Sep 9, 2024
2 parents 3a99a42 + 96dbdb5 commit 51b968a
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 37 deletions.
68 changes: 42 additions & 26 deletions src/faim_ipa/stitching/stitching_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,29 @@ def fuse_linear(warped_tiles: NDArray, warped_distance_masks: NDArray) -> NDArra
Fused image.
"""
dtype = warped_tiles.dtype
fused_image = np.zeros_like(warped_tiles[0], dtype=np.float32)

if warped_tiles.shape[0] > 1:
denominator = warped_distance_masks.sum(axis=0)
weights = np.true_divide(
warped_distance_masks, denominator, where=denominator > 0
)
weights = np.clip(np.nan_to_num(weights, nan=0, posinf=1, neginf=0), 0, 1)
denominator = np.sum(warped_distance_masks, axis=0)

for tile, mask in zip(warped_tiles, warped_distance_masks, strict=True):
weight = np.divide(mask, denominator, where=denominator > 0)
np.clip(weight, 0, 1, out=weight)
np.add(
fused_image,
tile.astype(np.float32) * weight,
out=fused_image,
where=weight > 0,
)
else:
weights = warped_distance_masks.astype(bool)
np.add(
fused_image,
warped_tiles[0],
out=fused_image,
where=warped_distance_masks[0] > 0,
)

return np.sum(warped_tiles * weights, axis=0).astype(dtype)
return np.clip(fused_image, 0, np.iinfo(dtype).max).astype(dtype)


def fuse_linear_random(
Expand Down Expand Up @@ -92,13 +105,21 @@ def fuse_mean(warped_tiles: NDArray, warped_distance_masks: NDArray) -> NDArray:
-------
Fused image.
"""
warped_masks = warped_distance_masks.astype(bool)
denominator = warped_masks.sum(axis=0)
weights = np.true_divide(warped_masks, denominator, where=denominator > 0)
weights = np.clip(np.nan_to_num(weights, nan=0, posinf=1, neginf=0), 0, 1)
denominator = np.sum(warped_distance_masks > 0, axis=0)
fused_image = np.zeros_like(warped_tiles[0], dtype=np.float32)

for tile, mask in zip(warped_tiles, warped_distance_masks, strict=True):
weight = np.divide(mask > 0, denominator, where=denominator > 0)
np.add(
fused_image,
tile.astype(np.float32) * weight,
out=fused_image,
where=weight > 0,
)

fused_image = np.sum(warped_tiles * weights, axis=0)
return fused_image.astype(warped_tiles.dtype)
return np.clip(fused_image, 0, np.iinfo(warped_tiles.dtype).max).astype(
warped_tiles.dtype
)


def fuse_sum(
Expand Down Expand Up @@ -232,39 +253,34 @@ def translate_tiles_2d(
def get_distance_mask(tile_shape):
mask = np.zeros(tile_shape[-2:], dtype=bool)
mask[1:-1, 1:-1] = True
distance_mask = distance_transform_cdt(mask, metric="taxicab").astype(np.uint16) + 1
if len(tile_shape) == 3:
distance_mask = np.repeat(distance_mask[np.newaxis, ...], tile_shape[0], axis=0)
else:
distance_mask = distance_mask[np.newaxis]
return distance_mask
return distance_transform_cdt(mask, metric="taxicab").astype(np.uint16) + 1


def shift_yx(chunk_zyx_origin, tile_data, tile_origin, chunk_shape):
warped_tile = np.zeros(chunk_shape, dtype=tile_data.dtype)
yx_shift = (tile_origin - chunk_zyx_origin)[1:]
if yx_shift[0] < 0:
tile_start_y = abs(yx_shift[0])
tile_end_y = min(tile_start_y + chunk_shape[1], tile_data.shape[1])
tile_end_y = min(tile_start_y + chunk_shape[1], tile_data.shape[-2])
else:
tile_start_y = 0
tile_end_y = max(
0, min(tile_start_y + chunk_shape[1] - yx_shift[0], tile_data.shape[1])
0, min(tile_start_y + chunk_shape[1] - yx_shift[0], tile_data.shape[-2])
)
if yx_shift[1] < 0:
tile_start_x = abs(yx_shift[1])
tile_end_x = min(tile_start_x + chunk_shape[2], tile_data.shape[2])
tile_end_x = min(tile_start_x + chunk_shape[2], tile_data.shape[-1])
else:
tile_start_x = 0
tile_end_x = min(
tile_start_x + chunk_shape[2] - yx_shift[1], tile_data.shape[2]
tile_start_x + chunk_shape[2] - yx_shift[1], tile_data.shape[-1]
)
tile_data = tile_data[:, tile_start_y:tile_end_y, tile_start_x:tile_end_x]
tile_data = tile_data[..., tile_start_y:tile_end_y, tile_start_x:tile_end_x]
if tile_data.size > 0:
start_y = max(0, yx_shift[0])
end_y = start_y + tile_data.shape[1]
end_y = start_y + tile_data.shape[-2]
start_x = max(0, yx_shift[1])
end_x = start_x + tile_data.shape[2]
end_x = start_x + tile_data.shape[-1]
warped_tile[: tile_data.shape[0], start_y:end_y, start_x:end_x] = tile_data
return warped_tile

Expand Down
18 changes: 7 additions & 11 deletions tests/stitching/test_stitching_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,13 +251,11 @@ def test_shift_to_origin():
def test_get_distance_mask():
expected_distance_mask_2d = np.array(
[
[
[1, 1, 1, 1, 1, 1],
[1, 2, 2, 2, 2, 1],
[1, 2, 3, 3, 2, 1],
[1, 2, 2, 2, 2, 1],
[1, 1, 1, 1, 1, 1],
]
[1, 1, 1, 1, 1, 1],
[1, 2, 2, 2, 2, 1],
[1, 2, 3, 3, 2, 1],
[1, 2, 2, 2, 2, 1],
[1, 1, 1, 1, 1, 1],
],
dtype=np.uint16,
)
Expand All @@ -268,13 +266,11 @@ def test_get_distance_mask():

result = get_distance_mask((1, 5, 6))
assert result.dtype == np.uint16
assert_array_equal(result, expected_distance_mask_2d.reshape((1, 5, 6)))
assert_array_equal(result, expected_distance_mask_2d)

result = get_distance_mask((4, 5, 6))
assert result.dtype == np.uint16
assert_array_equal(
result, np.concatenate([expected_distance_mask_2d for n in range(4)])
)
assert_array_equal(result, expected_distance_mask_2d)


def test_translate_3d_tiles_2d(tiles):
Expand Down

0 comments on commit 51b968a

Please sign in to comment.