From aa5fcf54698ddd5adb3c5017ca504f19bf608ccf Mon Sep 17 00:00:00 2001 From: Pavel Alekseev Date: Wed, 5 Feb 2025 12:59:51 +0300 Subject: [PATCH] Fork update + o3 model implementation --- composer.json | 1 - src/Encoder.php | 138 ++++++++++++++---------- src/EncoderProvider.php | 81 +++++++++++--- src/Util/EncodeUtil.php | 17 +-- src/Vocab/Loader/DefaultVocabLoader.php | 68 +++++++++--- src/Vocab/Vocab.php | 61 ++++++----- src/Vocab/VocabLoader.php | 8 +- 7 files changed, 250 insertions(+), 124 deletions(-) diff --git a/composer.json b/composer.json index da74c03..c2db699 100644 --- a/composer.json +++ b/composer.json @@ -13,7 +13,6 @@ } }, "minimum-stability": "stable", - "version": "1.0.0", "config": { "sort-packages": true, "allow-plugins": { diff --git a/src/Encoder.php b/src/Encoder.php index 5e7d184..716ef4e 100644 --- a/src/Encoder.php +++ b/src/Encoder.php @@ -4,32 +4,44 @@ namespace guttedgarden\Tiktoken; -use Closure; use guttedgarden\Tiktoken\Exception\RegexError; use guttedgarden\Tiktoken\Util\EncodeUtil; use guttedgarden\Tiktoken\Vocab\Vocab; use function array_map; -use function array_slice; +use function array_merge; use function array_values; -use function assert; use function count; use function implode; +use function preg_last_error; use function preg_match_all; -use function range; use function sprintf; +use function strlen; +use function substr; use const PHP_INT_MAX; /** @psalm-import-type NonEmptyByteVector from EncodeUtil */ final class Encoder { - private string $name; - private Vocab $vocab; - private string $pattern; /** - * @param non-empty-string $name - * @param non-empty-string $pattern + * @var string + */ + private $name; + + /** + * @var Vocab + */ + private $vocab; + + /** + * @var string + */ + private $pattern; + + /** + * @param string $name + * @param string $pattern */ public function __construct(string $name, Vocab $vocab, string $pattern) { @@ -43,7 +55,7 @@ public function __toString(): string return sprintf('Encoder(name="%s", vocab=%d)', $this->name, count($this->vocab)); } - /** @return list */ + /** @return int[] */ public function encode(string $text): array { if ($text === '') { @@ -51,7 +63,7 @@ public function encode(string $text): array } if (preg_match_all($this->pattern, $text, $matches) === false) { - throw new RegexError(sprintf('Matching failed with error: %s', $this->pregErrorString(preg_last_error()))); + throw new RegexError(sprintf('Matching failed with error code: %d', preg_last_error())); } $tokens = []; @@ -61,16 +73,14 @@ public function encode(string $text): array continue; } - $piece = EncodeUtil::toBytes($match); - $rank = $this->vocab->tryGetRank($piece); + $rank = $this->vocab->tryGetRank($match); if ($rank !== null) { $tokens[] = $rank; - continue; } - foreach ($this->mergeBytePairs($piece) as $rank) { + foreach ($this->mergeBytePairs($match) as $rank) { $tokens[] = $rank; } } @@ -78,49 +88,81 @@ public function encode(string $text): array return $tokens; } - /** @param array $tokens */ + /** @return int[][] */ + public function encodeInChunks(string $text, int $maxTokensPerChunk): array + { + if ($text === '') { + return []; + } + + if (preg_match_all($this->pattern, $text, $matches) === false) { + throw new RegexError(sprintf('Matching failed with error code: %d', preg_last_error())); + } + + $chunks = []; + $tokensInCurrentChunk = []; + + foreach ($matches[0] as $match) { + if ($match === '') { + continue; + } + + $rank = $this->vocab->tryGetRank($match); + $tokens = $rank !== null ? [$rank] : $this->mergeBytePairs($match); + + if (count($tokensInCurrentChunk) + count($tokens) > $maxTokensPerChunk) { + $chunks[] = $tokensInCurrentChunk; + $tokensInCurrentChunk = []; + } + + $tokensInCurrentChunk = array_merge($tokensInCurrentChunk, $tokens); + } + + if (count($tokensInCurrentChunk) > 0) { + $chunks[] = $tokensInCurrentChunk; + } + + return $chunks; + } + + /** @param int[] $tokens */ public function decode(array $tokens): string { if ($tokens === []) { return ''; } - return implode(array_map(Closure::fromCallable([$this->vocab, 'getToken']), $tokens)); + return implode(array_map([$this->vocab, 'getToken'], $tokens)); } /** - * @psalm-param NonEmptyByteVector $bytes + * @param string $piece * - * @return list + * @return int[] */ - private function mergeBytePairs(array $bytes): array + private function mergeBytePairs(string $piece): array { - /** @var list $parts */ - $parts = array_map( - function (int $i) use ($bytes): array { - if ($i + 1 < count($bytes)) { - $piece = array_slice($bytes, $i, 2); - assert(count($piece) === 2); - - return [$i, $this->vocab->tryGetRank($piece) ?? PHP_INT_MAX]; - } + $parts = []; + + for ($i = 0; $i <= strlen($piece); $i++) { + $parts[] = [$i, PHP_INT_MAX]; + } - return [$i, PHP_INT_MAX]; - }, - range(0, count($bytes)), - ); - $getRank = function (array $parts, int $startIndex) use ($bytes): int { - if ($startIndex + 2 >= count($parts)) { + $getRank = function (array $parts, int $startIndex, int $skip = 0) use ($piece): int { + if (($startIndex + $skip + 2) >= count($parts)) { return PHP_INT_MAX; } $offset = $parts[$startIndex][0]; - $piece = array_slice($bytes, $offset, $parts[$startIndex + 2][0] - $offset); - assert(count($piece) > 0); + $length = $parts[$startIndex + $skip + 2][0] - $offset; - return $this->vocab->tryGetRank($piece) ?? PHP_INT_MAX; + return $this->vocab->tryGetRank(substr($piece, $offset, $length)) ?? PHP_INT_MAX; }; + for ($i = 0; $i < count($parts) - 2; $i++) { + $parts[$i][1] = $getRank($parts, $i); + } + while (count($parts) > 1) { $minRank = PHP_INT_MAX; $partIndex = 0; @@ -155,26 +197,12 @@ function (int $i) use ($bytes): array { $res = []; for ($i = 0; $i < $stop; $i++) { - $piece = array_slice($bytes, $parts[$i][0], $parts[$i + 1][0] - $parts[$i][0]); - assert(count($piece) > 0); + $offset = $parts[$i][0]; + $length = $parts[$i + 1][0] - $offset; - $res[] = $this->vocab->getRank($piece); + $res[] = $this->vocab->getRank(substr($piece, $offset, $length)); } return $res; } - - private function pregErrorString($errorConstant) - { - static $errorMessages = [ - PREG_NO_ERROR => 'No error', - PREG_INTERNAL_ERROR => 'Internal error', - PREG_BACKTRACK_LIMIT_ERROR=> 'Backtrack limit error', - PREG_RECURSION_LIMIT_ERROR=> 'Recursion limit error', - PREG_BAD_UTF8_ERROR => 'Bad UTF8 error', - PREG_BAD_UTF8_OFFSET_ERROR=> 'Bad UTF8 offset error', - ]; - - return $errorMessages[$errorConstant] ?? 'Unknown error'; - } } diff --git a/src/EncoderProvider.php b/src/EncoderProvider.php index 103db80..1c0c882 100644 --- a/src/EncoderProvider.php +++ b/src/EncoderProvider.php @@ -17,31 +17,58 @@ final class EncoderProvider { + /** @var array> */ private const ENCODINGS = [ 'r50k_base' => [ 'vocab' => 'https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken', + 'hash' => '306cd27f03c1a714eca7108e03d66b7dc042abe8c258b44c199a7ed9838dd930', 'pat' => '/\'s|\'t|\'re|\'ve|\'m|\'ll|\'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+/u', ], 'p50k_base' => [ 'vocab' => 'https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken', + 'hash' => '94b5ca7dff4d00767bc256fdd1b27e5b17361d7b8a5f968547f9f23eb70d2069', 'pat' => '/\'s|\'t|\'re|\'ve|\'m|\'ll|\'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+/u', ], 'p50k_edit' => [ 'vocab' => 'https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken', + 'hash' => '94b5ca7dff4d00767bc256fdd1b27e5b17361d7b8a5f968547f9f23eb70d2069', 'pat' => '/\'s|\'t|\'re|\'ve|\'m|\'ll|\'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+/u', ], 'cl100k_base' => [ 'vocab' => 'https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken', + 'hash' => '223921b76ee99bde995b7ff738513eef100fb51d18c93597a113bcffe865b2a7', 'pat' => '/(?i:\'s|\'t|\'re|\'ve|\'m|\'ll|\'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+/u', ], + 'o200k_base' => [ + 'vocab' => 'https://openaipublic.blob.core.windows.net/encodings/o200k_base.tiktoken', + 'hash' => '446a9538cb6c348e3516120d7c08b09f57c36495e2acfffe59a5bf8b0cfb1a2d', + 'pat' => '/[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:\'s|\'t|\'re|\'ve|\'m|\'ll|\'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:\'s|\'t|\'re|\'ve|\'m|\'ll|\'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n\/]*|\s*[\r\n]+|\s+(?!\S)|\s+/u', + ], ]; + + /** @var array */ private const MODEL_PREFIX_TO_ENCODING = [ + 'o3-' => 'o200k_base', + 'o1-' => 'o200k_base', + 'chatgpt-4o-' => 'o200k_base', + 'gpt-4o-' => 'o200k_base', 'gpt-4-' => 'cl100k_base', 'gpt-3.5-turbo-' => 'cl100k_base', ]; + + /** @var array */ private const MODEL_TO_ENCODING = [ + 'o3' => 'o200k_base', + 'o1' => 'o200k_base', + 'gpt-4o' => 'o200k_base', 'gpt-4' => 'cl100k_base', 'gpt-3.5-turbo' => 'cl100k_base', + 'gpt-3.5' => 'cl100k_base', + 'davinci-002' => 'cl100k_base', + 'babbage-002' => 'cl100k_base', + 'text-embedding-ada-002' => 'cl100k_base', + 'text-embedding-3-small' => 'cl100k_base', + 'text-embedding-3-large' => 'cl100k_base', 'text-davinci-003' => 'p50k_base', 'text-davinci-002' => 'p50k_base', 'text-davinci-001' => 'r50k_base', @@ -60,7 +87,6 @@ final class EncoderProvider 'cushman-codex' => 'p50k_base', 'text-davinci-edit-001' => 'p50k_edit', 'code-davinci-edit-001' => 'p50k_edit', - 'text-embedding-ada-002' => 'cl100k_base', 'text-similarity-davinci-001' => 'r50k_base', 'text-similarity-curie-001' => 'r50k_base', 'text-similarity-babbage-001' => 'r50k_base', @@ -73,14 +99,17 @@ final class EncoderProvider 'code-search-ada-code-001' => 'r50k_base', ]; - private ?VocabLoader $vocabLoader = null; - private ?string $vocabCacheDir; + /** @var VocabLoader|null */ + private $vocabLoader = null; + + /** @var string|null */ + private $vocabCacheDir; - /** @var array */ - private array $encoders = []; + /** @var array */ + private $encoders = []; /** @var array */ - private array $vocabs = []; + private $vocabs = []; public function __construct() { @@ -93,7 +122,11 @@ public function __construct() $this->vocabCacheDir = $cacheDir !== '' ? $cacheDir : null; } - /** @param non-empty-string $model */ + /** + * @param string $model + * @return Encoder + * @throws InvalidArgumentException + */ public function getForModel(string $model): Encoder { if (isset(self::MODEL_TO_ENCODING[$model])) { @@ -109,45 +142,62 @@ public function getForModel(string $model): Encoder throw new InvalidArgumentException(sprintf('Unknown model name: %s', $model)); } - /** @param non-empty-string $encodingName */ + /** + * @param string $encodingName + * @return Encoder + * @throws InvalidArgumentException + */ public function get(string $encodingName): Encoder { - if (! isset(self::ENCODINGS[$encodingName])) { + if (!isset(self::ENCODINGS[$encodingName])) { throw new InvalidArgumentException(sprintf('Unknown encoding: %s', $encodingName)); } - if (! isset($this->encoders[$encodingName])) { + if (!isset($this->encoders[$encodingName])) { $options = self::ENCODINGS[$encodingName]; return $this->encoders[$encodingName] = new Encoder( $encodingName, $this->getVocab($encodingName), - $options['pat'], + $options['pat'] ); } return $this->encoders[$encodingName]; } - /** @param non-empty-string|null $cacheDir */ + /** + * @param string|null $cacheDir + * @return void + */ public function setVocabCache(?string $cacheDir): void { $this->vocabCacheDir = $cacheDir; $this->vocabLoader = null; } - /** @psalm-api */ + /** + * @param VocabLoader $loader + * @return void + */ public function setVocabLoader(VocabLoader $loader): void { $this->vocabLoader = $loader; } + /** + * @return void + */ public function reset(): void { $this->encoders = []; $this->vocabs = []; } + /** + * @param string $encodingName + * @return Vocab + */ private function getVocab(string $encodingName): Vocab { if (isset($this->vocabs[$encodingName])) { @@ -160,6 +210,9 @@ private function getVocab(string $encodingName): Vocab $loader = $this->vocabLoader = new DefaultVocabLoader($this->vocabCacheDir); } - return $this->vocabs[$encodingName] = $loader->load(self::ENCODINGS[$encodingName]['vocab']); + return $this->vocabs[$encodingName] = $loader->load( + self::ENCODINGS[$encodingName]['vocab'], + isset(self::ENCODINGS[$encodingName]['hash']) ? self::ENCODINGS[$encodingName]['hash'] : null + ); } } diff --git a/src/Util/EncodeUtil.php b/src/Util/EncodeUtil.php index 1fed38a..f8f4aa3 100644 --- a/src/Util/EncodeUtil.php +++ b/src/Util/EncodeUtil.php @@ -4,10 +4,9 @@ namespace guttedgarden\Tiktoken\Util; -use Closure; use function array_map; use function bin2hex; -use function pack; +use function hexdec; use function str_split; /** @psalm-type NonEmptyByteVector = non-empty-list> */ @@ -20,16 +19,8 @@ final class EncodeUtil */ public static function toBytes(string $text): array { - return array_map(Closure::fromCallable('hexdec'), str_split(bin2hex($text), 2)); - } - - /** - * @psalm-param NonEmptyByteVector $bytes - * - * @return non-empty-string - */ - public static function fromBytes(array $bytes): string - { - return pack('C*', ...$bytes); + return array_map(static function ($hex) { + return hexdec($hex); + }, str_split(bin2hex($text), 2)); } } diff --git a/src/Vocab/Loader/DefaultVocabLoader.php b/src/Vocab/Loader/DefaultVocabLoader.php index 54dc76c..0064c34 100644 --- a/src/Vocab/Loader/DefaultVocabLoader.php +++ b/src/Vocab/Loader/DefaultVocabLoader.php @@ -12,47 +12,52 @@ use function fclose; use function file_exists; use function fopen; +use function hash_equals; +use function hash_final; +use function hash_init; +use function hash_update_file; +use function hash_update_stream; use function is_dir; +use function is_resource; use function is_writable; use function mkdir; -use function preg_match; +use function rewind; use function sha1; use function sprintf; use function stream_copy_to_stream; +use function stream_get_meta_data; use const DIRECTORY_SEPARATOR; final class DefaultVocabLoader implements VocabLoader { - private ?string $cacheDir; + /** @var string|null */ + private $cacheDir; + public function __construct(?string $cacheDir = null) { $this->cacheDir = $cacheDir; } - public function load(string $uri): Vocab + public function load(string $uri, ?string $checksum = null): Vocab { - if ($this->cacheDir !== null && preg_match('@^https?://@i', $uri)) { - $cacheFile = $this->cacheDir . DIRECTORY_SEPARATOR . sha1($uri); - } else { - $cacheFile = null; - } + $cacheFile = $this->cacheDir !== null ? $this->cacheDir . DIRECTORY_SEPARATOR . sha1($uri) : null; if ($cacheFile !== null) { - if (file_exists($cacheFile)) { + if (file_exists($cacheFile) && $this->checkHash($cacheFile, $checksum)) { return Vocab::fromFile($cacheFile); } assert($this->cacheDir !== null); - if (! is_dir($this->cacheDir) && ! @mkdir($this->cacheDir, 0750, true)) { + if (!is_dir($this->cacheDir) && !@mkdir($this->cacheDir, 0750, true)) { throw new RuntimeException(sprintf( 'Directory does not exist and cannot be created: %s', - $this->cacheDir, + $this->cacheDir )); } - if (! is_writable($this->cacheDir)) { + if (!is_writable($this->cacheDir)) { throw new RuntimeException(sprintf('Directory is not writable: %s', $this->cacheDir)); } } @@ -64,6 +69,17 @@ public function load(string $uri): Vocab } try { + if ($checksum !== null && $this->isRewindable($stream)) { + if (!$this->checkHash($stream, $checksum)) { + throw new RuntimeException(sprintf( + 'Checksum failed. Could not load vocab from URI: %s', + $uri + )); + } + + rewind($stream); + } + if ($cacheFile !== null) { $cacheStream = fopen($cacheFile, 'w+'); @@ -85,4 +101,32 @@ public function load(string $uri): Vocab fclose($stream); } } + + /** @param string|resource $resource */ + private function checkHash($resource, ?string $expectedHash): bool + { + if ($expectedHash === null) { + return true; + } + + $ctx = hash_init('sha256'); + + if (is_resource($resource)) { + hash_update_stream($ctx, $resource); + } else { + hash_update_file($ctx, $resource); + } + + $hash = hash_final($ctx); + + return hash_equals($hash, $expectedHash); + } + + /** @param resource $stream */ + private function isRewindable($stream): bool + { + $meta = stream_get_meta_data($stream); + + return $meta['seekable']; + } } diff --git a/src/Vocab/Vocab.php b/src/Vocab/Vocab.php index 486d7b8..22b4a26 100644 --- a/src/Vocab/Vocab.php +++ b/src/Vocab/Vocab.php @@ -4,7 +4,6 @@ namespace guttedgarden\Tiktoken\Vocab; -use Closure; use Countable; use InvalidArgumentException; use OutOfBoundsException; @@ -26,22 +25,25 @@ use function rewind; use function sprintf; use function stream_get_meta_data; +use function strval; /** @psalm-import-type NonEmptyByteVector from EncodeUtil */ final class Vocab implements Countable { /** @var array */ - private array $tokenToRankMap; + private $tokenToRankMap; /** @var array */ - private array $rankToTokenMap; + private $rankToTokenMap; /** @param array $tokenRankMap */ private function __construct(array $tokenRankMap) { $this->tokenToRankMap = $tokenRankMap; /** @psalm-suppress PropertyTypeCoercion */ - $this->rankToTokenMap = array_map(Closure::fromCallable('strval'), array_flip($tokenRankMap)); + $this->rankToTokenMap = array_map(static function ($value) { + return strval($value); + }, array_flip($tokenRankMap)); if (count($this->tokenToRankMap) !== count($this->rankToTokenMap)) { throw new InvalidArgumentException('The map of tokens and ranks has duplicates of rank'); @@ -51,7 +53,7 @@ private function __construct(array $tokenRankMap) /** @param non-empty-string $bpeFile */ public static function fromFile(string $bpeFile): self { - if (! file_exists($bpeFile)) { + if (!file_exists($bpeFile)) { throw new RuntimeException(sprintf('File "%s" does not exist', $bpeFile)); } @@ -71,7 +73,7 @@ public static function fromFile(string $bpeFile): self /** * @param resource $stream * - * @return static + * @return self */ public static function fromStream($stream): self { @@ -87,7 +89,7 @@ public static function fromStream($stream): self while ($line !== false) { [$encodedToken, $rank] = explode(' ', $line); - $token = base64_decode($encodedToken); + $token = base64_decode($encodedToken, true); if ($token === false) { throw new ParseError(sprintf('Could not decode token "%s" at line %d', $encodedToken, $lineNo)); @@ -104,44 +106,49 @@ public static function fromStream($stream): self return new self($map); } - /** @psalm-param NonEmptyByteVector $bytes */ - public function tryGetRank(array $bytes): ?int + /** + * @param string $binary + * @return int|null + */ + public function tryGetRank(string $binary): ?int { - return $this->tokenToRankMap[EncodeUtil::fromBytes($bytes)] ?? null; + if ($binary === '') { + throw new InvalidArgumentException('Argument $binary cannot be an empty string'); + } + + return $this->tokenToRankMap[$binary] ?? null; } - /** - * @psalm-param NonEmptyByteVector $bytes - * - * @throws OutOfBoundsException - */ - public function getRank(array $bytes): int + /** @throws OutOfBoundsException */ + public function getRank(string $binary): int { - $key = EncodeUtil::fromBytes($bytes); - if (isset($this->tokenToRankMap[$key])) { - return $this->tokenToRankMap[$key]; - } else { + if ($binary === '') { + throw new InvalidArgumentException('Argument $binary cannot be an empty string'); + } + + if (!isset($this->tokenToRankMap[$binary])) { throw new OutOfBoundsException(sprintf( 'No rank for bytes vector: [%s]', - implode(', ', $bytes) + implode(', ', EncodeUtil::toBytes($binary)) )); } + + return $this->tokenToRankMap[$binary]; } /** - * @return non-empty-string - * + * @param int $rank + * @return string * @throws OutOfBoundsException */ public function getToken(int $rank): string { - if (isset($this->rankToTokenMap[$rank])) { - return $this->rankToTokenMap[$rank]; - } else { + if (!isset($this->rankToTokenMap[$rank])) { throw new OutOfBoundsException(sprintf('No token for rank: %d', $rank)); } - } + return $this->rankToTokenMap[$rank]; + } /** @psalm-api */ public function count(): int diff --git a/src/Vocab/VocabLoader.php b/src/Vocab/VocabLoader.php index db75d0e..045fe68 100644 --- a/src/Vocab/VocabLoader.php +++ b/src/Vocab/VocabLoader.php @@ -6,6 +6,10 @@ interface VocabLoader { - /** @param non-empty-string $uri */ - public function load(string $uri): Vocab; + /** + * @param string $uri + * @param string|null $checksum + * @return Vocab + */ + public function load(string $uri, ?string $checksum = null): Vocab; }