Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(lib): add method to easily mask duplicate path candidates #203

Merged
merged 2 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions differt/src/differt/geometry/_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,71 @@ def squeeze(self, axis: int | Sequence[int] | None = None) -> Self:
is_leaf=lambda x: x is None,
)

@eqx.filter_jit
@jaxtyped(
typechecker=None
) # typing.Self is (currently) not compatible with jaxtyping and beartype
def mask_duplicate_objects(self, axis: int = -1) -> Self:
"""
Return a copy by masking duplicate objects along a given axis.

E.g., when generating path candidates from a generative Machine Learning model,
see :ref:`sampling-paths`, it is possible that the model generates the same
path candidate multiple times. This method allows to mask these duplicates,
while maintaining the same batch dimensions and compatibility with :func:`jax.jit`.

Args:
axis: The batch axis along which the unique values are computed.

It defaults to the last axis, which is the axis where
different path candidates are stored when generating
paths with
:meth:`TriangleScene.compute_paths<differt.scene.TriangleScene.compute_paths>`.

Returns:
A new paths instance with masked duplicate objects.

Raises:
ValueError: If the provided axis is out-of-bounds.
"""
ndim = self.objects.ndim - 1
batch = self.objects.shape[:-1]
if not -ndim <= axis < ndim:
msg = f"The provided axis {axis} is out-of-bounds for batch of dimensions {ndim}!"
raise ValueError(msg)

size = batch[axis]

objects = jnp.moveaxis(self.objects, axis if axis >= 0 else axis - 1, -2)
indices = jnp.arange(size, dtype=objects.dtype)

@jaxtyped(typechecker=typechecker)
def f(
objects: Int[Array, "axis_length path_length"],
) -> Bool[Array, " axis_length"]:
_, index = jnp.unique(
objects,
axis=0,
size=size,
return_index=True,
)

return jnp.isin(indices, index)

for _ in range(max(ndim - 1, 0)):
f = jax.vmap(f)

mask = f(objects)
mask = jnp.moveaxis(mask, -1, axis)
mask = mask & self.mask if self.mask is not None else mask

return eqx.tree_at(
lambda p: p.mask,
self,
mask,
is_leaf=lambda x: x is None,
)

@property
@jaxtyped(typechecker=typechecker)
def path_length(self) -> int:
Expand Down
8 changes: 8 additions & 0 deletions differt/src/differt/scene/_triangle_scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,14 @@ def compute_paths( # noqa: C901
The paths, as class wrapping path vertices, object indices, and a masked
identify valid paths.

The returned paths have the following batch dimensions:

* ``[*transmitters_batch *receivers_batch num_path_candidates]``,
* ``[*transmitters_batch *receivers_batch chunk_size]``,
* or ``[*transmitters_batch *receivers_batch num_rays]``,

depending on the method used.

Raises:
ValueError: If neither ``order`` nor ``path_candidates`` has been provided,
or if both have been provided simultaneously.
Expand Down
79 changes: 78 additions & 1 deletion differt/tests/geometry/test_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
import pytest
from jaxtyping import PRNGKeyArray

from differt.geometry import path_lengths
from differt.geometry import TriangleMesh, path_lengths
from differt.geometry._paths import Paths, SBRPaths, merge_cell_ids
from differt.scene import TriangleScene


def test_merge_cell_ids() -> None:
Expand Down Expand Up @@ -90,6 +91,82 @@ def test_squeeze(
key=key,
).squeeze(axis=axis)

def test_mask_duplicate_objects(self, key: PRNGKeyArray) -> None:
mesh = TriangleMesh.box() # 6 objects
# 6 path candidates, only 3 are unique
path_candidates = jnp.array([
[0, 1, 2],
[1, 0, 2],
[0, 1, 2],
[0, 1, 2],
[2, 3, 4],
[1, 0, 2],
])

# 1 - One TX, one RX, no batch dimension

key_rx, key_tx = jax.random.split(key, 2)

scene = TriangleScene(
transmitters=jax.random.normal(key_tx, (3,)),
receivers=jax.random.normal(key_rx, (3,)),
mesh=mesh,
)

paths = scene.compute_paths(path_candidates=path_candidates)

assert paths.mask is not None
paths = eqx.tree_at(lambda p: p.mask, paths, jnp.ones_like(paths.mask))

got = paths.mask_duplicate_objects()

chex.assert_trees_all_equal(got.mask.sum(axis=-1), 3)

paths = eqx.tree_at(lambda p: p.mask, paths, None)

got = paths.mask_duplicate_objects()

chex.assert_trees_all_equal(got.mask.sum(axis=-1), 3)

with pytest.raises(
ValueError,
match="The provided axis -2 is out-of-bounds for batch of dimensions 1!",
):
_ = paths.mask_duplicate_objects(axis=-2)

# 2 - Many TXs, many RXs, multiple batch dimensions

scene = scene.with_transmitters_grid(2, 1).with_receivers_grid(4, 3)

paths = scene.compute_paths(path_candidates=path_candidates)

assert paths.mask is not None
chex.assert_shape(paths.mask, (1, 2, 3, 4, path_candidates.shape[0]))
paths = eqx.tree_at(lambda p: p.mask, paths, jnp.ones_like(paths.mask))

got = paths.mask_duplicate_objects()

chex.assert_shape(got.mask, (1, 2, 3, 4, path_candidates.shape[0]))
chex.assert_trees_all_equal(got.mask.sum(axis=-1), 3)

paths = eqx.tree_at(lambda p: p.mask, paths, None)

got = paths.mask_duplicate_objects()

chex.assert_shape(got.mask, (1, 2, 3, 4, path_candidates.shape[0]))
chex.assert_trees_all_equal(got.mask.sum(axis=-1), 3)

paths = eqx.tree_at(
lambda p: (p.vertices, p.objects),
paths,
(jnp.swapaxes(paths.vertices, 0, -3), jnp.swapaxes(paths.objects, 0, -2)),
)

got = paths.mask_duplicate_objects(axis=0)

chex.assert_shape(got.mask, (path_candidates.shape[0], 2, 3, 4, 1))
chex.assert_trees_all_equal(got.mask.sum(axis=0), 3)

@pytest.mark.parametrize("path_length", [3, 5])
@pytest.mark.parametrize("batch", [(), (1,), (1, 2, 3, 4)])
@pytest.mark.parametrize("num_objects", [1, 10])
Expand Down
Loading