Skip to content

Commit

Permalink
make python inference iterator finally lazy
Browse files Browse the repository at this point in the history
  • Loading branch information
bastiscode committed Dec 13, 2024
1 parent c5581d0 commit a6291cb
Showing 1 changed file with 20 additions and 28 deletions.
48 changes: 20 additions & 28 deletions src/data/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use loading::MultiTrainDataGenerator;
use log::warn;
use numpy::ndarray::prelude::*;
use numpy::IntoPyArray;
use pyo3::exceptions::PyStopIteration;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyIterator};
use std::collections::HashMap;
Expand Down Expand Up @@ -777,26 +778,29 @@ impl InferenceLoader {

pub struct InferenceIterator(Py<PyIterator>);

impl InferenceIterator {
pub fn new(iter: Py<PyIterator>) -> Self {
InferenceIterator(iter)
}
}

impl Iterator for InferenceIterator {
type Item = anyhow::Result<String>;

fn next(&mut self) -> Option<Self::Item> {
Python::with_gil(|py| {
let mut bound = self.0.clone_ref(py).into_bound(py);
let item = bound.next()?;
let Ok(item) = item else {
return Some(Err(anyhow!("error extracting item from iterator")));
let item = match self.0.call_method0(py, "__next__") {
Ok(item) => item,
Err(e) if e.is_instance_of::<PyStopIteration>(py) => {
return None;
}
Err(e) => {
return Some(Err(anyhow!(
"error calling next on inference iterator: {e}"
)))
}
};
Some(
item.extract()
.map_err(|py_err| anyhow!("error extracting item from iterator: {py_err}")),
)
match item.extract(py) {
Ok(None) => None,
Ok(Some(item)) => Some(Ok(item)),
Err(e) => Some(Err(anyhow!(
"error extracting item from inference iterator: {e}"
))),
}
})
}
}
Expand All @@ -818,8 +822,7 @@ impl InferenceLoader {
sort = false
))]
pub fn from_iterator(
py: Python<'_>,
iterator: Py<PyIterator>,
iterator: Bound<'_, PyIterator>,
tokenizer: TokenizerConfig,
window: WindowConfig,
ignore_special_tokens: bool,
Expand All @@ -830,19 +833,8 @@ impl InferenceLoader {
prefetch_factor: usize,
sort: bool,
) -> anyhow::Result<Self> {
let items: Vec<_> = iterator
.into_bound(py)
.map(|item| -> anyhow::Result<_> {
item.map_err(|e| anyhow!("error in inference iterator: {e}"))
.and_then(|item| {
item.extract().map_err(|e| {
anyhow!("error extracting item in inference iterator: {e}")
})
})
})
.collect();
Self::new(
items.into_iter(),
InferenceIterator(iterator.unbind()),
tokenizer,
ignore_special_tokens,
window,
Expand Down

0 comments on commit a6291cb

Please sign in to comment.