Skip to content

Commit

Permalink
feat(lib): add method to easily mask duplicate path candidates (#203)
Browse files Browse the repository at this point in the history
* feat(lib): add method to easily mask duplicate path candidates

* chore(docs): improve
  • Loading branch information
jeertmans authored Jan 16, 2025
1 parent 79c0fe5 commit 30eecd6
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 1 deletion.
65 changes: 65 additions & 0 deletions differt/src/differt/geometry/_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,71 @@ def squeeze(self, axis: int | Sequence[int] | None = None) -> Self:
is_leaf=lambda x: x is None,
)

@eqx.filter_jit
@jaxtyped(
typechecker=None
) # typing.Self is (currently) not compatible with jaxtyping and beartype
def mask_duplicate_objects(self, axis: int = -1) -> Self:
"""
Return a copy by masking duplicate objects along a given axis.
E.g., when generating path candidates from a generative Machine Learning model,
see :ref:`sampling-paths`, it is possible that the model generates the same
path candidate multiple times. This method allows to mask these duplicates,
while maintaining the same batch dimensions and compatibility with :func:`jax.jit`.
Args:
axis: The batch axis along which the unique values are computed.
It defaults to the last axis, which is the axis where
different path candidates are stored when generating
paths with
:meth:`TriangleScene.compute_paths<differt.scene.TriangleScene.compute_paths>`.
Returns:
A new paths instance with masked duplicate objects.
Raises:
ValueError: If the provided axis is out-of-bounds.
"""
ndim = self.objects.ndim - 1
batch = self.objects.shape[:-1]
if not -ndim <= axis < ndim:
msg = f"The provided axis {axis} is out-of-bounds for batch of dimensions {ndim}!"
raise ValueError(msg)

size = batch[axis]

objects = jnp.moveaxis(self.objects, axis if axis >= 0 else axis - 1, -2)
indices = jnp.arange(size, dtype=objects.dtype)

@jaxtyped(typechecker=typechecker)
def f(
objects: Int[Array, "axis_length path_length"],
) -> Bool[Array, " axis_length"]:
_, index = jnp.unique(
objects,
axis=0,
size=size,
return_index=True,
)

return jnp.isin(indices, index)

for _ in range(max(ndim - 1, 0)):
f = jax.vmap(f)

mask = f(objects)
mask = jnp.moveaxis(mask, -1, axis)
mask = mask & self.mask if self.mask is not None else mask

return eqx.tree_at(
lambda p: p.mask,
self,
mask,
is_leaf=lambda x: x is None,
)

@property
@jaxtyped(typechecker=typechecker)
def path_length(self) -> int:
Expand Down
8 changes: 8 additions & 0 deletions differt/src/differt/scene/_triangle_scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,14 @@ def compute_paths( # noqa: C901
The paths, as class wrapping path vertices, object indices, and a masked
identify valid paths.
The returned paths have the following batch dimensions:
* ``[*transmitters_batch *receivers_batch num_path_candidates]``,
* ``[*transmitters_batch *receivers_batch chunk_size]``,
* or ``[*transmitters_batch *receivers_batch num_rays]``,
depending on the method used.
Raises:
ValueError: If neither ``order`` nor ``path_candidates`` has been provided,
or if both have been provided simultaneously.
Expand Down
79 changes: 78 additions & 1 deletion differt/tests/geometry/test_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
import pytest
from jaxtyping import PRNGKeyArray

from differt.geometry import path_lengths
from differt.geometry import TriangleMesh, path_lengths
from differt.geometry._paths import Paths, SBRPaths, merge_cell_ids
from differt.scene import TriangleScene


def test_merge_cell_ids() -> None:
Expand Down Expand Up @@ -90,6 +91,82 @@ def test_squeeze(
key=key,
).squeeze(axis=axis)

def test_mask_duplicate_objects(self, key: PRNGKeyArray) -> None:
mesh = TriangleMesh.box() # 6 objects
# 6 path candidates, only 3 are unique
path_candidates = jnp.array([
[0, 1, 2],
[1, 0, 2],
[0, 1, 2],
[0, 1, 2],
[2, 3, 4],
[1, 0, 2],
])

# 1 - One TX, one RX, no batch dimension

key_rx, key_tx = jax.random.split(key, 2)

scene = TriangleScene(
transmitters=jax.random.normal(key_tx, (3,)),
receivers=jax.random.normal(key_rx, (3,)),
mesh=mesh,
)

paths = scene.compute_paths(path_candidates=path_candidates)

assert paths.mask is not None
paths = eqx.tree_at(lambda p: p.mask, paths, jnp.ones_like(paths.mask))

got = paths.mask_duplicate_objects()

chex.assert_trees_all_equal(got.mask.sum(axis=-1), 3)

paths = eqx.tree_at(lambda p: p.mask, paths, None)

got = paths.mask_duplicate_objects()

chex.assert_trees_all_equal(got.mask.sum(axis=-1), 3)

with pytest.raises(
ValueError,
match="The provided axis -2 is out-of-bounds for batch of dimensions 1!",
):
_ = paths.mask_duplicate_objects(axis=-2)

# 2 - Many TXs, many RXs, multiple batch dimensions

scene = scene.with_transmitters_grid(2, 1).with_receivers_grid(4, 3)

paths = scene.compute_paths(path_candidates=path_candidates)

assert paths.mask is not None
chex.assert_shape(paths.mask, (1, 2, 3, 4, path_candidates.shape[0]))
paths = eqx.tree_at(lambda p: p.mask, paths, jnp.ones_like(paths.mask))

got = paths.mask_duplicate_objects()

chex.assert_shape(got.mask, (1, 2, 3, 4, path_candidates.shape[0]))
chex.assert_trees_all_equal(got.mask.sum(axis=-1), 3)

paths = eqx.tree_at(lambda p: p.mask, paths, None)

got = paths.mask_duplicate_objects()

chex.assert_shape(got.mask, (1, 2, 3, 4, path_candidates.shape[0]))
chex.assert_trees_all_equal(got.mask.sum(axis=-1), 3)

paths = eqx.tree_at(
lambda p: (p.vertices, p.objects),
paths,
(jnp.swapaxes(paths.vertices, 0, -3), jnp.swapaxes(paths.objects, 0, -2)),
)

got = paths.mask_duplicate_objects(axis=0)

chex.assert_shape(got.mask, (path_candidates.shape[0], 2, 3, 4, 1))
chex.assert_trees_all_equal(got.mask.sum(axis=0), 3)

@pytest.mark.parametrize("path_length", [3, 5])
@pytest.mark.parametrize("batch", [(), (1,), (1, 2, 3, 4)])
@pytest.mark.parametrize("num_objects", [1, 10])
Expand Down

0 comments on commit 30eecd6

Please sign in to comment.