diff --git a/src/models/index.ts b/src/models/index.ts index a446944..3af477e 100644 --- a/src/models/index.ts +++ b/src/models/index.ts @@ -131,6 +131,10 @@ export function isChatModel( ); } +export function isOpenAIModel(model: AllOptions): boolean { + return model.indexOf("/") === -1; +} + export function isValidOption(model: unknown): model is AllOptions { return allOptions.safeParse(model).success; } diff --git a/src/models/tokenizer.ts b/src/models/tokenizer.ts index 477b27c..2e75670 100644 --- a/src/models/tokenizer.ts +++ b/src/models/tokenizer.ts @@ -17,21 +17,35 @@ export interface TokenizerResult { count: number; } +export interface TokenInfo { + id: number; + text?: string; + bytes?: Uint8Array; + // If this is a merge, the original token ids that were merged to form this token + merge?: [number, number]; + special?: boolean; +} + export interface Tokenizer { name: string; + type: string; tokenize(text: string): TokenizerResult; + getInfo(token: number): TokenInfo; + specialTokens: Record; + tokenCount: number; free?(): void; } export class TiktokenTokenizer implements Tokenizer { private enc: Tiktoken; + readonly specialTokens: Record = {}; name: string; + type = "BPE"; constructor(model: z.infer | z.infer) { const isModel = oaiModels.safeParse(model); const isEncoding = oaiEncodings.safeParse(model); - console.log(isModel.success, isEncoding.success, model) + console.log(isModel.success, isEncoding.success, model); if (isModel.success) { - if ( model === "text-embedding-3-small" || model === "text-embedding-3-large" @@ -39,21 +53,31 @@ export class TiktokenTokenizer implements Tokenizer { throw new Error("Model may be too new"); } - const enc = - model === "gpt-3.5-turbo" || model === "gpt-4" || model === "gpt-4-32k" - ? get_encoding("cl100k_base", { - "<|im_start|>": 100264, - "<|im_end|>": 100265, - "<|im_sep|>": 100266, - }) - : model === "gpt-4o" - ? get_encoding("o200k_base", { - "<|im_start|>": 200264, - "<|im_end|>": 200265, - "<|im_sep|>": 200266, - }) - : // @ts-expect-error r50k broken? - encoding_for_model(model); + let specialTokens: Record = {}; + let enc; + if ( + model === "gpt-3.5-turbo" || + model === "gpt-4" || + model === "gpt-4-32k" + ) { + specialTokens = { + "<|im_start|>": 100264, + "<|im_end|>": 100265, + "<|im_sep|>": 100266, + }; + enc = get_encoding("cl100k_base", specialTokens); + } else if (model === "gpt-4o") { + specialTokens = { + "<|im_start|>": 200264, + "<|im_end|>": 200265, + "<|im_sep|>": 200266, + }; + enc = get_encoding("o200k_base", specialTokens); + } else { + // @ts-expect-error r50k broken? + enc = encoding_for_model(model); + } + this.specialTokens = specialTokens; this.name = enc.name ?? model; this.enc = enc; } else if (isEncoding.success) { @@ -64,6 +88,10 @@ export class TiktokenTokenizer implements Tokenizer { } } + get tokenCount(): number { + return this.enc.token_byte_values().length; + } + tokenize(text: string): TokenizerResult { const tokens = [...(this.enc?.encode(text, "all") ?? [])]; return { @@ -74,18 +102,38 @@ export class TiktokenTokenizer implements Tokenizer { }; } + getInfo(token: number): TokenInfo { + const special = Object.entries(this.specialTokens).find( + ([_, value]) => value === token + ); + // Search merges. TODO: how to do this? + return { + id: token, + bytes: this.enc.decode_single_token_bytes(token), + text: new TextDecoder("utf-8", { fatal: false }).decode( + this.enc.decode_single_token_bytes(token) + ), + special: special !== undefined, + }; + } + free(): void { this.enc.free(); } } export class OpenSourceTokenizer implements Tokenizer { + readonly specialTokens: Record = {}; + name: string; + type = "SentencePieceBPE"; + constructor(private tokenizer: PreTrainedTokenizer, name?: string) { this.name = name ?? tokenizer.name; + this.specialTokens = Object.fromEntries( + tokenizer.added_tokens.map((t) => [t.content, t.id]) + ); } - name: string; - static async load( model: z.infer ): Promise { @@ -117,6 +165,23 @@ export class OpenSourceTokenizer implements Tokenizer { count: tokens.length, }; } + + get tokenCount(): number { + return this.tokenizer.model.tokens_to_ids.size; + } + + getInfo(token: number): TokenInfo { + const t = this.tokenizer.decode([token]); + const special = Object.entries(this.specialTokens).find( + ([_, value]) => value === token + ); + return { + id: token, + bytes: new TextEncoder().encode(t), + text: t, + special: special !== undefined, + }; + } } export async function createTokenizer(name: string): Promise { diff --git a/src/pages/[hf_org]/[model]/[token]/index.tsx b/src/pages/[hf_org]/[model]/[token]/index.tsx new file mode 100644 index 0000000..4bdaaff --- /dev/null +++ b/src/pages/[hf_org]/[model]/[token]/index.tsx @@ -0,0 +1,24 @@ +import { useRouter } from "next/router"; +import { z } from "zod"; +import { TokenInfo } from "~/pages/openai/[model]/[token]"; + +const OrganizationModelAndToken = z.object({ + hf_org: z.string(), + model: z.string(), + token: z.string(), +}); + +export default function Page() { + const router = useRouter(); + + // On initial load the router query is {}, so just return + if (Object.keys(router.query).length === 0) { + return null; + } + + const { hf_org, model, token } = OrganizationModelAndToken.parse( + router.query + ); + + return ; +} diff --git a/src/pages/[hf_org]/[model]/index.tsx b/src/pages/[hf_org]/[model]/index.tsx new file mode 100644 index 0000000..4b131d1 --- /dev/null +++ b/src/pages/[hf_org]/[model]/index.tsx @@ -0,0 +1,102 @@ +import { useQuery } from "@tanstack/react-query"; +import { useRouter } from "next/router"; +import { createTokenizer, Tokenizer } from "~/models/tokenizer"; +import { z } from "zod"; +import { useState } from "react"; +import Link from "next/link"; + +const OrganizationModel = z.object({ + hf_org: z.string(), + model: z.string(), +}); + +export default function Page() { + const router = useRouter(); + if (Object.keys(router.query).length === 0) { + return null; + } + const { hf_org, model } = OrganizationModel.parse(router.query); + return ; +} + +function* range(n: number) { + for (let i = 0; i < n; i++) { + yield i; + } +} + +function TokensList({ + model, + tokenizer, + max = 1000, +}: { + model: string; + tokenizer: Tokenizer; + max?: number; +}) { + const tokens = Array.from(range(max)); + return ( +
    + {tokens.map((i) => ( +
  1. + + {tokenizer.getInfo(i)?.text} + + {","} +
  2. + ))} +
+ ); +} + +export function TokenizerInfo({ model }: { model: string }) { + const tq = useQuery({ + queryKey: [model], + queryFn: ({ queryKey: [model] }) => createTokenizer(model!), + }); + + const [hoveredToken, setHoveredToken] = useState(null); + const tokenizer = tq.data; + if (tq.isLoading) { + return ( +
+

Tokenizer Info

+

Loading {model}...

+
+ ); + } + + return ( +
+

Tokenizer Info

+

Model: {tokenizer?.name}

+

Added tokens: {tokenizer?.specialTokens?.length}

+ +
    + {tokenizer?.specialTokens && + Object.entries(tokenizer.specialTokens).map(([token, id]) => ( +
  • setHoveredToken(id)} + onMouseLeave={() => setHoveredToken(null)} + > + {token} +
  • + ))} +
+ {hoveredToken && ( +
+

ID: {hoveredToken}

+
+ )} +

Tokens: {tokenizer?.tokenCount}

+ {tokenizer && ( + + )} +
+ ); +} diff --git a/src/pages/index.tsx b/src/pages/index.tsx index b314439..e486a02 100644 --- a/src/pages/index.tsx +++ b/src/pages/index.tsx @@ -11,7 +11,12 @@ import { ChatGPTEditor } from "../sections/ChatGPTEditor"; import { EncoderSelect } from "~/sections/EncoderSelect"; import { TokenViewer } from "~/sections/TokenViewer"; import { TextArea } from "~/components/Input"; -import { type AllOptions, isChatModel, isValidOption } from "~/models"; +import { + type AllOptions, + isChatModel, + isOpenAIModel, + isValidOption, +} from "~/models"; import { createTokenizer } from "~/models/tokenizer"; import { useQuery } from "@tanstack/react-query"; import { useRouter } from "next/router"; @@ -20,9 +25,7 @@ function useQueryParamsState() { const router = useRouter(); const params = useMemo((): AllOptions => { - return isValidOption(router.query?.model) - ? router.query.model - : "gpt-4o"; + return isValidOption(router.query?.model) ? router.query.model : "gpt-4o"; }, [router.query]); const setParams = (model: AllOptions) => { @@ -44,7 +47,6 @@ const Home: NextPage< queryKey: [model], queryFn: ({ queryKey: [model] }) => createTokenizer(model!), }); - const tokens = tokenizer.data?.tokenize(inputText); return ( @@ -83,7 +85,12 @@ const Home: NextPage<
- +