From 6e086442adcadafc7cbd458916a48fe09a68a479 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Mon, 12 Feb 2024 09:49:42 +0800 Subject: [PATCH] =?UTF-8?q?feat(tokenizer):=20=E5=AE=8C=E6=88=90=20bpe=20t?= =?UTF-8?q?okenizer=20=E7=9A=84=E5=9F=BA=E6=9C=AC=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- tokenizer/src/bpe.rs | 112 +++++++++++++++++++++++++++++++------ tokenizer/src/lib.rs | 24 +++++++- tokenizer/src/vocab_txt.rs | 26 +++------ 3 files changed, 124 insertions(+), 38 deletions(-) diff --git a/tokenizer/src/bpe.rs b/tokenizer/src/bpe.rs index fad244e4..a0a61e02 100644 --- a/tokenizer/src/bpe.rs +++ b/tokenizer/src/bpe.rs @@ -1,27 +1,54 @@ -use crate::{utok, Tokenizer}; -use std::path::Path; +use crate::{utok, ByteDecoder, Tokenizer}; +use std::{io::Result, path::Path}; pub struct BPE { mmap: memmap2::Mmap, + /// 保存每个序号对应的对象在文件中的偏移,用于从序号查询 token 字符串。 offsets: Vec, + /// 保存根据 token 字符串字典序排序的序号,用于从 token 字符串查询序号。 + sorted_indices: Vec, + max_piece_len: usize, + byte_pieces: ByteDecoder, } impl BPE { - pub fn from_model(model_file: impl AsRef) -> Self { - let file = std::fs::File::open(model_file).unwrap(); - let mmap = unsafe { memmap2::Mmap::map(&file) }.unwrap(); + pub fn from_model(model_file: impl AsRef) -> Result { + let file = std::fs::File::open(model_file)?; + let mmap = unsafe { memmap2::Mmap::map(&file) }?; // format: 10 10 21 [] let mut offsets = Vec::new(); let mut offset = 0usize; + let mut max_piece_len = 0usize; loop { let slice = &mmap[offset..]; if slice.is_empty() || slice[0] != 10 { break; } + max_piece_len = max_piece_len.max(slice[3] as usize); offsets.push(offset + 3); offset += 2 + slice[1] as usize; } - Self { mmap, offsets } + let mut sorted_indices = (0..offsets.len() as utok).collect::>(); + sorted_indices.sort_by_key(|&i| { + let slice = &mmap[offsets[i as usize]..]; + let len = slice[0] as usize; + std::str::from_utf8(&slice[1..][..len]).unwrap() + }); + Ok(Self { + mmap, + offsets, + sorted_indices, + max_piece_len: 0, + byte_pieces: ByteDecoder::new(), + }) + } + + #[inline] + fn find_piece(&self, piece: &str) -> Option { + self.sorted_indices + .binary_search_by_key(&piece, |&i| self.get_piece(i)) + .ok() + .map(|i| self.sorted_indices[i]) } #[inline] @@ -43,31 +70,80 @@ impl BPE { } impl Tokenizer for BPE { - fn bos(&self) -> crate::utok { - todo!() + #[inline] + fn bos(&self) -> utok { + 1 } - fn eos(&self) -> crate::utok { - todo!() + #[inline] + fn eos(&self) -> utok { + 2 } + #[inline] fn max_piece_len(&self) -> usize { - todo!() + self.max_piece_len } - fn encode(&self, text: &str, bos: bool, eos: bool) -> Vec { - todo!() + fn encode(&self, text: &str, bos: bool, eos: bool) -> Vec { + let mut tokens = Vec::new(); + if bos { + tokens.push(self.bos()); + } + if !text.is_empty() { + tokens.push(self.find_piece(" ").unwrap()) + } + + text.chars().map(|c| c.to_string()).for_each(|c| { + if let Some(index) = self.find_piece(&c) { + tokens.extend([index]); + } else { + tokens.extend(c.bytes().map(|c| c as utok + 3)); + } + }); + + loop { + let mut best_score = std::f32::NEG_INFINITY; + let mut replacement = None; + for (i, pair) in tokens.windows(2).enumerate() { + let pair = format!("{}{}", self.get_piece(pair[0]), self.get_piece(pair[1])); + if let Some(index) = self.find_piece(&pair) { + let score = self.get_score(index); + if score > best_score { + best_score = score; + replacement = Some((i, index)); + } + } + } + match replacement { + Some((i, j)) => { + tokens[i] = j; + tokens.remove(i + 1); + } + None => break, + } + } + + if bos { + assert_eq!(tokens[0], self.bos()); + } + if eos { + tokens.push(self.eos()); + } + tokens } - fn decode(&self, token: crate::utok, next: crate::utok) -> &str { - todo!() + #[inline] + fn decode(&self, token: utok) -> &str { + self.byte_pieces.decode(self.get_piece(token)) } } #[test] fn read_tokenizer() { - let bpe = BPE::from_model("tokenizer.model"); - for i in 0..bpe.offsets.len() { - println!("{}: {}", bpe.get_piece(i as utok), bpe.get_score(i as utok)); + if let Ok(bpe) = BPE::from_model("tokenizer.model") { + for i in 0..bpe.offsets.len() { + println!("{}: {}", bpe.get_piece(i as utok), bpe.get_score(i as utok)); + } } } diff --git a/tokenizer/src/lib.rs b/tokenizer/src/lib.rs index 6f079ce8..be3a8175 100644 --- a/tokenizer/src/lib.rs +++ b/tokenizer/src/lib.rs @@ -10,8 +10,30 @@ pub trait Tokenizer { fn eos(&self) -> utok; fn max_piece_len(&self) -> usize; fn encode(&self, text: &str, bos: bool, eos: bool) -> Vec; - fn decode(&self, token: utok, next: utok) -> &str; + fn decode(&self, token: utok) -> &str; } pub use bpe::BPE; pub use vocab_txt::VocabTxt; + +struct ByteDecoder([u8; 256]); + +impl ByteDecoder { + fn new() -> Self { + let mut ans = Self([0; 256]); + for (i, b) in ans.0.iter_mut().enumerate() { + *b = i as _; + } + ans + } + + fn decode<'a>(&'a self, piece: &'a str) -> &'a str { + if let Some(byte) = piece.strip_prefix("<0x").and_then(|s| s.strip_suffix('>')) { + let byte = u8::from_str_radix(byte, 16).unwrap(); + let byte = std::slice::from_ref(&self.0[byte as usize]); + unsafe { std::str::from_utf8_unchecked(byte) } + } else { + piece + } + } +} diff --git a/tokenizer/src/vocab_txt.rs b/tokenizer/src/vocab_txt.rs index 229e4aeb..d8802a58 100644 --- a/tokenizer/src/vocab_txt.rs +++ b/tokenizer/src/vocab_txt.rs @@ -1,4 +1,4 @@ -use super::{utok, Tokenizer}; +use crate::{utok, ByteDecoder, Tokenizer}; use memmap2::Mmap; use patricia_tree::PatriciaMap; use std::{fs::File, path::Path}; @@ -12,7 +12,7 @@ pub struct VocabTxt { /// 词汇的最大长度。 max_piece_len: usize, /// 单字节词汇转义。 - byte_pieces: [u8; 256], + byte_pieces: ByteDecoder, } impl VocabTxt { @@ -29,16 +29,12 @@ impl VocabTxt { words.push(piece.to_string()); trie.insert(piece, i as _); } - let mut ans = Self { + Self { words, trie, max_piece_len, - byte_pieces: [0; 256], - }; - for i in 0..=255u8 { - ans.byte_pieces[i as usize] = i; + byte_pieces: ByteDecoder::new(), } - ans } } @@ -90,16 +86,8 @@ impl Tokenizer for VocabTxt { tokens } - fn decode(&self, token: utok, next: utok) -> &str { - let piece = self.words[next as usize].as_str(); - if let Some(byte) = piece.strip_prefix("<0x").and_then(|s| s.strip_suffix('>')) { - let byte = u8::from_str_radix(byte, 16).unwrap(); - let byte = &self.byte_pieces[byte as usize..][..1]; - unsafe { std::str::from_utf8_unchecked(byte) } - } else if token == self.bos() && piece.starts_with(' ') { - &piece[1..] - } else { - piece - } + #[inline] + fn decode(&self, token: utok) -> &str { + self.byte_pieces.decode(self.words[token as usize].as_str()) } }