Skip to content

Commit

Permalink
feat(tokenizer): 完成 bpe tokenizer 的基本逻辑
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Feb 12, 2024
1 parent d611d06 commit 6e08644
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 38 deletions.
112 changes: 94 additions & 18 deletions tokenizer/src/bpe.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,54 @@
use crate::{utok, Tokenizer};
use std::path::Path;
use crate::{utok, ByteDecoder, Tokenizer};
use std::{io::Result, path::Path};

pub struct BPE {
mmap: memmap2::Mmap,
/// 保存每个序号对应的对象在文件中的偏移,用于从序号查询 token 字符串。
offsets: Vec<usize>,
/// 保存根据 token 字符串字典序排序的序号,用于从 token 字符串查询序号。
sorted_indices: Vec<utok>,
max_piece_len: usize,
byte_pieces: ByteDecoder,
}

impl BPE {
pub fn from_model(model_file: impl AsRef<Path>) -> Self {
let file = std::fs::File::open(model_file).unwrap();
let mmap = unsafe { memmap2::Mmap::map(&file) }.unwrap();
pub fn from_model(model_file: impl AsRef<Path>) -> Result<Self> {
let file = std::fs::File::open(model_file)?;
let mmap = unsafe { memmap2::Mmap::map(&file) }?;
// format: 10 <total_len> 10 <str_len> <str;str_len> 21 <score;4> []
let mut offsets = Vec::new();
let mut offset = 0usize;
let mut max_piece_len = 0usize;
loop {
let slice = &mmap[offset..];
if slice.is_empty() || slice[0] != 10 {
break;
}
max_piece_len = max_piece_len.max(slice[3] as usize);
offsets.push(offset + 3);
offset += 2 + slice[1] as usize;
}
Self { mmap, offsets }
let mut sorted_indices = (0..offsets.len() as utok).collect::<Vec<_>>();
sorted_indices.sort_by_key(|&i| {
let slice = &mmap[offsets[i as usize]..];
let len = slice[0] as usize;
std::str::from_utf8(&slice[1..][..len]).unwrap()
});
Ok(Self {
mmap,
offsets,
sorted_indices,
max_piece_len: 0,
byte_pieces: ByteDecoder::new(),
})
}

#[inline]
fn find_piece(&self, piece: &str) -> Option<utok> {
self.sorted_indices
.binary_search_by_key(&piece, |&i| self.get_piece(i))
.ok()
.map(|i| self.sorted_indices[i])
}

#[inline]
Expand All @@ -43,31 +70,80 @@ impl BPE {
}

impl Tokenizer for BPE {
fn bos(&self) -> crate::utok {
todo!()
#[inline]
fn bos(&self) -> utok {
1
}

fn eos(&self) -> crate::utok {
todo!()
#[inline]
fn eos(&self) -> utok {
2
}

#[inline]
fn max_piece_len(&self) -> usize {
todo!()
self.max_piece_len
}

fn encode(&self, text: &str, bos: bool, eos: bool) -> Vec<crate::utok> {
todo!()
fn encode(&self, text: &str, bos: bool, eos: bool) -> Vec<utok> {
let mut tokens = Vec::new();
if bos {
tokens.push(self.bos());
}
if !text.is_empty() {
tokens.push(self.find_piece(" ").unwrap())
}

text.chars().map(|c| c.to_string()).for_each(|c| {
if let Some(index) = self.find_piece(&c) {
tokens.extend([index]);
} else {
tokens.extend(c.bytes().map(|c| c as utok + 3));
}
});

loop {
let mut best_score = std::f32::NEG_INFINITY;
let mut replacement = None;
for (i, pair) in tokens.windows(2).enumerate() {
let pair = format!("{}{}", self.get_piece(pair[0]), self.get_piece(pair[1]));
if let Some(index) = self.find_piece(&pair) {
let score = self.get_score(index);
if score > best_score {
best_score = score;
replacement = Some((i, index));
}
}
}
match replacement {
Some((i, j)) => {
tokens[i] = j;
tokens.remove(i + 1);
}
None => break,
}
}

if bos {
assert_eq!(tokens[0], self.bos());
}
if eos {
tokens.push(self.eos());
}
tokens
}

fn decode(&self, token: crate::utok, next: crate::utok) -> &str {
todo!()
#[inline]
fn decode(&self, token: utok) -> &str {
self.byte_pieces.decode(self.get_piece(token))
}
}

#[test]
fn read_tokenizer() {
let bpe = BPE::from_model("tokenizer.model");
for i in 0..bpe.offsets.len() {
println!("{}: {}", bpe.get_piece(i as utok), bpe.get_score(i as utok));
if let Ok(bpe) = BPE::from_model("tokenizer.model") {
for i in 0..bpe.offsets.len() {
println!("{}: {}", bpe.get_piece(i as utok), bpe.get_score(i as utok));
}
}
}
24 changes: 23 additions & 1 deletion tokenizer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,30 @@ pub trait Tokenizer {
fn eos(&self) -> utok;
fn max_piece_len(&self) -> usize;
fn encode(&self, text: &str, bos: bool, eos: bool) -> Vec<utok>;
fn decode(&self, token: utok, next: utok) -> &str;
fn decode(&self, token: utok) -> &str;
}

pub use bpe::BPE;
pub use vocab_txt::VocabTxt;

struct ByteDecoder([u8; 256]);

impl ByteDecoder {
fn new() -> Self {
let mut ans = Self([0; 256]);
for (i, b) in ans.0.iter_mut().enumerate() {
*b = i as _;
}
ans
}

fn decode<'a>(&'a self, piece: &'a str) -> &'a str {
if let Some(byte) = piece.strip_prefix("<0x").and_then(|s| s.strip_suffix('>')) {
let byte = u8::from_str_radix(byte, 16).unwrap();
let byte = std::slice::from_ref(&self.0[byte as usize]);
unsafe { std::str::from_utf8_unchecked(byte) }
} else {
piece
}
}
}
26 changes: 7 additions & 19 deletions tokenizer/src/vocab_txt.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{utok, Tokenizer};
use crate::{utok, ByteDecoder, Tokenizer};
use memmap2::Mmap;
use patricia_tree::PatriciaMap;
use std::{fs::File, path::Path};
Expand All @@ -12,7 +12,7 @@ pub struct VocabTxt {
/// 词汇的最大长度。
max_piece_len: usize,
/// 单字节词汇转义。
byte_pieces: [u8; 256],
byte_pieces: ByteDecoder,
}

impl VocabTxt {
Expand All @@ -29,16 +29,12 @@ impl VocabTxt {
words.push(piece.to_string());
trie.insert(piece, i as _);
}
let mut ans = Self {
Self {
words,
trie,
max_piece_len,
byte_pieces: [0; 256],
};
for i in 0..=255u8 {
ans.byte_pieces[i as usize] = i;
byte_pieces: ByteDecoder::new(),
}
ans
}
}

Expand Down Expand Up @@ -90,16 +86,8 @@ impl Tokenizer for VocabTxt {
tokens
}

fn decode(&self, token: utok, next: utok) -> &str {
let piece = self.words[next as usize].as_str();
if let Some(byte) = piece.strip_prefix("<0x").and_then(|s| s.strip_suffix('>')) {
let byte = u8::from_str_radix(byte, 16).unwrap();
let byte = &self.byte_pieces[byte as usize..][..1];
unsafe { std::str::from_utf8_unchecked(byte) }
} else if token == self.bos() && piece.starts_with(' ') {
&piece[1..]
} else {
piece
}
#[inline]
fn decode(&self, token: utok) -> &str {
self.byte_pieces.decode(self.words[token as usize].as_str())
}
}

0 comments on commit 6e08644

Please sign in to comment.