Skip to content

Commit

Permalink
feat: add __setitem__ impl to post_processor::PySequence
Browse files Browse the repository at this point in the history
  • Loading branch information
McPatate committed Jan 14, 2025
1 parent 5cc53bf commit 3fc5e55
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 52 deletions.
6 changes: 3 additions & 3 deletions bindings/python/src/pre_tokenizers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -486,8 +486,8 @@ impl PySequence {
}

fn __setitem__(self_: PyRef<'_, Self>, index: usize, value: Bound<'_, PyAny>) -> PyResult<()> {
let norm: PyPreTokenizer = value.extract()?;
let PyPreTokenizerTypeWrapper::Single(norm) = norm.pretok else { return Err(PyException::new_err("normalizer should not be a sequence")); };
let pretok: PyPreTokenizer = value.extract()?;
let PyPreTokenizerTypeWrapper::Single(norm) = pretok.pretok else { return Err(PyException::new_err("normalizer should not be a sequence")); };
match &self_.as_ref().pretok {
PyPreTokenizerTypeWrapper::Sequence(inner) => match inner.get(index) {
Some(item) => {
Expand All @@ -500,7 +500,7 @@ impl PySequence {
}
},
PyPreTokenizerTypeWrapper::Single(_) => {
return Err(PyException::new_err("normalizer is not a sequence"))
return Err(PyException::new_err("pre tokenizer is not a sequence"))
}
};
Ok(())
Expand Down
94 changes: 50 additions & 44 deletions bindings/python/src/processors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,29 @@ impl PyPostProcessor {

pub(crate) fn get_as_subtype(&self, py: Python<'_>) -> PyResult<PyObject> {
let base = self.clone();
Ok(match self.processor.read().unwrap().clone() {
PostProcessorWrapper::ByteLevel(_) => Py::new(py, (PyByteLevel {}, base))?.into_py(py),
PostProcessorWrapper::Bert(_) => Py::new(py, (PyBertProcessing {}, base))?.into_py(py),
PostProcessorWrapper::Roberta(_) => {
Py::new(py, (PyRobertaProcessing {}, base))?.into_py(py)
}
PostProcessorWrapper::Template(_) => {
Py::new(py, (PyTemplateProcessing {}, base))?.into_py(py)
}
PostProcessorWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))?.into_py(py),
})
Ok(
match &*self
.processor
.read()
.map_err(|_| PyException::new_err("pre tokenizer rwlock is poisoned"))?
{
PostProcessorWrapper::ByteLevel(_) => {
Py::new(py, (PyByteLevel {}, base))?.into_py(py)
}
PostProcessorWrapper::Bert(_) => {
Py::new(py, (PyBertProcessing {}, base))?.into_py(py)
}
PostProcessorWrapper::Roberta(_) => {
Py::new(py, (PyRobertaProcessing {}, base))?.into_py(py)
}
PostProcessorWrapper::Template(_) => {
Py::new(py, (PyTemplateProcessing {}, base))?.into_py(py)
}
PostProcessorWrapper::Sequence(_) => {
Py::new(py, (PySequence {}, base))?.into_py(py)
}
},
)
}
}

Expand Down Expand Up @@ -527,19 +539,14 @@ impl PySequence {
}

fn __getitem__(self_: PyRef<'_, Self>, py: Python<'_>, index: usize) -> PyResult<Py<PyAny>> {
let super_ = self_.as_ref();
let mut wrapper = super_.processor.write().unwrap();
// if let PostProcessorWrapper::Sequence(ref mut post) = *wrapper {
// match post.get(index) {
// Some(item) => PyPostProcessor::new(Arc::clone(item)).get_as_subtype(py),
// _ => Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
// "Index not found",
// )),
// }
// }
let wrapper = self_
.as_ref()
.processor
.read()
.map_err(|_| PyException::new_err("post processor rwlock is poisoned"))?;

match *wrapper {
PostProcessorWrapper::Sequence(ref mut inner) => match inner.get_mut(index) {
PostProcessorWrapper::Sequence(ref inner) => match inner.get(index) {
Some(item) => {
PyPostProcessor::new(Arc::new(RwLock::new(item.to_owned()))).get_as_subtype(py)
}
Expand All @@ -553,32 +560,31 @@ impl PySequence {
}
}

fn __setitem__(
self_: PyRefMut<'_, Self>,
index: usize,
value: PyRef<'_, PyPostProcessor>,
) -> PyResult<()> {
let super_ = self_.as_ref();
let mut wrapper = super_.processor.write().unwrap();
let value = value.processor.read().unwrap().clone();
fn __setitem__(self_: PyRef<'_, Self>, index: usize, value: Bound<'_, PyAny>) -> PyResult<()> {
let processor: PyPostProcessor = value.extract()?;
let mut wrapper = self_
.as_ref()
.processor
.write()
.map_err(|_| PyException::new_err("post processor rwlock is poisoned"))?;
match *wrapper {
PostProcessorWrapper::Sequence(ref mut inner) => {
// Convert the Py<PyAny> into the appropriate Rust type
// Ensure we can set an item at the given index
if index < inner.get_processors().len() {
inner.set_mut(index, value); // Assuming you want to wrap the new item in Arc<RwLock>

Ok(())
} else {
Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
"Index out of bounds",
PostProcessorWrapper::Sequence(ref mut inner) => match inner.get_mut(index) {
Some(item) => {
*item = processor.processor.read().unwrap().clone();
}
_ => {
return Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
"Index not found",
))
}
},
_ => {
return Err(PyException::new_err(
"This processor is not a Sequence, it does not support __setitem__",
))
}
_ => Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
"This processor is not a Sequence, it does not support __setitem__",
)),
}
};
Ok(())
}
}

Expand Down
23 changes: 18 additions & 5 deletions tokenizers/src/processors/sequence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,40 @@ impl Sequence {
Self { processors }
}

pub fn get(&self, index: usize) -> Option<& PostProcessorWrapper> {
self.processors.get(index as usize)
pub fn get(&self, index: usize) -> Option<&PostProcessorWrapper> {
self.processors.get(index)
}

pub fn get_mut(&mut self, index: usize) -> Option<&mut PostProcessorWrapper> {
self.processors.get_mut(index)
}

pub fn set_mut(&mut self, index: usize, post_proc: PostProcessorWrapper) {
self.processors[index as usize] = post_proc;
self.processors[index] = post_proc;
}
}

pub fn get_processors(&self) -> &[PostProcessorWrapper] {
impl AsRef<[PostProcessorWrapper]> for Sequence {
fn as_ref(&self) -> &[PostProcessorWrapper] {
&self.processors
}
}

pub fn get_processors_mut(&mut self) -> &mut [PostProcessorWrapper] {
impl AsMut<[PostProcessorWrapper]> for Sequence {
fn as_mut(&mut self) -> &mut [PostProcessorWrapper] {
&mut self.processors
}
}

impl IntoIterator for Sequence {
type Item = PostProcessorWrapper;
type IntoIter = std::vec::IntoIter<Self::Item>;

fn into_iter(self) -> Self::IntoIter {
self.processors.into_iter()
}
}

impl PostProcessor for Sequence {
fn added_tokens(&self, is_pair: bool) -> usize {
self.processors
Expand Down

0 comments on commit 3fc5e55

Please sign in to comment.