From 30eecd68599140d063c3edaa6148391c53ca4d92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Thu, 16 Jan 2025 15:36:14 +0100 Subject: [PATCH] feat(lib): add method to easily mask duplicate path candidates (#203) * feat(lib): add method to easily mask duplicate path candidates * chore(docs): improve --- differt/src/differt/geometry/_paths.py | 65 ++++++++++++++++ differt/src/differt/scene/_triangle_scene.py | 8 ++ differt/tests/geometry/test_paths.py | 79 +++++++++++++++++++- 3 files changed, 151 insertions(+), 1 deletion(-) diff --git a/differt/src/differt/geometry/_paths.py b/differt/src/differt/geometry/_paths.py index bcfdc085..03e96b65 100644 --- a/differt/src/differt/geometry/_paths.py +++ b/differt/src/differt/geometry/_paths.py @@ -186,6 +186,71 @@ def squeeze(self, axis: int | Sequence[int] | None = None) -> Self: is_leaf=lambda x: x is None, ) + @eqx.filter_jit + @jaxtyped( + typechecker=None + ) # typing.Self is (currently) not compatible with jaxtyping and beartype + def mask_duplicate_objects(self, axis: int = -1) -> Self: + """ + Return a copy by masking duplicate objects along a given axis. + + E.g., when generating path candidates from a generative Machine Learning model, + see :ref:`sampling-paths`, it is possible that the model generates the same + path candidate multiple times. This method allows to mask these duplicates, + while maintaining the same batch dimensions and compatibility with :func:`jax.jit`. + + Args: + axis: The batch axis along which the unique values are computed. + + It defaults to the last axis, which is the axis where + different path candidates are stored when generating + paths with + :meth:`TriangleScene.compute_paths`. + + Returns: + A new paths instance with masked duplicate objects. + + Raises: + ValueError: If the provided axis is out-of-bounds. + """ + ndim = self.objects.ndim - 1 + batch = self.objects.shape[:-1] + if not -ndim <= axis < ndim: + msg = f"The provided axis {axis} is out-of-bounds for batch of dimensions {ndim}!" + raise ValueError(msg) + + size = batch[axis] + + objects = jnp.moveaxis(self.objects, axis if axis >= 0 else axis - 1, -2) + indices = jnp.arange(size, dtype=objects.dtype) + + @jaxtyped(typechecker=typechecker) + def f( + objects: Int[Array, "axis_length path_length"], + ) -> Bool[Array, " axis_length"]: + _, index = jnp.unique( + objects, + axis=0, + size=size, + return_index=True, + ) + + return jnp.isin(indices, index) + + for _ in range(max(ndim - 1, 0)): + f = jax.vmap(f) + + mask = f(objects) + mask = jnp.moveaxis(mask, -1, axis) + mask = mask & self.mask if self.mask is not None else mask + + return eqx.tree_at( + lambda p: p.mask, + self, + mask, + is_leaf=lambda x: x is None, + ) + @property @jaxtyped(typechecker=typechecker) def path_length(self) -> int: diff --git a/differt/src/differt/scene/_triangle_scene.py b/differt/src/differt/scene/_triangle_scene.py index 74a16604..c1467b52 100644 --- a/differt/src/differt/scene/_triangle_scene.py +++ b/differt/src/differt/scene/_triangle_scene.py @@ -863,6 +863,14 @@ def compute_paths( # noqa: C901 The paths, as class wrapping path vertices, object indices, and a masked identify valid paths. + The returned paths have the following batch dimensions: + + * ``[*transmitters_batch *receivers_batch num_path_candidates]``, + * ``[*transmitters_batch *receivers_batch chunk_size]``, + * or ``[*transmitters_batch *receivers_batch num_rays]``, + + depending on the method used. + Raises: ValueError: If neither ``order`` nor ``path_candidates`` has been provided, or if both have been provided simultaneously. diff --git a/differt/tests/geometry/test_paths.py b/differt/tests/geometry/test_paths.py index 2029b3e1..7195cf1e 100644 --- a/differt/tests/geometry/test_paths.py +++ b/differt/tests/geometry/test_paths.py @@ -9,8 +9,9 @@ import pytest from jaxtyping import PRNGKeyArray -from differt.geometry import path_lengths +from differt.geometry import TriangleMesh, path_lengths from differt.geometry._paths import Paths, SBRPaths, merge_cell_ids +from differt.scene import TriangleScene def test_merge_cell_ids() -> None: @@ -90,6 +91,82 @@ def test_squeeze( key=key, ).squeeze(axis=axis) + def test_mask_duplicate_objects(self, key: PRNGKeyArray) -> None: + mesh = TriangleMesh.box() # 6 objects + # 6 path candidates, only 3 are unique + path_candidates = jnp.array([ + [0, 1, 2], + [1, 0, 2], + [0, 1, 2], + [0, 1, 2], + [2, 3, 4], + [1, 0, 2], + ]) + + # 1 - One TX, one RX, no batch dimension + + key_rx, key_tx = jax.random.split(key, 2) + + scene = TriangleScene( + transmitters=jax.random.normal(key_tx, (3,)), + receivers=jax.random.normal(key_rx, (3,)), + mesh=mesh, + ) + + paths = scene.compute_paths(path_candidates=path_candidates) + + assert paths.mask is not None + paths = eqx.tree_at(lambda p: p.mask, paths, jnp.ones_like(paths.mask)) + + got = paths.mask_duplicate_objects() + + chex.assert_trees_all_equal(got.mask.sum(axis=-1), 3) + + paths = eqx.tree_at(lambda p: p.mask, paths, None) + + got = paths.mask_duplicate_objects() + + chex.assert_trees_all_equal(got.mask.sum(axis=-1), 3) + + with pytest.raises( + ValueError, + match="The provided axis -2 is out-of-bounds for batch of dimensions 1!", + ): + _ = paths.mask_duplicate_objects(axis=-2) + + # 2 - Many TXs, many RXs, multiple batch dimensions + + scene = scene.with_transmitters_grid(2, 1).with_receivers_grid(4, 3) + + paths = scene.compute_paths(path_candidates=path_candidates) + + assert paths.mask is not None + chex.assert_shape(paths.mask, (1, 2, 3, 4, path_candidates.shape[0])) + paths = eqx.tree_at(lambda p: p.mask, paths, jnp.ones_like(paths.mask)) + + got = paths.mask_duplicate_objects() + + chex.assert_shape(got.mask, (1, 2, 3, 4, path_candidates.shape[0])) + chex.assert_trees_all_equal(got.mask.sum(axis=-1), 3) + + paths = eqx.tree_at(lambda p: p.mask, paths, None) + + got = paths.mask_duplicate_objects() + + chex.assert_shape(got.mask, (1, 2, 3, 4, path_candidates.shape[0])) + chex.assert_trees_all_equal(got.mask.sum(axis=-1), 3) + + paths = eqx.tree_at( + lambda p: (p.vertices, p.objects), + paths, + (jnp.swapaxes(paths.vertices, 0, -3), jnp.swapaxes(paths.objects, 0, -2)), + ) + + got = paths.mask_duplicate_objects(axis=0) + + chex.assert_shape(got.mask, (path_candidates.shape[0], 2, 3, 4, 1)) + chex.assert_trees_all_equal(got.mask.sum(axis=0), 3) + @pytest.mark.parametrize("path_length", [3, 5]) @pytest.mark.parametrize("batch", [(), (1,), (1, 2, 3, 4)]) @pytest.mark.parametrize("num_objects", [1, 10])