Skip to content

Commit

Permalink
improve init performance of pattern match vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
maxbachmann committed Nov 25, 2023
1 parent 527d017 commit e1b2fb2
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 60 deletions.
31 changes: 28 additions & 3 deletions benches/benches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ fn benchmark(c: &mut Criterion) {

let cached = distance::jaro_winkler::BatchComparator::new(s1.bytes(), None);
group.bench_with_input(
BenchmarkId::new("cached_rapidfuzz", i),
BenchmarkId::new("rapidfuzz (BatchComparator)", i),
&(&cached, &s2),
|b, val| {
b.iter(|| {
Expand Down Expand Up @@ -87,7 +87,7 @@ fn benchmark(c: &mut Criterion) {

let cached = distance::osa::BatchComparator::new(s1.chars());
group.bench_with_input(
BenchmarkId::new("cached_rapidfuzz", i),
BenchmarkId::new("rapidfuzz (BatchComparator)", i),
&(&cached, &s2),
|b, val| {
b.iter(|| {
Expand Down Expand Up @@ -133,7 +133,7 @@ fn benchmark(c: &mut Criterion) {

let cached = distance::levenshtein::BatchComparator::new(s1.bytes(), None);
group.bench_with_input(
BenchmarkId::new("cached_rapidfuzz", i),
BenchmarkId::new("rapidfuzz (BatchComparator)", i),
&(&cached, &s2),
|b, val| {
b.iter(|| {
Expand All @@ -145,6 +145,31 @@ fn benchmark(c: &mut Criterion) {

group.finish();

group = c.benchmark_group("Generic Levenshtein");

for i in lens.clone() {
let s1 = generate(i);
let s2 = generate(i);

group.bench_with_input(BenchmarkId::new("rapidfuzz", i), &(&s1, &s2), |b, val| {
b.iter(|| {
black_box(distance::levenshtein::distance(
val.0.bytes(),
val.1.bytes(),
Some(distance::levenshtein::WeightTable {
insertion_cost: 1,
deletion_cost: 2,
substitution_cost: 3,
}),
None,
None,
));
})
});
}

group.finish();

group = c.benchmark_group("DamerauLevenshtein");

for i in lens.clone() {
Expand Down
56 changes: 35 additions & 21 deletions src/details/pattern_match_vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ impl BitvectorHashmap {
}

pub struct PatternMatchVector {
pub map_unsigned: BitvectorHashmap,
pub map_signed: BitvectorHashmap,
pub extended_ascii: [u64; 256],
pub map_unsigned: Option<BitvectorHashmap>,
pub map_signed: Option<BitvectorHashmap>,
}

pub trait BitVectorInterface {
Expand All @@ -78,23 +78,17 @@ pub trait BitVectorInterface {
fn size(&self) -> usize;
}

impl PatternMatchVector {
// right now this can't be used since rust fails to elide the memcpy
// on return
/*pub fn new<Iter1, CharT>(s1: Iter1) -> Self
where
Iter1: Iterator<Item = CharT>,
CharT: HashableChar,
{
let mut vec = Self {
map_unsigned: BitvectorHashmap::default(),
map_signed: BitvectorHashmap::default(),
impl Default for PatternMatchVector {
fn default() -> Self {
Self {
map_unsigned: None,
map_signed: None,
extended_ascii: [0; 256],
};
vec.insert(s1);
vec
}*/
}
}
}

impl PatternMatchVector {
pub fn insert<Iter1, CharT>(&mut self, s1: Iter1)
where
Iter1: Iterator<Item = CharT>,
Expand All @@ -114,17 +108,27 @@ impl PatternMatchVector {
match key.hash_char() {
Hash::SIGNED(value) => {
if value < 0 {
if self.map_signed.is_none() {
self.map_signed = Some(BitvectorHashmap::default());
}
let item = self
.map_signed
.as_mut()
.expect("map should have been created above")
.get_mut(u64::from_ne_bytes(value.to_ne_bytes()));
*item |= mask;
} else if value <= 255 {
let val_u8 = u8::try_from(value).expect("we check the bounds above");
let item = &mut self.extended_ascii[usize::from(val_u8)];
*item |= mask;
} else {
if self.map_unsigned.is_none() {
self.map_unsigned = Some(BitvectorHashmap::default());
}
let item = self
.map_unsigned
.as_mut()
.expect("map should have been created above")
.get_mut(u64::from_ne_bytes(value.to_ne_bytes()));
*item |= mask;
}
Expand All @@ -135,7 +139,14 @@ impl PatternMatchVector {
let item = &mut self.extended_ascii[usize::from(val_u8)];
*item |= mask;
} else {
let item = self.map_unsigned.get_mut(value);
if self.map_unsigned.is_none() {
self.map_unsigned = Some(BitvectorHashmap::default());
}
let item = self
.map_unsigned
.as_mut()
.expect("map should have been created above")
.get_mut(value);
*item |= mask;
}
}
Expand All @@ -152,21 +163,24 @@ impl BitVectorInterface for PatternMatchVector {
match key.hash_char() {
Hash::SIGNED(value) => {
if value < 0 {
self.map_signed.get(u64::from_ne_bytes(value.to_ne_bytes()))
self.map_signed
.as_ref()
.map_or(0, |map| map.get(u64::from_ne_bytes(value.to_ne_bytes())))
} else if value <= 255 {
let val_u8 = u8::try_from(value).expect("we check the bounds above");
self.extended_ascii[usize::from(val_u8)]
} else {
self.map_unsigned
.get(u64::from_ne_bytes(value.to_ne_bytes()))
.as_ref()
.map_or(0, |map| map.get(u64::from_ne_bytes(value.to_ne_bytes())))
}
}
Hash::UNSIGNED(value) => {
if value <= 255 {
let val_u8 = u8::try_from(value).expect("we check the bounds above");
self.extended_ascii[usize::from(val_u8)]
} else {
self.map_unsigned.get(value)
self.map_unsigned.as_ref().map_or(0, |map| map.get(value))
}
}
}
Expand Down
12 changes: 2 additions & 10 deletions src/distance/jaro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::details::common::{find_common_prefix, HashableChar};
use crate::details::distance::Metricf64;
use crate::details::intrinsics::{bit_mask_lsb_u64, blsi_u64, blsr_u64, ceil_div_usize};
use crate::details::pattern_match_vector::{
BitVectorInterface, BitvectorHashmap, BlockPatternMatchVector, PatternMatchVector,
BitVectorInterface, BlockPatternMatchVector, PatternMatchVector,
};
use crate::Hash;
use std::cmp::min;
Expand Down Expand Up @@ -412,16 +412,8 @@ where
if len1 == 0 || len2 == 0 {
// already has correct number of common chars and transpositions
} else if len1 <= 64 && len2 <= 64 {
// rust fails to elide the copy when returning the array
// from PatternMatchVector::new so manually inline it
//let block = PatternMatchVector::new(s2_iter.clone());
let mut pm = PatternMatchVector {
map_unsigned: BitvectorHashmap::default(),
map_signed: BitvectorHashmap::default(),
extended_ascii: [0; 256],
};
let mut pm = PatternMatchVector::default();
pm.insert(s1_iter);

let flagged = flag_similar_characters_word(&pm, len1, s2_iter.clone(), len2, bound);

common_chars += flagged.count_common_chars();
Expand Down
11 changes: 2 additions & 9 deletions src/distance/lcs_seq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::details::distance::MetricUsize;
use crate::details::intrinsics::{carrying_add, ceil_div_usize};
use crate::details::matrix::ShiftedBitMatrix;
use crate::details::pattern_match_vector::{
BitVectorInterface, BitvectorHashmap, BlockPatternMatchVector, PatternMatchVector,
BitVectorInterface, BlockPatternMatchVector, PatternMatchVector,
};
use std::cmp::{max, min};

Expand Down Expand Up @@ -374,14 +374,7 @@ where
if len1 == 0 {
Some(0)
} else if len1 <= 64 {
// rust fails to elide the copy when returning the array
// from PatternMatchVector::new so manually inline it
//let block = PatternMatchVector::new(s2_iter.clone());
let mut pm = PatternMatchVector {
map_unsigned: BitvectorHashmap::default(),
map_signed: BitvectorHashmap::default(),
extended_ascii: [0; 256],
};
let mut pm = PatternMatchVector::default();
pm.insert(s1.clone());
longest_common_subsequence_with_pm(&pm, s1, len1, s2, len2, score_cutoff)
} else {
Expand Down
68 changes: 59 additions & 9 deletions src/distance/levenshtein.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::details::growing_hashmap::{GrowingHashmap, HybridGrowingHashmap};
use crate::details::intrinsics::{ceil_div_usize, shr64};
use crate::details::matrix::ShiftedBitMatrix;
use crate::details::pattern_match_vector::{
BitVectorInterface, BitvectorHashmap, BlockPatternMatchVector, PatternMatchVector,
BitVectorInterface, BlockPatternMatchVector, PatternMatchVector,
};
use crate::distance::indel;
use std::cmp::{max, min};
Expand Down Expand Up @@ -975,6 +975,7 @@ where

// do this first, since we can not remove any affix in encoded form
// todo actually we could at least remove the common prefix and just shift the band
// todo for short strings this is likely a performance regression
if score_cutoff >= 4 {
// todo could safe up to 25% even without score_cutoff when ignoring irrelevant paths
// in the upper and lower corner
Expand Down Expand Up @@ -1070,14 +1071,7 @@ where

/* when the short strings has less then 65 elements Hyyrös' algorithm can be used */
if affix.len2 <= 64 {
// rust fails to elide the copy when returning the array
// from PatternMatchVector::new so manually inline it
//let block = PatternMatchVector::new(s2_iter.clone());
let mut pm = PatternMatchVector {
map_unsigned: BitvectorHashmap::default(),
map_signed: BitvectorHashmap::default(),
extended_ascii: [0; 256],
};
let mut pm = PatternMatchVector::default();
pm.insert(affix.s2.clone());

let res: DistanceResult<0, 0> = hyrroe2003(
Expand Down Expand Up @@ -1316,6 +1310,17 @@ impl MetricUsize for IndividualComparator {
}
}

/// Levenshtein distance
///
/// Calculates the Levenshtein distance.
///
/// # Examples
///
/// ```
/// use rapidfuzz::distance::levenshtein;
///
/// assert_eq!(Some(3), levenshtein::distance("CA".chars(), "ABC".chars(), None, None, None));
/// ```
pub fn distance<Iter1, Iter2, Elem1, Elem2, ScoreCutoff, ScoreHint>(
s1: Iter1,
s2: Iter2,
Expand Down Expand Up @@ -1345,6 +1350,22 @@ where
)
}

/// Levenshtein similarity in the range [max, 0]
///
/// This is calculated as `maximum - `[`distance`]. Where maximum is defined as
/// ```notrust
/// if len1 >= len2:
/// maximum = min(
/// len1 * deletion_cost + len2 * insertion_cost,
/// len2 * substitution_cost + (len1 - len2) * deletion_cost
/// )
/// else:
/// maximum = min(
/// len1 * deletion_cost + len2 * insertion_cost,
/// len1 * substitution_cost + (len2 - len1) * insertion_cost,
/// )
/// ```
///
pub fn similarity<Iter1, Iter2, Elem1, Elem2, ScoreCutoff, ScoreHint>(
s1: Iter1,
s2: Iter2,
Expand Down Expand Up @@ -1374,6 +1395,21 @@ where
)
}

/// Normalized Levenshtein distance in the range [1.0, 0.0]
///
/// This is calculated as [`distance`]` / maximum`. Where maximum is defined as
/// ```notrust
/// if len1 >= len2:
/// maximum = min(
/// len1 * deletion_cost + len2 * insertion_cost,
/// len2 * substitution_cost + (len1 - len2) * deletion_cost
/// )
/// else:
/// maximum = min(
/// len1 * deletion_cost + len2 * insertion_cost,
/// len1 * substitution_cost + (len2 - len1) * insertion_cost,
/// )
/// ```
pub fn normalized_distance<Iter1, Iter2, Elem1, Elem2, ScoreCutoff, ScoreHint>(
s1: Iter1,
s2: Iter2,
Expand Down Expand Up @@ -1403,6 +1439,10 @@ where
)
}

/// Normalized Levenshtein similarity in the range [0.0, 1.0]
///
/// This is calculated as `1.0 - `[`normalized_distance`].
///
pub fn normalized_similarity<Iter1, Iter2, Elem1, Elem2, ScoreCutoff, ScoreHint>(
s1: Iter1,
s2: Iter2,
Expand Down Expand Up @@ -1432,6 +1472,16 @@ where
)
}

/// `One x Many` comparisions using the Levenshtein distance
///
/// # Examples
///
/// ```
/// use rapidfuzz::distance::levenshtein;
///
/// let scorer = levenshtein::BatchComparator::new("CA".chars(), None);
/// assert_eq!(Some(3), scorer.distance("ABC".chars(), None, None));
/// ```
pub struct BatchComparator<Elem1> {
s1: Vec<Elem1>,
pm: BlockPatternMatchVector,
Expand Down
9 changes: 1 addition & 8 deletions src/distance/osa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,14 +203,7 @@ impl MetricUsize for IndividualComparator {
None
}
} else if affix.len1 <= 64 {
// rust fails to elide the copy when returning the array
// from PatternMatchVector::new so manually inline it
//let block = PatternMatchVector::new(s2_iter.clone());
let mut pm = PatternMatchVector {
map_unsigned: BitvectorHashmap::default(),
map_signed: BitvectorHashmap::default(),
extended_ascii: [0; 256],
};
let mut pm = PatternMatchVector::default();
pm.insert(affix.s1.clone());
hyrroe2003(
&pm,
Expand Down

0 comments on commit e1b2fb2

Please sign in to comment.