Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "chore(lib): return array of usize instead (#22)" #27

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions src/rt/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,26 @@ use pyo3::prelude::*;
#[pyfunction]
pub fn generate_all_path_candidates(
py: Python<'_>,
num_primitives: usize,
num_primitives: u32,
order: u32,
) -> &PyArray2<usize> {
) -> &PyArray2<u32> {
if order == 0 {
// One path of size 0
return Array2::default((0, 1)).into_pyarray(py);
} else if num_primitives == 0 {
// Zero path of size order
return Array2::default((order as usize, 0)).into_pyarray(py);
} else if order == 1 {
let mut path_candidates = Array2::default((1, num_primitives));
let mut path_candidates = Array2::default((1, num_primitives as usize));

for j in 0..num_primitives {
path_candidates[(0, j)] = j;
path_candidates[(0, j as usize)] = j;
}
return path_candidates.into_pyarray(py);
}
let num_choices = num_primitives - 1;
let num_choices = (num_primitives - 1) as usize;
let num_candidates_per_batch = num_choices.pow(order - 1);
let num_candidates = num_primitives * num_candidates_per_batch;
let num_candidates = (num_primitives as usize) * num_candidates_per_batch;

let mut path_candidates = Array2::default((order as usize, num_candidates));
let mut batch_size = num_candidates_per_batch;
Expand Down Expand Up @@ -164,11 +164,11 @@ pub fn generate_path_candidates_from_visibility_matrix<'py>(
py: Python<'py>,
visibility_matrix: PyReadonlyArray2<'py, bool>,
order: u32,
) -> &'py PyArray2<usize> {
) -> &'py PyArray2<u32> {
let num_primitives = visibility_matrix.shape()[0];

if order <= 1 || num_primitives == 0 {
return generate_all_path_candidates(py, num_primitives, order);
return generate_all_path_candidates(py, num_primitives as u32, order);
}

let _indices = where_true(&visibility_matrix.as_array());
Expand Down Expand Up @@ -239,9 +239,9 @@ mod tests {
].t().to_owned()
)]
fn test_generate_all_path_candidates(
#[case] num_primitives: usize,
#[case] num_primitives: u32,
#[case] order: u32,
#[case] expected: Array2<usize>,
#[case] expected: Array2<u32>,
) {
Python::with_gil(|py| {
let got = generate_all_path_candidates(py, num_primitives, order);
Expand Down
Loading