Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(test): typecheck dataclasses #20

Merged
merged 4 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions python/differt/geometry/triangle_mesh.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""Mesh geometry made of triangles and utilities."""
from __future__ import annotations

from functools import cached_property
from pathlib import Path
from typing import Any
Expand Down Expand Up @@ -106,19 +104,19 @@ def paths_intersect_triangles(
return jnp.any(intersect, axis=(0, 2))


@jaxtyped(typechecker=typechecker)
class TriangleMesh(eqx.Module):
"""
A simple geometry made of triangles.

Args:
vertices: The array of triangle vertices.
triangles: The array of triangle indices.

"""

vertices: Float[Array, "num_vertices 3"]
vertices: Float[Array, "num_vertices 3"] = eqx.field(converter=jnp.asarray)
"""The array of triangle vertices."""
triangles: UInt[Array, "num_triangles 3"]
triangles: UInt[Array, "num_triangles 3"] = eqx.field(converter=jnp.asarray)
"""The array of triangle indices."""

@cached_property
Expand All @@ -136,7 +134,7 @@ def diffraction_edges(self) -> UInt[Array, "num_edges 3"]:
raise NotImplementedError

@classmethod
def load_obj(cls, file: Path) -> TriangleMesh:
def load_obj(cls, file: Path) -> "TriangleMesh":
"""
Load a triangle mesh from a Wavefront .obj file.

Expand All @@ -153,7 +151,8 @@ def load_obj(cls, file: Path) -> TriangleMesh:
"""
mesh = _core.geometry.triangle_mesh.TriangleMesh.load_obj(str(file))
return cls(
vertices=jnp.asarray(mesh.vertices), triangles=jnp.asarray(mesh.triangles)
vertices=mesh.vertices,
triangles=mesh.triangles,
)

def plot(self, **kwargs: Any) -> Any:
Expand Down
31 changes: 30 additions & 1 deletion src/geometry/triangle_mesh.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,51 @@
use std::fs::File;
use std::io::BufReader;

use numpy::{Element, PyArray2};
use obj::raw::object::{parse_obj, RawObj};
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::PyType;

#[pyclass]
#[pyo3(get_all)]
struct TriangleMesh {
/// Array of size [num_vertices 3].
vertices: Vec<(f32, f32, f32)>,
/// Array of size [num_triangles 3].
triangles: Vec<(usize, usize, usize)>,
}

#[inline]
fn pyarray2_from_vec_tuple<'py, T: Copy + Element>(
py: Python<'py>,
v: &[(T, T, T)],
) -> &'py PyArray2<T> {
let n = v.len();
unsafe {
let arr = PyArray2::<T>::new(py, [n, 3], false);

for i in 0..n {
let tup = v.get_unchecked(i);
arr.uget_raw([i, 0]).write(tup.0);
arr.uget_raw([i, 1]).write(tup.1);
arr.uget_raw([i, 2]).write(tup.2);
}
arr
}
}

#[pymethods]
impl TriangleMesh {
#[getter]
fn vertices<'py>(&self, py: Python<'py>) -> &'py PyArray2<f32> {
pyarray2_from_vec_tuple(py, &self.vertices)
}

#[getter]
fn triangles<'py>(&self, py: Python<'py>) -> &'py PyArray2<usize> {
pyarray2_from_vec_tuple(py, &self.triangles)
}

#[classmethod]
fn load_obj(_: &PyType, filename: &str) -> PyResult<Self> {
let input = BufReader::new(File::open(filename)?);
Expand Down
36 changes: 35 additions & 1 deletion tests/geometry/test_triangle_mesh.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
from collections.abc import Iterator
from contextlib import AbstractContextManager
from contextlib import nullcontext as does_not_raise
from pathlib import Path

import chex
import jax.numpy as jnp
import jaxtyping
import pytest
from chex import Array

from differt.geometry.triangle_mesh import (
TriangleMesh,
triangles_contain_vertices_assuming_inside_same_plane,
)

from ..utils import random_inputs


@pytest.fixture(scope="module")
def two_buildings_obj_file() -> Iterator[Path]:
Expand All @@ -33,6 +38,36 @@ def sphere() -> Iterator[TriangleMesh]:
yield TriangleMesh(vertices=vertices, triangles=triangles)


@pytest.mark.parametrize(
("triangle_vertices,vertices,expectation"),
[
((20, 10, 3, 3), (20, 10, 3), does_not_raise()),
((10, 3, 3), (10, 3), does_not_raise()),
((3, 3), (3,), does_not_raise()),
(
(3, 3),
(4,),
pytest.raises(TypeError),
),
(
(10, 3, 3),
(12, 3),
pytest.raises(TypeError),
),
],
)
@random_inputs("triangle_vertices", "vertices")
def test_triangles_contain_vertices_assuming_inside_same_planes_random_inputs(
triangle_vertices: Array,
vertices: Array,
expectation: AbstractContextManager[Exception],
) -> None:
with expectation:
_ = triangles_contain_vertices_assuming_inside_same_plane(
triangle_vertices, vertices
)


def test_triangles_contain_vertices_assuming_inside_same_planes() -> None:
triangle_vertices = jnp.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]])
vertices = jnp.array(
Expand Down Expand Up @@ -63,7 +98,6 @@ def test_triangles_contain_vertices_assuming_inside_same_planes() -> None:


class TestTriangleMesh:
@pytest.mark.xfail(reason="Unknown, to be investigated...")
def test_invalid_args(self) -> None:
vertices = jnp.ones((10, 2))
triangles = jnp.ones((20, 3))
Expand Down
3 changes: 2 additions & 1 deletion tests/geometry/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from chex import Array

from differt.geometry.utils import pairwise_cross
from tests.utils import random_inputs

from ..utils import random_inputs


def test_pairwise_cross() -> None:
Expand Down
3 changes: 2 additions & 1 deletion tests/rt/test_image_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
image_of_vertices_with_respect_to_mirrors,
intersection_of_line_segments_with_planes,
)
from tests.utils import random_inputs

from ..utils import random_inputs


def test_image_of_vertices_with_respect_to_mirrors() -> None:
Expand Down
Loading