Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace js-tiktoken BPE merge algorithm with faster heap based algorithm #101

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

mikolalysenko
Copy link

@mikolalysenko mikolalysenko commented Apr 19, 2024

There are several open issues noting that in the worst case BPE merge algorithm in js-tiktoken takes quadratic time in the number of input characters for certain pathological inputs.

This PR fixes this problem using a heap to avoid recalculating the ranks of all tokens at each character. This technique should also work for the rust/wasm tokenizer but it seems less important in those cases since the native parsers are already pretty fast.

I also added a new test fixture and an example string which causes pathological behavior.

Related issues:

Note: This should be a mild CVE since an attacker may use this behavior to cause a denial of service against services that check user input with js-tiktoken.

@huytool157
Copy link

@dqbd are we considering merging this, we notice some performance issues if the input is large as well

@enricoros
Copy link

@mikolalysenko have you benchmarked this vs baseline vs wasm? curious for 1, 1000, 1M, 100M tokens.

@jonschlinkert
Copy link

have you benchmarked this vs baseline vs wasm? curious for 1, 1000, 1M, 100M tokens.

If there is a perf difference at different scales, we should consider toggling algorithms based on length.

@tmcw
Copy link

tmcw commented Nov 20, 2024

It would be lovely if this got merged - this module has been causing issues for my application in prod.

@danny-avila
Copy link

+1

@timothycarambat
Copy link

I am not sure if this fully works, but modify core.ts to this loc

import base64 from "base64-js";
import type { TiktokenModel } from "./ranks/ranks";
import { never } from "./utils";

type BPEMergeNode = {
  listNext: BPEMergeNode | null;
  listPrev: BPEMergeNode | null;

  deleted: boolean;
  updated: boolean;
  updatedRank: number;
  removed: boolean;

  rank: number;
  start: number;
  end: number;
};

function compareNode(a: BPEMergeNode, b: BPEMergeNode) {
  return a.rank - b.rank || a.start - b.start;
}

// Helper function to swap elements at two indices
function swap(heap: BPEMergeNode[], i: number, j: number) {
  const temp = heap[i];
  heap[i] = heap[j];
  heap[j] = temp;
}

// standard binary heap push, generated by gpt4
function heapPush(heap: BPEMergeNode[], part: BPEMergeNode) {
  heap.push(part); // Add the new element to the end
  let currentIndex = heap.length - 1;
  let parentIndex = Math.floor((currentIndex - 1) / 2);

  // Bubble the new element up to its correct position
  while (
    currentIndex > 0 &&
    compareNode(heap[currentIndex], heap[parentIndex]) < 0
  ) {
    swap(heap, currentIndex, parentIndex);
    currentIndex = parentIndex;
    parentIndex = Math.floor((currentIndex - 1) / 2);
  }
}

// standard heap pop, also ai generated
function heapPop(heap: BPEMergeNode[]) {
  if (heap.length === 0) {
    return undefined; // Return undefined if the heap is empty
  }

  const rootValue = heap[0]; // The root element to return
  const lastValue = heap.pop(); // Remove the last element

  if (heap.length > 0 && lastValue) {
    heap[0] = lastValue; // Move the last element to the root
    let currentIndex = 0;

    // Bubble down the new root element to its correct position
    while (true) {
      let leftChildIndex = 2 * currentIndex + 1;
      let rightChildIndex = 2 * currentIndex + 2;
      let smallestIndex = currentIndex;

      if (
        leftChildIndex < heap.length &&
        compareNode(heap[leftChildIndex], heap[smallestIndex]) < 0
      ) {
        smallestIndex = leftChildIndex;
      }

      if (
        rightChildIndex < heap.length &&
        compareNode(heap[rightChildIndex], heap[smallestIndex]) < 0
      ) {
        smallestIndex = rightChildIndex;
      }

      if (smallestIndex !== currentIndex) {
        swap(heap, currentIndex, smallestIndex);
        currentIndex = smallestIndex;
      } else {
        break;
      }
    }
  }

  return rootValue;
}

