From 57afdca1adb41556efdd40aeef8738a78b37f904 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Tue, 16 Jan 2024 20:35:13 +0100 Subject: [PATCH] feat(lib): implementing path candidates iterator (#28) * feat(lib): implementing path candidates iterator Cleaning must still be done * chore(lib): fixes and tests --- .pre-commit-config.yaml | 2 +- .rustfmt.toml | 13 ++++ python/differt/_core/rt/utils.pyi | 5 ++ python/differt/rt/utils.py | 22 +++++- src/geometry/triangle_mesh.rs | 14 ++-- src/rt/utils.rs | 107 ++++++++++++++---------------- tests/rt/test_utils.py | 26 +++++++- 7 files changed, 124 insertions(+), 65 deletions(-) create mode 100644 .rustfmt.toml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8d39abab..b99eceda 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,7 +20,7 @@ repos: args: [--fix] - id: ruff-format - repo: https://github.com/RobertCraigie/pyright-python - rev: v1.1.346 + rev: v1.1.347 hooks: - id: pyright - repo: https://github.com/doublify/pre-commit-rust diff --git a/.rustfmt.toml b/.rustfmt.toml new file mode 100644 index 00000000..414f4784 --- /dev/null +++ b/.rustfmt.toml @@ -0,0 +1,13 @@ +condense_wildcard_suffixes = true +error_on_line_overflow = true +error_on_unformatted = true +force_multiline_blocks = true +format_code_in_doc_comments = true +format_macro_matchers = true +format_strings = true +imports_granularity = "Crate" +match_block_trailing_comma = true +normalize_doc_attributes = true +unstable_features = true +version = "Two" +wrap_comments = true diff --git a/python/differt/_core/rt/utils.pyi b/python/differt/_core/rt/utils.pyi index ef12994c..62b5d955 100644 --- a/python/differt/_core/rt/utils.pyi +++ b/python/differt/_core/rt/utils.pyi @@ -1,6 +1,11 @@ +from collections.abc import Iterator + import numpy as np from jaxtyping import UInt def generate_all_path_candidates( num_primitives: int, order: int ) -> UInt[np.ndarray, "num_path_candidates order"]: ... +def generate_all_path_candidates_iter( + num_primitives: int, order: int +) -> Iterator[UInt[np.ndarray, " order"]]: ... diff --git a/python/differt/rt/utils.py b/python/differt/rt/utils.py index 360a057a..a99880ab 100644 --- a/python/differt/rt/utils.py +++ b/python/differt/rt/utils.py @@ -1,4 +1,5 @@ """Ray Tracing utilies.""" +from collections.abc import Iterator import jax.numpy as jnp from jaxtyping import Array, Bool, Float, UInt, jaxtyped @@ -34,7 +35,26 @@ def generate_all_path_candidates( """ return jnp.asarray( _core.rt.utils.generate_all_path_candidates(num_primitives, order), - dtype=jnp.uint32, + ) + + +@jaxtyped(typechecker=typechecker) +def generate_all_path_candidates_iter( + num_primitives: int, order: int +) -> Iterator[UInt[Array, " order"]]: + """ + Iterator variant of :func:`generate_all_path_candidates`. + + Args: + num_primitives: The (positive) number of primitives. + order: The path order. + + Returns: + An iterator of unsigned arrays with primitive indices. + """ + return map( + jnp.asarray, + _core.rt.utils.generate_all_path_candidates_iter(num_primitives, order), ) diff --git a/src/geometry/triangle_mesh.rs b/src/geometry/triangle_mesh.rs index 1ac0724b..860abf29 100644 --- a/src/geometry/triangle_mesh.rs +++ b/src/geometry/triangle_mesh.rs @@ -1,11 +1,8 @@ -use std::fs::File; -use std::io::BufReader; +use std::{fs::File, io::BufReader}; use numpy::{Element, PyArray2}; use obj::raw::object::{parse_obj, RawObj}; -use pyo3::exceptions::PyValueError; -use pyo3::prelude::*; -use pyo3::types::PyType; +use pyo3::{exceptions::PyValueError, prelude::*, types::PyType}; #[pyclass] struct TriangleMesh { @@ -84,7 +81,12 @@ impl TryFrom for TriangleMesh { PTN(v) if v.len() == 3 => { triangles.push((v[0].0, v[1].0, v[2].0)); }, - _ => return Err(PyValueError::new_err("Cannot create TriangleMesh from an object that contains something else than triangles")), + _ => { + return Err(PyValueError::new_err( + "Cannot create TriangleMesh from an object that contains something else \ + than triangles", + )); + }, } } diff --git a/src/rt/utils.rs b/src/rt/utils.rs index 61c1a798..2d7dcb5e 100644 --- a/src/rt/utils.rs +++ b/src/rt/utils.rs @@ -1,9 +1,11 @@ -use numpy::ndarray::parallel::prelude::*; -use numpy::ndarray::{s, Array2, ArrayView2, Axis}; -use numpy::{IntoPyArray, PyArray1, PyArray2, PyReadonlyArray2}; +use numpy::{ + ndarray::{parallel::prelude::*, s, Array2, ArrayView2, Axis}, + IntoPyArray, PyArray1, PyArray2, PyReadonlyArray2, +}; use pyo3::prelude::*; -/// Generate an array of all path candidates (assuming fully connected primitives). +/// Generate an array of all path candidates (assuming fully connected +/// primitives). #[pyfunction] pub fn generate_all_path_candidates( py: Python<'_>, @@ -49,10 +51,11 @@ pub fn generate_all_path_candidates( path_candidates.into_pyarray(py) } +/// Iterator variant of [`generate_all_path_candidates`]. #[pyclass] pub struct AllPathCandidates { /// Number of primitives. - num_primitives: u32, + num_primitives: usize, /// Path order. order: u32, /// Exact number of path candidates that will be generated. @@ -60,60 +63,72 @@ pub struct AllPathCandidates { /// The index of the current path candidate. index: usize, /// Last path candidate. - path_candidate: Vec, - counter: Vec, + path_candidate: Vec, + /// Count how many times a given index has been changed. + counter: Vec, + done: bool, } impl AllPathCandidates { #[inline] - fn new(num_primitives: u32, order: u32) -> Self { - let num_choices = num_primitives.saturating_sub(1) as usize; + fn new(num_primitives: usize, order: u32) -> Self { + let num_choices = num_primitives.saturating_sub(1); let num_candidates_per_batch = num_choices.pow(order.saturating_sub(1)); - let num_candidates = (num_primitives as usize) * num_candidates_per_batch; + let num_candidates = num_primitives * num_candidates_per_batch; + let index = 0; + let path_candidate = (0..order as usize).collect(); // [0, 1, 2, ..., order - 1] + let mut counter = vec![1; order as usize]; + counter[0] = 0; Self { num_primitives, order, num_candidates, - index: 0, - path_candidate: (0..order).collect(), - counter: vec![2; order as usize], + index, + path_candidate, + counter, + done: num_primitives == 0, } } } impl Iterator for AllPathCandidates { - type Item = Vec; + type Item = Vec; #[inline] fn next(&mut self) -> Option { - self.index += 1; - + if self.done { + return None; + } + // 1. Output is generated as a copy of the current path_candidate let path_candidate = self.path_candidate.clone(); - let start = self + // 2. Generate the next path candidate + + // Identify which 'index' should be increased by 1, + // from right to left. + if let Some(start) = self .counter .iter() - .rposition(|&count| count < self.num_primitives)?; - - println!("Actual counter: {:?}", self.counter); - println!("start index:{:?}", start); - - self.counter[start] += 1; - self.path_candidate[start] = (self.path_candidate[start] + 1) % self.num_primitives; - - for i in (start + 1)..(self.order as usize) { - self.path_candidate[i] = (self.path_candidate[i - 1] + 1) % self.num_primitives; - self.counter[i] = 2; + .rposition(|&count| count < self.num_primitives - 1) + { + self.counter[start] += 1; + self.path_candidate[start] = (self.path_candidate[start] + 1) % self.num_primitives; + + for i in (start + 1)..(self.order as usize) { + self.path_candidate[i] = (self.path_candidate[i - 1] + 1) % self.num_primitives; + self.counter[i] = 1; + } + } else { + self.done = true; } - println!("{:?}", self.counter); Some(path_candidate) } #[inline] fn size_hint(&self) -> (usize, Option) { - let rem = self.num_candidates - self.index; + let rem = self.num_candidates.saturating_sub(self.index); (rem, Some(rem)) } @@ -130,7 +145,10 @@ impl AllPathCandidates { slf } - fn __next__<'py>(mut slf: PyRefMut<'py, Self>, py: Python<'py>) -> Option<&'py PyArray1> { + fn __next__<'py>( + mut slf: PyRefMut<'py, Self>, + py: Python<'py>, + ) -> Option<&'py PyArray1> { slf.next().map(|v| PyArray1::from_vec(py, v)) } } @@ -139,7 +157,7 @@ impl AllPathCandidates { #[pyfunction] pub fn generate_all_path_candidates_iter( _py: Python<'_>, - num_primitives: u32, + num_primitives: usize, order: u32, ) -> AllPathCandidates { AllPathCandidates::new(num_primitives, order) @@ -194,6 +212,7 @@ impl PathCandidates { pub(crate) fn create_module(py: Python<'_>) -> PyResult<&PyModule> { let m = pyo3::prelude::PyModule::new(py, "utils")?; m.add_function(wrap_pyfunction!(generate_all_path_candidates, m)?)?; + m.add_function(wrap_pyfunction!(generate_all_path_candidates_iter, m)?)?; m.add_function(wrap_pyfunction!( generate_path_candidates_from_visibility_matrix, m @@ -250,30 +269,6 @@ mod tests { }); } - /* - #[rstest] - #[should_panic] // Because we do not handle this edge case (empty iterator) - #[case(0, 0)] - #[should_panic] // Because we do not handle this edge case (empty iterator) - #[case(3, 0)] - #[should_panic] // Because we do not handle this edge case (empty iterator) - #[case(0, 3)] - #[case(9, 1)] - #[case(3, 1)] - #[case(3, 2)] - #[case(3, 3)] - fn test_generate_all_path_candidates_iter(#[case] num_primitives: u32, #[case] order: u32) { - Python::with_gil(|py| { - let got: Vec> = - generate_all_path_candidates_iter(py, num_primitives, order).collect(); - let expected = generate_all_path_candidates(py, num_primitives, order); - - let got = PyArray2::from_vec2(py, &got).unwrap(); - - assert_eq!(got.to_owned_array().t(), expected.to_owned_array()); - }); - }*/ - #[rstest] #[case( array![ diff --git a/tests/rt/test_utils.py b/tests/rt/test_utils.py index 9e262861..66e501dc 100644 --- a/tests/rt/test_utils.py +++ b/tests/rt/test_utils.py @@ -5,7 +5,11 @@ import pytest from jaxtyping import Array -from differt.rt.utils import generate_all_path_candidates, rays_intersect_triangles +from differt.rt.utils import ( + generate_all_path_candidates, + generate_all_path_candidates_iter, + rays_intersect_triangles, +) from differt.utils import sorted_array2 @@ -52,6 +56,26 @@ def test_generate_all_path_candidates( chex.assert_trees_all_equal(got, expected) +@pytest.mark.parametrize( + "num_primitives,order", + [ + (3, 1), + (3, 2), + (3, 3), + (5, 4), + ], +) +def test_generate_all_path_candidates_iter(num_primitives: int, order: int) -> None: + expected = generate_all_path_candidates(num_primitives, order) + expected = sorted_array2(expected.T).T + got = list(generate_all_path_candidates_iter(num_primitives, order)) + got = jnp.asarray(got).T + got = sorted_array2(got.T).T + + chex.assert_trees_all_equal_shapes_and_dtypes(got, expected) + chex.assert_trees_all_equal(got, expected) + + @pytest.mark.parametrize( "ray_orig,ray_dest,expected", [