Skip to content

Commit

Permalink
feat(llm):
Browse files Browse the repository at this point in the history
- concurrent request
- update llm functions
  • Loading branch information
MorvanZhou committed Jul 18, 2024
1 parent 83c21f8 commit 0b439a0
Show file tree
Hide file tree
Showing 23 changed files with 699 additions and 282 deletions.
3 changes: 3 additions & 0 deletions src/retk/const/response_codes.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class CodeEnum(IntEnum):
LLM_SERVICE_ERROR = 39
LLM_NO_CHOICE = 40
LLM_INVALID_RESPONSE_FORMAT = 41
LLM_API_LIMIT_EXCEEDED = 42


@dataclass
Expand Down Expand Up @@ -110,6 +111,7 @@ class CodeMessage:
CodeEnum.LLM_SERVICE_ERROR: CodeMessage(zh="模型服务错误", en="Model service error"),
CodeEnum.LLM_NO_CHOICE: CodeMessage(zh="无回复", en="No response"),
CodeEnum.LLM_INVALID_RESPONSE_FORMAT: CodeMessage(zh="无效的回复格式", en="Invalid response format"),
CodeEnum.LLM_API_LIMIT_EXCEEDED: CodeMessage(zh="LLM API 调用次数超过限制", en="LLM API call limit exceeded"),
}

CODE2STATUS_CODE: Dict[CodeEnum, int] = {
Expand Down Expand Up @@ -155,6 +157,7 @@ class CodeMessage:
CodeEnum.LLM_SERVICE_ERROR: 500,
CodeEnum.LLM_NO_CHOICE: 404,
CodeEnum.LLM_INVALID_RESPONSE_FORMAT: 500,
CodeEnum.LLM_API_LIMIT_EXCEEDED: 429,
}


Expand Down
53 changes: 41 additions & 12 deletions src/retk/core/ai/llm/api/aliyun.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import asyncio
import json
from enum import Enum
from typing import Tuple, AsyncIterable, Optional, Dict
from typing import Tuple, AsyncIterable, Optional, Dict, List

from retk import config, const
from retk.core.utils import ratelimiter
from retk.logger import logger
from .base import BaseLLMService, MessagesType, NoAPIKeyError, ModelConfig

Expand Down Expand Up @@ -49,11 +51,14 @@ class AliyunModelEnum(Enum):
)


_key2model: Dict[str, AliyunModelEnum] = {m.value.key: m for m in AliyunModelEnum}


