diff --git a/src/data/mod.rs b/src/data/mod.rs index 9954573..80017e9 100644 --- a/src/data/mod.rs +++ b/src/data/mod.rs @@ -15,6 +15,7 @@ use loading::MultiTrainDataGenerator; use log::warn; use numpy::ndarray::prelude::*; use numpy::IntoPyArray; +use pyo3::exceptions::PyStopIteration; use pyo3::prelude::*; use pyo3::types::{PyDict, PyIterator}; use std::collections::HashMap; @@ -777,26 +778,29 @@ impl InferenceLoader { pub struct InferenceIterator(Py); -impl InferenceIterator { - pub fn new(iter: Py) -> Self { - InferenceIterator(iter) - } -} - impl Iterator for InferenceIterator { type Item = anyhow::Result; fn next(&mut self) -> Option { Python::with_gil(|py| { - let mut bound = self.0.clone_ref(py).into_bound(py); - let item = bound.next()?; - let Ok(item) = item else { - return Some(Err(anyhow!("error extracting item from iterator"))); + let item = match self.0.call_method0(py, "__next__") { + Ok(item) => item, + Err(e) if e.is_instance_of::(py) => { + return None; + } + Err(e) => { + return Some(Err(anyhow!( + "error calling next on inference iterator: {e}" + ))) + } }; - Some( - item.extract() - .map_err(|py_err| anyhow!("error extracting item from iterator: {py_err}")), - ) + match item.extract(py) { + Ok(None) => None, + Ok(Some(item)) => Some(Ok(item)), + Err(e) => Some(Err(anyhow!( + "error extracting item from inference iterator: {e}" + ))), + } }) } } @@ -818,8 +822,7 @@ impl InferenceLoader { sort = false ))] pub fn from_iterator( - py: Python<'_>, - iterator: Py, + iterator: Bound<'_, PyIterator>, tokenizer: TokenizerConfig, window: WindowConfig, ignore_special_tokens: bool, @@ -830,19 +833,8 @@ impl InferenceLoader { prefetch_factor: usize, sort: bool, ) -> anyhow::Result { - let items: Vec<_> = iterator - .into_bound(py) - .map(|item| -> anyhow::Result<_> { - item.map_err(|e| anyhow!("error in inference iterator: {e}")) - .and_then(|item| { - item.extract().map_err(|e| { - anyhow!("error extracting item in inference iterator: {e}") - }) - }) - }) - .collect(); Self::new( - items.into_iter(), + InferenceIterator(iterator.unbind()), tokenizer, ignore_special_tokens, window,