Current File : /home/mdkeenpw/www/wp-content/plugins/ai-engine/vendor/yethee/tiktoken/src/Encoder.php
<?php

declare(strict_types=1);

namespace Yethee\Tiktoken;

use Stringable;
use Yethee\Tiktoken\Exception\RegexError;
use Yethee\Tiktoken\Util\EncodeUtil;
use Yethee\Tiktoken\Vocab\Vocab;

use function array_map;
use function array_merge;
use function array_values;
use function count;
use function implode;
use function preg_last_error_msg;
use function preg_match_all;
use function sprintf;
use function strlen;
use function substr;

use const PHP_INT_MAX;

/** @psalm-import-type NonEmptyByteVector from EncodeUtil */
final class Encoder implements Stringable
{
    /**
     * @param non-empty-string $name
     * @param non-empty-string $pattern
     */
    public function __construct(public readonly string $name, private Vocab $vocab, private string $pattern)
    {
    }

    public function __toString(): string
    {
        return sprintf('Encoder(name="%s", vocab=%d)', $this->name, count($this->vocab));
    }

    /** @return list<int> */
    public function encode(string $text): array
    {
        if ($text === '') {
            return [];
        }

        if (preg_match_all($this->pattern, $text, $matches) === false) {
            throw new RegexError(sprintf('Matching failed with error: %s', preg_last_error_msg()));
        }

        $tokens = [];

        foreach ($matches[0] as $match) {
            if ($match === '') {
                continue;
            }

            $rank = $this->vocab->tryGetRank($match);

            if ($rank !== null) {
                $tokens[] = $rank;

                continue;
            }

            foreach ($this->mergeBytePairs($match) as $rank) {
                $tokens[] = $rank;
            }
        }

        return $tokens;
    }

    /**
     * Encodes a given text into chunks of Byte-Pair Encoded (BPE) tokens, with each chunk containing a specified
     * maximum number of tokens.
     *
     * @param string       $text              The input text to be encoded.
     * @param positive-int $maxTokensPerChunk The maximum number of tokens allowed per chunk.
     *
     * @return list<list<int>> An array of arrays containing BPE token chunks.
     */
    public function encodeInChunks(string $text, int $maxTokensPerChunk): array
    {
        if ($text === '') {
            return [];
        }

        if (preg_match_all($this->pattern, $text, $matches) === false) {
            throw new RegexError(sprintf('Matching failed with error: %s', preg_last_error_msg()));
        }

        $chunks = [];
        $tokensInCurrentChunk = [];

        foreach ($matches[0] as $match) {
            if ($match === '') {
                continue;
            }

            $rank = $this->vocab->tryGetRank($match);
            $tokens = $rank !== null ? [$rank] : $this->mergeBytePairs($match);

            if (count($tokensInCurrentChunk) + count($tokens) > $maxTokensPerChunk) {
                $chunks[] = $tokensInCurrentChunk;
                $tokensInCurrentChunk = [];
            }

            $tokensInCurrentChunk = array_merge($tokensInCurrentChunk, $tokens);
        }

        if (count($tokensInCurrentChunk) > 0) {
            $chunks[] = $tokensInCurrentChunk;
        }

        return $chunks;
    }

    /** @param array<int> $tokens */
    public function decode(array $tokens): string
    {
        if ($tokens === []) {
            return '';
        }

        return implode(array_map($this->vocab->getToken(...), $tokens));
    }

    /**
     * @param non-empty-string $piece
     *
     * @return list<int>
     */
    private function mergeBytePairs(string $piece): array
    {
        $parts = [];

        for ($i = 0; $i <= strlen($piece); $i++) {
            $parts[] = [$i, PHP_INT_MAX];
        }

        $getRank = function (array $parts, int $startIndex, int $skip = 0) use (&$piece): int {
            if (($startIndex + $skip + 2) >= count($parts)) {
                return PHP_INT_MAX;
            }

            $offset = $parts[$startIndex][0];
            $length = $parts[$startIndex + $skip + 2][0] - $offset;

            return $this->vocab->tryGetRank(substr($piece, $offset, $length)) ?? PHP_INT_MAX;
        };

        for ($i = 0; $i < count($parts) - 2; $i++) {
            $parts[$i][1] = $getRank($parts, $i);
        }

        while (count($parts) > 1) {
            $minRank = PHP_INT_MAX;
            $partIndex = 0;
            $stop = count($parts) - 1;

            for ($i = 0; $i < $stop; $i++) {
                if ($minRank <= $parts[$i][1]) {
                    continue;
                }

                $minRank = $parts[$i][1];
                $partIndex = $i;
            }

            if ($minRank === PHP_INT_MAX) {
                break;
            }

            unset($parts[$partIndex + 1]);
            $parts = array_values($parts);

            $parts[$partIndex][1] = $getRank($parts, $partIndex);

            if ($partIndex <= 0) {
                continue;
            }

            $parts[$partIndex - 1][1] = $getRank($parts, $partIndex - 1);
        }

        $stop = count($parts) - 1;
        $res = [];

        for ($i = 0; $i < $stop; $i++) {
            $offset = $parts[$i][0];
            $length = $parts[$i + 1][0] - $offset;

            $res[] = $this->vocab->getRank(substr($piece, $offset, $length));
        }

        return $res;
    }
}