Skip to content

Commit

Permalink
Merge pull request #157 from fmi-faim/fs_stitching
Browse files Browse the repository at this point in the history
Update stitching_utils
  • Loading branch information
imagejan authored Jul 5, 2024
2 parents f08b4a9 + d3341d8 commit 6232e5f
Show file tree
Hide file tree
Showing 2 changed files with 279 additions and 73 deletions.
184 changes: 140 additions & 44 deletions src/faim_ipa/stitching/stitching_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

import numpy as np
from numpy._typing import NDArray
from scipy.ndimage import distance_transform_edt
from scipy.ndimage import distance_transform_cdt

from faim_ipa.stitching.Tile import Tile, TilePosition


def fuse_linear(warped_tiles: NDArray, warped_masks: NDArray) -> NDArray:
def fuse_linear(warped_tiles: NDArray, warped_distance_masks: NDArray) -> NDArray:
"""
Fuse transformed tiles using a linear gradient to compute the weighted
average where tiles are overlapping.
Expand All @@ -16,76 +16,103 @@ def fuse_linear(warped_tiles: NDArray, warped_masks: NDArray) -> NDArray:
----------
warped_tiles :
Tile images transformed to the final image space.
warped_masks :
Masks indicating foreground pixels for the transformed tiles.
warped_distance_masks :
Distance masks for the transformed tiles. They are non-zero for
foreground pixels, and the value is the distance to the closest edge of
the tile.
Returns
-------
Fused image.
"""
dtype = warped_tiles.dtype
if warped_tiles.shape[0] > 1:
warped_masks = warped_masks[:, 0]
weights = np.zeros_like(warped_masks, dtype=np.float32)
for i, mask in enumerate(warped_masks):
weights[i] = distance_transform_edt(
warped_masks[i].astype(np.float32),
)

