From 0a28e200dd8c4d14b01c3cf938c97c053657c4cc Mon Sep 17 00:00:00 2001 From: Mike Noseworthy Date: Sun, 16 Feb 2025 12:20:44 -0330 Subject: [PATCH] Add get_encoding_name_for_model to tiktoken The `tiktoken-js` library includes a very helpful function, `getEncodingNameForModel()`. This function is buried in the implementation of `encoding_for_model()` in the rust based `tiktoken` package. This function is very useful when implementing an encoding cache based on the model used. In this case, having a mapping from model -> encoding and then caching based on the encoding name conserves resources since so many models re-use the same encoding. I've refactored the typescript definition generation a little bit so that all the types are declared and their references when used all appear in the same block and there are no use-before-declaration warnings. Finally, I've exposed a new `get_encoding_name_for_model()` function that behaves similarly to the one in the `tiktoken-js` package, and used it inside of `encoding_for_model()`. I also added a test to ensure that this function can be called properly from typescript code, and that it properly throws exceptions in the case of invalid model names. Fixes: dqbd/tiktoken#123 --- wasm/src/lib.rs | 212 ++++++++++++++------------- wasm/test/test_simple_public.test.ts | 11 +- 2 files changed, 121 insertions(+), 102 deletions(-) diff --git a/wasm/src/lib.rs b/wasm/src/lib.rs index 65b29a0..0b03f62 100644 --- a/wasm/src/lib.rs +++ b/wasm/src/lib.rs @@ -341,28 +341,6 @@ impl Tiktoken { const _: &'static str = r#" export type TiktokenEncoding = "gpt2" | "r50k_base" | "p50k_base" | "p50k_edit" | "cl100k_base" | "o200k_base"; -/** - * @param {TiktokenEncoding} encoding - * @param {Record} [extend_special_tokens] - * @returns {Tiktoken} - */ -export function get_encoding(encoding: TiktokenEncoding, extend_special_tokens?: Record): Tiktoken; -"#; - -#[cfg(feature = "inline")] -#[wasm_bindgen(skip_typescript)] -pub fn get_encoding(encoding: &str, extend_special_tokens: JsValue) -> Result { - Tiktoken::with_encoding( - encoding, - &extend_special_tokens - .into_serde::>() - .ok(), - ) -} - -#[cfg(feature = "inline")] -#[wasm_bindgen(typescript_custom_section)] -const _: &'static str = r#" export type TiktokenModel = | "davinci-002" | "babbage-002" @@ -438,98 +416,130 @@ export type TiktokenModel = | "gpt-4o-realtime" | "gpt-4o-realtime-preview-2024-10-01" + +/** + * @param {TiktokenEncoding} encoding + * @param {Record} [extend_special_tokens] + * @returns {Tiktoken} + */ +export function get_encoding(encoding: TiktokenEncoding, extend_special_tokens?: Record): Tiktoken; + /** - * @param {TiktokenModel} encoding + * @param {TiktokenModel} model + * @returns {TiktokenEncoding} + */ +export function get_encoding_name_for_model(model: TiktokenModel): TiktokenEncoding; + +/** + * @param {TiktokenModel} model * @param {Record} [extend_special_tokens] * @returns {Tiktoken} */ export function encoding_for_model(model: TiktokenModel, extend_special_tokens?: Record): Tiktoken; "#; +#[cfg(feature = "inline")] +#[wasm_bindgen(skip_typescript)] +pub fn get_encoding(encoding: &str, extend_special_tokens: JsValue) -> Result { + Tiktoken::with_encoding( + encoding, + &extend_special_tokens + .into_serde::>() + .ok(), + ) +} + +#[cfg(feature="inline")] +#[wasm_bindgen(skip_typescript)] +pub fn get_encoding_name_for_model(model: &str) -> Result { + match model { + "text-davinci-003" => Ok("p50k_base".into()), + "text-davinci-002" => Ok("p50k_base".into()), + "text-davinci-001" => Ok("r50k_base".into()), + "text-curie-001" => Ok("r50k_base".into()), + "text-babbage-001" => Ok("r50k_base".into()), + "text-ada-001" => Ok("r50k_base".into()), + "davinci" => Ok("r50k_base".into()), + "davinci-002" => Ok("cl100k_base".into()), + "curie" => Ok("r50k_base".into()), + "babbage" => Ok("r50k_base".into()), + "babbage-002" => Ok("cl100k_base".into()), + "ada" => Ok("r50k_base".into()), + "code-davinci-002" => Ok("p50k_base".into()), + "code-davinci-001" => Ok("p50k_base".into()), + "code-cushman-002" => Ok("p50k_base".into()), + "code-cushman-001" => Ok("p50k_base".into()), + "davinci-codex" => Ok("p50k_base".into()), + "cushman-codex" => Ok("p50k_base".into()), + "text-davinci-edit-001" => Ok("p50k_edit".into()), + "code-davinci-edit-001" => Ok("p50k_edit".into()), + "text-embedding-ada-002" => Ok("cl100k_base".into()), + "text-embedding-3-small" => Ok("cl100k_base".into()), + "text-embedding-3-large" => Ok("cl100k_base".into()), + "text-similarity-davinci-001" => Ok("r50k_base".into()), + "text-similarity-curie-001" => Ok("r50k_base".into()), + "text-similarity-babbage-001" => Ok("r50k_base".into()), + "text-similarity-ada-001" => Ok("r50k_base".into()), + "text-search-davinci-doc-001" => Ok("r50k_base".into()), + "text-search-curie-doc-001" => Ok("r50k_base".into()), + "text-search-babbage-doc-001" => Ok("r50k_base".into()), + "text-search-ada-doc-001" => Ok("r50k_base".into()), + "code-search-babbage-code-001" => Ok("r50k_base".into()), + "code-search-ada-code-001" => Ok("r50k_base".into()), + "gpt2" => Ok("gpt2".into()), + "gpt-3.5-turbo" => Ok("cl100k_base".into()), + "gpt-3.5-turbo-0301" => Ok("cl100k_base".into()), + "gpt-3.5-turbo-0613" => Ok("cl100k_base".into()), + "gpt-3.5-turbo-16k" => Ok("cl100k_base".into()), + "gpt-3.5-turbo-16k-0613" => Ok("cl100k_base".into()), + "gpt-3.5-turbo-instruct" => Ok("cl100k_base".into()), + "gpt-3.5-turbo-instruct-0914" => Ok("cl100k_base".into()), + "gpt-4" => Ok("cl100k_base".into()), + "gpt-4-0314" => Ok("cl100k_base".into()), + "gpt-4-0613" => Ok("cl100k_base".into()), + "gpt-4-32k" => Ok("cl100k_base".into()), + "gpt-4-32k-0314" => Ok("cl100k_base".into()), + "gpt-4-32k-0613" => Ok("cl100k_base".into()), + "gpt-3.5-turbo-1106" => Ok("cl100k_base".into()), + "gpt-35-turbo" => Ok("cl100k_base".into()), + "gpt-4-1106-preview" => Ok("cl100k_base".into()), + "gpt-4-vision-preview" => Ok("cl100k_base".into()), + "gpt-3.5-turbo-0125" => Ok("cl100k_base".into()), + "gpt-4-turbo" => Ok("cl100k_base".into()), + "gpt-4-turbo-2024-04-09" => Ok("cl100k_base".into()), + "gpt-4-turbo-preview" => Ok("cl100k_base".into()), + "gpt-4-0125-preview" => Ok("cl100k_base".into()), + "gpt-4o" => Ok("o200k_base".into()), + "gpt-4o-2024-05-13" => Ok("o200k_base".into()), + "gpt-4o-2024-08-06" => Ok("o200k_base".into()), + "gpt-4o-2024-11-20" => Ok("o200k_base".into()), + "gpt-4o-mini-2024-07-18" => Ok("o200k_base".into()), + "gpt-4o-mini" => Ok("o200k_base".into()), + "o1" => Ok("o200k_base".into()), + "o1-2024-12-17" => Ok("o200k_base".into()), + "o1-mini" => Ok("o200k_base".into()), + "o1-preview" => Ok("o200k_base".into()), + "o1-preview-2024-09-12" => Ok("o200k_base".into()), + "o1-mini-2024-09-12" => Ok("o200k_base".into()), + "chatgpt-4o-latest" => Ok("o200k_base".into()), + "gpt-4o-realtime" => Ok("o200k_base".into()), + "gpt-4o-realtime-preview-2024-10-01" => Ok("o200k_base".into()), + "o3-mini" => Ok("o200k_base".into()), + "o3-mini-2025-01-31" => Ok("o200k_base".into()), + model => Err(JsError::new( + format!("Invalid model: {}", model.to_string()).as_str(), + )), + } +} + #[cfg(feature = "inline")] #[wasm_bindgen(skip_typescript)] pub fn encoding_for_model( model: &str, extend_special_tokens: JsValue, ) -> Result { - let encoding = match model { - "text-davinci-003" => Ok("p50k_base"), - "text-davinci-002" => Ok("p50k_base"), - "text-davinci-001" => Ok("r50k_base"), - "text-curie-001" => Ok("r50k_base"), - "text-babbage-001" => Ok("r50k_base"), - "text-ada-001" => Ok("r50k_base"), - "davinci" => Ok("r50k_base"), - "davinci-002" => Ok("cl100k_base"), - "curie" => Ok("r50k_base"), - "babbage" => Ok("r50k_base"), - "babbage-002" => Ok("cl100k_base"), - "ada" => Ok("r50k_base"), - "code-davinci-002" => Ok("p50k_base"), - "code-davinci-001" => Ok("p50k_base"), - "code-cushman-002" => Ok("p50k_base"), - "code-cushman-001" => Ok("p50k_base"), - "davinci-codex" => Ok("p50k_base"), - "cushman-codex" => Ok("p50k_base"), - "text-davinci-edit-001" => Ok("p50k_edit"), - "code-davinci-edit-001" => Ok("p50k_edit"), - "text-embedding-ada-002" => Ok("cl100k_base"), - "text-embedding-3-small" => Ok("cl100k_base"), - "text-embedding-3-large" => Ok("cl100k_base"), - "text-similarity-davinci-001" => Ok("r50k_base"), - "text-similarity-curie-001" => Ok("r50k_base"), - "text-similarity-babbage-001" => Ok("r50k_base"), - "text-similarity-ada-001" => Ok("r50k_base"), - "text-search-davinci-doc-001" => Ok("r50k_base"), - "text-search-curie-doc-001" => Ok("r50k_base"), - "text-search-babbage-doc-001" => Ok("r50k_base"), - "text-search-ada-doc-001" => Ok("r50k_base"), - "code-search-babbage-code-001" => Ok("r50k_base"), - "code-search-ada-code-001" => Ok("r50k_base"), - "gpt2" => Ok("gpt2"), - "gpt-3.5-turbo" => Ok("cl100k_base"), - "gpt-3.5-turbo-0301" => Ok("cl100k_base"), - "gpt-3.5-turbo-0613" => Ok("cl100k_base"), - "gpt-3.5-turbo-16k" => Ok("cl100k_base"), - "gpt-3.5-turbo-16k-0613" => Ok("cl100k_base"), - "gpt-3.5-turbo-instruct" => Ok("cl100k_base"), - "gpt-3.5-turbo-instruct-0914" => Ok("cl100k_base"), - "gpt-4" => Ok("cl100k_base"), - "gpt-4-0314" => Ok("cl100k_base"), - "gpt-4-0613" => Ok("cl100k_base"), - "gpt-4-32k" => Ok("cl100k_base"), - "gpt-4-32k-0314" => Ok("cl100k_base"), - "gpt-4-32k-0613" => Ok("cl100k_base"), - "gpt-3.5-turbo-1106" => Ok("cl100k_base"), - "gpt-35-turbo" => Ok("cl100k_base"), - "gpt-4-1106-preview" => Ok("cl100k_base"), - "gpt-4-vision-preview" => Ok("cl100k_base"), - "gpt-3.5-turbo-0125" => Ok("cl100k_base"), - "gpt-4-turbo" => Ok("cl100k_base"), - "gpt-4-turbo-2024-04-09" => Ok("cl100k_base"), - "gpt-4-turbo-preview" => Ok("cl100k_base"), - "gpt-4-0125-preview" => Ok("cl100k_base"), - "gpt-4o" => Ok("o200k_base"), - "gpt-4o-2024-05-13" => Ok("o200k_base"), - "gpt-4o-2024-08-06" => Ok("o200k_base"), - "gpt-4o-2024-11-20" => Ok("o200k_base"), - "gpt-4o-mini-2024-07-18" => Ok("o200k_base"), - "gpt-4o-mini" => Ok("o200k_base"), - "o1" => Ok("o200k_base"), - "o1-2024-12-17" => Ok("o200k_base"), - "o1-mini" => Ok("o200k_base"), - "o1-preview" => Ok("o200k_base"), - "o1-preview-2024-09-12" => Ok("o200k_base"), - "o1-mini-2024-09-12" => Ok("o200k_base"), - "chatgpt-4o-latest" => Ok("o200k_base"), - "gpt-4o-realtime" => Ok("o200k_base"), - "gpt-4o-realtime-preview-2024-10-01" => Ok("o200k_base"), - "o3-mini" => Ok("o200k_base"), - "o3-mini-2025-01-31" => Ok("o200k_base"), - model => Err(JsError::new( - format!("Invalid model: {}", model.to_string()).as_str(), - )), - }?; + let binding = get_encoding_name_for_model(model)?; + let encoding = binding.as_str(); Tiktoken::with_encoding( encoding, diff --git a/wasm/test/test_simple_public.test.ts b/wasm/test/test_simple_public.test.ts index 56839f4..0f74247 100644 --- a/wasm/test/test_simple_public.test.ts +++ b/wasm/test/test_simple_public.test.ts @@ -1,5 +1,5 @@ import { it, expect, describe } from "vitest"; -import { encoding_for_model, get_encoding } from "../dist"; +import { encoding_for_model, get_encoding, get_encoding_name_for_model } from "../dist"; it("encoding_for_model initialization", () => { expect(() => encoding_for_model("gpt2")).not.toThrowError(); @@ -106,6 +106,15 @@ it("test_encoding_for_model", () => { expect(encoding_for_model("gpt-3.5-turbo").name).toEqual("cl100k_base"); }); +it("test_get_encoding_name_for_model", () => { + expect(get_encoding_name_for_model("gpt2")).toEqual("gpt2"); + expect(get_encoding_name_for_model("text-davinci-003")).toEqual("p50k_base"); + expect(get_encoding_name_for_model("gpt-3.5-turbo")).toEqual("cl100k_base"); + + // @ts-expect-error - explicitly testing for invalid model + expect(() => get_encoding_name_for_model("unknown")).toThrowError("Invalid model: unknown"); +}) + it("test_custom_tokens", () => { const enc = encoding_for_model("gpt2", { "<|im_start|>": 100264,