diff --git a/wasm/src/lib.rs b/wasm/src/lib.rs index 65b29a0..b884e06 100644 --- a/wasm/src/lib.rs +++ b/wasm/src/lib.rs @@ -439,7 +439,7 @@ export type TiktokenModel = | "gpt-4o-realtime-preview-2024-10-01" /** - * @param {TiktokenModel} encoding + * @param {TiktokenModel} model * @param {Record} [extend_special_tokens] * @returns {Tiktoken} */ @@ -452,84 +452,8 @@ pub fn encoding_for_model( model: &str, extend_special_tokens: JsValue, ) -> Result { - let encoding = match model { - "text-davinci-003" => Ok("p50k_base"), - "text-davinci-002" => Ok("p50k_base"), - "text-davinci-001" => Ok("r50k_base"), - "text-curie-001" => Ok("r50k_base"), - "text-babbage-001" => Ok("r50k_base"), - "text-ada-001" => Ok("r50k_base"), - "davinci" => Ok("r50k_base"), - "davinci-002" => Ok("cl100k_base"), - "curie" => Ok("r50k_base"), - "babbage" => Ok("r50k_base"), - "babbage-002" => Ok("cl100k_base"), - "ada" => Ok("r50k_base"), - "code-davinci-002" => Ok("p50k_base"), - "code-davinci-001" => Ok("p50k_base"), - "code-cushman-002" => Ok("p50k_base"), - "code-cushman-001" => Ok("p50k_base"), - "davinci-codex" => Ok("p50k_base"), - "cushman-codex" => Ok("p50k_base"), - "text-davinci-edit-001" => Ok("p50k_edit"), - "code-davinci-edit-001" => Ok("p50k_edit"), - "text-embedding-ada-002" => Ok("cl100k_base"), - "text-embedding-3-small" => Ok("cl100k_base"), - "text-embedding-3-large" => Ok("cl100k_base"), - "text-similarity-davinci-001" => Ok("r50k_base"), - "text-similarity-curie-001" => Ok("r50k_base"), - "text-similarity-babbage-001" => Ok("r50k_base"), - "text-similarity-ada-001" => Ok("r50k_base"), - "text-search-davinci-doc-001" => Ok("r50k_base"), - "text-search-curie-doc-001" => Ok("r50k_base"), - "text-search-babbage-doc-001" => Ok("r50k_base"), - "text-search-ada-doc-001" => Ok("r50k_base"), - "code-search-babbage-code-001" => Ok("r50k_base"), - "code-search-ada-code-001" => Ok("r50k_base"), - "gpt2" => Ok("gpt2"), - "gpt-3.5-turbo" => Ok("cl100k_base"), - "gpt-3.5-turbo-0301" => Ok("cl100k_base"), - "gpt-3.5-turbo-0613" => Ok("cl100k_base"), - "gpt-3.5-turbo-16k" => Ok("cl100k_base"), - "gpt-3.5-turbo-16k-0613" => Ok("cl100k_base"), - "gpt-3.5-turbo-instruct" => Ok("cl100k_base"), - "gpt-3.5-turbo-instruct-0914" => Ok("cl100k_base"), - "gpt-4" => Ok("cl100k_base"), - "gpt-4-0314" => Ok("cl100k_base"), - "gpt-4-0613" => Ok("cl100k_base"), - "gpt-4-32k" => Ok("cl100k_base"), - "gpt-4-32k-0314" => Ok("cl100k_base"), - "gpt-4-32k-0613" => Ok("cl100k_base"), - "gpt-3.5-turbo-1106" => Ok("cl100k_base"), - "gpt-35-turbo" => Ok("cl100k_base"), - "gpt-4-1106-preview" => Ok("cl100k_base"), - "gpt-4-vision-preview" => Ok("cl100k_base"), - "gpt-3.5-turbo-0125" => Ok("cl100k_base"), - "gpt-4-turbo" => Ok("cl100k_base"), - "gpt-4-turbo-2024-04-09" => Ok("cl100k_base"), - "gpt-4-turbo-preview" => Ok("cl100k_base"), - "gpt-4-0125-preview" => Ok("cl100k_base"), - "gpt-4o" => Ok("o200k_base"), - "gpt-4o-2024-05-13" => Ok("o200k_base"), - "gpt-4o-2024-08-06" => Ok("o200k_base"), - "gpt-4o-2024-11-20" => Ok("o200k_base"), - "gpt-4o-mini-2024-07-18" => Ok("o200k_base"), - "gpt-4o-mini" => Ok("o200k_base"), - "o1" => Ok("o200k_base"), - "o1-2024-12-17" => Ok("o200k_base"), - "o1-mini" => Ok("o200k_base"), - "o1-preview" => Ok("o200k_base"), - "o1-preview-2024-09-12" => Ok("o200k_base"), - "o1-mini-2024-09-12" => Ok("o200k_base"), - "chatgpt-4o-latest" => Ok("o200k_base"), - "gpt-4o-realtime" => Ok("o200k_base"), - "gpt-4o-realtime-preview-2024-10-01" => Ok("o200k_base"), - "o3-mini" => Ok("o200k_base"), - "o3-mini-2025-01-31" => Ok("o200k_base"), - model => Err(JsError::new( - format!("Invalid model: {}", model.to_string()).as_str(), - )), - }?; + let binding = get_encoding_name_for_model(model)?; + let encoding = binding.as_str(); Tiktoken::with_encoding( encoding, @@ -538,3 +462,96 @@ pub fn encoding_for_model( .ok(), ) } + +#[cfg(feature = "inline")] +#[wasm_bindgen(typescript_custom_section)] +const _: &'static str = r#" +/** + * @param {TiktokenModel} model + * @returns {TiktokenEncoding} + */ +export function get_encoding_name_for_model(model: TiktokenModel): TiktokenEncoding; +"#; + +#[cfg(feature="inline")] +#[wasm_bindgen(skip_typescript)] +pub fn get_encoding_name_for_model(model: &str) -> Result { + match model { + "text-davinci-003" => Ok("p50k_base".into()), + "text-davinci-002" => Ok("p50k_base".into()), + "text-davinci-001" => Ok("r50k_base".into()), + "text-curie-001" => Ok("r50k_base".into()), + "text-babbage-001" => Ok("r50k_base".into()), + "text-ada-001" => Ok("r50k_base".into()), + "davinci" => Ok("r50k_base".into()), + "davinci-002" => Ok("cl100k_base".into()), + "curie" => Ok("r50k_base".into()), + "babbage" => Ok("r50k_base".into()), + "babbage-002" => Ok("cl100k_base".into()), + "ada" => Ok("r50k_base".into()), + "code-davinci-002" => Ok("p50k_base".into()), + "code-davinci-001" => Ok("p50k_base".into()), + "code-cushman-002" => Ok("p50k_base".into()), + "code-cushman-001" => Ok("p50k_base".into()), + "davinci-codex" => Ok("p50k_base".into()), + "cushman-codex" => Ok("p50k_base".into()), + "text-davinci-edit-001" => Ok("p50k_edit".into()), + "code-davinci-edit-001" => Ok("p50k_edit".into()), + "text-embedding-ada-002" => Ok("cl100k_base".into()), + "text-embedding-3-small" => Ok("cl100k_base".into()), + "text-embedding-3-large" => Ok("cl100k_base".into()), + "text-similarity-davinci-001" => Ok("r50k_base".into()), + "text-similarity-curie-001" => Ok("r50k_base".into()), + "text-similarity-babbage-001" => Ok("r50k_base".into()), + "text-similarity-ada-001" => Ok("r50k_base".into()), + "text-search-davinci-doc-001" => Ok("r50k_base".into()), + "text-search-curie-doc-001" => Ok("r50k_base".into()), + "text-search-babbage-doc-001" => Ok("r50k_base".into()), + "text-search-ada-doc-001" => Ok("r50k_base".into()), + "code-search-babbage-code-001" => Ok("r50k_base".into()), + "code-search-ada-code-001" => Ok("r50k_base".into()), + "gpt2" => Ok("gpt2".into()), + "gpt-3.5-turbo" => Ok("cl100k_base".into()), + "gpt-3.5-turbo-0301" => Ok("cl100k_base".into()), + "gpt-3.5-turbo-0613" => Ok("cl100k_base".into()), + "gpt-3.5-turbo-16k" => Ok("cl100k_base".into()), + "gpt-3.5-turbo-16k-0613" => Ok("cl100k_base".into()), + "gpt-3.5-turbo-instruct" => Ok("cl100k_base".into()), + "gpt-3.5-turbo-instruct-0914" => Ok("cl100k_base".into()), + "gpt-4" => Ok("cl100k_base".into()), + "gpt-4-0314" => Ok("cl100k_base".into()), + "gpt-4-0613" => Ok("cl100k_base".into()), + "gpt-4-32k" => Ok("cl100k_base".into()), + "gpt-4-32k-0314" => Ok("cl100k_base".into()), + "gpt-4-32k-0613" => Ok("cl100k_base".into()), + "gpt-3.5-turbo-1106" => Ok("cl100k_base".into()), + "gpt-35-turbo" => Ok("cl100k_base".into()), + "gpt-4-1106-preview" => Ok("cl100k_base".into()), + "gpt-4-vision-preview" => Ok("cl100k_base".into()), + "gpt-3.5-turbo-0125" => Ok("cl100k_base".into()), + "gpt-4-turbo" => Ok("cl100k_base".into()), + "gpt-4-turbo-2024-04-09" => Ok("cl100k_base".into()), + "gpt-4-turbo-preview" => Ok("cl100k_base".into()), + "gpt-4-0125-preview" => Ok("cl100k_base".into()), + "gpt-4o" => Ok("o200k_base".into()), + "gpt-4o-2024-05-13" => Ok("o200k_base".into()), + "gpt-4o-2024-08-06" => Ok("o200k_base".into()), + "gpt-4o-2024-11-20" => Ok("o200k_base".into()), + "gpt-4o-mini-2024-07-18" => Ok("o200k_base".into()), + "gpt-4o-mini" => Ok("o200k_base".into()), + "o1" => Ok("o200k_base".into()), + "o1-2024-12-17" => Ok("o200k_base".into()), + "o1-mini" => Ok("o200k_base".into()), + "o1-preview" => Ok("o200k_base".into()), + "o1-preview-2024-09-12" => Ok("o200k_base".into()), + "o1-mini-2024-09-12" => Ok("o200k_base".into()), + "chatgpt-4o-latest" => Ok("o200k_base".into()), + "gpt-4o-realtime" => Ok("o200k_base".into()), + "gpt-4o-realtime-preview-2024-10-01" => Ok("o200k_base".into()), + "o3-mini" => Ok("o200k_base".into()), + "o3-mini-2025-01-31" => Ok("o200k_base".into()), + model => Err(JsError::new( + format!("Invalid model: {}", model.to_string()).as_str(), + )), + } +} diff --git a/wasm/test/test_simple_public.test.ts b/wasm/test/test_simple_public.test.ts index 56839f4..a48a29f 100644 --- a/wasm/test/test_simple_public.test.ts +++ b/wasm/test/test_simple_public.test.ts @@ -1,5 +1,5 @@ import { it, expect, describe } from "vitest"; -import { encoding_for_model, get_encoding } from "../dist"; +import { encoding_for_model, get_encoding, get_encoding_name_for_model } from "../dist"; it("encoding_for_model initialization", () => { expect(() => encoding_for_model("gpt2")).not.toThrowError(); @@ -106,6 +106,17 @@ it("test_encoding_for_model", () => { expect(encoding_for_model("gpt-3.5-turbo").name).toEqual("cl100k_base"); }); +it("test_get_encoding_name_for_model", () => { + expect(get_encoding_name_for_model("gpt2")).toEqual("gpt2"); + expect(get_encoding_name_for_model("text-davinci-003")).toEqual("p50k_base"); + expect(get_encoding_name_for_model("gpt-3.5-turbo")).toEqual("cl100k_base"); + + // @ts-expect-error - explicitly testing for invalid model + expect(() => get_encoding_name_for_model("gpt2-unknown")).toThrowError( + "Invalid model: gpt2-unknown" + ); +}) + it("test_custom_tokens", () => { const enc = encoding_for_model("gpt2", { "<|im_start|>": 100264,