From 97ef995b0854457273b52126c50045e665851876 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Mon, 12 Feb 2024 23:10:46 +0800 Subject: [PATCH] =?UTF-8?q?refactor(tokenizer):=20Tokenizer=20=E4=B8=8D?= =?UTF-8?q?=E5=86=8D=E6=89=BF=E6=8B=85=E6=B7=BB=E5=8A=A0=20bos=20eos=20?= =?UTF-8?q?=E7=9A=84=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- tokenizer/Cargo.toml | 2 +- tokenizer/src/bpe.rs | 73 ++++++++++++-------------------------- tokenizer/src/lib.rs | 4 +-- tokenizer/src/vocab_txt.rs | 21 +---------- 4 files changed, 25 insertions(+), 75 deletions(-) diff --git a/tokenizer/Cargo.toml b/tokenizer/Cargo.toml index 6cfd0632..48da11cd 100644 --- a/tokenizer/Cargo.toml +++ b/tokenizer/Cargo.toml @@ -7,5 +7,5 @@ authors = ["YdrMaster "] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -patricia_tree = "0.8" memmap2 = "0.9" +patricia_tree = "0.8" diff --git a/tokenizer/src/bpe.rs b/tokenizer/src/bpe.rs index cc56fd2e..6f807a1f 100644 --- a/tokenizer/src/bpe.rs +++ b/tokenizer/src/bpe.rs @@ -80,16 +80,6 @@ impl BPE { } impl Tokenizer for BPE { - #[inline] - fn bos(&self) -> utok { - 1 - } - - #[inline] - fn eos(&self) -> utok { - 2 - } - fn vocab_size(&self) -> usize { self.offsets.len() } @@ -99,13 +89,12 @@ impl Tokenizer for BPE { self.max_piece_len } - fn encode(&self, text: &str, bos: bool, eos: bool) -> Vec { + fn encode(&self, text: &str) -> Vec { let mut tokens = Vec::new(); - if bos { - tokens.push(self.bos()); - } - if !text.is_empty() { - tokens.push(self.find_piece("▁").unwrap()) + if let Some(c) = text.chars().next() { + if c.is_alphabetic() { + tokens.push(self.find_piece("▁").unwrap()) + } } text.chars().map(|c| c.to_string()).for_each(|c| { @@ -116,51 +105,33 @@ impl Tokenizer for BPE { } }); - fn map_pair(bpe: &BPE, tokens: &[utok], i: usize) -> Option { + fn map_pair(bpe: &BPE, tokens: &[utok], i: usize) -> Option<(utok, f32)> { bpe.find_piece(&format!( "{}{}", bpe.get_piece(tokens[i]), bpe.get_piece(tokens[i + 1]) )) + .map(|tok| (tok, bpe.get_score(tok))) } let mut merges = (0..tokens.len() - 1) .map(|i| map_pair(self, &tokens, i)) .collect::>(); - loop { - let mut best_score = std::f32::NEG_INFINITY; - let mut replacement = None; - for (i, index) in merges.iter().enumerate() { - if let Some(index) = index { - let score = self.get_score(*index); - if score > best_score { - best_score = score; - replacement = Some((i, *index)); - } - } - } - match replacement { - Some((i, j)) => { - tokens[i] = j; - tokens.remove(i + 1); - merges.remove(i); - if let Some(index) = merges.get_mut(i - 1) { - *index = map_pair(self, &tokens, i - 1); - } - if let Some(index) = merges.get_mut(i) { - *index = map_pair(self, &tokens, i); - } - } - None => break, + while let Some((i, (tok, _))) = merges + .iter() + .enumerate() + .filter_map(|(i, tok)| tok.map(|tok| (i, tok))) + .max_by(|(_, (_, a)), (_, (_, b))| a.total_cmp(b)) + { + tokens[i] = tok; + tokens.remove(i + 1); + merges.remove(i); + if let Some(i) = i.checked_sub(1) { + merges[i] = map_pair(self, &tokens, i); } + merges[i] = map_pair(self, &tokens, i); } - if bos { - assert_eq!(tokens[0], self.bos()); - } - if eos { - tokens.push(self.eos()); - } tokens } @@ -184,13 +155,13 @@ fn once_upon_a_time() { use std::time::Instant; if let Ok(bpe) = BPE::from_model_file("tokenizer.model") { const PROMPT: &str = "Once▁upon▁a▁time,"; - let tokens = bpe.encode(PROMPT, true, false); + let tokens = bpe.encode(PROMPT); let t0 = Instant::now(); for _ in 0..10000 { - let _tokens = bpe.encode(PROMPT, true, false); + let _tokens = bpe.encode(PROMPT); } let t1 = Instant::now(); println!("{:?}", t1 - t0); - assert_eq!(tokens, &[1, 9038, 2501, 263, 931, 29892]); + assert_eq!(tokens, &[9038, 2501, 263, 931, 29892]); } } diff --git a/tokenizer/src/lib.rs b/tokenizer/src/lib.rs index 0949374e..89debba2 100644 --- a/tokenizer/src/lib.rs +++ b/tokenizer/src/lib.rs @@ -6,11 +6,9 @@ mod vocab_txt; pub type utok = u32; pub trait Tokenizer { - fn bos(&self) -> utok; - fn eos(&self) -> utok; fn vocab_size(&self) -> usize; fn max_piece_len(&self) -> usize; - fn encode(&self, text: &str, bos: bool, eos: bool) -> Vec; + fn encode(&self, text: &str) -> Vec; fn decode(&self, token: utok) -> &str; } diff --git a/tokenizer/src/vocab_txt.rs b/tokenizer/src/vocab_txt.rs index 7af492ef..ae2280f9 100644 --- a/tokenizer/src/vocab_txt.rs +++ b/tokenizer/src/vocab_txt.rs @@ -40,16 +40,6 @@ impl VocabTxt { } impl Tokenizer for VocabTxt { - #[inline] - fn bos(&self) -> utok { - 1 - } - - #[inline] - fn eos(&self) -> utok { - 2 - } - fn vocab_size(&self) -> usize { self.words.len() } @@ -59,11 +49,8 @@ impl Tokenizer for VocabTxt { self.max_piece_len } - fn encode(&self, mut text: &str, bos: bool, eos: bool) -> Vec { + fn encode(&self, mut text: &str) -> Vec { let mut tokens = Vec::::new(); - if bos { - tokens.push(self.bos()); - } while !text.is_empty() { let piece = if text.len() > self.max_piece_len { @@ -82,12 +69,6 @@ impl Tokenizer for VocabTxt { } } - if bos { - assert_eq!(tokens[0], self.bos()); - } - if eos { - tokens.push(self.eos()); - } tokens }