diff --git a/src/rt/utils.rs b/src/rt/utils.rs index 2d7dcb5e..7c10c1ec 100644 --- a/src/rt/utils.rs +++ b/src/rt/utils.rs @@ -54,37 +54,32 @@ pub fn generate_all_path_candidates( /// Iterator variant of [`generate_all_path_candidates`]. #[pyclass] pub struct AllPathCandidates { - /// Number of primitives. + /// Number of primitives to choose from. num_primitives: usize, /// Path order. order: u32, - /// Exact number of path candidates that will be generated. - num_candidates: usize, - /// The index of the current path candidate. - index: usize, /// Last path candidate. path_candidate: Vec, /// Count how many times a given index has been changed. counter: Vec, + /// Whether iterator is consumed. done: bool, } impl AllPathCandidates { #[inline] fn new(num_primitives: usize, order: u32) -> Self { - let num_choices = num_primitives.saturating_sub(1); - let num_candidates_per_batch = num_choices.pow(order.saturating_sub(1)); - let num_candidates = num_primitives * num_candidates_per_batch; - let index = 0; let path_candidate = (0..order as usize).collect(); // [0, 1, 2, ..., order - 1] let mut counter = vec![1; order as usize]; - counter[0] = 0; + + // Must check in case oder is zero. + if let Some(count) = counter.get_mut(0) { + *count = 0; + } Self { num_primitives, order, - num_candidates, - index, path_candidate, counter, done: num_primitives == 0, @@ -100,13 +95,8 @@ impl Iterator for AllPathCandidates { if self.done { return None; } - // 1. Output is generated as a copy of the current path_candidate let path_candidate = self.path_candidate.clone(); - // 2. Generate the next path candidate - - // Identify which 'index' should be increased by 1, - // from right to left. if let Some(start) = self .counter .iter() @@ -125,18 +115,6 @@ impl Iterator for AllPathCandidates { Some(path_candidate) } - - #[inline] - fn size_hint(&self) -> (usize, Option) { - let rem = self.num_candidates.saturating_sub(self.index); - - (rem, Some(rem)) - } - - #[inline] - fn count(self) -> usize { - self.num_candidates - } } #[pymethods] @@ -268,6 +246,26 @@ mod tests { assert_eq!(got.to_owned_array(), expected); }); } + + #[rstest] + #[case(0, 0, 0)] + #[case(3, 0, 1)] + #[case(0, 3, 0)] + #[case(9, 1, 9)] + #[case(3, 1, 3)] + #[case(3, 2, 6)] + #[case(3, 3, 12)] + fn test_generate_all_path_candidates_iter_count( + #[case] num_primitives: usize, + #[case] order: u32, + #[case] expected: usize, + ) { + Python::with_gil(|py| { + let got = generate_all_path_candidates_iter(py, num_primitives, order); + + assert_eq!(got.count(), expected); + }); + } #[rstest] #[case(