From c09f5b930aa26f3433deb6302b7635dc7c06bf69 Mon Sep 17 00:00:00 2001 From: kitty <914384274@qq.com> Date: Mon, 22 Jul 2024 23:56:50 +0800 Subject: [PATCH] feat(RUST): String detection is performed using SIMD techniques (#1752) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What does this PR do? Using SIMD technology to speed up string detection ## Related issues ## Does this PR introduce any user-facing change? - [x] Does this PR introduce any public API change? - [ ] Does this PR introduce any binary protocol compatibility change? ## Benchmark In fury/rust, run ```bash cargo bench ``` result: ```bash SIMD sse short time: [992.07 ps 993.91 ps 995.81 ps] change: [-1.1907% -0.7021% -0.2276%] (p = 0.00 < 0.05) Change within noise threshold. Found 5 outliers among 100 measurements (5.00%) 2 (2.00%) high mild 3 (3.00%) high severe SIMD sse long time: [368.88 ns 369.70 ns 370.59 ns] change: [-1.1407% -0.4415% +0.1537%] (p = 0.20 > 0.05) No change in performance detected. Found 11 outliers among 100 measurements (11.00%) 9 (9.00%) high mild 2 (2.00%) high severe SIMD avx short time: [4.4313 ns 4.4425 ns 4.4566 ns] change: [-0.4239% -0.0440% +0.3277%] (p = 0.82 > 0.05) No change in performance detected. Found 9 outliers among 100 measurements (9.00%) 5 (5.00%) high mild 4 (4.00%) high severe SIMD avx long time: [18.215 ns 18.277 ns 18.351 ns] change: [+0.6658% +1.0962% +1.6058%] (p = 0.00 < 0.05) Change within noise threshold. Found 9 outliers among 100 measurements (9.00%) 5 (5.00%) high mild 4 (4.00%) high severe Standard short time: [5.1115 ns 5.1670 ns 5.2491 ns] change: [+0.5348% +1.5193% +2.5623%] (p = 0.00 < 0.05) Change within noise threshold. Found 3 outliers among 100 measurements (3.00%) 2 (2.00%) high mild 1 (1.00%) high severe Standard long time: [3.6904 µs 3.7205 µs 3.7606 µs] change: [+1.5445% +2.5638% +3.9167%] (p = 0.00 < 0.05) Performance has regressed. Found 13 outliers among 100 measurements (13.00%) 7 (7.00%) high mild 6 (6.00%) high severe ``` --------- Co-authored-by: hezz --- rust/fury/Cargo.toml | 10 ++ rust/fury/benches/simd_bench.rs | 112 ++++++++++++++++++ rust/fury/src/meta/meta_string.rs | 4 +- rust/fury/src/meta/mod.rs | 1 + rust/fury/src/meta/string_util.rs | 189 ++++++++++++++++++++++++++++++ 5 files changed, 315 insertions(+), 1 deletion(-) create mode 100644 rust/fury/benches/simd_bench.rs create mode 100644 rust/fury/src/meta/string_util.rs diff --git a/rust/fury/Cargo.toml b/rust/fury/Cargo.toml index d39d537572..bba04afb73 100644 --- a/rust/fury/Cargo.toml +++ b/rust/fury/Cargo.toml @@ -30,3 +30,13 @@ lazy_static = { version = "1.4" } byteorder = { version = "1.4" } chrono = "0.4" thiserror = { default-features = false, version = "1.0" } + + +[[bench]] +name = "simd_bench" +harness = false + + +[dev-dependencies] +criterion = "0.5.1" +rand = "0.8.5" \ No newline at end of file diff --git a/rust/fury/benches/simd_bench.rs b/rust/fury/benches/simd_bench.rs new file mode 100644 index 0000000000..7649533fc3 --- /dev/null +++ b/rust/fury/benches/simd_bench.rs @@ -0,0 +1,112 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +#[cfg(target_feature = "avx2")] +use std::arch::x86_64::*; + +#[cfg(target_feature = "sse2")] +use std::arch::x86_64::*; + +#[cfg(target_feature = "avx2")] +pub(crate) const MIN_DIM_SIZE_AVX: usize = 32; + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +pub(crate) const MIN_DIM_SIZE_SIMD: usize = 16; + +#[cfg(target_feature = "sse2")] +unsafe fn is_latin_sse(s: &str) -> bool { + let bytes = s.as_bytes(); + let len = s.len(); + let ascii_mask = _mm_set1_epi8(0x80u8 as i8); + let remaining = len % MIN_DIM_SIZE_SIMD; + let range_end = len - remaining; + for i in (0..range_end).step_by(MIN_DIM_SIZE_SIMD) { + let chunk = _mm_loadu_si128(bytes.as_ptr().add(i) as *const __m128i); + let masked = _mm_and_si128(chunk, ascii_mask); + let cmp = _mm_cmpeq_epi8(masked, _mm_setzero_si128()); + if _mm_movemask_epi8(cmp) != 0xFFFF { + return false; + } + } + for item in bytes.iter().take(range_end).skip(range_end) { + if !item.is_ascii() { + return false; + } + } + true +} + +#[cfg(target_feature = "avx2")] +unsafe fn is_latin_avx(s: &str) -> bool { + let bytes = s.as_bytes(); + let len = s.len(); + let ascii_mask = _mm256_set1_epi8(0x80u8 as i8); + let remaining = len % MIN_DIM_SIZE_AVX; + let range_end = len - remaining; + for i in (0..(len - remaining)).step_by(MIN_DIM_SIZE_AVX) { + let chunk = _mm256_loadu_si256(bytes.as_ptr().add(i) as *const __m256i); + let masked = _mm256_and_si256(chunk, ascii_mask); + let cmp = _mm256_cmpeq_epi8(masked, _mm256_setzero_si256()); + if _mm256_movemask_epi8(cmp) != 0xFFFF { + return false; + } + } + for item in bytes.iter().take(range_end).skip(range_end) { + if !item.is_ascii() { + return false; + } + } + true +} + +fn is_latin_std(s: &str) -> bool { + s.bytes().all(|b| b.is_ascii()) +} + +fn criterion_benchmark(c: &mut Criterion) { + let test_str_short = "Hello, World!"; + let test_str_long = "Hello, World! ".repeat(1000); + + #[cfg(target_feature = "sse2")] + c.bench_function("SIMD sse short", |b| { + b.iter(|| unsafe { is_latin_sse(black_box(test_str_short)) }) + }); + #[cfg(target_feature = "sse2")] + c.bench_function("SIMD sse long", |b| { + b.iter(|| unsafe { is_latin_sse(black_box(&test_str_long)) }) + }); + #[cfg(target_feature = "avx2")] + c.bench_function("SIMD avx short", |b| { + b.iter(|| unsafe { is_latin_avx(black_box(test_str_short)) }) + }); + #[cfg(target_feature = "avx2")] + c.bench_function("SIMD avx long", |b| { + b.iter(|| unsafe { is_latin_avx(black_box(&test_str_long)) }) + }); + + c.bench_function("Standard short", |b| { + b.iter(|| is_latin_std(black_box(test_str_short))) + }); + + c.bench_function("Standard long", |b| { + b.iter(|| is_latin_std(black_box(&test_str_long))) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/rust/fury/src/meta/meta_string.rs b/rust/fury/src/meta/meta_string.rs index 3b46265e8d..627433c9fd 100644 --- a/rust/fury/src/meta/meta_string.rs +++ b/rust/fury/src/meta/meta_string.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +use crate::meta::string_util; + #[derive(Debug, PartialEq)] pub enum Encoding { Utf8 = 0x00, @@ -102,7 +104,7 @@ impl MetaStringEncoder { } fn is_latin(&self, s: &str) -> bool { - s.bytes().all(|b| b.is_ascii()) + string_util::is_latin(s) } pub fn encode(&self, input: &str) -> Result { diff --git a/rust/fury/src/meta/mod.rs b/rust/fury/src/meta/mod.rs index 4e4d40b29a..02871e8257 100644 --- a/rust/fury/src/meta/mod.rs +++ b/rust/fury/src/meta/mod.rs @@ -16,4 +16,5 @@ // under the License. mod meta_string; +mod string_util; pub use meta_string::{Encoding, MetaStringDecoder, MetaStringEncoder}; diff --git a/rust/fury/src/meta/string_util.rs b/rust/fury/src/meta/string_util.rs new file mode 100644 index 0000000000..ea8659110b --- /dev/null +++ b/rust/fury/src/meta/string_util.rs @@ -0,0 +1,189 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[cfg(target_feature = "neon")] +use std::arch::aarch64::*; + +#[cfg(target_feature = "avx2")] +use std::arch::x86_64::*; + +#[cfg(target_feature = "sse2")] +use std::arch::x86_64::*; + +#[cfg(target_arch = "x86_64")] +pub(crate) const MIN_DIM_SIZE_AVX: usize = 32; + +#[cfg(any( + target_arch = "x86", + target_arch = "x86_64", + all(target_arch = "aarch64", target_feature = "neon") +))] +pub(crate) const MIN_DIM_SIZE_SIMD: usize = 16; + +#[cfg(target_arch = "x86_64")] +unsafe fn is_latin_avx(s: &str) -> bool { + let bytes = s.as_bytes(); + let len = bytes.len(); + let ascii_mask = _mm256_set1_epi8(0x80u8 as i8); + let remaining = len % MIN_DIM_SIZE_AVX; + let range_end = len - remaining; + + for i in (0..range_end).step_by(MIN_DIM_SIZE_AVX) { + let chunk = _mm256_loadu_si256(bytes.as_ptr().add(i) as *const __m256i); + let masked = _mm256_and_si256(chunk, ascii_mask); + let cmp = _mm256_cmpeq_epi8(masked, _mm256_setzero_si256()); + if _mm256_movemask_epi8(cmp) != -1 { + return false; + } + } + for item in bytes.iter().take(len).skip(range_end) { + if !item.is_ascii() { + return false; + } + } + true +} + +#[cfg(target_feature = "sse2")] +unsafe fn is_latin_sse(s: &str) -> bool { + let bytes = s.as_bytes(); + let len = bytes.len(); + let ascii_mask = _mm_set1_epi8(0x80u8 as i8); + let remaining = len % MIN_DIM_SIZE_SIMD; + let range_end = len - remaining; + for i in (0..range_end).step_by(MIN_DIM_SIZE_SIMD) { + let chunk = _mm_loadu_si128(bytes.as_ptr().add(i) as *const __m128i); + let masked = _mm_and_si128(chunk, ascii_mask); + let cmp = _mm_cmpeq_epi8(masked, _mm_setzero_si128()); + if _mm_movemask_epi8(cmp) != 0xFFFF { + return false; + } + } + for item in bytes.iter().take(len).skip(range_end) { + if !item.is_ascii() { + return false; + } + } + true +} + +#[cfg(target_feature = "neon")] +unsafe fn is_latin_neon(s: &str) -> bool { + let bytes = s.as_bytes(); + let len = bytes.len(); + let ascii_mask = vdupq_n_u8(0x80); + let remaining = len % MIN_DIM_SIZE_SIMD; + let range_end = len - remaining; + for i in (0..range_end).step_by(MIN_DIM_SIZE_SIMD) { + let chunk = vld1q_u8(bytes.as_ptr().add(i)); + let masked = vandq_u8(chunk, ascii_mask); + let cmp = vceqq_u8(masked, vdupq_n_u8(0)); + if vminvq_u8(cmp) == 0 { + return false; + } + } + for item in bytes.iter().take(len).skip(range_end) { + if !item.is_ascii() { + return false; + } + } + true +} + +fn is_latin_standard(s: &str) -> bool { + s.bytes().all(|b| b.is_ascii()) +} + +pub(crate) fn is_latin(s: &str) -> bool { + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx") + && is_x86_feature_detected!("fma") + && s.len() >= MIN_DIM_SIZE_AVX + { + return unsafe { is_latin_avx(s) }; + } + } + + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + { + if is_x86_feature_detected!("sse") && s.len() >= MIN_DIM_SIZE_SIMD { + return unsafe { is_latin_sse(s) }; + } + } + + #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] + { + if std::arch::is_aarch64_feature_detected!("neon") && s.len() >= MIN_DIM_SIZE_SIMD { + return unsafe { is_latin_neon(s) }; + } + } + is_latin_standard(s) +} + +#[cfg(test)] +mod tests { + // 导入外部模块中的内容 + use super::*; + use rand::Rng; + + fn generate_random_string(length: usize) -> String { + const CHARSET: &[u8] = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; + let mut rng = rand::thread_rng(); + + let result: String = (0..length) + .map(|_| { + let idx = rng.gen_range(0..CHARSET.len()); + CHARSET[idx] as char + }) + .collect(); + + result + } + + #[test] + fn test_is_latin() { + let s = generate_random_string(1000); + let not_latin_str = generate_random_string(1000) + "abc\u{1234}"; + + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx") && is_x86_feature_detected!("fma") { + assert!(unsafe { is_latin_avx(&s) }); + assert!(!unsafe { is_latin_avx(¬_latin_str) }); + } + } + + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + { + if is_x86_feature_detected!("sse") && s.len() >= MIN_DIM_SIZE_SIMD { + assert!(unsafe { is_latin_sse(&s) }); + assert!(!unsafe { is_latin_sse(¬_latin_str) }); + } + } + + #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] + { + if std::arch::is_aarch64_feature_detected!("neon") && s.len() >= MIN_DIM_SIZE_SIMD { + assert!(unsafe { is_latin_neon(&s) }); + assert!(!unsafe { is_latin_neon(¬_latin_str) }); + } + } + assert!(is_latin_standard(&s)); + assert!(!is_latin_standard(¬_latin_str)); + } +}