Skip to content

Commit

Permalink
Add get_encoding_name_for_model to tiktoken
Browse files Browse the repository at this point in the history
The `tiktoken-js` library includes a very helpful function,
`getEncodingNameForModel()`. This function is buried in the
implementation of `encoding_for_model()` in the rust based
`tiktoken` package.

This function is very useful when implementing an encoding cache based
on the model used. In this case, having a mapping from model ->
encoding and then caching based on the encoding name conserves
resources since so many models re-use the same encoding.

I've refactored the typescript definition generation a little bit so
that all the types are declared and their references when used all
appear in the same block and there are no use-before-declaration
warnings.

Finally, I've exposed a new `get_encoding_name_for_model()` function
that behaves similarly to the one in the `tiktoken-js` package, and used
it inside of `encoding_for_model()`. I also added a test to ensure that
this function can be called properly from typescript code, and that it
properly throws exceptions in the case of invalid model names.

Fixes: #123
  • Loading branch information
noseworthy committed Feb 16, 2025
1 parent 8963e56 commit 0a28e20
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 102 deletions.
212 changes: 111 additions & 101 deletions wasm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -341,28 +341,6 @@ impl Tiktoken {
const _: &'static str = r#"
export type TiktokenEncoding = "gpt2" | "r50k_base" | "p50k_base" | "p50k_edit" | "cl100k_base" | "o200k_base";
/**
* @param {TiktokenEncoding} encoding
* @param {Record<string, number>} [extend_special_tokens]
* @returns {Tiktoken}
*/
export function get_encoding(encoding: TiktokenEncoding, extend_special_tokens?: Record<string, number>): Tiktoken;
"#;

#[cfg(feature = "inline")]
#[wasm_bindgen(skip_typescript)]
pub fn get_encoding(encoding: &str, extend_special_tokens: JsValue) -> Result<Tiktoken, JsError> {
Tiktoken::with_encoding(
encoding,
&extend_special_tokens
.into_serde::<HashMap<String, usize>>()
.ok(),
)
}

#[cfg(feature = "inline")]
#[wasm_bindgen(typescript_custom_section)]
const _: &'static str = r#"
export type TiktokenModel =
| "davinci-002"
| "babbage-002"
Expand Down Expand Up @@ -438,98 +416,130 @@ export type TiktokenModel =
| "gpt-4o-realtime"
| "gpt-4o-realtime-preview-2024-10-01"
/**
* @param {TiktokenEncoding} encoding
* @param {Record<string, number>} [extend_special_tokens]
* @returns {Tiktoken}
*/
export function get_encoding(encoding: TiktokenEncoding, extend_special_tokens?: Record<string, number>): Tiktoken;
/**
* @param {TiktokenModel} encoding
* @param {TiktokenModel} model
* @returns {TiktokenEncoding}
*/
export function get_encoding_name_for_model(model: TiktokenModel): TiktokenEncoding;
/**
* @param {TiktokenModel} model
* @param {Record<string, number>} [extend_special_tokens]
* @returns {Tiktoken}
*/
export function encoding_for_model(model: TiktokenModel, extend_special_tokens?: Record<string, number>): Tiktoken;
"#;

#[cfg(feature = "inline")]
#[wasm_bindgen(skip_typescript)]
pub fn get_encoding(encoding: &str, extend_special_tokens: JsValue) -> Result<Tiktoken, JsError> {
Tiktoken::with_encoding(
encoding,
&extend_special_tokens
.into_serde::<HashMap<String, usize>>()
.ok(),
)
}

#[cfg(feature="inline")]
#[wasm_bindgen(skip_typescript)]
pub fn get_encoding_name_for_model(model: &str) -> Result<String, JsError> {
match model {
"text-davinci-003" => Ok("p50k_base".into()),
"text-davinci-002" => Ok("p50k_base".into()),
"text-davinci-001" => Ok("r50k_base".into()),
"text-curie-001" => Ok("r50k_base".into()),
"text-babbage-001" => Ok("r50k_base".into()),
"text-ada-001" => Ok("r50k_base".into()),
"davinci" => Ok("r50k_base".into()),
"davinci-002" => Ok("cl100k_base".into()),
"curie" => Ok("r50k_base".into()),
"babbage" => Ok("r50k_base".into()),
"babbage-002" => Ok("cl100k_base".into()),
"ada" => Ok("r50k_base".into()),
"code-davinci-002" => Ok("p50k_base".into()),
"code-davinci-001" => Ok("p50k_base".into()),
"code-cushman-002" => Ok("p50k_base".into()),
"code-cushman-001" => Ok("p50k_base".into()),
"davinci-codex" => Ok("p50k_base".into()),
"cushman-codex" => Ok("p50k_base".into()),
"text-davinci-edit-001" => Ok("p50k_edit".into()),
"code-davinci-edit-001" => Ok("p50k_edit".into()),
"text-embedding-ada-002" => Ok("cl100k_base".into()),
"text-embedding-3-small" => Ok("cl100k_base".into()),
"text-embedding-3-large" => Ok("cl100k_base".into()),
"text-similarity-davinci-001" => Ok("r50k_base".into()),
"text-similarity-curie-001" => Ok("r50k_base".into()),
"text-similarity-babbage-001" => Ok("r50k_base".into()),
"text-similarity-ada-001" => Ok("r50k_base".into()),
"text-search-davinci-doc-001" => Ok("r50k_base".into()),
"text-search-curie-doc-001" => Ok("r50k_base".into()),
"text-search-babbage-doc-001" => Ok("r50k_base".into()),
"text-search-ada-doc-001" => Ok("r50k_base".into()),
"code-search-babbage-code-001" => Ok("r50k_base".into()),
"code-search-ada-code-001" => Ok("r50k_base".into()),
"gpt2" => Ok("gpt2".into()),
"gpt-3.5-turbo" => Ok("cl100k_base".into()),
"gpt-3.5-turbo-0301" => Ok("cl100k_base".into()),
"gpt-3.5-turbo-0613" => Ok("cl100k_base".into()),
"gpt-3.5-turbo-16k" => Ok("cl100k_base".into()),
"gpt-3.5-turbo-16k-0613" => Ok("cl100k_base".into()),
"gpt-3.5-turbo-instruct" => Ok("cl100k_base".into()),
"gpt-3.5-turbo-instruct-0914" => Ok("cl100k_base".into()),
"gpt-4" => Ok("cl100k_base".into()),
"gpt-4-0314" => Ok("cl100k_base".into()),
"gpt-4-0613" => Ok("cl100k_base".into()),
"gpt-4-32k" => Ok("cl100k_base".into()),
"gpt-4-32k-0314" => Ok("cl100k_base".into()),
"gpt-4-32k-0613" => Ok("cl100k_base".into()),
"gpt-3.5-turbo-1106" => Ok("cl100k_base".into()),
"gpt-35-turbo" => Ok("cl100k_base".into()),
"gpt-4-1106-preview" => Ok("cl100k_base".into()),
"gpt-4-vision-preview" => Ok("cl100k_base".into()),
"gpt-3.5-turbo-0125" => Ok("cl100k_base".into()),
"gpt-4-turbo" => Ok("cl100k_base".into()),
"gpt-4-turbo-2024-04-09" => Ok("cl100k_base".into()),
"gpt-4-turbo-preview" => Ok("cl100k_base".into()),
"gpt-4-0125-preview" => Ok("cl100k_base".into()),
"gpt-4o" => Ok("o200k_base".into()),
"gpt-4o-2024-05-13" => Ok("o200k_base".into()),
"gpt-4o-2024-08-06" => Ok("o200k_base".into()),
"gpt-4o-2024-11-20" => Ok("o200k_base".into()),
"gpt-4o-mini-2024-07-18" => Ok("o200k_base".into()),
"gpt-4o-mini" => Ok("o200k_base".into()),
"o1" => Ok("o200k_base".into()),
"o1-2024-12-17" => Ok("o200k_base".into()),
"o1-mini" => Ok("o200k_base".into()),
"o1-preview" => Ok("o200k_base".into()),
"o1-preview-2024-09-12" => Ok("o200k_base".into()),
"o1-mini-2024-09-12" => Ok("o200k_base".into()),
"chatgpt-4o-latest" => Ok("o200k_base".into()),
"gpt-4o-realtime" => Ok("o200k_base".into()),
"gpt-4o-realtime-preview-2024-10-01" => Ok("o200k_base".into()),
"o3-mini" => Ok("o200k_base".into()),
"o3-mini-2025-01-31" => Ok("o200k_base".into()),
model => Err(JsError::new(
format!("Invalid model: {}", model.to_string()).as_str(),
)),
}
}

#[cfg(feature = "inline")]
#[wasm_bindgen(skip_typescript)]
pub fn encoding_for_model(
model: &str,
extend_special_tokens: JsValue,
) -> Result<Tiktoken, JsError> {
let encoding = match model {
"text-davinci-003" => Ok("p50k_base"),
"text-davinci-002" => Ok("p50k_base"),
"text-davinci-001" => Ok("r50k_base"),
"text-curie-001" => Ok("r50k_base"),
"text-babbage-001" => Ok("r50k_base"),
"text-ada-001" => Ok("r50k_base"),
"davinci" => Ok("r50k_base"),
"davinci-002" => Ok("cl100k_base"),
"curie" => Ok("r50k_base"),
"babbage" => Ok("r50k_base"),
"babbage-002" => Ok("cl100k_base"),
"ada" => Ok("r50k_base"),
"code-davinci-002" => Ok("p50k_base"),
"code-davinci-001" => Ok("p50k_base"),
"code-cushman-002" => Ok("p50k_base"),
"code-cushman-001" => Ok("p50k_base"),
"davinci-codex" => Ok("p50k_base"),
"cushman-codex" => Ok("p50k_base"),
"text-davinci-edit-001" => Ok("p50k_edit"),
"code-davinci-edit-001" => Ok("p50k_edit"),
"text-embedding-ada-002" => Ok("cl100k_base"),
"text-embedding-3-small" => Ok("cl100k_base"),
"text-embedding-3-large" => Ok("cl100k_base"),
"text-similarity-davinci-001" => Ok("r50k_base"),
"text-similarity-curie-001" => Ok("r50k_base"),
"text-similarity-babbage-001" => Ok("r50k_base"),
"text-similarity-ada-001" => Ok("r50k_base"),
"text-search-davinci-doc-001" => Ok("r50k_base"),
"text-search-curie-doc-001" => Ok("r50k_base"),
"text-search-babbage-doc-001" => Ok("r50k_base"),
"text-search-ada-doc-001" => Ok("r50k_base"),
"code-search-babbage-code-001" => Ok("r50k_base"),
"code-search-ada-code-001" => Ok("r50k_base"),
"gpt2" => Ok("gpt2"),
"gpt-3.5-turbo" => Ok("cl100k_base"),
"gpt-3.5-turbo-0301" => Ok("cl100k_base"),
"gpt-3.5-turbo-0613" => Ok("cl100k_base"),
"gpt-3.5-turbo-16k" => Ok("cl100k_base"),
"gpt-3.5-turbo-16k-0613" => Ok("cl100k_base"),
"gpt-3.5-turbo-instruct" => Ok("cl100k_base"),
"gpt-3.5-turbo-instruct-0914" => Ok("cl100k_base"),
"gpt-4" => Ok("cl100k_base"),
"gpt-4-0314" => Ok("cl100k_base"),
"gpt-4-0613" => Ok("cl100k_base"),
"gpt-4-32k" => Ok("cl100k_base"),
"gpt-4-32k-0314" => Ok("cl100k_base"),
"gpt-4-32k-0613" => Ok("cl100k_base"),
"gpt-3.5-turbo-1106" => Ok("cl100k_base"),
"gpt-35-turbo" => Ok("cl100k_base"),
"gpt-4-1106-preview" => Ok("cl100k_base"),
"gpt-4-vision-preview" => Ok("cl100k_base"),
"gpt-3.5-turbo-0125" => Ok("cl100k_base"),
"gpt-4-turbo" => Ok("cl100k_base"),
"gpt-4-turbo-2024-04-09" => Ok("cl100k_base"),
"gpt-4-turbo-preview" => Ok("cl100k_base"),
"gpt-4-0125-preview" => Ok("cl100k_base"),
"gpt-4o" => Ok("o200k_base"),
"gpt-4o-2024-05-13" => Ok("o200k_base"),
"gpt-4o-2024-08-06" => Ok("o200k_base"),
"gpt-4o-2024-11-20" => Ok("o200k_base"),
"gpt-4o-mini-2024-07-18" => Ok("o200k_base"),
"gpt-4o-mini" => Ok("o200k_base"),
"o1" => Ok("o200k_base"),
"o1-2024-12-17" => Ok("o200k_base"),
"o1-mini" => Ok("o200k_base"),
"o1-preview" => Ok("o200k_base"),
"o1-preview-2024-09-12" => Ok("o200k_base"),
"o1-mini-2024-09-12" => Ok("o200k_base"),
"chatgpt-4o-latest" => Ok("o200k_base"),
"gpt-4o-realtime" => Ok("o200k_base"),
"gpt-4o-realtime-preview-2024-10-01" => Ok("o200k_base"),
"o3-mini" => Ok("o200k_base"),
"o3-mini-2025-01-31" => Ok("o200k_base"),
model => Err(JsError::new(
format!("Invalid model: {}", model.to_string()).as_str(),
)),
}?;
let binding = get_encoding_name_for_model(model)?;
let encoding = binding.as_str();

Tiktoken::with_encoding(
encoding,
Expand Down
11 changes: 10 additions & 1 deletion wasm/test/test_simple_public.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { it, expect, describe } from "vitest";
import { encoding_for_model, get_encoding } from "../dist";
import { encoding_for_model, get_encoding, get_encoding_name_for_model } from "../dist";

it("encoding_for_model initialization", () => {
expect(() => encoding_for_model("gpt2")).not.toThrowError();
Expand Down Expand Up @@ -106,6 +106,15 @@ it("test_encoding_for_model", () => {
expect(encoding_for_model("gpt-3.5-turbo").name).toEqual("cl100k_base");
});

it("test_get_encoding_name_for_model", () => {
expect(get_encoding_name_for_model("gpt2")).toEqual("gpt2");
expect(get_encoding_name_for_model("text-davinci-003")).toEqual("p50k_base");
expect(get_encoding_name_for_model("gpt-3.5-turbo")).toEqual("cl100k_base");

// @ts-expect-error - explicitly testing for invalid model
expect(() => get_encoding_name_for_model("unknown")).toThrowError("Invalid model: unknown");
})

it("test_custom_tokens", () => {
const enc = encoding_for_model("gpt2", {
"<|im_start|>": 100264,
Expand Down

0 comments on commit 0a28e20

Please sign in to comment.