Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Give error when initializing tokenizer with too high stride #1306

Merged
merged 5 commits into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions bindings/python/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -712,15 +712,16 @@ impl PyTokenizer {
}
}

self.tokenizer.with_truncation(Some(params));

if let Err(error_message) = self.tokenizer.with_truncation(Some(params)) {
return Err(PyError(error_message.to_string()).into_pyerr::<exceptions::PyValueError>());
}
Ok(())
}

/// Disable truncation
#[pyo3(text_signature = "(self)")]
fn no_truncation(&mut self) {
self.tokenizer.with_truncation(None);
let _ = self.tokenizer.with_truncation(None);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you expect here ? no_truncation should never fail, but it's better to actually fail than be in the dark.

}

/// Get the currently set truncation parameters
Expand Down
14 changes: 7 additions & 7 deletions tokenizers/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,19 @@ impl<'a> Serialize for OrderedVocabIter<'a> {
let mut holes = vec![];
let result = if let Some(max) = self.vocab_r.iter().map(|(key, _)| key).max() {
let iter = (0..*max + 1).filter_map(|i| {
if let Some(token) = self.vocab_r.get(&i){
Some((token, i))
}else{
holes.push(i);
None
}
if let Some(token) = self.vocab_r.get(&i) {
Some((token, i))
} else {
holes.push(i);
None
}
});
serializer.collect_map(iter)
} else {
serializer.collect_map(std::iter::empty::<(&str, u32)>())
};

if !holes.is_empty(){
if !holes.is_empty() {
warn!("The OrderedVocab you are attempting to save contains holes for indices {:?}, your vocabulary could be corrupted !", holes);
println!("The OrderedVocab you are attempting to save contains holes for indices {:?}, your vocabulary could be corrupted !", holes);
}
Expand Down
34 changes: 27 additions & 7 deletions tokenizers/src/tokenizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,10 @@ impl DerefMut for Tokenizer {
}
}

#[derive(thiserror::Error, Debug)]
#[error("{0}")]
pub struct TruncationParamError(String);

/// A `Tokenizer` is capable of encoding/decoding any text.
#[derive(Clone, Debug)]
pub struct TokenizerImpl<M, N, PT, PP, D> {
Expand Down Expand Up @@ -595,9 +599,21 @@ where
}

/// Set the truncation parameters
pub fn with_truncation(&mut self, trunc: Option<TruncationParams>) -> &mut Self {
///
/// Fails if `stride` is too high relative to `max_length` and `post_processor.added_tokens()`
pub fn with_truncation(&mut self, trunc: Option<TruncationParams>) -> Result<&mut Self> {
if let Some(trunc_params) = &trunc {
let n_added_tokens = self.get_n_added_tokens(false);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why false ? It depends.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assumed that fewer tokens were added when is_pair=false (I'm not actually sure if this is a good assumption). So, when with_truncation() is called, we check if there is any possible case where this is a valid combination of max_length and stride for this tokenizer, and fail if it isn't.

Another reason why I kept the original assert!() from the inference part of the code is that it's possible that any additional special token(s) when is_pair=true will push it over the edge and make the stride too large for the max_length

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Narsil can you confirm whether my assumption in the above comment is a good one (i.e. is added_tokens(false) always less than added_tokens(true)?)

let effective_max_length = trunc_params.max_length - n_added_tokens;
if effective_max_length <= trunc_params.stride {
return Err(Box::new(TruncationParamError(format!(
"tokenizer stride set to {}, which is greater than or equal to its effective max length of {} (= {} original max length - {} added special tokens), ",
trunc_params.stride, effective_max_length, trunc_params.max_length, n_added_tokens
))));
}
}
self.truncation = trunc;
self
Ok(self)
}

/// Get the currently set truncation parameters
Expand Down Expand Up @@ -902,11 +918,7 @@ where
// 1. First we truncate if needed
let (encoding, pair_encoding) = {
if let Some(trunc) = &self.truncation {
let n_added_tokens = if let Some(processor) = &self.post_processor {
processor.added_tokens(pair_encoding.is_some())
} else {
0
};
let n_added_tokens = self.get_n_added_tokens(pair_encoding.is_some());

if add_special_tokens && n_added_tokens > 0 {
let params = TruncationParams {
Expand Down Expand Up @@ -950,6 +962,14 @@ where

Ok(final_encoding)
}

fn get_n_added_tokens(&self, is_pair: bool) -> usize {
if let Some(processor) = &self.post_processor {
processor.added_tokens(is_pair)
} else {
0
}
}
}

impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
Expand Down
Loading