Skip to content

Commit

Permalink
some fixes + return window indices in batch.indices()
Browse files Browse the repository at this point in the history
  • Loading branch information
bastiscode committed Dec 12, 2024
1 parent 202e079 commit d99440d
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 264 deletions.
9 changes: 7 additions & 2 deletions src/data/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -516,13 +516,18 @@ impl InferenceBatch {
})
}

fn indices(&self) -> anyhow::Result<Batch<usize>> {
fn indices(&self) -> anyhow::Result<Batch<(usize, usize)>> {
self.batch
.as_ref()
.ok_or_else(|| {
anyhow!("can only get indices before getting items, because they are moved")
})
.map(|batch| batch.iter().map(|item| item.item_idx).collect())
.map(|batch| {
batch
.iter()
.map(|item| (item.item_idx, item.window_idx))
.collect()
})
}

fn items(&mut self) -> anyhow::Result<Batch<InferenceItem>> {
Expand Down
182 changes: 76 additions & 106 deletions src/data/task.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use anyhow::anyhow;
use std::collections::HashMap;

use itertools::Itertools;
use pyo3::prelude::*;
use pyo3::types::PyDict;

use crate::{
tokenization::{
tokenizer, TokenizationConstraint, TokenizationConstraintConfig, TokenizerConfig,
},
utils::{py_invalid_type_error, py_required_key_error},
tokenization::{tokenizer, TokenizerConfig},
utils::{py_invalid_type_error, py_required_key_error, py_value_error},
whitespace::operations,
};

Expand All @@ -19,17 +21,10 @@ pub enum TrainTaskConfig {
WhitespaceCorrection(bool, TokenizerConfig),
// text generation aka language modeling
Generation(bool, TokenizerConfig, bool, Option<String>),
// constrained text generation
ConstrainedGeneration(
bool,
TokenizerConfig,
bool,
TokenizationConstraintConfig,
Option<String>,
Option<String>,
),
// conditional generation aka text-to-text
ConditionalGeneration(TokenizerConfig, bool, TokenizerConfig, bool),
// classification
Classification(TokenizerConfig, bool, Vec<String>),
}

impl<'a> FromPyObject<'a> for TrainTaskConfig {
Expand Down Expand Up @@ -79,46 +74,6 @@ impl<'a> FromPyObject<'a> for TrainTaskConfig {
separator,
)
}
"constrained_generation" => {
let mask_input = d
.get_item("mask_input")?
.map(|item| item.extract())
.transpose()?
.unwrap_or_default();
let Some(tokenizer_config) = d.get_item("tokenizer")? else {
return Err(py_required_key_error(
"tokenizer",
"constrained generation config",
));
};
let ignore_special_tokens = d
.get_item("ignore_special_tokens")?
.map(|item| item.extract())
.transpose()?
.unwrap_or_default();
let Some(constraint_config) = d.get_item("constraint")? else {
return Err(py_required_key_error(
"constraint",
"constrained generation config",
));
};
let separator = d
.get_item("separator")?
.map(|item| item.extract())
.transpose()?;
let suffix = d
.get_item("suffix")?
.map(|item| item.extract())
.transpose()?;
TrainTaskConfig::ConstrainedGeneration(
mask_input,
tokenizer_config.extract()?,
ignore_special_tokens,
constraint_config.extract()?,
separator,
suffix,
)
}
"conditional_generation" => {
let Some(input_tokenizer) = d.get_item("input_tokenizer")? else {
return Err(py_required_key_error(
Expand Down Expand Up @@ -149,6 +104,32 @@ impl<'a> FromPyObject<'a> for TrainTaskConfig {
target_ignore_special_tokens,
)
}
"classification" => {
let Some(tokenizer_config) = d.get_item("tokenizer")? else {
return Err(py_required_key_error("tokenizer", "classification config"));
};
let ignore_special_tokens = d
.get_item("ignore_special_tokens")?
.map(|item| item.extract())
.transpose()?
.unwrap_or_default();
let Some(classes) = d.get_item("classes")? else {
return Err(py_required_key_error("classes", "classification config"));
};
let classes: Vec<_> = classes.extract()?;
if classes.len() < 2 {
return Err(py_value_error(
"classification requires at least two classes",
));
} else if classes.iter().unique().count() != classes.len() {
return Err(py_value_error("classes must be unique"));
}
TrainTaskConfig::Classification(
tokenizer_config.extract()?,
ignore_special_tokens,
classes,
)
}
k => {
return Err(py_invalid_type_error(k, "task"));
}
Expand Down Expand Up @@ -183,16 +164,11 @@ fn generation_input(
mask_prefix: bool,
tokenizer_cfg: TokenizerConfig,
ignore_special_tokens: bool,
constraint: Option<TokenizationConstraintConfig>,
separator: Option<String>,
suffix: Option<String>,
) -> Box<TaskFn> {
let tokenizer =
tokenizer(tokenizer_cfg).expect("failed to create tokenizer for generation input function");
let constraint = constraint
.map(TokenizationConstraint::from_config)
.transpose()
.expect("failed to create tokenization constraint for generation input function");
Box::new(move |item| {
let mask_len = if mask_prefix {
let pfx = format!("{}{}", item.input, separator.as_deref().unwrap_or_default());
Expand All @@ -204,38 +180,16 @@ fn generation_input(
} else {
0
};
let mut token_ids = if let Some(constraint) = &constraint {
let mut token_ids = tokenizer
.tokenize(
&format!("{}{}", item.input, separator.as_deref().unwrap_or_default()),
ignore_special_tokens,
)?
.token_ids;
if let Err(e) =
tokenizer.tokenize_with_constraint(&item.target, ignore_special_tokens, constraint)
{
println!("Failed to tokenize with constraint: {e}");
};
let constrained_token_ids = tokenizer
.tokenize_with_constraint(&item.target, ignore_special_tokens, constraint)?
.token_ids;
token_ids.extend(constrained_token_ids);
if let Some(suffix) = suffix.as_ref() {
token_ids.extend(tokenizer.tokenize(suffix, ignore_special_tokens)?.token_ids);
}
token_ids
} else {
let joined = format!(
"{}{}{}{}",
item.input,
separator.as_deref().unwrap_or_default(),
item.target,
suffix.as_deref().unwrap_or_default()
);
tokenizer
.tokenize(&joined, ignore_special_tokens)?
.token_ids
};
let joined = format!(
"{}{}{}{}",
item.input,
separator.as_deref().unwrap_or_default(),
item.target,
suffix.as_deref().unwrap_or_default()
);
let mut token_ids = tokenizer
.tokenize(&joined, ignore_special_tokens)?
.token_ids;
// for n tokens, 1..n-1 are input, 2..n are labels
let labels = vec![-1; mask_len]
.into_iter()
Expand Down Expand Up @@ -280,6 +234,35 @@ fn conditional_generation_input(
})
}

fn classification_input(
tokenizer_cfg: TokenizerConfig,
ignore_special_tokens: bool,
classes: Vec<String>,
) -> Box<TaskFn> {
assert!(
classes.len() <= i32::MAX as usize,
"too many classes for classification task"
);
let tokenizer = tokenizer(tokenizer_cfg)
.expect("failed to create tokenizer for classification input function");
let class_to_index: HashMap<_, _> = classes
.into_iter()
.enumerate()
.map(|(i, c)| (c, i as i32))
.collect();
Box::new(move |item| {
Ok(TrainTaskInput::Classification {
token_ids: tokenizer
.tokenize(&item.input, ignore_special_tokens)?
.token_ids,
pad_token_id: tokenizer.pad_token_id(),
label: class_to_index.get(&item.target).copied().ok_or_else(|| {
anyhow!("class '{}' not found in classification task", item.target)
})?,
})
})
}

pub fn train_task(task: TrainTaskConfig) -> Box<TaskFn> {
match task {
TrainTaskConfig::WhitespaceCorrection(use_graphemes, tokenizer) => {
Expand All @@ -294,25 +277,9 @@ pub fn train_task(task: TrainTaskConfig) -> Box<TaskFn> {
mask_input,
tokenizer_cfg,
ignore_special_tokens,
None,
separator,
None,
),
TrainTaskConfig::ConstrainedGeneration(
mask_input,
tokenizer_cfg,
ignore_special_tokens,
constraint,
separator,
suffix,
) => generation_input(
mask_input,
tokenizer_cfg,
ignore_special_tokens,
Some(constraint),
separator,
suffix,
),
TrainTaskConfig::ConditionalGeneration(
input_tokenizer,
input_ignore_special_tokens,
Expand All @@ -324,5 +291,8 @@ pub fn train_task(task: TrainTaskConfig) -> Box<TaskFn> {
target_tokenizer,
target_ignore_special_tokens,
),
TrainTaskConfig::Classification(tokenizer_config, ignore_special_tokens, classes) => {
classification_input(tokenizer_config, ignore_special_tokens, classes)
}
}
}
Loading

0 comments on commit d99440d

Please sign in to comment.