From c5f7a045dac8a07d96d2bde32b623a7f13be1f11 Mon Sep 17 00:00:00 2001 From: kojix2 <2xijok@gmail.com> Date: Thu, 16 May 2024 13:19:22 +0900 Subject: [PATCH] Ensure allowed_special is not null --- src/lib.rs | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 154a04f..00f34ed 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -347,6 +347,10 @@ pub extern "C" fn c_corebpe_encode( warn!("Null pointer provided for text!"); return std::ptr::null_mut(); } + if allowed_special.is_null() { + warn!("Null pointer provided for allowed_special!"); + return std::ptr::null_mut(); + } let text = unsafe { let raw = CStr::from_ptr(text); match raw.to_str() { @@ -799,7 +803,20 @@ mod tests { let text = CString::new("I am a cat. <|endoftext|>").unwrap(); let mut num_tokens: usize = 0; let corebpe = c_get_bpe_from_model(model.as_ptr()); - let tokens = c_corebpe_encode(corebpe, text.as_ptr(), std::ptr::null(), 0, &mut num_tokens); + + // zero-length slices require a non-null pointer + // Create a CString and an array of pointers to pass to the function + let placeholder = CString::new("").unwrap(); + let placeholder_ptr: *const i8 = placeholder.as_ptr(); + let ptr_array: [*const i8; 1] = [placeholder_ptr]; + + let tokens = c_corebpe_encode( + corebpe, + text.as_ptr(), + ptr_array.as_ptr(), + 0, + &mut num_tokens, + ); assert_eq!(num_tokens, 11); let tokens = unsafe { std::slice::from_raw_parts(tokens, num_tokens) }; let tokens: Vec = tokens.iter().map(|&x| x as usize).collect();