-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace js-tiktoken BPE merge algorithm with faster heap based algorithm #101
base: main
Are you sure you want to change the base?
Conversation
@dqbd are we considering merging this, we notice some performance issues if the input is large as well |
@mikolalysenko have you benchmarked this vs baseline vs wasm? curious for 1, 1000, 1M, 100M tokens. |
If there is a perf difference at different scales, we should consider toggling algorithms based on length. |
It would be lovely if this got merged - this module has been causing issues for my application in prod. |
+1 |
I am not sure if this fully works, but modify import base64 from "base64-js";
import type { TiktokenModel } from "./ranks/ranks";
import { never } from "./utils";
type BPEMergeNode = {
listNext: BPEMergeNode | null;
listPrev: BPEMergeNode | null;
deleted: boolean;
updated: boolean;
updatedRank: number;
removed: boolean;
rank: number;
start: number;
end: number;
};
function compareNode(a: BPEMergeNode, b: BPEMergeNode) {
return a.rank - b.rank || a.start - b.start;
}
// Helper function to swap elements at two indices
function swap(heap: BPEMergeNode[], i: number, j: number) {
const temp = heap[i];
heap[i] = heap[j];
heap[j] = temp;
}
// standard binary heap push, generated by gpt4
function heapPush(heap: BPEMergeNode[], part: BPEMergeNode) {
heap.push(part); // Add the new element to the end
let currentIndex = heap.length - 1;
let parentIndex = Math.floor((currentIndex - 1) / 2);
// Bubble the new element up to its correct position
while (
currentIndex > 0 &&
compareNode(heap[currentIndex], heap[parentIndex]) < 0
) {
swap(heap, currentIndex, parentIndex);
currentIndex = parentIndex;
parentIndex = Math.floor((currentIndex - 1) / 2);
}
}
// standard heap pop, also ai generated
function heapPop(heap: BPEMergeNode[]) {
if (heap.length === 0) {
return undefined; // Return undefined if the heap is empty
}
const rootValue = heap[0]; // The root element to return
const lastValue = heap.pop(); // Remove the last element
if (heap.length > 0 && lastValue) {
heap[0] = lastValue; // Move the last element to the root
let currentIndex = 0;
// Bubble down the new root element to its correct position
while (true) {
let leftChildIndex = 2 * currentIndex + 1;
let rightChildIndex = 2 * currentIndex + 2;
let smallestIndex = currentIndex;
if (
leftChildIndex < heap.length &&
compareNode(heap[leftChildIndex], heap[smallestIndex]) < 0
) {
smallestIndex = leftChildIndex;
}
if (
rightChildIndex < heap.length &&
compareNode(heap[rightChildIndex], heap[smallestIndex]) < 0
) {
smallestIndex = rightChildIndex;
}
if (smallestIndex !== currentIndex) {
swap(heap, currentIndex, smallestIndex);
currentIndex = smallestIndex;
} else {
break;
}
}
}
return rootValue;
}
function bytePairMerge(
piece: Uint8Array,
ranks: Map<string, number>
): Array<{ start: number; end: number }> {
const parts: BPEMergeNode[] = Array.from(
{ length: piece.length },
(_, i) => ({
start: i,
end: i + 1,
rank: Infinity,
deleted: false,
updated: false,
updatedRank: 0,
removed: true,
listNext: null,
listPrev: null,
})
);
if (parts.length === 0) {
return [];
}
// Initialize linked list
const head = parts[0];
for (let i = 0; i < parts.length; ++i) {
parts[i].listPrev = parts[i - 1] ?? null;
parts[i].listNext = parts[i + 1] ?? null;
}
// Initialize heap with valid merges
const heap: BPEMergeNode[] = [];
for (let i = 0; i < parts.length - 1; ++i) {
const slice = piece.slice(parts[i].start, parts[i + 1].end);
const rank = ranks.get(slice.join(","));
if (rank == null) continue;
const part = parts[i];
part.removed = false;
part.rank = rank;
heapPush(heap, part);
}
while (heap.length > 0) {
const part = heapPop(heap);
if (!part) break;
if (part.deleted || !part.listNext) {
continue;
}
if (part.updated) {
part.rank = part.updatedRank;
part.updated = false;
heapPush(heap, part);
continue;
}
// Verify the merge is still valid
const currentSlice = piece.slice(part.start, part.listNext.end);
const currentRank = ranks.get(currentSlice.join(","));
if (currentRank !== part.rank) {
continue;
}
// Perform merge
part.end = part.listNext.end;
part.listNext.deleted = true;
part.listNext = part.listNext.listNext;
if (part.listNext) {
part.listNext.listPrev = part;
}
// Check for new possible merges
let addedNewMerge = false;
if (part.listNext) {
const slice = piece.slice(part.start, part.listNext.end);
const rank = ranks.get(slice.join(","));
if (rank != null) {
part.rank = rank;
part.removed = false;
heapPush(heap, part);
addedNewMerge = true;
}
}
if (part.listPrev && !part.listPrev.deleted) {
const slice = piece.slice(part.listPrev.start, part.end);
const rank = ranks.get(slice.join(","));
if (rank != null) {
if (!part.listPrev.removed) {
part.listPrev.updated = true;
part.listPrev.updatedRank = rank;
} else {
part.listPrev.removed = false;
part.listPrev.rank = rank;
heapPush(heap, part.listPrev);
}
addedNewMerge = true;
}
}
if (!addedNewMerge) {
part.removed = true;
}
}
const result: Array<{ start: number; end: number }> = [];
let current: BPEMergeNode | null = head;
while (current) {
if (!current.deleted) {
result.push({ start: current.start, end: current.end });
}
current = current.listNext;
}
return result;
}
// rest of code unchanged It will then pass tests since in the current PR there is a decoding error with non-latin chars FAIL test/compatibility.test.ts > LiteTokenizer matches the behavior of tiktoken > Emojis and non-latin characters
js-tiktoken:test: AssertionError: expected [ 9468, 239, 102, 378, 235, …(109) ] to deeply equal [ 9468, 239, 102, 378, 235, …(111) ]
js-tiktoken:test: ❯ test/compatibility.test.ts:50:38
js-tiktoken:test: 48|
js-tiktoken:test: 49| for (const text of fixtures) {
js-tiktoken:test: 50| expect([...lite.encode(text)]).toEqual([...full.encode(text)]);
js-tiktoken:test: | ^
js-tiktoken:test: 51| }
js-tiktoken:test: 52| }); With this import { test, expect, describe, } from "vitest";
import { getEncoding } from "../src/index";
const TARGET_TIME = 30_000;
const TARGET_STRING_LENGTH = 1_000_000; // Crazy high number to test the limits of the lite tokenizer
const EVIL_STRING = Array.from({ length: TARGET_STRING_LENGTH }, () => {
return String.fromCharCode(Math.floor(Math.random() * 256));
}).join("");
// This test will be flaky - so perhaps we should run it externally
// from the main CI pipeline since it depends on the machine it's run on
describe(`Lite tokenizer resolves ${EVIL_STRING.length / 1000}K string in acceptable time (${TARGET_TIME}ms)`, () => {
const lite = getEncoding("cl100k_base");
test("Test lite performance", () => {
const start = Date.now();
const result = lite.encode(EVIL_STRING);
const end = Date.now();
console.log(`Lite encoding time: ${end - start}ms`);
expect(end - start).toBeLessThanOrEqual(TARGET_TIME);
});
test("Test encoding/decoding", () => {
const result = lite.encode(EVIL_STRING);
const decoded = lite.decode(result);
expect(decoded).toEqual(EVIL_STRING);
});
}); With a crazy length of |
There are several open issues noting that in the worst case BPE merge algorithm in js-tiktoken takes quadratic time in the number of input characters for certain pathological inputs.
This PR fixes this problem using a heap to avoid recalculating the ranks of all tokens at each character. This technique should also work for the rust/wasm tokenizer but it seems less important in those cases since the native parsers are already pretty fast.
I also added a new test fixture and an example string which causes pathological behavior.
Related issues:
Note: This should be a mild CVE since an attacker may use this behavior to cause a denial of service against services that check user input with js-tiktoken.