From 97c8144dba642331309cc912b94dbca5d3828678 Mon Sep 17 00:00:00 2001 From: Sebastian Walter Date: Sun, 28 Jan 2024 20:49:00 +0100 Subject: [PATCH] add prefix subcrate --- text-utils-prefix/Cargo.toml | 23 + text-utils-prefix/benches/benchmark.rs | 116 +++++ text-utils-prefix/src/adaptive_radix_trie.rs | 188 +++++++ text-utils-prefix/src/lib.rs | 62 +++ text-utils-prefix/src/patricia_trie.rs | 495 +++++++++++++++++++ text-utils-prefix/src/trie.rs | 151 ++++++ 6 files changed, 1035 insertions(+) create mode 100644 text-utils-prefix/Cargo.toml create mode 100644 text-utils-prefix/benches/benchmark.rs create mode 100644 text-utils-prefix/src/adaptive_radix_trie.rs create mode 100644 text-utils-prefix/src/lib.rs create mode 100644 text-utils-prefix/src/patricia_trie.rs create mode 100644 text-utils-prefix/src/trie.rs diff --git a/text-utils-prefix/Cargo.toml b/text-utils-prefix/Cargo.toml new file mode 100644 index 0000000..ad88f29 --- /dev/null +++ b/text-utils-prefix/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "text-utils-prefix" +version = "0.1.0" +edition = "2021" + +[dependencies] + +[dev-dependencies] +criterion = "0.5" +art-tree = "0.2.0" +patricia_tree = "0.8.0" +rand = "0.8" +rand_distr = "0.4" +rand_chacha = "0.3" + +[profile.release] +lto = true +codegen-units = 1 +strip = true + +[[bench]] +name = "benchmark" +harness = false diff --git a/text-utils-prefix/benches/benchmark.rs b/text-utils-prefix/benches/benchmark.rs new file mode 100644 index 0000000..f04ea4b --- /dev/null +++ b/text-utils-prefix/benches/benchmark.rs @@ -0,0 +1,116 @@ +use std::fs; +use std::path::PathBuf; + +use criterion::{criterion_group, criterion_main, Criterion}; +use rand::seq::SliceRandom; +use rand::SeedableRng; +use rand_chacha::ChaCha8Rng; +use text_utils_prefix::{patricia_trie::PatriciaTrie, trie::Trie}; +use text_utils_prefix::{ContinuationSearch, PrefixSearch}; + +use art_tree::{Art, ByteString}; +use patricia_tree::PatriciaMap; + +const ASCII_LETTERS: &[u8; 52] = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; + +fn bench_prefix(c: &mut Criterion) { + let dir = env!("CARGO_MANIFEST_DIR"); + let index = fs::read_to_string(PathBuf::from(dir).join("resources/test/index.txt")) + .expect("failed to read file"); + let words: Vec<_> = index.lines().map(|s| s.as_bytes()).take(100_000).collect(); + let mut rng = ChaCha8Rng::seed_from_u64(22); + // sample random word from all words + let word = *words.choose(&mut rng).unwrap(); + println!("choose word {}", String::from_utf8_lossy(word)); + let mut group = c.benchmark_group("prefix_search"); + let continuations: Vec<_> = ASCII_LETTERS.iter().map(|&c| [c]).collect(); + + // benchmark art-tree + let mut trie: Art<_, _> = Art::new(); + for (i, word) in words.iter().enumerate() { + trie.insert(ByteString::new(word), i); + } + group.bench_with_input("art_tree_insert", word, |b, input| { + b.iter(|| trie.insert(ByteString::new(input), 1)); + }); + group.bench_with_input("art_tree_get", word, |b, input| { + b.iter(|| trie.get(&ByteString::new(input))); + }); + + // benchmark patricia_tree + let mut trie: PatriciaMap<_> = PatriciaMap::new(); + for (i, word) in words.iter().enumerate() { + trie.insert(word, i); + } + group.bench_with_input("patricia_tree_insert", word, |b, input| { + b.iter(|| trie.insert(input, 1)); + }); + group.bench_with_input("patricia_tree_get", word, |b, input| { + b.iter(|| trie.get(input)); + }); + + // benchmark prefix tries + let mut trie: Trie<_> = words.iter().zip(0..words.len()).collect(); + group.bench_with_input("trie_insert", word, |b, input| { + b.iter(|| trie.insert(input, 1)); + }); + group.bench_with_input("trie_get", word, |b, input| { + b.iter(|| trie.get(input)); + }); + group.bench_with_input("trie_contains", word, |b, input| { + b.iter(|| trie.contains_prefix(&input[..input.len().saturating_sub(3)])); + }); + + // benchmark patricia trie + let mut trie: PatriciaTrie<_> = words.iter().zip(0..words.len()).collect(); + group.bench_with_input("patricia_trie_insert", word, |b, input| { + b.iter(|| trie.insert(input, 1)); + }); + group.bench_with_input("patricia_trie_get", word, |b, input| { + b.iter(|| trie.get(input)); + }); + group.bench_with_input("patricia_trie_contains", word, |b, input| { + b.iter(|| trie.contains_prefix(&input[..input.len().saturating_sub(3)])); + }); + let conts = trie.contains_continuations("Albert", &continuations); + assert_eq!( + conts.iter().map(|&b| if b { 1 } else { 0 }).sum::(), + 4 + ); + group.bench_with_input( + "patricia_trie_continuations", + &("Albert", &continuations), + |b, input| { + let (word, continuations) = input; + b.iter(|| trie.contains_continuations(&word, &continuations)); + }, + ); + group.bench_with_input( + "patricia_trie_batch_continuations", + &(["Albert"; 64], &continuations), + |b, input| { + let (words, continuations) = input; + b.iter(|| trie.batch_contains_continuations(words, &continuations)); + }, + ); + + // benchmark build, load, and save + drop(group); + let mut group = c.benchmark_group("prefix_io"); + let n = 10_000; + + // let trie: RadixTrie<_> = words.iter().zip(0..words.len()).take(n).collect(); + // let path = PathBuf::from(dir).join("resources/test/byte_trie.bin"); + // group.bench_with_input("byte_trie_build", &words, |b, input| { + // b.iter(|| { + // input + // .iter() + // .zip(0..input.len()) + // .take(n) + // .collect::>() + // }); + // }); +} + +criterion_group!(benches, bench_prefix); +criterion_main!(benches); diff --git a/text-utils-prefix/src/adaptive_radix_trie.rs b/text-utils-prefix/src/adaptive_radix_trie.rs new file mode 100644 index 0000000..60b2c2b --- /dev/null +++ b/text-utils-prefix/src/adaptive_radix_trie.rs @@ -0,0 +1,188 @@ +use std::iter::once; + +use crate::PrefixSearch; + +type Index = [u8; N]; +type Children = [Option>>; N]; + +enum NodeType { + Leaf(V), + N4(Index<4>, Children, usize), + N16(Index<16>, Children, usize), + N48(Box>, Children, usize), + N256(Children, usize), +} + +struct Node { + prefix: Box<[u8]>, + inner: NodeType, +} + +pub struct AdaptiveRadixTrie { + root: Option>, +} + +impl Default for AdaptiveRadixTrie { + fn default() -> Self { + Self { root: None } + } +} + +impl FromIterator<(K, V)> for AdaptiveRadixTrie +where + K: AsRef<[u8]>, +{ + fn from_iter>(iter: T) -> Self { + let mut trie = Self::default(); + for (k, v) in iter { + trie.insert(k, v); + } + trie + } +} + +enum Matching { + FullKey(usize), + FullNode, + Partial(usize, u8), +} + +impl Node { + #[inline] + fn is_leaf(&self) -> bool { + matches!(self.inner, NodeType::Leaf(_)) + } + + #[inline] + fn advance_key<'a>(&self, key: &mut impl Iterator) -> Matching { + let mut i = 0; + while i < self.prefix.len() { + let Some(k) = key.next() else { + return Matching::FullKey(i); + }; + if k != &self.prefix[i] { + return Matching::Partial(i, *k); + } + i += 1; + } + Matching::FullNode + } + + #[inline] + fn exact_match<'a>(&self, key: &mut impl Iterator) -> bool { + let mut i = 0; + while i < self.prefix.len() { + let Some(k) = key.next() else { + return false; + }; + if k != &self.prefix[i] { + return false; + } + i += 1; + } + // we have to be at the end of the key for an exact match + key.next().is_none() + } + + #[inline] + fn find(&self, key: &[u8]) -> Option<&Self> { + let mut node = self; + // extend given key with null byte + // because its needed for the correctness of the algorithm + // when it comes to key lookup + let mut key = key.iter().chain(once(&0)); + loop { + if node.is_leaf() { + if self.exact_match(&mut key) { + return Some(node); + } + break; + } + + let Matching::FullNode = self.advance_key(&mut key) else { + // if we have not a full node match, + // we can return early + return None; + }; + + let k = key.next()?; + let Some(child) = node.find_child(k) else { + break; + }; + node = child; + } + None + } + + #[inline] + fn find_child(&self, key: &u8) -> Option<&Self> { + match &self.inner { + NodeType::Leaf(_) => None, + NodeType::N4(keys, children, num_children) => { + for i in 0..*num_children { + if &keys[i] == key { + return children[i].as_deref(); + } + } + None + } + NodeType::N16(keys, children, num_children) => { + let idx = keys[..*num_children].binary_search(key).ok()?; + children[idx].as_deref() + } + NodeType::N48(keys, children, _) => { + children.get(keys[*key as usize] as usize)?.as_deref() + } + NodeType::N256(children, _) => children[*key as usize].as_deref(), + } + } + + fn upgrade(self) -> Result { + todo!() + } + + fn downgrade(self) -> Result { + todo!() + } +} + +impl PrefixSearch for AdaptiveRadixTrie { + fn insert(&mut self, key: K, value: V) + where + K: AsRef<[u8]>, + { + todo!() + } + + fn delete(&mut self, key: K) -> Option + where + K: AsRef<[u8]>, + { + todo!() + } + + fn get(&self, key: K) -> Option<&V> + where + K: AsRef<[u8]>, + { + let Some(root) = &self.root else { + return None; + }; + + root.find(key.as_ref()).and_then(|node| match &node.inner { + NodeType::Leaf(v) => Some(v), + _ => None, + }) + } + + fn contains_prefix

(&self, prefix: P) -> bool + where + P: AsRef<[u8]>, + { + let Some(root) = &self.root else { + return false; + }; + + todo!(); + } +} diff --git a/text-utils-prefix/src/lib.rs b/text-utils-prefix/src/lib.rs new file mode 100644 index 0000000..57a73cd --- /dev/null +++ b/text-utils-prefix/src/lib.rs @@ -0,0 +1,62 @@ +pub mod adaptive_radix_trie; +pub mod patricia_trie; +pub mod trie; + +pub trait PrefixSearch { + fn insert(&mut self, key: K, value: V) + where + K: AsRef<[u8]>; + + fn delete(&mut self, key: K) -> Option + where + K: AsRef<[u8]>; + + fn get(&self, key: K) -> Option<&V> + where + K: AsRef<[u8]>; + + fn contains_prefix

(&self, prefix: P) -> bool + where + P: AsRef<[u8]>; +} + +pub trait ContinuationSearch: PrefixSearch { + fn continuations<'a, P>(&'a self, prefix: P) -> impl Iterator, &'a V)> + where + P: AsRef<[u8]>, + V: 'a; + + fn contains_continuation(&self, prefix: P, continuation: C) -> bool + where + P: AsRef<[u8]>, + C: AsRef<[u8]>; + + fn contains_continuations(&self, prefix: P, continuations: &[C]) -> Vec + where + P: AsRef<[u8]>, + C: AsRef<[u8]>, + { + // default naive implementation, should be overridden if there is a more efficient way + continuations + .iter() + .map(|c| self.contains_continuation(prefix.as_ref(), c.as_ref())) + .collect() + } + + fn batch_contains_continuations( + &self, + prefixes: &[P], + continuations: &[C], + ) -> Vec> + where + P: AsRef<[u8]>, + C: AsRef<[u8]>, + Self: Sync, + { + // default naive implementation, should be overridden if there is a more efficient way + prefixes + .iter() + .map(|p| self.contains_continuations(p, continuations)) + .collect() + } +} diff --git a/text-utils-prefix/src/patricia_trie.rs b/text-utils-prefix/src/patricia_trie.rs new file mode 100644 index 0000000..ffa2cef --- /dev/null +++ b/text-utils-prefix/src/patricia_trie.rs @@ -0,0 +1,495 @@ +use std::{ + fmt::Debug, + iter::{empty, once}, +}; + +use crate::{ContinuationSearch, PrefixSearch}; + +#[derive(Default)] +enum NodeType { + #[default] + Empty, + Leaf(V), + Inner([Option>>; 256]), +} + +impl Debug for NodeType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Empty => f.debug_tuple("Empty").finish(), + Self::Leaf(value) => f.debug_tuple("Leaf").field(value).finish(), + Self::Inner(children) => f + .debug_tuple("Inner") + .field( + &children + .iter() + .enumerate() + .filter_map(|(i, c)| if c.is_some() { Some((i, c)) } else { None }) + .collect::>(), + ) + .finish(), + } + } +} + +#[derive(Debug)] +struct Node { + prefix: Box<[u8]>, + inner: NodeType, +} + +#[derive(Debug)] +pub struct PatriciaTrie { + root: Option>, + num_keys: usize, +} + +impl PatriciaTrie { + pub fn size(&self) -> usize { + self.num_keys + } +} + +impl Default for PatriciaTrie { + fn default() -> Self { + Self { + root: None, + num_keys: 0, + } + } +} + +impl FromIterator<(K, V)> for PatriciaTrie +where + K: AsRef<[u8]>, +{ + fn from_iter>(iter: T) -> Self { + let mut trie = Self::default(); + for (k, v) in iter { + trie.insert(k, v); + } + trie + } +} + +enum Matching { + FullKey(usize), + FullPrefix(u8), + Exact, + Partial(usize, u8), +} + +impl Node { + fn new_leaf(prefix: Vec, value: V) -> Self { + Self { + prefix: prefix.into_boxed_slice(), + inner: NodeType::Leaf(value), + } + } + + fn new_inner(prefix: Vec) -> Self { + Self { + prefix: prefix.into_boxed_slice(), + inner: NodeType::Inner(std::array::from_fn(|_| None)), + } + } + + #[inline] + fn is_leaf(&self) -> bool { + matches!(self.inner, NodeType::Leaf(_)) + } + + #[inline] + fn is_inner(&self) -> bool { + !self.is_leaf() + } + + #[inline] + fn matching(&self, key: &mut impl Iterator, offset: usize) -> Matching { + let mut i = offset; + while i < self.prefix.len() { + let Some(k) = key.next() else { + return Matching::FullKey(i); + }; + if k != self.prefix[i] { + return Matching::Partial(i, k); + } + i += 1; + } + if let Some(k) = key.next() { + Matching::FullPrefix(k) + } else { + Matching::Exact + } + } + + #[inline] + fn find_iter(&self, mut key: impl Iterator) -> Option<&Self> { + let mut node = self; + loop { + if node.is_leaf() { + if let Matching::Exact = node.matching(&mut key, 0) { + return Some(node); + } + break; + } + + let Matching::FullPrefix(k) = node.matching(&mut key, 0) else { + // if we dont match the full node prefix, + // we can return early + return None; + }; + + let Some(child) = node.find_child(k) else { + break; + }; + node = child; + } + None + } + + #[inline] + fn has_child(&self, key: u8) -> bool { + match &self.inner { + NodeType::Empty | NodeType::Leaf(_) => false, + NodeType::Inner(children) => children[key as usize].is_some(), + } + } + + #[inline] + fn find_child(&self, key: u8) -> Option<&Self> { + match &self.inner { + NodeType::Empty | NodeType::Leaf(_) => None, + NodeType::Inner(children) => children[key as usize].as_deref(), + } + } + + #[inline] + fn find_child_mut(&mut self, key: u8) -> Option<&mut Self> { + match &mut self.inner { + NodeType::Empty | NodeType::Leaf(_) => None, + NodeType::Inner(children) => children[key as usize].as_deref_mut(), + } + } + + #[inline] + fn set_child(&mut self, key: u8, child: Self) -> Result<(), Self> { + let NodeType::Inner(children) = &mut self.inner else { + return Err(child); + }; + let pos = &mut children[key as usize]; + if pos.is_some() { + return Err(child); + } + *pos = Some(Box::new(child)); + Ok(()) + } + + #[inline] + fn contains_prefix_iter( + &self, + mut key: impl Iterator, + offset: usize, + ) -> Option<(&Self, usize)> { + let mut node = self; + // extend given key with null byte + // because its needed for the correctness of the algorithm + // when it comes to key lookup + loop { + let k = match node.matching(&mut key, offset) { + Matching::FullKey(n) => return Some((node, n)), + Matching::Exact => return Some((node, node.prefix.len())), + Matching::FullPrefix(k) => k, + Matching::Partial(_, _) => break, + }; + + let Some(child) = node.find_child(k) else { + break; + }; + node = child; + } + None + } +} + +impl PrefixSearch for PatriciaTrie { + fn insert(&mut self, key: K, value: V) + where + K: AsRef<[u8]>, + { + let mut key = key.as_ref().iter().copied().chain(once(0)); + // empty tree + let Some(root) = &mut self.root else { + // insert leaf at root + self.root = Some(Node::new_leaf(key.collect(), value)); + self.num_keys += 1; + return; + }; + let mut node = root; + loop { + let matching = node.matching(&mut key, 0); + if node.is_leaf() { + let (inner_prefix, new_prefix, n, k) = match matching { + Matching::FullKey(_) => unreachable!("should not happen"), + Matching::FullPrefix(_) => unreachable!("should not happen"), + Matching::Partial(n, k) => ( + node.prefix[..n].to_vec(), + node.prefix[n + 1..].to_vec(), + n, + k, + ), + Matching::Exact => { + // exact match, only replace leaf value + node.inner = NodeType::Leaf(value); + return; + } + }; + let mut inner = Node::new_inner(inner_prefix); + let NodeType::Leaf(node_value) = std::mem::take(&mut node.inner) else { + unreachable!("should not happen"); + }; + inner + .set_child(node.prefix[n], Node::new_leaf(new_prefix, node_value)) + .expect("should not happen"); + inner + .set_child(k, Node::new_leaf(key.collect(), value)) + .expect("should not happen"); + *node = inner; + break; + } else if let Matching::FullPrefix(k) = matching { + // full prefix match, either go to next child + // or append leaf with rest of key + if node.has_child(k) { + node = node.find_child_mut(k).expect("should not happen"); + continue; + } + node.set_child(k, Node::new_leaf(key.collect(), value)) + .expect("should not happen"); + } else if let Matching::Partial(n, k) = matching { + // partial prefix match, introduce new inner node + let mut inner = Node::new_inner(node.prefix[..n].to_vec()); + let mut new_node = Node::new_inner(node.prefix[n + 1..].to_vec()); + new_node.inner = std::mem::take(&mut node.inner); + inner + .set_child(node.prefix[n], new_node) + .expect("should not happen"); + inner + .set_child(k, Node::new_leaf(key.collect(), value)) + .expect("should not happen"); + *node = inner; + } + break; + } + self.num_keys += 1; + } + + fn delete(&mut self, key: K) -> Option + where + K: AsRef<[u8]>, + { + let Some(root) = &mut self.root else { + return None; + }; + + // handle special case where root is leaf + if root.is_leaf() { + let NodeType::Leaf(value) = std::mem::take(&mut root.inner) else { + unreachable!("should not happen"); + }; + self.root = None; + self.num_keys -= 1; + return Some(value); + } + + let mut node = root; + let mut key = key.as_ref().iter().copied().chain(once(0)); + loop { + let matching = node.matching(&mut key, 0); + + let Matching::FullPrefix(k) = matching else { + // on inner nodes we always need full prefix matching + return None; + }; + + // return if we dont find a child + let child = node.find_child(k)?; + + // traverse down if child is inner + if child.is_inner() { + node = node.find_child_mut(k)?; + continue; + } + + // handle case if child is leaf + let Matching::Exact = child.matching(&mut key, 0) else { + break; + }; + // key is an exact match for a leaf + let NodeType::Inner(children) = &mut node.inner else { + unreachable!("should not happen"); + }; + let child = std::mem::take(&mut children[k as usize])?; + let NodeType::Leaf(value) = child.inner else { + unreachable!("should not happen"); + }; + let child_indices: Vec<_> = children + .iter() + .enumerate() + .filter_map(|(i, child)| child.as_ref().map(|_| i)) + .collect(); + assert!(!child_indices.is_empty()); + if child_indices.len() == 1 { + // if we only have one child left, we can merge + // the child into the current node + let single_child_k = child_indices.into_iter().next().unwrap(); + let single_child = std::mem::take(&mut children[single_child_k])?; + let new_prefix: Vec<_> = node + .prefix + .iter() + .copied() + .chain(once(single_child_k as u8)) + .chain(single_child.prefix.iter().copied()) + .collect(); + node.prefix = new_prefix.into_boxed_slice(); + node.inner = single_child.inner; + } + self.num_keys -= 1; + return Some(value); + } + None + } + + fn get(&self, key: K) -> Option<&V> + where + K: AsRef<[u8]>, + { + let Some(root) = &self.root else { + return None; + }; + + let key = key.as_ref().iter().copied().chain(once(0)); + root.find_iter(key).and_then(|node| match &node.inner { + NodeType::Leaf(v) => Some(v), + _ => None, + }) + } + + fn contains_prefix

(&self, prefix: P) -> bool + where + P: AsRef<[u8]>, + { + let Some(root) = &self.root else { + return false; + }; + + let key = prefix.as_ref().iter().copied(); + root.contains_prefix_iter(key, 0).is_some() + } +} + +impl ContinuationSearch for PatriciaTrie { + fn continuations<'a, P>(&'a self, prefix: P) -> impl Iterator, &'a V)> + where + P: AsRef<[u8]>, + V: 'a, + { + empty() + } + + fn contains_continuation(&self, prefix: P, continuation: C) -> bool + where + P: AsRef<[u8]>, + C: AsRef<[u8]>, + { + let Some(root) = &self.root else { + return false; + }; + + let key = prefix + .as_ref() + .iter() + .chain(continuation.as_ref().iter()) + .copied(); + root.contains_prefix_iter(key, 0).is_some() + } + + fn contains_continuations(&self, prefix: P, continuations: &[C]) -> Vec + where + P: AsRef<[u8]>, + C: AsRef<[u8]>, + { + let Some(root) = &self.root else { + return vec![false; continuations.len()]; + }; + + let key = prefix.as_ref().iter().copied(); + let Some((node, n)) = root.contains_prefix_iter(key, 0) else { + return vec![false; continuations.len()]; + }; + + continuations + .iter() + .map(|c| { + let key = c.as_ref().iter().copied(); + node.contains_prefix_iter(key, n).is_some() + }) + .collect() + } +} + +#[cfg(test)] +mod test { + use crate::{patricia_trie::PatriciaTrie, PrefixSearch}; + use std::fs; + use std::path::PathBuf; + + #[test] + fn test_trie() { + let mut trie = PatriciaTrie::default(); + assert_eq!(trie.get(b"hello"), None); + assert_eq!(trie.get(b""), None); + assert!(!trie.contains_prefix(b"")); + trie.insert(b"hello", 1); + assert_eq!(trie.delete(b"hello"), Some(1)); + assert_eq!(trie.delete(b"hello "), None); + trie.insert(b"hello", 1); + trie.insert(b"hell", 2); + trie.insert(b"hello world", 3); + assert_eq!(trie.get(b"hello"), Some(&1)); + assert_eq!(trie.get(b"hell"), Some(&2)); + assert_eq!(trie.get(b"hello world"), Some(&3)); + assert_eq!(trie.contains_prefix(b"hell"), true); + assert_eq!(trie.contains_prefix(b"hello"), true); + assert_eq!(trie.contains_prefix(b""), true); + assert_eq!(trie.contains_prefix(b"hello world!"), false); + assert_eq!(trie.contains_prefix(b"test"), false); + assert_eq!(trie.delete(b"hello"), Some(1)); + assert_eq!(trie.get(b"hello"), None); + assert_eq!(trie.size(), 2); + + let dir = env!("CARGO_MANIFEST_DIR"); + let index = fs::read_to_string(PathBuf::from(dir).join("resources/test/index.txt")) + .expect("failed to read file"); + let N = 100_000; + let words: Vec<_> = index.lines().map(|s| s.as_bytes()).take(N).collect(); + + let mut trie: PatriciaTrie<_> = words.iter().enumerate().map(|(i, w)| (w, i)).collect(); + assert_eq!(trie.size(), N); + for (i, word) in words.iter().enumerate() { + assert_eq!(trie.get(word), Some(&i)); + for j in 0..word.len() { + assert!(trie.contains_prefix(&word[..=j])); + } + } + for (i, word) in words.iter().enumerate() { + let even = i % 2 == 0; + if even { + assert_eq!(trie.delete(word), Some(i)); + assert_eq!(trie.get(word), None); + } else { + assert_eq!(trie.get(word), Some(&i)); + } + } + assert_eq!(trie.size(), N / 2); + } +} diff --git a/text-utils-prefix/src/trie.rs b/text-utils-prefix/src/trie.rs new file mode 100644 index 0000000..bcb9b88 --- /dev/null +++ b/text-utils-prefix/src/trie.rs @@ -0,0 +1,151 @@ +use crate::PrefixSearch; + +struct Node { + value: Option, + children: [Option>>; 256], +} + +impl Default for Node { + fn default() -> Self { + Self { + value: None, + children: std::array::from_fn(|_| None), + } + } +} + +pub struct Trie { + root: Option>, +} + +impl Default for Trie { + fn default() -> Self { + Self { root: None } + } +} + +impl FromIterator<(K, V)> for Trie +where + K: AsRef<[u8]>, +{ + fn from_iter>(iter: T) -> Self { + let mut trie = Self::default(); + for (k, v) in iter { + trie.insert(k, v); + } + trie + } +} + +impl Node { + #[inline] + fn is_leaf(&self) -> bool { + self.children.is_empty() + } + + #[inline] + fn find(&self, key: &[u8]) -> Option<&Self> { + let mut node = self; + for k in key { + let Some(child) = &node.children[*k as usize] else { + return None; + }; + node = child; + } + Some(node) + } +} + +impl PrefixSearch for Trie { + fn insert(&mut self, key: K, value: V) + where + K: AsRef<[u8]>, + { + let mut node = if let Some(node) = &mut self.root { + node + } else { + self.root = Some(Node::default()); + self.root.as_mut().unwrap() + }; + for k in key.as_ref() { + let child = &mut node.children[*k as usize]; + if child.is_none() { + *child = Some(Default::default()); + } + node = unsafe { child.as_mut().unwrap_unchecked() }; + } + node.value = Some(value); + } + + fn delete(&mut self, key: K) -> Option + where + K: AsRef<[u8]>, + { + let Some(root) = &mut self.root else { + return None; + }; + + let key = key.as_ref(); + if key.is_empty() { + let value = root.value.take(); + self.root = None; + return value; + } + let mut node = root; + for k in key.iter().take(key.len() - 1) { + let Some(child) = &mut node.children[*k as usize] else { + return None; + }; + node = child; + } + let last = *key.last()? as usize; + let Some(child) = &mut node.children[last] else { + return None; + }; + if child.is_leaf() { + node.children[last].take().and_then(|node| node.value) + } else { + child.value.take() + } + } + + fn get(&self, key: K) -> Option<&V> + where + K: AsRef<[u8]>, + { + let Some(root) = &self.root else { + return None; + }; + root.find(key.as_ref()).and_then(|node| node.value.as_ref()) + } + + fn contains_prefix

(&self, prefix: P) -> bool + where + P: AsRef<[u8]>, + { + let Some(root) = &self.root else { + return false; + }; + root.find(prefix.as_ref()).is_some() + } +} + +#[cfg(test)] +mod test { + use crate::{trie::Trie, PrefixSearch}; + + #[test] + fn test_trie() { + let mut trie = Trie::default(); + assert_eq!(trie.get(b"hello"), None); + trie.insert(b"hello", 1); + assert_eq!(trie.get(b"hello"), Some(&1)); + assert_eq!(trie.contains_prefix(b"hell"), true); + assert_eq!(trie.contains_prefix(b"hello"), true); + assert_eq!(trie.contains_prefix(b""), true); + assert_eq!(trie.contains_prefix(b"hello world"), false); + assert_eq!(trie.contains_prefix(b"test"), false); + assert_eq!(trie.delete(b"hello"), Some(1)); + assert_eq!(trie.get(b"hello"), None); + } +}