Skip to content

Commit

Permalink
update adaptive prefix trie
Browse files Browse the repository at this point in the history
  • Loading branch information
bastiscode committed Jan 30, 2024
1 parent f046099 commit dcbf45a
Show file tree
Hide file tree
Showing 4 changed files with 261 additions and 47 deletions.
71 changes: 70 additions & 1 deletion text-utils-prefix/benches/benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use criterion::{criterion_group, criterion_main, Criterion};
use rand::seq::SliceRandom;
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use text_utils_prefix::adaptive_radix_trie::AdaptiveRadixTrie;
use text_utils_prefix::{optimized_continuations, ContinuationSearch, PrefixSearch};
use text_utils_prefix::{patricia_trie::PatriciaTrie, trie::Trie};

Expand Down Expand Up @@ -123,12 +124,80 @@ fn bench_prefix(c: &mut Criterion) {
.collect();
group.bench_with_input(
"patricia_trie_continuations_batch_optimized_parallel",
&(inputs, continuations),
&(&inputs, &continuations),
|b, input| {
let (words, continuations) = input;
b.iter(|| {
trie.batch_contains_continuations_optimized_parallel(
*words,
continuations,
&permutation,
&skips,
)
});
},
);

// benchmark adaptive radix tree
let mut trie: AdaptiveRadixTrie<_> = words.iter().zip(0..words.len()).collect();
group.bench_with_input("adaptive_radix_trie_insert", word, |b, input| {
b.iter(|| trie.insert(input, 1));
});
group.bench_with_input("adaptive_radix_trie_get", word, |b, input| {
b.iter(|| trie.get(input));
});
group.bench_with_input("adaptive_radix_trie_contains", word, |b, input| {
b.iter(|| trie.contains_prefix(&input[..input.len().saturating_sub(3)]));
});
group.bench_with_input(
"adaptive_radix_trie_continuations",
&("Albert", &continuations),
|b, input| {
let (word, continuations) = input;
b.iter(|| trie.contains_continuations(&word, &continuations));
},
);
group.bench_with_input(
"adaptive_radix_trie_continuations_optimized",
&("Albert", &continuations),
|b, input| {
let (word, continuations) = input;
b.iter(|| {
trie.contains_continuations_optimized(&word, &continuations, &permutation, &skips)
});
},
);
group.bench_with_input(
"adaptive_radix_trie_continuations_batch",
&(["Albert"; 64], &continuations),
|b, input| {
let (words, continuations) = input;
b.iter(|| trie.batch_contains_continuations(words, &continuations));
},
);
group.bench_with_input(
"adaptive_radix_trie_continuations_batch_optimized",
&(["Albert"; 64], &continuations),
|b, input| {
let (words, continuations) = input;
b.iter(|| {
trie.batch_contains_continuations_optimized(
words,
&continuations,
&permutation,
&skips,
)
});
},
);
group.bench_with_input(
"adaptive_radix_trie_continuations_batch_optimized_parallel",
&(&inputs, &continuations),
|b, input| {
let (words, continuations) = input;
b.iter(|| {
trie.batch_contains_continuations_optimized_parallel(
*words,
continuations,
&permutation,
&skips,
Expand Down
100 changes: 84 additions & 16 deletions text-utils-prefix/src/adaptive_radix_trie.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::iter::{empty, once};
use std::{
collections::HashMap,
iter::{empty, once},
};

use crate::{ContinuationSearch, PrefixSearch};

Expand All @@ -25,21 +28,61 @@ struct Node<V> {
#[derive(Debug)]
pub struct AdaptiveRadixTrie<V> {
root: Option<Node<V>>,
num_keys: usize,
}

#[derive(Debug)]
pub struct AdaptiveRadixTrieStats {
pub depth: usize,
pub num_nodes: usize,
pub num_keys: usize,
pub node_info: HashMap<String, (usize, f32)>,
}

impl<V> AdaptiveRadixTrie<V> {
pub fn size(&self) -> usize {
self.num_keys
pub fn stats(&self) -> AdaptiveRadixTrieStats {
let mut dist = HashMap::from_iter(
["leaf", "n4", "n16", "n48", "n256"]
.iter()
.map(|&s| (s.to_string(), (0, 0.0))),
);
let Some(root) = &self.root else {
return AdaptiveRadixTrieStats {
depth: 0,
num_nodes: 0,
num_keys: 0,
node_info: dist,
};
};
let mut stack = vec![(root, 0)];
let mut max_depth = 0;
while let Some((node, depth)) = stack.pop() {
max_depth = max_depth.max(depth);
let name = match &node.inner {
NodeType::Empty => unreachable!("should not happen"),
NodeType::Leaf(_) => "leaf",
NodeType::N4(..) => "n4",
NodeType::N16(..) => "n16",
NodeType::N48(..) => "n48",
NodeType::N256(..) => "n256",
};
let val = dist.get_mut(name).unwrap();
val.0 += 1;
let n = val.0 as f32;
val.1 = (val.1 * (n - 1.0) + node.prefix.len() as f32) / n;
stack.extend(node.children().map(|child| (child, depth + 1)));
}
AdaptiveRadixTrieStats {
depth: max_depth,
num_nodes: dist.iter().map(|(_, (n, _))| n).sum(),
num_keys: dist["leaf"].0,
node_info: dist,
}
}
}

impl<V> Default for AdaptiveRadixTrie<V> {
fn default() -> Self {
Self {
root: None,
num_keys: 0,
}
Self { root: None }
}
}

Expand Down Expand Up @@ -137,10 +180,35 @@ impl<V> Node<V> {
self.find_child(key).is_some()
}

fn children(&self) -> Box<dyn Iterator<Item = &Self> + '_> {
match &self.inner {
NodeType::Empty | NodeType::Leaf(_) => Box::new(empty()),
NodeType::N4(_, children, num_children) => Box::new(
children[..*num_children]
.iter()
.filter_map(|child| child.as_deref()),
),
NodeType::N16(_, children, num_children) => Box::new(
children[..*num_children]
.iter()
.filter_map(|child| child.as_deref()),
),
NodeType::N48(_, children, num_children) => Box::new(
children[..*num_children]
.iter()
.filter_map(|child| child.as_deref()),
),
NodeType::N256(children, _) => {
Box::new(children.iter().filter_map(|child| child.as_deref()))
}
}
}

#[inline]
fn set_child(&mut self, key: u8, child: Self) {
// potentially upgrade the current node before insertion, will change
// nothing if the node does not need to be upgraded
assert!(self.find_child(key).is_none());
self.upgrade();
match &mut self.inner {
NodeType::Empty | NodeType::Leaf(_) => unreachable!("should not happen"),
Expand Down Expand Up @@ -206,7 +274,7 @@ impl<V> Node<V> {
Matching::FullKey(n) => return Some((node, n)),
Matching::Exact => return Some((node, node.prefix.len())),
Matching::FullPrefix(k) => k,
Matching::Partial(_, _) => break,
Matching::Partial(..) => break,
};

let Some(child) = node.find_child(k) else {
Expand Down Expand Up @@ -326,6 +394,7 @@ impl<V> Node<V> {
std::array::from_fn(|i| {
let idx = index[i];
if idx < 48 {
assert!(children[idx as usize].is_some());
std::mem::take(&mut children[idx as usize])
} else {
None
Expand Down Expand Up @@ -355,7 +424,6 @@ impl<V> PrefixSearch for AdaptiveRadixTrie<V> {
let Some(root) = &mut self.root else {
// insert leaf at root
self.root = Some(Node::new_leaf(key.collect(), value));
self.num_keys += 1;
return;
};
let mut node = root;
Expand Down Expand Up @@ -404,7 +472,6 @@ impl<V> PrefixSearch for AdaptiveRadixTrie<V> {
}
break;
}
self.num_keys += 1;
}

fn delete<K>(&mut self, key: K) -> Option<V>
Expand All @@ -421,7 +488,6 @@ impl<V> PrefixSearch for AdaptiveRadixTrie<V> {
unreachable!("should not happen");
};
self.root = None;
self.num_keys -= 1;
return Some(value);
}

Expand Down Expand Up @@ -479,7 +545,6 @@ impl<V> PrefixSearch for AdaptiveRadixTrie<V> {
// node.prefix = new_prefix.into_boxed_slice();
// node.inner = single_child.inner;
// }
// self.num_keys -= 1;
// return Some(value);
}
None
Expand Down Expand Up @@ -624,7 +689,7 @@ mod test {
// trie.insert(b"hello", 1);
trie.insert(b"hell", 2);
trie.insert(b"hello world", 3);
println!("{:#?}", trie);
// println!("{:#?}", trie);
assert_eq!(trie.get(b"hello"), Some(&1));
assert_eq!(trie.get(b"hell"), Some(&2));
assert_eq!(trie.get(b"hello world"), Some(&3));
Expand All @@ -645,13 +710,15 @@ mod test {

let mut trie: AdaptiveRadixTrie<_> =
words.iter().enumerate().map(|(i, w)| (w, i)).collect();
assert_eq!(trie.size(), n);
let stats = trie.stats();
assert_eq!(stats.num_keys, n);
for (i, word) in words.iter().enumerate() {
assert_eq!(trie.get(word), Some(&i));
for j in 0..word.len() {
assert!(trie.contains_prefix(&word[..=j]));
}
}
println!("{:#?}", trie.stats());
// for (i, word) in words.iter().enumerate() {
// let even = i % 2 == 0;
// if even {
Expand All @@ -661,6 +728,7 @@ mod test {
// assert_eq!(trie.get(word), Some(&i));
// }
// }
// assert_eq!(trie.size(), n / 2);
// let stats = trie.stats();
// assert_eq!(stats.num_keys, n / 2);
}
}
Loading

0 comments on commit dcbf45a

Please sign in to comment.