class AliyunService(BaseLLMService):
def __init__(
self,
top_p: float = 0.9,
temperature: float = 0.7,
temperature: float = 0.4,
timeout: float = 60.,
):
super().__init__(
Expand All @@ -63,10 +68,7 @@ def __init__(
timeout=timeout,
default_model=AliyunModelEnum.QWEN1_5_05B.value,
)

@staticmethod
def get_concurrency() -> int:
return 1
self.concurrency = 5

@staticmethod
def get_headers(stream: bool) -> Dict[str, str]:
Expand Down Expand Up @@ -113,10 +115,14 @@ async def complete(
)
if code != const.CodeEnum.OK:
return "Aliyun model error, please try later", code
if rj.get("code") is not None:
logger.error(f"ReqId={req_id} Aliyun model error: code={rj['code']} {rj['message']}")
rcode = rj.get("code")
if rcode is not None:
logger.error(f"ReqId={req_id} | Aliyun {model} | error: code={rj['code']} {rj['message']}")
if rcode == "Throttling.RateQuota":
return "Aliyun model rate limit exceeded", const.CodeEnum.LLM_API_LIMIT_EXCEEDED
return "Aliyun model error, please try later", const.CodeEnum.LLM_SERVICE_ERROR
logger.info(f"ReqId={req_id} Aliyun model usage: {rj['usage']}")

logger.info(f"ReqId={req_id} | Aliyun {model} | usage: {rj['usage']}")
return rj["output"]["choices"][0]["message"]["content"], const.CodeEnum.OK

async def stream_complete(
Expand Down Expand Up @@ -144,16 +150,39 @@ async def stream_complete(
try:
json_str = s[5:]
except IndexError:
logger.error(f"ReqId={req_id} Aliyun model stream error: string={s}")
logger.error(f"ReqId={req_id} | Aliyun {model} | stream error: string={s}")
continue
try:
json_data = json.loads(json_str)
except json.JSONDecodeError as e:
logger.error(f"ReqId={req_id} Aliyun model stream error: string={s}, error={e}")
logger.error(f"ReqId={req_id} | Aliyun {model} | stream error: string={s}, error={e}")
continue
choice = json_data["output"]["choices"][0]
if choice["finish_reason"] != "null":
logger.info(f"ReqId={req_id} Aliyun model usage: {json_data['usage']}")
logger.info(f"ReqId={req_id} | Aliyun {model} | usage: {json_data['usage']}")
break
txt += choice["message"]["content"]
yield txt.encode("utf-8"), code

async def batch_complete(
self,
messages: List[MessagesType],
model: str = None,
req_id: str = None,
) -> List[Tuple[str, const.CodeEnum]]:
if model is None:
m = self.default_model
else:
m = _key2model[model].value
concurrent_limiter = ratelimiter.ConcurrentLimiter(n=self.concurrency)
rate_limiter = ratelimiter.RateLimiter(requests=m.RPM, period=60)

tasks = [
self._batch_complete(
limiters=[concurrent_limiter, rate_limiter],
messages=m,
model=model,
req_id=req_id,
) for m in messages
]
return await asyncio.gather(*tasks)
50 changes: 37 additions & 13 deletions src/retk/core/ai/llm/api/baidu.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import asyncio
import json
from datetime import datetime
from enum import Enum
from typing import Tuple, AsyncIterable
from typing import Tuple, AsyncIterable, List, Dict

import httpx

from retk import config, const
from retk.core.utils import ratelimiter
from retk.logger import logger
from .base import BaseLLMService, MessagesType, NoAPIKeyError, ModelConfig

Expand Down Expand Up @@ -56,11 +58,14 @@ class BaiduModelEnum(Enum):
) # free


_key2model: Dict[str, BaiduModelEnum] = {m.value.key: m for m in BaiduModelEnum}


class BaiduService(BaseLLMService):
def __init__(
self,
top_p: float = 0.9,
temperature: float = 0.7,
temperature: float = 0.4,
timeout: float = 60.,
):
super().__init__(
Expand All @@ -70,17 +75,14 @@ def __init__(
timeout=timeout,
default_model=BaiduModelEnum.ERNIE_SPEED_8K.value,
)

self.headers = {
"Content-Type": "application/json",
}

self.token_expires_at = datetime.now().timestamp()
self.token = ""

@staticmethod
def get_concurrency() -> int:
return 9999

async def set_token(self, req_id: str = None):
_s = config.get_settings()
if _s.BAIDU_QIANFAN_API_KEY == "" or _s.BAIDU_QIANFAN_SECRET_KEY == "":
Expand All @@ -101,11 +103,11 @@ async def set_token(self, req_id: str = None):
}
)
if resp.status_code != 200:
logger.error(f"ReqId={req_id} Baidu model error: {resp.text}")
logger.error(f"ReqId={req_id} | Baidu | error: {resp.text}")
return ""
rj = resp.json()
if rj.get("error") is not None:
logger.error(f"ReqId={req_id} Baidu model token error: {rj['error_description']}")
logger.error(f"ReqId={req_id} | Baidu | token error: {rj['error_description']}")
return ""

self.token_expires_at = rj["expires_in"] + datetime.now().timestamp()
Expand Down Expand Up @@ -148,9 +150,9 @@ async def complete(
return "Model error, please try later", code

if resp.get("error_code") is not None:
logger.error(f"ReqId={req_id} Baidu model error: code={resp['error_code']} {resp['error_msg']}")
logger.error(f"ReqId={req_id} | Baidu {model} | error: code={resp['error_code']} {resp['error_msg']}")
return resp["error_msg"], const.CodeEnum.LLM_SERVICE_ERROR
logger.info(f"ReqId={req_id} Baidu model usage: {resp['usage']}")
logger.info(f"ReqId={req_id} | Baidu {model} | usage: {resp['usage']}")
return resp["result"], const.CodeEnum.OK

async def stream_complete(
Expand Down Expand Up @@ -183,16 +185,38 @@ async def stream_complete(
try:
json_str = s[6:]
except IndexError:
logger.error(f"ReqId={req_id} Baidu model stream error: string={s}")
logger.error(f"ReqId={req_id} | Baidu {model} | stream error: string={s}")
continue
try:
json_data = json.loads(json_str)
except json.JSONDecodeError as e:
logger.error(f"ReqId={req_id} Baidu model stream error: string={s}, error={e}")
logger.error(f"ReqId={req_id} | Baidu {model} | stream error: string={s}, error={e}")
continue

if json_data["is_end"]:
logger.info(f"ReqId={req_id} Baidu model usage: {json_data['usage']}")
logger.info(f"ReqId={req_id} | Baidu {model} | usage: {json_data['usage']}")
break
txt += json_data["result"]
yield txt.encode("utf-8"), code

async def batch_complete(
self,
messages: List[MessagesType],
model: str = None,
req_id: str = None,
) -> List[Tuple[str, const.CodeEnum]]:
if model is None:
m = self.default_model
else:
m = _key2model[model].value
limiter = ratelimiter.RateLimiter(requests=m.RPM, period=60)

tasks = [
self._batch_complete(
limiters=[limiter],
messages=m,
model=model,
req_id=req_id,
) for m in messages
]
return await asyncio.gather(*tasks)
93 changes: 34 additions & 59 deletions src/retk/core/ai/llm/api/base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import asyncio
import datetime as dt
from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import wraps
from dataclasses import dataclass, field
from typing import List, Dict, Literal, AsyncIterable, Tuple, Optional, Union

import httpx

from retk import const
from retk.core.utils import ratelimiter
from retk.logger import logger

MessagesType = List[Dict[Literal["role", "content"], str]]
Expand All @@ -17,8 +15,9 @@
class ModelConfig:
key: str
max_tokens: int
RPM: Optional[int] = None
TPM: Optional[int] = None
RPM: int = field(default=999999)
RPD: int = field(default=9999999999)
TPM: int = field(default=9999999999)


class NoAPIKeyError(Exception):
Expand All @@ -32,17 +31,15 @@ def __init__(
self,
endpoint: str,
top_p: float = 1.,
temperature: float = 1.,
temperature: float = 0.4,
timeout: float = None,
default_model: Optional[ModelConfig] = None,
concurrency: int = -1,
):
self.top_p = top_p
self.temperature = temperature
self.timeout = self.default_timeout if timeout is not None else timeout
self.default_model: Optional[ModelConfig] = default_model
self.endpoint = endpoint
self.concurrency = concurrency

async def _complete(
self,
Expand Down Expand Up @@ -118,6 +115,28 @@ async def _stream_complete(
yield chunk, const.CodeEnum.OK
await client.aclose()

async def _batch_complete(
self,
limiters: List[Union[ratelimiter.RateLimiter, ratelimiter.ConcurrentLimiter]],
messages: MessagesType,
model: str = None,
req_id: str = None,
) -> Tuple[str, const.CodeEnum]:
if len(limiters) == 4:
async with limiters[0], limiters[1], limiters[2], limiters[3]:
return await self.complete(messages=messages, model=model, req_id=req_id)
elif len(limiters) == 3:
async with limiters[0], limiters[1], limiters[2]:
return await self.complete(messages=messages, model=model, req_id=req_id)
elif len(limiters) == 2:
async with limiters[0], limiters[1]:
return await self.complete(messages=messages, model=model, req_id=req_id)
elif len(limiters) == 1:
async with limiters[0]:
return await self.complete(messages=messages, model=model, req_id=req_id)
else:
raise ValueError("Invalid number of limiters, should less than 4")

@abstractmethod
async def stream_complete(
self,
Expand All @@ -127,55 +146,11 @@ async def stream_complete(
) -> AsyncIterable[Tuple[bytes, const.CodeEnum]]:
...

@staticmethod
@abstractmethod
def get_concurrency() -> int:
...


# unless you keep a strong reference to a running task, it can be dropped during execution
# https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
_background_tasks = set()


class RateLimitedClient(httpx.AsyncClient):
"""httpx.AsyncClient with a rate limit."""

def __init__(
async def batch_complete(
self,
interval: Union[dt.timedelta, float],
count=1,
**kwargs
):
"""
Parameters
----------
interval : Union[dt.timedelta, float]
Length of interval.
If a float is given, seconds are assumed.
numerator : int, optional
Number of requests which can be sent in any given interval (default 1).
"""
if isinstance(interval, dt.timedelta):
interval = interval.total_seconds()

self.interval = interval
self.semaphore = asyncio.Semaphore(count)
super().__init__(**kwargs)

def _schedule_semaphore_release(self):
wait = asyncio.create_task(asyncio.sleep(self.interval))
_background_tasks.add(wait)

def wait_cb(task):
self.semaphore.release()
_background_tasks.discard(task)

wait.add_done_callback(wait_cb)

@wraps(httpx.AsyncClient.send)
async def send(self, *args, **kwargs):
await self.semaphore.acquire()
send = asyncio.create_task(super().send(*args, **kwargs))
self._schedule_semaphore_release()
return await send
messages: List[MessagesType],
model: str = None,
req_id: str = None,
) -> List[Tuple[str, const.CodeEnum]]:
...
Loading

0 comments on commit 0b439a0

Please sign in to comment.