Skip to content

Commit

Permalink
feat(lib): implementing path candidates iterator (#28)
Browse files Browse the repository at this point in the history
* feat(lib): implementing path candidates iterator

Cleaning must still be done

* chore(lib): fixes and tests
  • Loading branch information
jeertmans authored Jan 16, 2024
1 parent ffce9cf commit 57afdca
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 65 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ repos:
args: [--fix]
- id: ruff-format
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.346
rev: v1.1.347
hooks:
- id: pyright
- repo: https://github.com/doublify/pre-commit-rust
Expand Down
13 changes: 13 additions & 0 deletions .rustfmt.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
condense_wildcard_suffixes = true
error_on_line_overflow = true
error_on_unformatted = true
force_multiline_blocks = true
format_code_in_doc_comments = true
format_macro_matchers = true
format_strings = true
imports_granularity = "Crate"
match_block_trailing_comma = true
normalize_doc_attributes = true
unstable_features = true
version = "Two"
wrap_comments = true
5 changes: 5 additions & 0 deletions python/differt/_core/rt/utils.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from collections.abc import Iterator

import numpy as np
from jaxtyping import UInt

def generate_all_path_candidates(
num_primitives: int, order: int
) -> UInt[np.ndarray, "num_path_candidates order"]: ...
def generate_all_path_candidates_iter(
num_primitives: int, order: int
) -> Iterator[UInt[np.ndarray, " order"]]: ...
22 changes: 21 additions & 1 deletion python/differt/rt/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Ray Tracing utilies."""
from collections.abc import Iterator

import jax.numpy as jnp
from jaxtyping import Array, Bool, Float, UInt, jaxtyped
Expand Down Expand Up @@ -34,7 +35,26 @@ def generate_all_path_candidates(
"""
return jnp.asarray(
_core.rt.utils.generate_all_path_candidates(num_primitives, order),
dtype=jnp.uint32,
)


@jaxtyped(typechecker=typechecker)
def generate_all_path_candidates_iter(
num_primitives: int, order: int
) -> Iterator[UInt[Array, " order"]]:
"""
Iterator variant of :func:`generate_all_path_candidates`.
Args:
num_primitives: The (positive) number of primitives.
order: The path order.
Returns:
An iterator of unsigned arrays with primitive indices.
"""
return map(
jnp.asarray,
_core.rt.utils.generate_all_path_candidates_iter(num_primitives, order),
)


Expand Down
14 changes: 8 additions & 6 deletions src/geometry/triangle_mesh.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
use std::fs::File;
use std::io::BufReader;
use std::{fs::File, io::BufReader};

use numpy::{Element, PyArray2};
use obj::raw::object::{parse_obj, RawObj};
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::PyType;
use pyo3::{exceptions::PyValueError, prelude::*, types::PyType};

#[pyclass]
struct TriangleMesh {
Expand Down Expand Up @@ -84,7 +81,12 @@ impl TryFrom<RawObj> for TriangleMesh {
PTN(v) if v.len() == 3 => {
triangles.push((v[0].0, v[1].0, v[2].0));
},
_ => return Err(PyValueError::new_err("Cannot create TriangleMesh from an object that contains something else than triangles")),
_ => {
return Err(PyValueError::new_err(
"Cannot create TriangleMesh from an object that contains something else \
than triangles",
));
},
}
}

Expand Down
107 changes: 51 additions & 56 deletions src/rt/utils.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use numpy::ndarray::parallel::prelude::*;
use numpy::ndarray::{s, Array2, ArrayView2, Axis};
use numpy::{IntoPyArray, PyArray1, PyArray2, PyReadonlyArray2};
use numpy::{
ndarray::{parallel::prelude::*, s, Array2, ArrayView2, Axis},
IntoPyArray, PyArray1, PyArray2, PyReadonlyArray2,
};
use pyo3::prelude::*;

/// Generate an array of all path candidates (assuming fully connected primitives).
/// Generate an array of all path candidates (assuming fully connected
/// primitives).
#[pyfunction]
pub fn generate_all_path_candidates(
py: Python<'_>,
Expand Down Expand Up @@ -49,71 +51,84 @@ pub fn generate_all_path_candidates(
path_candidates.into_pyarray(py)
}

/// Iterator variant of [`generate_all_path_candidates`].
#[pyclass]
pub struct AllPathCandidates {
/// Number of primitives.
num_primitives: u32,
num_primitives: usize,
/// Path order.
order: u32,
/// Exact number of path candidates that will be generated.
num_candidates: usize,
/// The index of the current path candidate.
index: usize,
/// Last path candidate.
path_candidate: Vec<u32>,
counter: Vec<u32>,
path_candidate: Vec<usize>,
/// Count how many times a given index has been changed.
counter: Vec<usize>,
done: bool,
}

impl AllPathCandidates {
#[inline]
fn new(num_primitives: u32, order: u32) -> Self {
let num_choices = num_primitives.saturating_sub(1) as usize;
fn new(num_primitives: usize, order: u32) -> Self {
let num_choices = num_primitives.saturating_sub(1);
let num_candidates_per_batch = num_choices.pow(order.saturating_sub(1));
let num_candidates = (num_primitives as usize) * num_candidates_per_batch;
let num_candidates = num_primitives * num_candidates_per_batch;
let index = 0;
let path_candidate = (0..order as usize).collect(); // [0, 1, 2, ..., order - 1]
let mut counter = vec![1; order as usize];
counter[0] = 0;

Self {
num_primitives,
order,
num_candidates,
index: 0,
path_candidate: (0..order).collect(),
counter: vec![2; order as usize],
index,
path_candidate,
counter,
done: num_primitives == 0,
}
}
}

impl Iterator for AllPathCandidates {
type Item = Vec<u32>;
type Item = Vec<usize>;

#[inline]
fn next(&mut self) -> Option<Self::Item> {
self.index += 1;

if self.done {
return None;
}
// 1. Output is generated as a copy of the current path_candidate
let path_candidate = self.path_candidate.clone();

let start = self
// 2. Generate the next path candidate

// Identify which 'index' should be increased by 1,
// from right to left.
if let Some(start) = self
.counter
.iter()
.rposition(|&count| count < self.num_primitives)?;

println!("Actual counter: {:?}", self.counter);
println!("start index:{:?}", start);

self.counter[start] += 1;
self.path_candidate[start] = (self.path_candidate[start] + 1) % self.num_primitives;

for i in (start + 1)..(self.order as usize) {
self.path_candidate[i] = (self.path_candidate[i - 1] + 1) % self.num_primitives;
self.counter[i] = 2;
.rposition(|&count| count < self.num_primitives - 1)
{
self.counter[start] += 1;
self.path_candidate[start] = (self.path_candidate[start] + 1) % self.num_primitives;

for i in (start + 1)..(self.order as usize) {
self.path_candidate[i] = (self.path_candidate[i - 1] + 1) % self.num_primitives;
self.counter[i] = 1;
}
} else {
self.done = true;
}
println!("{:?}", self.counter);

Some(path_candidate)
}

#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let rem = self.num_candidates - self.index;
let rem = self.num_candidates.saturating_sub(self.index);

(rem, Some(rem))
}
Expand All @@ -130,7 +145,10 @@ impl AllPathCandidates {
slf
}

fn __next__<'py>(mut slf: PyRefMut<'py, Self>, py: Python<'py>) -> Option<&'py PyArray1<u32>> {
fn __next__<'py>(
mut slf: PyRefMut<'py, Self>,
py: Python<'py>,
) -> Option<&'py PyArray1<usize>> {
slf.next().map(|v| PyArray1::from_vec(py, v))
}
}
Expand All @@ -139,7 +157,7 @@ impl AllPathCandidates {
#[pyfunction]
pub fn generate_all_path_candidates_iter(
_py: Python<'_>,
num_primitives: u32,
num_primitives: usize,
order: u32,
) -> AllPathCandidates {
AllPathCandidates::new(num_primitives, order)
Expand Down Expand Up @@ -194,6 +212,7 @@ impl PathCandidates {
pub(crate) fn create_module(py: Python<'_>) -> PyResult<&PyModule> {
let m = pyo3::prelude::PyModule::new(py, "utils")?;
m.add_function(wrap_pyfunction!(generate_all_path_candidates, m)?)?;
m.add_function(wrap_pyfunction!(generate_all_path_candidates_iter, m)?)?;
m.add_function(wrap_pyfunction!(
generate_path_candidates_from_visibility_matrix,
m
Expand Down Expand Up @@ -250,30 +269,6 @@ mod tests {
});
}

/*
#[rstest]
#[should_panic] // Because we do not handle this edge case (empty iterator)
#[case(0, 0)]
#[should_panic] // Because we do not handle this edge case (empty iterator)
#[case(3, 0)]
#[should_panic] // Because we do not handle this edge case (empty iterator)
#[case(0, 3)]
#[case(9, 1)]
#[case(3, 1)]
#[case(3, 2)]
#[case(3, 3)]
fn test_generate_all_path_candidates_iter(#[case] num_primitives: u32, #[case] order: u32) {
Python::with_gil(|py| {
let got: Vec<Vec<_>> =
generate_all_path_candidates_iter(py, num_primitives, order).collect();
let expected = generate_all_path_candidates(py, num_primitives, order);
let got = PyArray2::from_vec2(py, &got).unwrap();
assert_eq!(got.to_owned_array().t(), expected.to_owned_array());
});
}*/

#[rstest]
#[case(
array![
Expand Down
26 changes: 25 additions & 1 deletion tests/rt/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
import pytest
from jaxtyping import Array

from differt.rt.utils import generate_all_path_candidates, rays_intersect_triangles
from differt.rt.utils import (
generate_all_path_candidates,
generate_all_path_candidates_iter,
rays_intersect_triangles,
)
from differt.utils import sorted_array2


Expand Down Expand Up @@ -52,6 +56,26 @@ def test_generate_all_path_candidates(
chex.assert_trees_all_equal(got, expected)


@pytest.mark.parametrize(
"num_primitives,order",
[
(3, 1),
(3, 2),
(3, 3),
(5, 4),
],
)
def test_generate_all_path_candidates_iter(num_primitives: int, order: int) -> None:
expected = generate_all_path_candidates(num_primitives, order)
expected = sorted_array2(expected.T).T
got = list(generate_all_path_candidates_iter(num_primitives, order))
got = jnp.asarray(got).T
got = sorted_array2(got.T).T

chex.assert_trees_all_equal_shapes_and_dtypes(got, expected)
chex.assert_trees_all_equal(got, expected)


@pytest.mark.parametrize(
"ray_orig,ray_dest,expected",
[
Expand Down

0 comments on commit 57afdca

Please sign in to comment.