function bytePairMerge(
  piece: Uint8Array,
  ranks: Map<string, number>
): Array<{ start: number; end: number }> {
  const parts: BPEMergeNode[] = Array.from(
    { length: piece.length },
    (_, i) => ({
      start: i,
      end: i + 1,
      rank: Infinity,
      deleted: false,
      updated: false,
      updatedRank: 0,
      removed: true,
      listNext: null,
      listPrev: null,
    })
  );

  if (parts.length === 0) {
    return [];
  }

  // Initialize linked list
  const head = parts[0];
  for (let i = 0; i < parts.length; ++i) {
    parts[i].listPrev = parts[i - 1] ?? null;
    parts[i].listNext = parts[i + 1] ?? null;
  }

  // Initialize heap with valid merges
  const heap: BPEMergeNode[] = [];
  for (let i = 0; i < parts.length - 1; ++i) {
    const slice = piece.slice(parts[i].start, parts[i + 1].end);
    const rank = ranks.get(slice.join(","));
    if (rank == null) continue;
    const part = parts[i];
    part.removed = false;
    part.rank = rank;
    heapPush(heap, part);
  }

  while (heap.length > 0) {
    const part = heapPop(heap);
    if (!part) break;

    if (part.deleted || !part.listNext) {
      continue;
    }

    if (part.updated) {
      part.rank = part.updatedRank;
      part.updated = false;
      heapPush(heap, part);
      continue;
    }

    // Verify the merge is still valid
    const currentSlice = piece.slice(part.start, part.listNext.end);
    const currentRank = ranks.get(currentSlice.join(","));
    if (currentRank !== part.rank) {
      continue;
    }

    // Perform merge
    part.end = part.listNext.end;
    part.listNext.deleted = true;
    part.listNext = part.listNext.listNext;
    if (part.listNext) {
      part.listNext.listPrev = part;
    }

    // Check for new possible merges
    let addedNewMerge = false;
    if (part.listNext) {
      const slice = piece.slice(part.start, part.listNext.end);
      const rank = ranks.get(slice.join(","));
      if (rank != null) {
        part.rank = rank;
        part.removed = false;
        heapPush(heap, part);
        addedNewMerge = true;
      }
    }

    if (part.listPrev && !part.listPrev.deleted) {
      const slice = piece.slice(part.listPrev.start, part.end);
      const rank = ranks.get(slice.join(","));
      if (rank != null) {
        if (!part.listPrev.removed) {
          part.listPrev.updated = true;
          part.listPrev.updatedRank = rank;
        } else {
          part.listPrev.removed = false;
          part.listPrev.rank = rank;
          heapPush(heap, part.listPrev);
        }
        addedNewMerge = true;
      }
    }

    if (!addedNewMerge) {
      part.removed = true;
    }
  }

  const result: Array<{ start: number; end: number }> = [];
  let current: BPEMergeNode | null = head;
  while (current) {
    if (!current.deleted) {
      result.push({ start: current.start, end: current.end });
    }
    current = current.listNext;
  }
  return result;
}


// rest of code unchanged

It will then pass tests since in the current PR there is a decoding error with non-latin chars

FAIL  test/compatibility.test.ts > LiteTokenizer matches the behavior of tiktoken > Emojis and non-latin characters
js-tiktoken:test: AssertionError: expected [ 9468, 239, 102, 378, 235, …(109) ] to deeply equal [ 9468, 239, 102, 378, 235, …(111) ]
js-tiktoken:test:  ❯ test/compatibility.test.ts:50:38
js-tiktoken:test:      48| 
js-tiktoken:test:      49|     for (const text of fixtures) {
js-tiktoken:test:      50|       expect([...lite.encode(text)]).toEqual([...full.encode(text)]);
js-tiktoken:test:        |                                      ^
js-tiktoken:test:      51|     }
js-tiktoken:test:      52|   });

With this lite_performance.text.js

import { test, expect, describe, } from "vitest";
import { getEncoding } from "../src/index";
const TARGET_TIME = 30_000;
const TARGET_STRING_LENGTH = 1_000_000; // Crazy high number to test the limits of the lite tokenizer
const EVIL_STRING = Array.from({ length: TARGET_STRING_LENGTH }, () => {
  return String.fromCharCode(Math.floor(Math.random() * 256));
}).join("");

// This test will be flaky - so perhaps we should run it externally
// from the main CI pipeline since it depends on the machine it's run on
describe(`Lite tokenizer resolves ${EVIL_STRING.length / 1000}K string in acceptable time (${TARGET_TIME}ms)`, () => {
  const lite = getEncoding("cl100k_base");
  test("Test lite performance", () => {
    const start = Date.now();
    const result = lite.encode(EVIL_STRING);
    const end = Date.now();
    console.log(`Lite encoding time: ${end - start}ms`);
    expect(end - start).toBeLessThanOrEqual(TARGET_TIME);
  });

  test("Test encoding/decoding", () => {
    const result = lite.encode(EVIL_STRING);
    const decoded = lite.decode(result);
    expect(decoded).toEqual(EVIL_STRING);
  });
});

With a crazy length of 1_000_000 chars I get encoding in 1908ms on an intel MBP

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants