Skip to content

Commit

Permalink
refactor(tokenizer): Tokenizer 不再承担添加 bos eos 的功能
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Feb 12, 2024
1 parent 2ec4aaa commit 97ef995
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 75 deletions.
2 changes: 1 addition & 1 deletion tokenizer/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ authors = ["YdrMaster <ydrml@hotmail.com>"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
patricia_tree = "0.8"
memmap2 = "0.9"
patricia_tree = "0.8"
73 changes: 22 additions & 51 deletions tokenizer/src/bpe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,6 @@ impl BPE {
}

impl Tokenizer for BPE {
#[inline]
fn bos(&self) -> utok {
1
}

#[inline]
fn eos(&self) -> utok {
2
}

fn vocab_size(&self) -> usize {
self.offsets.len()
}
Expand All @@ -99,13 +89,12 @@ impl Tokenizer for BPE {
self.max_piece_len
}

fn encode(&self, text: &str, bos: bool, eos: bool) -> Vec<utok> {
fn encode(&self, text: &str) -> Vec<utok> {
let mut tokens = Vec::new();
if bos {
tokens.push(self.bos());
}
if !text.is_empty() {
tokens.push(self.find_piece("▁").unwrap())
if let Some(c) = text.chars().next() {
if c.is_alphabetic() {
tokens.push(self.find_piece("▁").unwrap())
}
}

text.chars().map(|c| c.to_string()).for_each(|c| {
Expand All @@ -116,51 +105,33 @@ impl Tokenizer for BPE {
}
});

fn map_pair(bpe: &BPE, tokens: &[utok], i: usize) -> Option<utok> {
fn map_pair(bpe: &BPE, tokens: &[utok], i: usize) -> Option<(utok, f32)> {
bpe.find_piece(&format!(
"{}{}",
bpe.get_piece(tokens[i]),
bpe.get_piece(tokens[i + 1])
))
.map(|tok| (tok, bpe.get_score(tok)))
}

let mut merges = (0..tokens.len() - 1)
.map(|i| map_pair(self, &tokens, i))
.collect::<Vec<_>>();
loop {
let mut best_score = std::f32::NEG_INFINITY;
let mut replacement = None;
for (i, index) in merges.iter().enumerate() {
if let Some(index) = index {
let score = self.get_score(*index);
if score > best_score {
best_score = score;
replacement = Some((i, *index));
}
}
}
match replacement {
Some((i, j)) => {
tokens[i] = j;
tokens.remove(i + 1);
merges.remove(i);
if let Some(index) = merges.get_mut(i - 1) {
*index = map_pair(self, &tokens, i - 1);
}
if let Some(index) = merges.get_mut(i) {
*index = map_pair(self, &tokens, i);
}
}
None => break,
while let Some((i, (tok, _))) = merges
.iter()
.enumerate()
.filter_map(|(i, tok)| tok.map(|tok| (i, tok)))
.max_by(|(_, (_, a)), (_, (_, b))| a.total_cmp(b))
{
tokens[i] = tok;
tokens.remove(i + 1);
merges.remove(i);
if let Some(i) = i.checked_sub(1) {
merges[i] = map_pair(self, &tokens, i);
}
merges[i] = map_pair(self, &tokens, i);
}

if bos {
assert_eq!(tokens[0], self.bos());
}
if eos {
tokens.push(self.eos());
}
tokens
}

Expand All @@ -184,13 +155,13 @@ fn once_upon_a_time() {
use std::time::Instant;
if let Ok(bpe) = BPE::from_model_file("tokenizer.model") {
const PROMPT: &str = "Once▁upon▁a▁time,";
let tokens = bpe.encode(PROMPT, true, false);
let tokens = bpe.encode(PROMPT);
let t0 = Instant::now();
for _ in 0..10000 {
let _tokens = bpe.encode(PROMPT, true, false);
let _tokens = bpe.encode(PROMPT);
}
let t1 = Instant::now();
println!("{:?}", t1 - t0);
assert_eq!(tokens, &[1, 9038, 2501, 263, 931, 29892]);
assert_eq!(tokens, &[9038, 2501, 263, 931, 29892]);
}
}
4 changes: 1 addition & 3 deletions tokenizer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@ mod vocab_txt;
pub type utok = u32;

pub trait Tokenizer {
fn bos(&self) -> utok;
fn eos(&self) -> utok;
fn vocab_size(&self) -> usize;
fn max_piece_len(&self) -> usize;
fn encode(&self, text: &str, bos: bool, eos: bool) -> Vec<utok>;
fn encode(&self, text: &str) -> Vec<utok>;
fn decode(&self, token: utok) -> &str;
}

Expand Down
21 changes: 1 addition & 20 deletions tokenizer/src/vocab_txt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,6 @@ impl VocabTxt {
}

impl Tokenizer for VocabTxt {
#[inline]
fn bos(&self) -> utok {
1
}

#[inline]
fn eos(&self) -> utok {
2
}

fn vocab_size(&self) -> usize {
self.words.len()
}
Expand All @@ -59,11 +49,8 @@ impl Tokenizer for VocabTxt {
self.max_piece_len
}

fn encode(&self, mut text: &str, bos: bool, eos: bool) -> Vec<utok> {
fn encode(&self, mut text: &str) -> Vec<utok> {
let mut tokens = Vec::<utok>::new();
if bos {
tokens.push(self.bos());
}

while !text.is_empty() {
let piece = if text.len() > self.max_piece_len {
Expand All @@ -82,12 +69,6 @@ impl Tokenizer for VocabTxt {
}
}

if bos {
assert_eq!(tokens[0], self.bos());
}
if eos {
tokens.push(self.eos());
}
tokens
}

Expand Down

0 comments on commit 97ef995

Please sign in to comment.