From 094c1e4a6a10401e2a1e9b91a1497713843348f6 Mon Sep 17 00:00:00 2001 From: Buck McCready Date: Tue, 18 May 2021 17:37:47 -0700 Subject: [PATCH] Changes to use ControlFlow trait for visitors - Changed query functions to use visitor and ControlFlow trait rather than using a FnMut trait returning a bool - Traits are implemented for FnMut and () to allow for simple use with closures - Idea taken from petgraph --- README.md | 6 +- benches/bench_static_aabb2d_index.rs | 2 - examples/build_and_query.rs | 4 +- src/core.rs | 119 +++++++++++++++++++++++++++ src/lib.rs | 6 +- src/static_aabb2d_index.rs | 68 +++++++-------- tests/test.rs | 30 ++++--- 7 files changed, 176 insertions(+), 59 deletions(-) diff --git a/README.md b/README.md index b066217..25afbc1 100644 --- a/README.md +++ b/README.md @@ -31,10 +31,10 @@ Fast static spatial index data structure for 2D axis aligned bounding boxes util assert_eq!(query_results, vec![1]); // the query may also be done with a visiting function that can stop the query early let mut visited_results: Vec = Vec::new(); - let mut visitor = |box_added_pos: usize| -> bool { + let mut visitor = |box_added_pos: usize| -> Control<()> { visited_results.push(box_added_pos); - // return true to continue visiting results, false to stop early - true + // return continue to continue visiting results, break to stop early + Control::Continue }; index.visit_query(-1.0, -1.0, -0.5, -0.5, &mut visitor); diff --git a/benches/bench_static_aabb2d_index.rs b/benches/bench_static_aabb2d_index.rs index 0142145..c82bf8d 100644 --- a/benches/bench_static_aabb2d_index.rs +++ b/benches/bench_static_aabb2d_index.rs @@ -91,7 +91,6 @@ fn bench_visit_query(b: &mut Bencher, index: &StaticAABB2DIndex) { b.max_y + delta, &mut |index: usize| { query_results.push(index); - true }, ); } @@ -150,7 +149,6 @@ fn bench_visit_query_reuse_stack(b: &mut Bencher, index: &StaticAABB2DIndex b.max_y + delta, &mut |index: usize| { query_results.push(index); - true }, &mut stack, ); diff --git a/examples/build_and_query.rs b/examples/build_and_query.rs index f8caef7..7b1cb5a 100644 --- a/examples/build_and_query.rs +++ b/examples/build_and_query.rs @@ -20,10 +20,8 @@ fn main() { assert_eq!(query_results, vec![1]); // the query may also be done with a visiting function that can stop the query early let mut visited_results: Vec = Vec::new(); - let mut visitor = |box_added_pos: usize| -> bool { + let mut visitor = |box_added_pos: usize| { visited_results.push(box_added_pos); - // return true to continue visiting results, false to stop early - true }; index.visit_query(-1.0, -1.0, -0.5, -0.5, &mut visitor); diff --git a/src/core.rs b/src/core.rs index 5a6366c..429595b 100644 --- a/src/core.rs +++ b/src/core.rs @@ -38,6 +38,7 @@ impl IndexableNum for f32 {} impl IndexableNum for f64 {} /// Simple 2D axis aligned bounding box which holds the extents of a 2D box. +#[allow(clippy::upper_case_acronyms)] #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub struct AABB { /// Min x extent of the axis aligned bounding box. @@ -132,3 +133,121 @@ where self.min_x <= min_x && self.min_y <= min_y && self.max_x >= max_x && self.max_y >= max_y } } + +/// Basic control flow enum that can be used when visiting query results. +#[derive(Debug)] +pub enum Control { + /// Indicates to the query function to continue visiting results. + Continue, + /// Indicates to the query function to stop visiting results and return a value. + Break(B), +} + +impl Default for Control { + fn default() -> Self { + Control::Continue + } +} + +/// Trait for control flow inside query functions. +pub trait ControlFlow { + /// Constructs state indicating to continue. + fn continuing() -> Self; + /// Should return true if control flow should break. + fn should_break(&self) -> bool; +} + +impl ControlFlow for Control { + #[inline] + fn continuing() -> Self { + Control::Continue + } + + #[inline] + fn should_break(&self) -> bool { + matches!(*self, Control::Break(_)) + } +} + +impl ControlFlow for () { + #[inline] + fn continuing() -> Self {} + + #[inline] + fn should_break(&self) -> bool { + false + } +} + +impl ControlFlow for Result +where + C: ControlFlow, +{ + fn continuing() -> Self { + Ok(C::continuing()) + } + + fn should_break(&self) -> bool { + matches!(self, Err(_)) + } +} + +/// Visitor trait used to visit the results of a StaticAABB2DIndex query. +/// +/// This trait is blanket implemented for FnMut(usize) -> impl ControlFlow. +pub trait QueryVisitor +where + T: IndexableNum, + C: ControlFlow, +{ + /// Visit the index position of AABB returned by query. + fn visit(&mut self, index_pos: usize) -> C; +} + +impl QueryVisitor for F +where + T: IndexableNum, + C: ControlFlow, + F: FnMut(usize) -> C, +{ + #[inline] + fn visit(&mut self, index_pos: usize) -> C { + self(index_pos) + } +} + +/// Visitor trait used to visit the results of a StaticAABB2DIndex nearest neighbors query. +pub trait NeighborVisitor +where + T: IndexableNum, + C: ControlFlow, +{ + /// Visits the result containing the index position of the AABB neighbor and its euclidean + /// distance squared to the nearest neighbor input. + fn visit(&mut self, index_pos: usize, dist_squared: T) -> C; +} + +impl NeighborVisitor for F +where + T: IndexableNum, + C: ControlFlow, + F: FnMut(usize, T) -> C, +{ + #[inline] + fn visit(&mut self, index_pos: usize, dist_squared: T) -> C { + self(index_pos, dist_squared) + } +} + +#[macro_export] +macro_rules! try_control { + ($e:expr) => { + match $e { + x => { + if x.should_break() { + return x; + } + } + } + }; +} diff --git a/src/lib.rs b/src/lib.rs index a67ad66..a79c9e1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,10 +29,10 @@ //! assert_eq!(query_results, vec![1]); //! // the query may also be done with a visiting function that can stop the query early //! let mut visited_results: Vec = Vec::new(); -//! let mut visitor = |box_added_pos: usize| -> bool { +//! let mut visitor = |box_added_pos: usize| -> Control<()> { //! visited_results.push(box_added_pos); -//! // return true to continue visiting results, false to stop early -//! true +//! // return continue to continue visiting results, break to stop early +//! Control::Continue //! }; //! //! index.visit_query(-1.0, -1.0, -0.5, -0.5, &mut visitor); diff --git a/src/static_aabb2d_index.rs b/src/static_aabb2d_index.rs index 5ce2807..41ac51f 100644 --- a/src/static_aabb2d_index.rs +++ b/src/static_aabb2d_index.rs @@ -5,7 +5,7 @@ use std::{ collections::BinaryHeap, }; -use crate::{IndexableNum, AABB}; +use crate::{try_control, ControlFlow, IndexableNum, NeighborVisitor, QueryVisitor, AABB}; /// Error type for errors that may be returned in attempting to build the index. #[derive(Debug, PartialEq)] @@ -98,10 +98,10 @@ where /// assert_eq!(query_results, vec![1]); /// // the query may also be done with a visiting function that can stop the query early /// let mut visited_results: Vec = Vec::new(); -/// let mut visitor = |box_added_pos: usize| -> bool { +/// let mut visitor = |box_added_pos: usize| -> Control<()> { /// visited_results.push(box_added_pos); -/// // return true to continue visiting results, false to stop early -/// true +/// // return continue to continue visiting results, break to stop early +/// Control::Continue /// }; /// /// index.visit_query(-1.0, -1.0, -0.5, -0.5, &mut visitor); @@ -180,9 +180,7 @@ where let mut n = num_items; let mut num_nodes = num_items; - let mut level_bounds: Vec = Vec::new(); - - level_bounds.push(n); + let mut level_bounds: Vec = vec![n]; // calculate the total number of nodes in the R-tree to allocate space for // and the index of each tree level (level_bounds, used in search later) @@ -851,7 +849,6 @@ where let mut results = Vec::new(); let mut visitor = |i| { results.push(i); - true }; self.visit_query(min_x, min_y, max_x, max_y, &mut visitor); results @@ -900,11 +897,12 @@ where /// Same as [StaticAABB2DIndex::query] but instead of returning a collection of indexes a /// `visitor` function is called for each index that would be returned. The `visitor` returns a - /// bool indicating whether to continue visiting (true) or not (false). + /// control flow indicating whether to continue visiting or break. #[inline] - pub fn visit_query(&self, min_x: T, min_y: T, max_x: T, max_y: T, visitor: &mut F) + pub fn visit_query(&self, min_x: T, min_y: T, max_x: T, max_y: T, visitor: &mut V) where - F: FnMut(usize) -> bool, + C: ControlFlow, + V: QueryVisitor, { let mut stack: Vec = Vec::with_capacity(16); self.visit_query_with_stack(min_x, min_y, max_x, max_y, visitor, &mut stack); @@ -974,7 +972,6 @@ where let mut results = Vec::new(); let mut visitor = |i| { results.push(i); - true }; self.visit_query_with_stack(min_x, min_y, max_x, max_y, &mut visitor, stack); results @@ -983,23 +980,25 @@ where /// Same as [StaticAABB2DIndex::visit_query] but accepts an existing [Vec] to be used as a stack /// buffer when performing the query to avoid the need for allocation (this is for performance /// benefit only). - pub fn visit_query_with_stack( + pub fn visit_query_with_stack( &self, min_x: T, min_y: T, max_x: T, max_y: T, - visitor: &mut F, + visitor: &mut V, stack: &mut Vec, - ) where - F: FnMut(usize) -> bool, + ) -> C + where + C: ControlFlow, + V: QueryVisitor, { let mut node_index = self.boxes.len() - 1; let mut level = self.level_bounds.len() - 1; // ensure the stack is empty for use stack.clear(); - 'search_loop: loop { + loop { let end = min( node_index + self.node_size, *get_at_index!(self.level_bounds, level), @@ -1014,9 +1013,7 @@ where let index = *get_at_index!(self.indices, pos); if node_index < self.num_items { - if !visitor(index) { - break 'search_loop; - } + try_control!(visitor.visit(index)) } else { stack.push(index); stack.push(level - 1); @@ -1027,17 +1024,16 @@ where level = stack.pop().unwrap(); node_index = stack.pop().unwrap(); } else { - break 'search_loop; + return C::continuing(); } } } /// Visit all neighboring items in order of minimum euclidean distance to the point defined by - /// `x` and `y` until `visitor` returns false. + /// `x` and `y` until `visitor` breaks or all items have been visited. /// /// ## Notes - /// * The visitor function must return false to stop visiting items or all items will be - /// visited. + /// * The visitor function must break to stop visiting items or all items will be visited. /// * The visitor function receives the index of the item being visited and the squared /// euclidean distance to that item from the point given. /// * Because distances are squared (`dx * dx + dy * dy`) be cautious of smaller numeric types @@ -1046,9 +1042,10 @@ where /// * If repeatedly calling this method then [StaticAABB2DIndex::visit_neighbors_with_queue] can /// be used to avoid repeated allocations for the priority queue used internally. #[inline] - pub fn visit_neighbors(&self, x: T, y: T, visitor: &mut F) + pub fn visit_neighbors(&self, x: T, y: T, visitor: &mut V) where - F: FnMut(usize, T) -> bool, + C: ControlFlow, + V: NeighborVisitor, { let mut queue = NeighborPriorityQueue::new(); self.visit_neighbors_with_queue(x, y, visitor, &mut queue); @@ -1056,14 +1053,16 @@ where /// Works the same as [StaticAABB2DIndex::visit_neighbors] but accepts an existing binary heap /// to be used as a priority queue to avoid allocations. - pub fn visit_neighbors_with_queue( + pub fn visit_neighbors_with_queue( &self, x: T, y: T, - visitor: &mut F, + visitor: &mut V, queue: &mut NeighborPriorityQueue, - ) where - F: FnMut(usize, T) -> bool, + ) -> C + where + C: ControlFlow, + V: NeighborVisitor, { // small helper function to compute axis distance between point and bounding box axis fn axis_dist(k: U, min: U, max: U) -> U @@ -1082,7 +1081,7 @@ where let mut node_index = self.boxes.len() - 1; queue.clear(); - 'search_loop: loop { + loop { let upper_bound_level_index = match self.level_bounds.binary_search(&node_index) { // level bound found, add one to get upper bound Ok(i) => i + 1, @@ -1113,10 +1112,7 @@ where while let Some(state) = queue.pop() { if state.is_leaf_node { // visit leaf node - if !visitor(state.index, state.dist) { - // stop visiting if visitor returns false - break 'search_loop; - } + try_control!(visitor.visit(state.index, state.dist)) } else { // update node index for next iteration node_index = state.index; @@ -1127,7 +1123,7 @@ where } if !continue_search { - break 'search_loop; + return C::continuing(); } } } diff --git a/tests/test.rs b/tests/test.rs index 0450917..b0e90b9 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -284,7 +284,6 @@ fn visit_query() { let mut results = Vec::new(); let mut visitor = |i| { results.push(i); - true }; index.visit_query(40, 40, 60, 60, &mut visitor); @@ -300,7 +299,6 @@ fn visit_query_with_many_levels() { let mut results = Vec::new(); let mut visitor = |i| { results.push(i); - true }; index.visit_query(40, 40, 60, 60, &mut visitor); @@ -318,7 +316,6 @@ fn visit_query_with_stack() { let mut results = Vec::new(); let mut visitor = |i| { results.push(i); - true }; index.visit_query_with_stack(40, 40, 60, 60, &mut visitor, &mut stack); @@ -336,7 +333,6 @@ fn visit_query_with_stack_with_many_levels() { let mut results = Vec::new(); let mut visitor = |i| { results.push(i); - true }; index.visit_query_with_stack(40, 40, 60, 60, &mut visitor, &mut stack); @@ -352,7 +348,11 @@ fn visit_query_stops_early() { let mut results = HashSet::new(); let mut visitor = |i| { results.insert(i); - results.len() != 2 + if results.len() != 2 { + Control::Continue + } else { + Control::Break(()) + } }; index.visit_query(40, 40, 60, 60, &mut visitor); @@ -368,7 +368,11 @@ fn visit_neighbors_max_results() { let max_results = 3; let mut visitor = |i, _| { results.push(i); - results.len() < max_results + if results.len() < max_results { + Control::Continue + } else { + Control::Break(()) + } }; index.visit_neighbors(50, 50, &mut visitor); @@ -386,9 +390,9 @@ fn visit_neighbors_max_distance() { let mut visitor = |i, d| { if (d as f64) < max_distance_squared { results.push(i); - return true; + return Control::Continue; } - false + Control::Break(()) }; index.visit_neighbors(50, 50, &mut visitor); @@ -406,9 +410,13 @@ fn visit_neighbors_max_results_filtered() { // filtering by only collecting indexes which are even if i % 2 == 0 { results.push(i); - return results.len() < max_results; + if results.len() < max_results { + return Control::Continue; + } + + return Control::Break(()); } - true + Control::Continue }; index.visit_neighbors(50, 50, &mut visitor); @@ -423,8 +431,6 @@ fn visit_neighbors_all_items() { let mut results = Vec::new(); let mut visitor = |i, _| { results.push(i); - // visit all items by always returning true - true }; index.visit_neighbors(50, 50, &mut visitor);