Skip to content

Commit

Permalink
Fork update + o3 model implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Pavel Alekseev committed Feb 5, 2025
1 parent f382ea9 commit aa5fcf5
Show file tree
Hide file tree
Showing 7 changed files with 250 additions and 124 deletions.
1 change: 0 additions & 1 deletion composer.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
}
},
"minimum-stability": "stable",
"version": "1.0.0",
"config": {
"sort-packages": true,
"allow-plugins": {
Expand Down
138 changes: 83 additions & 55 deletions src/Encoder.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,44 @@

namespace guttedgarden\Tiktoken;

use Closure;
use guttedgarden\Tiktoken\Exception\RegexError;
use guttedgarden\Tiktoken\Util\EncodeUtil;
use guttedgarden\Tiktoken\Vocab\Vocab;

use function array_map;
use function array_slice;
use function array_merge;
use function array_values;
use function assert;
use function count;
use function implode;
use function preg_last_error;
use function preg_match_all;
use function range;
use function sprintf;
use function strlen;
use function substr;

use const PHP_INT_MAX;

/** @psalm-import-type NonEmptyByteVector from EncodeUtil */
final class Encoder
{
private string $name;
private Vocab $vocab;
private string $pattern;
/**
* @param non-empty-string $name
* @param non-empty-string $pattern
* @var string
*/
private $name;

/**
* @var Vocab
*/
private $vocab;

/**
* @var string
*/
private $pattern;

/**
* @param string $name
* @param string $pattern
*/
public function __construct(string $name, Vocab $vocab, string $pattern)
{
Expand All @@ -43,15 +55,15 @@ public function __toString(): string
return sprintf('Encoder(name="%s", vocab=%d)', $this->name, count($this->vocab));
}

/** @return list<int> */
/** @return int[] */
public function encode(string $text): array
{
if ($text === '') {
return [];
}

if (preg_match_all($this->pattern, $text, $matches) === false) {
throw new RegexError(sprintf('Matching failed with error: %s', $this->pregErrorString(preg_last_error())));
throw new RegexError(sprintf('Matching failed with error code: %d', preg_last_error()));
}

$tokens = [];
Expand All @@ -61,66 +73,96 @@ public function encode(string $text): array
continue;
}

$piece = EncodeUtil::toBytes($match);
$rank = $this->vocab->tryGetRank($piece);
$rank = $this->vocab->tryGetRank($match);

if ($rank !== null) {
$tokens[] = $rank;

continue;
}

foreach ($this->mergeBytePairs($piece) as $rank) {
foreach ($this->mergeBytePairs($match) as $rank) {
$tokens[] = $rank;
}
}

return $tokens;
}

/** @param array<int> $tokens */
/** @return int[][] */
public function encodeInChunks(string $text, int $maxTokensPerChunk): array
{
if ($text === '') {
return [];
}

if (preg_match_all($this->pattern, $text, $matches) === false) {
throw new RegexError(sprintf('Matching failed with error code: %d', preg_last_error()));
}

$chunks = [];
$tokensInCurrentChunk = [];

foreach ($matches[0] as $match) {
if ($match === '') {
continue;
}

$rank = $this->vocab->tryGetRank($match);
$tokens = $rank !== null ? [$rank] : $this->mergeBytePairs($match);

if (count($tokensInCurrentChunk) + count($tokens) > $maxTokensPerChunk) {
$chunks[] = $tokensInCurrentChunk;
$tokensInCurrentChunk = [];
}

$tokensInCurrentChunk = array_merge($tokensInCurrentChunk, $tokens);
}

if (count($tokensInCurrentChunk) > 0) {
$chunks[] = $tokensInCurrentChunk;
}

return $chunks;
}

/** @param int[] $tokens */
public function decode(array $tokens): string
{
if ($tokens === []) {
return '';
}

return implode(array_map(Closure::fromCallable([$this->vocab, 'getToken']), $tokens));
return implode(array_map([$this->vocab, 'getToken'], $tokens));
}

/**
* @psalm-param NonEmptyByteVector $bytes
* @param string $piece
*
* @return list<int>
* @return int[]
*/
private function mergeBytePairs(array $bytes): array
private function mergeBytePairs(string $piece): array
{
/** @var list<array{int, int}> $parts */
$parts = array_map(
function (int $i) use ($bytes): array {
if ($i + 1 < count($bytes)) {
$piece = array_slice($bytes, $i, 2);
assert(count($piece) === 2);

return [$i, $this->vocab->tryGetRank($piece) ?? PHP_INT_MAX];
}
$parts = [];

for ($i = 0; $i <= strlen($piece); $i++) {
$parts[] = [$i, PHP_INT_MAX];
}

return [$i, PHP_INT_MAX];
},
range(0, count($bytes)),
);
$getRank = function (array $parts, int $startIndex) use ($bytes): int {
if ($startIndex + 2 >= count($parts)) {
$getRank = function (array $parts, int $startIndex, int $skip = 0) use ($piece): int {
if (($startIndex + $skip + 2) >= count($parts)) {
return PHP_INT_MAX;
}

$offset = $parts[$startIndex][0];
$piece = array_slice($bytes, $offset, $parts[$startIndex + 2][0] - $offset);
assert(count($piece) > 0);
$length = $parts[$startIndex + $skip + 2][0] - $offset;

return $this->vocab->tryGetRank($piece) ?? PHP_INT_MAX;
return $this->vocab->tryGetRank(substr($piece, $offset, $length)) ?? PHP_INT_MAX;
};

for ($i = 0; $i < count($parts) - 2; $i++) {
$parts[$i][1] = $getRank($parts, $i);
}

while (count($parts) > 1) {
$minRank = PHP_INT_MAX;
$partIndex = 0;
Expand Down Expand Up @@ -155,26 +197,12 @@ function (int $i) use ($bytes): array {
$res = [];

for ($i = 0; $i < $stop; $i++) {
$piece = array_slice($bytes, $parts[$i][0], $parts[$i + 1][0] - $parts[$i][0]);
assert(count($piece) > 0);
$offset = $parts[$i][0];
$length = $parts[$i + 1][0] - $offset;

$res[] = $this->vocab->getRank($piece);
$res[] = $this->vocab->getRank(substr($piece, $offset, $length));
}

return $res;
}

private function pregErrorString($errorConstant)
{
static $errorMessages = [
PREG_NO_ERROR => 'No error',
PREG_INTERNAL_ERROR => 'Internal error',
PREG_BACKTRACK_LIMIT_ERROR=> 'Backtrack limit error',
PREG_RECURSION_LIMIT_ERROR=> 'Recursion limit error',
PREG_BAD_UTF8_ERROR => 'Bad UTF8 error',
PREG_BAD_UTF8_OFFSET_ERROR=> 'Bad UTF8 offset error',
];

return $errorMessages[$errorConstant] ?? 'Unknown error';
}
}
Loading

0 comments on commit aa5fcf5

Please sign in to comment.