denominator = weights.sum(axis=0)
weights = np.true_divide(weights, denominator, where=denominator > 0)
weights = np.nan_to_num(weights, nan=0, posinf=1, neginf=0)
weights = np.clip(
weights,
0,
1,
denominator = warped_distance_masks.sum(axis=0)
weights = np.true_divide(
warped_distance_masks, denominator, where=denominator > 0
)
weights = weights[:, np.newaxis, ...]
weights = np.clip(np.nan_to_num(weights, nan=0, posinf=1, neginf=0), 0, 1)
else:
weights = warped_masks
weights = warped_distance_masks.astype(bool)

return np.sum(warped_tiles * weights, axis=0).astype(dtype)


def fuse_mean(warped_tiles: NDArray, warped_masks: NDArray) -> NDArray:
def fuse_linear_random(
warped_tiles: NDArray, warped_distance_masks: NDArray
) -> NDArray:
"""
Fuse transformed tiles by sampling random pixels where tiles are
overlapping, using a linear gradient to compute the random weights.
Parameters
----------
warped_tiles :
Tile images transformed to the final image space.
warped_distance_masks :
Distance masks for the transformed tiles. They are non-zero for
foreground pixels, and the value is the distance to the closest edge of
the tile.
Returns
-------
Fused image.
"""
np.random.seed(0)
dtype = warped_tiles.dtype
if warped_tiles.shape[0] > 1:
denominator = warped_distance_masks.sum(axis=0)
weights = np.true_divide(
warped_distance_masks, denominator, where=denominator > 0
)
weights = np.clip(np.nan_to_num(weights, nan=0, posinf=1, neginf=0), 0, 1)
weights = np.cumsum(weights, axis=0)
weights = np.insert(weights, 0, np.zeros_like(weights[0]), axis=0)
rand_tile = np.random.rand(*warped_tiles.shape[1:])
for i in range(len(warped_tiles)):
warped_tiles[i, (rand_tile < weights[i]) | (weights[i + 1] < rand_tile)] = 0

return np.sum(warped_tiles, axis=0).astype(dtype)


def fuse_mean(warped_tiles: NDArray, warped_distance_masks: NDArray) -> NDArray:
"""
Fuse transformed tiles and compute the mean of the overlapping pixels.
Parameters
----------
warped_tiles :
Tile images transformed to the final image space.
warped_masks :
Masks indicating foreground pixels for the transformed tiles.
warped_distance_masks :
Distance masks for the transformed tiles. They are non-zero for
foreground pixels, and the value is the distance to the closest edge of
the tile.
Returns
-------
Fused image.
"""
warped_masks = warped_masks[:, 0]
warped_masks = warped_distance_masks.astype(bool)
denominator = warped_masks.sum(axis=0)
weights = np.true_divide(warped_masks, denominator, where=denominator > 0)
weights = np.clip(
np.nan_to_num(weights, nan=0, posinf=1, neginf=0),
0,
1,
)
weights = weights[:, np.newaxis, ...]
weights = np.clip(np.nan_to_num(weights, nan=0, posinf=1, neginf=0), 0, 1)

fused_image = np.sum(warped_tiles * weights, axis=0)
return fused_image.astype(warped_tiles.dtype)


def fuse_sum(warped_tiles: NDArray, warped_masks: NDArray) -> NDArray:
def fuse_sum(warped_tiles: NDArray, warped_distance_masks: NDArray) -> NDArray:
"""
Fuse transformed tiles and compute the sum of the overlapping pixels.
Parameters
----------
warped_tiles :
Tile images transformed to the final image space.
warped_masks :
Masks indicating foreground pixels for the transformed tiles.
warped_distance_masks :
Distance masks for the transformed tiles. They are non-zero for
foreground pixels, and the value is the distance to the closest edge of
the tile. (Not used in this function)
Returns
-------
Expand All @@ -95,6 +122,62 @@ def fuse_sum(warped_tiles: NDArray, warped_masks: NDArray) -> NDArray:
return fused_image.astype(warped_tiles.dtype)


def fuse_overlay_fwd(warped_tiles: NDArray, warped_distance_masks: NDArray) -> NDArray:
"""
Fuse transformed tiles. Where tiles overlap, the tile later in the sequence
overwrites the earlier one.
Parameters
----------
warped_tiles :
Tile images transformed to the final image space.
warped_distance_masks :
Distance masks for the transformed tiles. They are non-zero for
foreground pixels, and the value is the distance to the closest edge of
the tile.
Returns
-------
Fused image.
"""

warped_masks = warped_distance_masks.astype(bool)

fused_image = np.zeros_like(warped_tiles[0])
for tile, mask in zip(warped_tiles, warped_masks):
fused_image[mask] = tile[mask]

return fused_image


def fuse_overlay_bwd(warped_tiles: NDArray, warped_distance_masks: NDArray) -> NDArray:
"""
Fuse transformed tiles. Where tiles overlap, the tile earlier in the
sequence overwrites the later one.
Parameters
----------
warped_tiles :
Tile images transformed to the final image space.
warped_distance_masks :
Distance masks for the transformed tiles. They are non-zero for
foreground pixels, and the value is the distance to the closest edge of
the tile.
Returns
-------
Fused image.
"""

warped_masks = warped_distance_masks.astype(bool)

fused_image = np.zeros_like(warped_tiles[0])
for tile, mask in zip(reversed(warped_tiles), reversed(warped_masks)):
fused_image[mask] = tile[mask]

return fused_image


def translate_tiles_2d(
block_info, chunk_shape, tiles, build_acquisition_mask: bool = False
):
Expand All @@ -112,14 +195,21 @@ def translate_tiles_2d(
Returns
-------
translated tiles, translated masks
translated tiles, translated distance masks
"""
array_location = block_info[None]["array-location"]
chunk_zyx_origin = np.array(
[array_location[2][0], array_location[3][0], array_location[4][0]]
)

if not all(tile.shape == tiles[0].shape for tile in tiles):
raise ValueError("All tiles must have the same shape.")
distance_mask = get_distance_mask(tiles[0].shape)
if distance_mask.ndim == 2:
distance_mask = distance_mask[np.newaxis, ...]

warped_tiles = []
warped_masks = []
warped_distance_masks = []
for tile in tiles:
tile_origin = np.array(tile.get_zyx_position())
if build_acquisition_mask:
Expand All @@ -128,19 +218,26 @@ def translate_tiles_2d(
tile_data = tile.load_data()
if tile_data.ndim == 2:
tile_data = tile_data[np.newaxis, ...]
warped_mask, warped_tile = shift_yx(
chunk_zyx_origin, tile_data, tile_origin, chunk_shape
warped_tile = shift_yx(chunk_zyx_origin, tile_data, tile_origin, chunk_shape)
warped_distance_mask = shift_yx(
chunk_zyx_origin, distance_mask, tile_origin, chunk_shape
)

warped_tiles.append(warped_tile)
warped_masks.append(warped_mask)
warped_distance_masks.append(warped_distance_mask)

return np.array(warped_tiles), np.array(warped_distance_masks)


return np.array(warped_tiles), np.array(warped_masks)
def get_distance_mask(tile_shape):
mask = np.zeros(tile_shape, dtype=bool)
mask[..., 1:-1, 1:-1] = True
distance_mask = distance_transform_cdt(mask, metric="taxicab") + 1
return distance_mask.astype(np.uint16)


def shift_yx(chunk_zyx_origin, tile_data, tile_origin, chunk_shape):
warped_tile = np.zeros(chunk_shape, dtype=tile_data.dtype)
warped_mask = np.zeros(chunk_shape, dtype=bool)
yx_shift = (tile_origin - chunk_zyx_origin)[1:]
if yx_shift[0] < 0:
tile_start_y = abs(yx_shift[0])
Expand All @@ -165,8 +262,7 @@ def shift_yx(chunk_zyx_origin, tile_data, tile_origin, chunk_shape):
start_x = max(0, yx_shift[1])
end_x = start_x + tile_data.shape[2]
warped_tile[: tile_data.shape[0], start_y:end_y, start_x:end_x] = tile_data
warped_mask[: tile_data.shape[0], start_y:end_y, start_x:end_x] = True
return warped_mask, warped_tile
return warped_tile


def assemble_chunk(
Expand Down Expand Up @@ -202,14 +298,14 @@ def assemble_chunk(
tiles = tile_map[chunk_location]

if len(tiles) > 0:
warped_tiles, warped_masks = warp_func(
warped_tiles, warped_distance_masks = warp_func(
block_info, chunk_shape[-3:], tiles, build_acquisition_mask
)

if len(tiles) > 1:
stitched_img = fuse_func(
warped_tiles,
warped_masks,
warped_distance_masks,
)
stitched_img = stitched_img[np.newaxis, np.newaxis, ...]
else:
Expand Down
Loading

0 comments on commit 6232e5f

Please sign in to comment.