Skip to content

Commit

Permalink
perf(tokenizer): 减少 bpe 合并中的重复计算以优化性能
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Feb 12, 2024
1 parent 6e08644 commit 132d172
Showing 1 changed file with 38 additions and 6 deletions.
44 changes: 38 additions & 6 deletions tokenizer/src/bpe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ impl Tokenizer for BPE {
tokens.push(self.bos());
}
if !text.is_empty() {
tokens.push(self.find_piece(" ").unwrap())
tokens.push(self.find_piece("").unwrap())
}

text.chars().map(|c| c.to_string()).for_each(|c| {
Expand All @@ -102,23 +102,40 @@ impl Tokenizer for BPE {
}
});

fn map_pair(bpe: &BPE, tokens: &[utok], i: usize) -> Option<utok> {
bpe.find_piece(&format!(
"{}{}",
bpe.get_piece(tokens[i]),
bpe.get_piece(tokens[i + 1])
))
}

let mut merges = (0..tokens.len() - 1)
.map(|i| map_pair(self, &tokens, i))
.collect::<Vec<_>>();
loop {
let mut best_score = std::f32::NEG_INFINITY;
let mut replacement = None;
for (i, pair) in tokens.windows(2).enumerate() {
let pair = format!("{}{}", self.get_piece(pair[0]), self.get_piece(pair[1]));
if let Some(index) = self.find_piece(&pair) {
let score = self.get_score(index);
for (i, index) in merges.iter().enumerate() {
if let Some(index) = index {
let score = self.get_score(*index);
if score > best_score {
best_score = score;
replacement = Some((i, index));
replacement = Some((i, *index));
}
}
}
match replacement {
Some((i, j)) => {
tokens[i] = j;
tokens.remove(i + 1);
merges.remove(i);
if let Some(index) = merges.get_mut(i - 1) {
*index = map_pair(self, &tokens, i - 1);
}
if let Some(index) = merges.get_mut(i) {
*index = map_pair(self, &tokens, i);
}
}
None => break,
}
Expand Down Expand Up @@ -147,3 +164,18 @@ fn read_tokenizer() {
}
}
}

#[test]
fn once_upon_a_time() {
use std::time::Instant;
if let Ok(bpe) = BPE::from_model("tokenizer.model") {
let tokens = bpe.encode("Once▁upon▁a▁time,", true, false);
let t0 = Instant::now();
for _ in 0..10000 {
let _tokens = bpe.encode("Once▁upon▁a▁time,", true, false);
}
let t1 = Instant::now();
println!("{:?}", t1 - t0);
assert_eq!(tokens, &[1, 9038, 2501, 263, 931, 29892]);
}
}

0 comments on commit 132d172

Please sign in to comment.