Skip to content

Commit

Permalink
chore(lib): use Equinox's Module instead of Chex's dataclass
Browse files Browse the repository at this point in the history
  • Loading branch information
jeertmans committed Jan 5, 2024
1 parent 9b75479 commit 39da9a4
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 20 deletions.
47 changes: 32 additions & 15 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ classifiers = [
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = [
"chex>=0.1.84",
"jax>=0.4.20",
"jaxtyping>=0.2.24",
"numpy>=1.26.1",
"optax>=0.1.7",
"typing-extensions>=4.9.0;python_version < '3.10'",
"equinox>=0.11.2",
]
description = "Differentiable Ray Tracing Toolbox for Radio Propagation Simulations"
dynamic = ["license", "readme", "version"]
Expand Down Expand Up @@ -90,6 +90,7 @@ github-action = [
"jax[cpu]>=0.4.20",
]
test = [
"chex>=0.1.84",
"differt[all]",
"open3d-cpu>=0.17.0;python_version <= '3.10' and sys_platform == 'linux'",
"pytest>=7.4.3",
Expand Down
16 changes: 12 additions & 4 deletions python/differt/geometry/triangle_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from pathlib import Path
from typing import Any

import equinox as eqx
import jax.numpy as jnp
import numpy as np
from chex import dataclass
from jaxtyping import Array, Bool, Float, Scalar, UInt, jaxtyped
from typeguard import typechecked as typechecker

Expand Down Expand Up @@ -106,8 +106,7 @@ def paths_intersect_triangles(
return jnp.any(intersect, axis=(0, 2))


@dataclass
class TriangleMesh:
class TriangleMesh(eqx.Module):
"""
A simple geometry made of triangles.
Expand Down Expand Up @@ -158,7 +157,16 @@ def load_obj(cls, file: Path) -> TriangleMesh:
)

def plot(self, **kwargs: Any) -> Any:
"""*TODO*."""
"""
Plot this mesh on a 3D scene.
Args:
kwargs: Keyword arguments passed to
:py:func:`draw_mesh<differt.plotting.draw_mesh>`.
Returns:
The resulting plot output.
"""
return draw_mesh(
vertices=np.asarray(self.vertices),
triangles=np.asarray(self.triangles),
Expand Down

0 comments on commit 39da9a4

Please sign in to comment.