From d99440d27f47ada1fddd615be6c0f53e3af5218c Mon Sep 17 00:00:00 2001 From: Sebastian Walter Date: Thu, 12 Dec 2024 16:50:03 +0100 Subject: [PATCH] some fixes + return window indices in batch.indices() --- src/data/mod.rs | 9 ++- src/data/task.rs | 182 ++++++++++++++++++-------------------------- src/tokenization.rs | 144 +---------------------------------- src/unicode.rs | 20 ++--- src/utils.rs | 6 +- src/windows.rs | 6 +- 6 files changed, 103 insertions(+), 264 deletions(-) diff --git a/src/data/mod.rs b/src/data/mod.rs index 0d159b2..c041b73 100644 --- a/src/data/mod.rs +++ b/src/data/mod.rs @@ -516,13 +516,18 @@ impl InferenceBatch { }) } - fn indices(&self) -> anyhow::Result> { + fn indices(&self) -> anyhow::Result> { self.batch .as_ref() .ok_or_else(|| { anyhow!("can only get indices before getting items, because they are moved") }) - .map(|batch| batch.iter().map(|item| item.item_idx).collect()) + .map(|batch| { + batch + .iter() + .map(|item| (item.item_idx, item.window_idx)) + .collect() + }) } fn items(&mut self) -> anyhow::Result> { diff --git a/src/data/task.rs b/src/data/task.rs index 01ccda2..d228362 100644 --- a/src/data/task.rs +++ b/src/data/task.rs @@ -1,11 +1,13 @@ +use anyhow::anyhow; +use std::collections::HashMap; + +use itertools::Itertools; use pyo3::prelude::*; use pyo3::types::PyDict; use crate::{ - tokenization::{ - tokenizer, TokenizationConstraint, TokenizationConstraintConfig, TokenizerConfig, - }, - utils::{py_invalid_type_error, py_required_key_error}, + tokenization::{tokenizer, TokenizerConfig}, + utils::{py_invalid_type_error, py_required_key_error, py_value_error}, whitespace::operations, }; @@ -19,17 +21,10 @@ pub enum TrainTaskConfig { WhitespaceCorrection(bool, TokenizerConfig), // text generation aka language modeling Generation(bool, TokenizerConfig, bool, Option), - // constrained text generation - ConstrainedGeneration( - bool, - TokenizerConfig, - bool, - TokenizationConstraintConfig, - Option, - Option, - ), // conditional generation aka text-to-text ConditionalGeneration(TokenizerConfig, bool, TokenizerConfig, bool), + // classification + Classification(TokenizerConfig, bool, Vec), } impl<'a> FromPyObject<'a> for TrainTaskConfig { @@ -79,46 +74,6 @@ impl<'a> FromPyObject<'a> for TrainTaskConfig { separator, ) } - "constrained_generation" => { - let mask_input = d - .get_item("mask_input")? - .map(|item| item.extract()) - .transpose()? - .unwrap_or_default(); - let Some(tokenizer_config) = d.get_item("tokenizer")? else { - return Err(py_required_key_error( - "tokenizer", - "constrained generation config", - )); - }; - let ignore_special_tokens = d - .get_item("ignore_special_tokens")? - .map(|item| item.extract()) - .transpose()? - .unwrap_or_default(); - let Some(constraint_config) = d.get_item("constraint")? else { - return Err(py_required_key_error( - "constraint", - "constrained generation config", - )); - }; - let separator = d - .get_item("separator")? - .map(|item| item.extract()) - .transpose()?; - let suffix = d - .get_item("suffix")? - .map(|item| item.extract()) - .transpose()?; - TrainTaskConfig::ConstrainedGeneration( - mask_input, - tokenizer_config.extract()?, - ignore_special_tokens, - constraint_config.extract()?, - separator, - suffix, - ) - } "conditional_generation" => { let Some(input_tokenizer) = d.get_item("input_tokenizer")? else { return Err(py_required_key_error( @@ -149,6 +104,32 @@ impl<'a> FromPyObject<'a> for TrainTaskConfig { target_ignore_special_tokens, ) } + "classification" => { + let Some(tokenizer_config) = d.get_item("tokenizer")? else { + return Err(py_required_key_error("tokenizer", "classification config")); + }; + let ignore_special_tokens = d + .get_item("ignore_special_tokens")? + .map(|item| item.extract()) + .transpose()? + .unwrap_or_default(); + let Some(classes) = d.get_item("classes")? else { + return Err(py_required_key_error("classes", "classification config")); + }; + let classes: Vec<_> = classes.extract()?; + if classes.len() < 2 { + return Err(py_value_error( + "classification requires at least two classes", + )); + } else if classes.iter().unique().count() != classes.len() { + return Err(py_value_error("classes must be unique")); + } + TrainTaskConfig::Classification( + tokenizer_config.extract()?, + ignore_special_tokens, + classes, + ) + } k => { return Err(py_invalid_type_error(k, "task")); } @@ -183,16 +164,11 @@ fn generation_input( mask_prefix: bool, tokenizer_cfg: TokenizerConfig, ignore_special_tokens: bool, - constraint: Option, separator: Option, suffix: Option, ) -> Box { let tokenizer = tokenizer(tokenizer_cfg).expect("failed to create tokenizer for generation input function"); - let constraint = constraint - .map(TokenizationConstraint::from_config) - .transpose() - .expect("failed to create tokenization constraint for generation input function"); Box::new(move |item| { let mask_len = if mask_prefix { let pfx = format!("{}{}", item.input, separator.as_deref().unwrap_or_default()); @@ -204,38 +180,16 @@ fn generation_input( } else { 0 }; - let mut token_ids = if let Some(constraint) = &constraint { - let mut token_ids = tokenizer - .tokenize( - &format!("{}{}", item.input, separator.as_deref().unwrap_or_default()), - ignore_special_tokens, - )? - .token_ids; - if let Err(e) = - tokenizer.tokenize_with_constraint(&item.target, ignore_special_tokens, constraint) - { - println!("Failed to tokenize with constraint: {e}"); - }; - let constrained_token_ids = tokenizer - .tokenize_with_constraint(&item.target, ignore_special_tokens, constraint)? - .token_ids; - token_ids.extend(constrained_token_ids); - if let Some(suffix) = suffix.as_ref() { - token_ids.extend(tokenizer.tokenize(suffix, ignore_special_tokens)?.token_ids); - } - token_ids - } else { - let joined = format!( - "{}{}{}{}", - item.input, - separator.as_deref().unwrap_or_default(), - item.target, - suffix.as_deref().unwrap_or_default() - ); - tokenizer - .tokenize(&joined, ignore_special_tokens)? - .token_ids - }; + let joined = format!( + "{}{}{}{}", + item.input, + separator.as_deref().unwrap_or_default(), + item.target, + suffix.as_deref().unwrap_or_default() + ); + let mut token_ids = tokenizer + .tokenize(&joined, ignore_special_tokens)? + .token_ids; // for n tokens, 1..n-1 are input, 2..n are labels let labels = vec![-1; mask_len] .into_iter() @@ -280,6 +234,35 @@ fn conditional_generation_input( }) } +fn classification_input( + tokenizer_cfg: TokenizerConfig, + ignore_special_tokens: bool, + classes: Vec, +) -> Box { + assert!( + classes.len() <= i32::MAX as usize, + "too many classes for classification task" + ); + let tokenizer = tokenizer(tokenizer_cfg) + .expect("failed to create tokenizer for classification input function"); + let class_to_index: HashMap<_, _> = classes + .into_iter() + .enumerate() + .map(|(i, c)| (c, i as i32)) + .collect(); + Box::new(move |item| { + Ok(TrainTaskInput::Classification { + token_ids: tokenizer + .tokenize(&item.input, ignore_special_tokens)? + .token_ids, + pad_token_id: tokenizer.pad_token_id(), + label: class_to_index.get(&item.target).copied().ok_or_else(|| { + anyhow!("class '{}' not found in classification task", item.target) + })?, + }) + }) +} + pub fn train_task(task: TrainTaskConfig) -> Box { match task { TrainTaskConfig::WhitespaceCorrection(use_graphemes, tokenizer) => { @@ -294,25 +277,9 @@ pub fn train_task(task: TrainTaskConfig) -> Box { mask_input, tokenizer_cfg, ignore_special_tokens, - None, separator, None, ), - TrainTaskConfig::ConstrainedGeneration( - mask_input, - tokenizer_cfg, - ignore_special_tokens, - constraint, - separator, - suffix, - ) => generation_input( - mask_input, - tokenizer_cfg, - ignore_special_tokens, - Some(constraint), - separator, - suffix, - ), TrainTaskConfig::ConditionalGeneration( input_tokenizer, input_ignore_special_tokens, @@ -324,5 +291,8 @@ pub fn train_task(task: TrainTaskConfig) -> Box { target_tokenizer, target_ignore_special_tokens, ), + TrainTaskConfig::Classification(tokenizer_config, ignore_special_tokens, classes) => { + classification_input(tokenizer_config, ignore_special_tokens, classes) + } } } diff --git a/src/tokenization.rs b/src/tokenization.rs index 10a2efc..fe9a8ac 100644 --- a/src/tokenization.rs +++ b/src/tokenization.rs @@ -136,85 +136,6 @@ impl<'a> FromPyObject<'a> for SpecialConfig { } } -#[derive(Clone, Debug)] -pub enum TokenizationConstraintConfig { - LR1Grammar { - lexer: String, - grammar: String, - skip_ignore_tokens: bool, - }, -} - -pub enum TokenizationConstraint { - // LR1Grammar { - // parser: LR1GrammarParser, - // skip_ignore_tokens: bool, - // }, -} - -impl TokenizationConstraint { - pub fn from_config(_config: TokenizationConstraintConfig) -> anyhow::Result { - unimplemented!() - // match config { - // TokenizationConstraintConfig::LR1Grammar { - // lexer, - // grammar, - // skip_ignore_tokens, - // } => { - // let parser = LR1GrammarParser::from_files(&grammar, &lexer).map_err(|e| { - // anyhow!( - // "failed to create LR1 grammar parser from lexer {lexer} and grammar {grammar}: {e}" - // ) - // })?; - // Ok(Self::LR1Grammar { - // parser, - // skip_ignore_tokens, - // }) - // } - // } - } -} - -impl<'a> FromPyObject<'a> for TokenizationConstraintConfig { - fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult { - let d: &Bound<'_, PyDict> = ob.downcast()?; - let Some(constraint_type) = d.get_item("type")? else { - return Err(py_required_key_error("type", "generation config")); - }; - let constraint_type: String = constraint_type.extract()?; - let constraint = match constraint_type.as_str() { - "lr1_grammar" => { - let Some(lexer) = d.get_item("lexer")? else { - return Err(py_required_key_error( - "lexer", - "tokenization constraint config", - )); - }; - let Some(grammar) = d.get_item("grammar")? else { - return Err(py_required_key_error( - "grammar", - "tokenization constraint config", - )); - }; - let skip_ignore_tokens = d - .get_item("skip_ignore_tokens")? - .map(|v| v.extract()) - .transpose()? - .unwrap_or(false); - TokenizationConstraintConfig::LR1Grammar { - lexer: lexer.extract()?, - grammar: grammar.extract()?, - skip_ignore_tokens, - } - } - k => { - return Err(py_invalid_type_error(k, "tokenization constraint config")); - } - }; - Ok(constraint) - } -} - /// This is a tokenizer config, containing configs for special tokens, language, /// and the actual tokenize config inside it. #[derive(Clone, Debug)] @@ -717,54 +638,6 @@ pub trait Tokenize: BaseTokenize { fn tokenize(&self, s: &str, ignore_special_tokens: bool) -> anyhow::Result; - fn tokenize_with_constraint( - &self, - _s: &str, - _ignore_special_tokens: bool, - _constraint: &TokenizationConstraint, - ) -> anyhow::Result { - unimplemented!(); - // match constraint { - // TokenizationConstraint::LR1Grammar { - // parser, - // skip_ignore_tokens, - // } => { - // let lexemes = parser.lex(s).map_err(|e| { - // anyhow!("tokenizing with grammar constraint failed with a lexer error: {e}") - // })?; - // let mut all_token_ids = vec![]; - // let num_lexemes = lexemes.len(); - // for (i, (lexeme, (start, len))) in lexemes.into_iter().enumerate() { - // if *skip_ignore_tokens && lexeme.is_none() { - // continue; - // } - // let tokenization = - // self.tokenize(&s[start..start + len], ignore_special_tokens)?; - // if !matches!(tokenization.info, TokenizationInfo::Empty) { - // return Err(anyhow!( - // "default implementation does not support tokenization info with grammar constraint" - // )); - // } - // let pfx = if i == 0 { 0 } else { self.num_prefix_tokens() }; - // let sfx = if i == num_lexemes - 1 { - // 0 - // } else { - // self.num_suffix_tokens() - // }; - // let num_tokens = tokenization.token_ids.len() - pfx - sfx; - // all_token_ids.extend( - // tokenization - // .token_ids - // .into_iter() - // .skip(pfx) - // .take(num_tokens), - // ); - // } - // Ok(Tokenization::new(all_token_ids, TokenizationInfo::Empty)) - // } - // } - } - fn de_tokenize(&self, token_ids: &[u32], ignore_special_tokens: bool) -> anyhow::Result; } @@ -2340,33 +2213,20 @@ pub fn tokenizer(cfg: TokenizerConfig) -> anyhow::Result { #[pyo3(name = "Tokenizer")] struct PyTokenizer { tokenizer: Tokenizer, - constraint: Option, } #[pymethods] impl PyTokenizer { #[staticmethod] - #[pyo3(signature = (config, constraint = None))] - fn from_config( - config: TokenizerConfig, - constraint: Option, - ) -> anyhow::Result { + fn from_config(config: TokenizerConfig) -> anyhow::Result { Ok(PyTokenizer { tokenizer: tokenizer(config)?, - constraint: constraint - .map(TokenizationConstraint::from_config) - .transpose()?, }) } #[pyo3(signature = (s, ignore_special_tokens = false))] fn tokenize(&self, s: &str, ignore_special_tokens: bool) -> anyhow::Result { - if let Some(constraint) = &self.constraint { - self.tokenizer - .tokenize_with_constraint(s, ignore_special_tokens, constraint) - } else { - self.tokenizer.tokenize(s, ignore_special_tokens) - } + self.tokenizer.tokenize(s, ignore_special_tokens) } fn token_to_id(&self, token: &str) -> Option { diff --git a/src/unicode.rs b/src/unicode.rs index 094c77f..3d3b781 100644 --- a/src/unicode.rs +++ b/src/unicode.rs @@ -26,8 +26,8 @@ pub(crate) type CS<'a> = CharString<'a>; // CharString::new("नमस्ते", true).len() -> 4; num grapheme clusters, closest to what // humans consider to be characters (in Python available via third party libraries) -impl<'a> CharString<'a> { - pub fn new(str: &'a str, use_graphemes: bool) -> CharString { +impl<'s> CharString<'s> { + pub fn new(str: &'s str, use_graphemes: bool) -> CharString<'s> { let cluster_lengths: Vec = if use_graphemes { str.graphemes(true).map(str::len).collect() } else { @@ -86,7 +86,7 @@ impl<'a> CharString<'a> { (start_byte, end_byte) } - pub fn get(&self, n: usize) -> Option<&'a str> { + pub fn get(&self, n: usize) -> Option<&'s str> { if n >= self.len() { return None; } @@ -94,11 +94,11 @@ impl<'a> CharString<'a> { Some(&self.str[start..end]) } - pub fn get_char(&self, n: usize) -> Option> { - self.get(n).map(|s| Character { str: s }) + pub fn get_char(&self, n: usize) -> Option> { + self.get(n).map(|str| Character { str }) } - pub fn sub(&self, start: usize, end: usize) -> &'a str { + pub fn sub(&self, start: usize, end: usize) -> &'s str { assert!(start <= end, "start cannot be larger than end"); let start = start.min(self.len()); let end = end.min(self.len()); @@ -109,7 +109,7 @@ impl<'a> CharString<'a> { &self.str[start..end] } - pub fn chars(&self) -> impl Iterator { + pub fn chars(&self) -> impl Iterator> { (0..self.len()).map(|i| self.get_char(i).unwrap()) } @@ -134,8 +134,8 @@ impl Display for CharString<'_> { } #[derive(Debug)] -pub struct Character<'a> { - pub str: &'a str, +pub struct Character<'s> { + pub str: &'s str, } #[inline] @@ -178,7 +178,7 @@ pub(crate) fn is_right_punctuation(s: &str) -> bool { Regex::new(r"^[\p{Pf}\p{Pe}]+$").unwrap().is_match(s) } -impl<'a> Character<'a> { +impl Character<'_> { pub fn byte_len(&self) -> usize { self.str.len() } diff --git a/src/utils.rs b/src/utils.rs index 420166a..0036ceb 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,5 +1,5 @@ use indicatif::{ProgressBar, ProgressDrawTarget, ProgressStyle}; -use pyo3::exceptions::PyTypeError; +use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; use serde::de::DeserializeOwned; use serde::Serialize; @@ -205,6 +205,10 @@ pub(crate) fn py_invalid_type_error(name: impl Display, type_name: impl Display) PyTypeError::new_err(format!("\"{name}\" is not a valid {type_name} type")) } +pub(crate) fn py_value_error(msg: impl Display) -> PyErr { + PyValueError::new_err(msg.to_string()) +} + #[cfg(test)] mod tests { use crate::utils::{ diff --git a/src/windows.rs b/src/windows.rs index b1cc644..189477e 100644 --- a/src/windows.rs +++ b/src/windows.rs @@ -8,7 +8,7 @@ use pyo3::pybacked::PyBackedStr; use pyo3::types::PyDict; #[derive(Debug, Clone)] -pub struct Window<'a> { +pub struct Window<'s> { ctx_start: usize, ctx_end: usize, window_start: usize, @@ -17,10 +17,10 @@ pub struct Window<'a> { byte_ctx_end: usize, byte_window_start: usize, byte_window_end: usize, - pub str: &'a str, + pub str: &'s str, } -impl<'a> Window<'a> { +impl Window<'_> { pub fn boundaries(&self) -> (usize, usize, usize, usize) { ( self.ctx_start,