diff --git a/src/lib.rs b/src/lib.rs index 2179dc5..72c5a48 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,33 +5,14 @@ use std::os::raw::c_char; use tiktoken_rs; use tiktoken_rs::CoreBPE; +mod utils; +use utils::c_str_to_string; + #[no_mangle] pub extern "C" fn tiktoken_init_logger() { SimpleLogger::new().init().unwrap(); } -fn get_string_from_c_char(ptr: *const c_char) -> Result { - let c_str = unsafe { CStr::from_ptr(ptr) }; - let str_slice = c_str.to_str()?; - Ok(str_slice.to_string()) -} - -fn c_str_to_string(ptr: *const c_char) -> Option { - if ptr.is_null() { - return None; - } - - let c_str = match get_string_from_c_char(ptr) { - Ok(str) => str, - Err(_) => { - warn!("Invalid UTF-8 sequence provided!"); - return None; - } - }; - - Some(c_str) -} - #[no_mangle] pub extern "C" fn tiktoken_r50k_base() -> *mut CoreBPE { let bpe = tiktoken_rs::r50k_base(); @@ -470,6 +451,7 @@ pub extern "C" fn tiktoken_c_version() -> *const c_char { mod tests { use super::*; use std::ffi::CString; + use utils::get_string_from_c_char; #[test] fn test_tiktoken_c_version() { @@ -478,20 +460,6 @@ mod tests { assert_eq!(version, env!("CARGO_PKG_VERSION")); } - #[test] - fn test_get_string_from_c_char() { - let c_str = CString::new("I am a cat.").unwrap(); - let str = get_string_from_c_char(c_str.as_ptr()).unwrap(); - assert_eq!(str, "I am a cat."); - } - - #[test] - fn test_c_str_to_string() { - let c_str = CString::new("I am a cat.").unwrap(); - let str = c_str_to_string(c_str.as_ptr()).unwrap(); - assert_eq!(str, "I am a cat."); - } - #[test] fn test_c50k_base() { let corebpe = tiktoken_r50k_base(); diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..d41a910 --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,45 @@ +use log::warn; +use std::ffi::CStr; +use std::os::raw::c_char; + +pub fn get_string_from_c_char(ptr: *const c_char) -> Result { + let c_str = unsafe { CStr::from_ptr(ptr) }; + let str_slice = c_str.to_str()?; + Ok(str_slice.to_string()) +} + +pub fn c_str_to_string(ptr: *const c_char) -> Option { + if ptr.is_null() { + return None; + } + + let c_str = match get_string_from_c_char(ptr) { + Ok(str) => str, + Err(_) => { + warn!("Invalid UTF-8 sequence provided!"); + return None; + } + }; + + Some(c_str) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::ffi::CString; + + #[test] + fn test_get_string_from_c_char() { + let c_str = CString::new("I am a cat.").unwrap(); + let str = get_string_from_c_char(c_str.as_ptr()).unwrap(); + assert_eq!(str, "I am a cat."); + } + + #[test] + fn test_c_str_to_string() { + let c_str = CString::new("I am a cat.").unwrap(); + let str = c_str_to_string(c_str.as_ptr()).unwrap(); + assert_eq!(str, "I am a cat."); + } +}