Skip to content

Commit

Permalink
Ensure allowed_special is not null
Browse files Browse the repository at this point in the history
  • Loading branch information
kojix2 committed May 16, 2024
1 parent 2539e9a commit c5f7a04
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,10 @@ pub extern "C" fn c_corebpe_encode(
warn!("Null pointer provided for text!");
return std::ptr::null_mut();
}
if allowed_special.is_null() {
warn!("Null pointer provided for allowed_special!");
return std::ptr::null_mut();
}
let text = unsafe {
let raw = CStr::from_ptr(text);
match raw.to_str() {
Expand Down Expand Up @@ -799,7 +803,20 @@ mod tests {
let text = CString::new("I am a cat. <|endoftext|>").unwrap();
let mut num_tokens: usize = 0;
let corebpe = c_get_bpe_from_model(model.as_ptr());
let tokens = c_corebpe_encode(corebpe, text.as_ptr(), std::ptr::null(), 0, &mut num_tokens);

// zero-length slices require a non-null pointer
// Create a CString and an array of pointers to pass to the function
let placeholder = CString::new("").unwrap();
let placeholder_ptr: *const i8 = placeholder.as_ptr();
let ptr_array: [*const i8; 1] = [placeholder_ptr];

let tokens = c_corebpe_encode(
corebpe,
text.as_ptr(),
ptr_array.as_ptr(),
0,
&mut num_tokens,
);
assert_eq!(num_tokens, 11);
let tokens = unsafe { std::slice::from_raw_parts(tokens, num_tokens) };
let tokens: Vec<usize> = tokens.iter().map(|&x| x as usize).collect();
Expand Down

0 comments on commit c5f7a04

Please sign in to comment.