From 0b439a033b7d49a646f7d61230f0dcf1f1daacbc Mon Sep 17 00:00:00 2001 From: morvanzhou Date: Thu, 18 Jul 2024 22:07:35 +0800 Subject: [PATCH] feat(llm): - concurrent request - update llm functions --- src/retk/const/response_codes.py | 3 + src/retk/core/ai/llm/api/aliyun.py | 53 +++++-- src/retk/core/ai/llm/api/baidu.py | 50 ++++-- src/retk/core/ai/llm/api/base.py | 93 ++++------- src/retk/core/ai/llm/api/moonshot.py | 27 +++- src/retk/core/ai/llm/api/openai.py | 53 +++++-- src/retk/core/ai/llm/api/tencent.py | 41 +++-- src/retk/core/ai/llm/api/xfyun.py | 37 +++-- src/retk/core/ai/llm/knowledge/__init__.py | 4 +- src/retk/core/ai/llm/knowledge/extending.py | 13 -- src/retk/core/ai/llm/knowledge/ops.py | 146 +++++++++++++----- .../core/ai/llm/knowledge/system_extend.md | 16 +- .../core/ai/llm/knowledge/system_summary.md | 7 +- src/retk/core/ai/llm/knowledge/utils.py | 12 +- src/retk/core/scheduler/schedule.py | 2 +- src/retk/core/scheduler/tasks/extend_node.py | 72 +++++---- src/retk/core/utils/__init__.py | 0 src/retk/core/utils/ratelimiter.py | 42 +++++ tests/test_ai_llm_api.py | 62 ++++++++ tests/test_ai_llm_knowledge.py | 81 +++++++--- tests/test_core_local.py | 42 ++--- tests/test_core_remote.py | 28 ++-- tests/test_core_utils.py | 97 ++++++++++++ 23 files changed, 699 insertions(+), 282 deletions(-) create mode 100644 src/retk/core/utils/__init__.py create mode 100644 src/retk/core/utils/ratelimiter.py create mode 100644 tests/test_core_utils.py diff --git a/src/retk/const/response_codes.py b/src/retk/const/response_codes.py index 5d78911..28a5220 100644 --- a/src/retk/const/response_codes.py +++ b/src/retk/const/response_codes.py @@ -49,6 +49,7 @@ class CodeEnum(IntEnum): LLM_SERVICE_ERROR = 39 LLM_NO_CHOICE = 40 LLM_INVALID_RESPONSE_FORMAT = 41 + LLM_API_LIMIT_EXCEEDED = 42 @dataclass @@ -110,6 +111,7 @@ class CodeMessage: CodeEnum.LLM_SERVICE_ERROR: CodeMessage(zh="模型服务错误", en="Model service error"), CodeEnum.LLM_NO_CHOICE: CodeMessage(zh="无回复", en="No response"), CodeEnum.LLM_INVALID_RESPONSE_FORMAT: CodeMessage(zh="无效的回复格式", en="Invalid response format"), + CodeEnum.LLM_API_LIMIT_EXCEEDED: CodeMessage(zh="LLM API 调用次数超过限制", en="LLM API call limit exceeded"), } CODE2STATUS_CODE: Dict[CodeEnum, int] = { @@ -155,6 +157,7 @@ class CodeMessage: CodeEnum.LLM_SERVICE_ERROR: 500, CodeEnum.LLM_NO_CHOICE: 404, CodeEnum.LLM_INVALID_RESPONSE_FORMAT: 500, + CodeEnum.LLM_API_LIMIT_EXCEEDED: 429, } diff --git a/src/retk/core/ai/llm/api/aliyun.py b/src/retk/core/ai/llm/api/aliyun.py index 5dbd5e3..a62a8d8 100644 --- a/src/retk/core/ai/llm/api/aliyun.py +++ b/src/retk/core/ai/llm/api/aliyun.py @@ -1,8 +1,10 @@ +import asyncio import json from enum import Enum -from typing import Tuple, AsyncIterable, Optional, Dict +from typing import Tuple, AsyncIterable, Optional, Dict, List from retk import config, const +from retk.core.utils import ratelimiter from retk.logger import logger from .base import BaseLLMService, MessagesType, NoAPIKeyError, ModelConfig @@ -49,11 +51,14 @@ class AliyunModelEnum(Enum): ) +_key2model: Dict[str, AliyunModelEnum] = {m.value.key: m for m in AliyunModelEnum} + + class AliyunService(BaseLLMService): def __init__( self, top_p: float = 0.9, - temperature: float = 0.7, + temperature: float = 0.4, timeout: float = 60., ): super().__init__( @@ -63,10 +68,7 @@ def __init__( timeout=timeout, default_model=AliyunModelEnum.QWEN1_5_05B.value, ) - - @staticmethod - def get_concurrency() -> int: - return 1 + self.concurrency = 5 @staticmethod def get_headers(stream: bool) -> Dict[str, str]: @@ -113,10 +115,14 @@ async def complete( ) if code != const.CodeEnum.OK: return "Aliyun model error, please try later", code - if rj.get("code") is not None: - logger.error(f"ReqId={req_id} Aliyun model error: code={rj['code']} {rj['message']}") + rcode = rj.get("code") + if rcode is not None: + logger.error(f"ReqId={req_id} | Aliyun {model} | error: code={rj['code']} {rj['message']}") + if rcode == "Throttling.RateQuota": + return "Aliyun model rate limit exceeded", const.CodeEnum.LLM_API_LIMIT_EXCEEDED return "Aliyun model error, please try later", const.CodeEnum.LLM_SERVICE_ERROR - logger.info(f"ReqId={req_id} Aliyun model usage: {rj['usage']}") + + logger.info(f"ReqId={req_id} | Aliyun {model} | usage: {rj['usage']}") return rj["output"]["choices"][0]["message"]["content"], const.CodeEnum.OK async def stream_complete( @@ -144,16 +150,39 @@ async def stream_complete( try: json_str = s[5:] except IndexError: - logger.error(f"ReqId={req_id} Aliyun model stream error: string={s}") + logger.error(f"ReqId={req_id} | Aliyun {model} | stream error: string={s}") continue try: json_data = json.loads(json_str) except json.JSONDecodeError as e: - logger.error(f"ReqId={req_id} Aliyun model stream error: string={s}, error={e}") + logger.error(f"ReqId={req_id} | Aliyun {model} | stream error: string={s}, error={e}") continue choice = json_data["output"]["choices"][0] if choice["finish_reason"] != "null": - logger.info(f"ReqId={req_id} Aliyun model usage: {json_data['usage']}") + logger.info(f"ReqId={req_id} | Aliyun {model} | usage: {json_data['usage']}") break txt += choice["message"]["content"] yield txt.encode("utf-8"), code + + async def batch_complete( + self, + messages: List[MessagesType], + model: str = None, + req_id: str = None, + ) -> List[Tuple[str, const.CodeEnum]]: + if model is None: + m = self.default_model + else: + m = _key2model[model].value + concurrent_limiter = ratelimiter.ConcurrentLimiter(n=self.concurrency) + rate_limiter = ratelimiter.RateLimiter(requests=m.RPM, period=60) + + tasks = [ + self._batch_complete( + limiters=[concurrent_limiter, rate_limiter], + messages=m, + model=model, + req_id=req_id, + ) for m in messages + ] + return await asyncio.gather(*tasks) diff --git a/src/retk/core/ai/llm/api/baidu.py b/src/retk/core/ai/llm/api/baidu.py index 3fc4fe8..ac7898b 100644 --- a/src/retk/core/ai/llm/api/baidu.py +++ b/src/retk/core/ai/llm/api/baidu.py @@ -1,11 +1,13 @@ +import asyncio import json from datetime import datetime from enum import Enum -from typing import Tuple, AsyncIterable +from typing import Tuple, AsyncIterable, List, Dict import httpx from retk import config, const +from retk.core.utils import ratelimiter from retk.logger import logger from .base import BaseLLMService, MessagesType, NoAPIKeyError, ModelConfig @@ -56,11 +58,14 @@ class BaiduModelEnum(Enum): ) # free +_key2model: Dict[str, BaiduModelEnum] = {m.value.key: m for m in BaiduModelEnum} + + class BaiduService(BaseLLMService): def __init__( self, top_p: float = 0.9, - temperature: float = 0.7, + temperature: float = 0.4, timeout: float = 60., ): super().__init__( @@ -70,6 +75,7 @@ def __init__( timeout=timeout, default_model=BaiduModelEnum.ERNIE_SPEED_8K.value, ) + self.headers = { "Content-Type": "application/json", } @@ -77,10 +83,6 @@ def __init__( self.token_expires_at = datetime.now().timestamp() self.token = "" - @staticmethod - def get_concurrency() -> int: - return 9999 - async def set_token(self, req_id: str = None): _s = config.get_settings() if _s.BAIDU_QIANFAN_API_KEY == "" or _s.BAIDU_QIANFAN_SECRET_KEY == "": @@ -101,11 +103,11 @@ async def set_token(self, req_id: str = None): } ) if resp.status_code != 200: - logger.error(f"ReqId={req_id} Baidu model error: {resp.text}") + logger.error(f"ReqId={req_id} | Baidu | error: {resp.text}") return "" rj = resp.json() if rj.get("error") is not None: - logger.error(f"ReqId={req_id} Baidu model token error: {rj['error_description']}") + logger.error(f"ReqId={req_id} | Baidu | token error: {rj['error_description']}") return "" self.token_expires_at = rj["expires_in"] + datetime.now().timestamp() @@ -148,9 +150,9 @@ async def complete( return "Model error, please try later", code if resp.get("error_code") is not None: - logger.error(f"ReqId={req_id} Baidu model error: code={resp['error_code']} {resp['error_msg']}") + logger.error(f"ReqId={req_id} | Baidu {model} | error: code={resp['error_code']} {resp['error_msg']}") return resp["error_msg"], const.CodeEnum.LLM_SERVICE_ERROR - logger.info(f"ReqId={req_id} Baidu model usage: {resp['usage']}") + logger.info(f"ReqId={req_id} | Baidu {model} | usage: {resp['usage']}") return resp["result"], const.CodeEnum.OK async def stream_complete( @@ -183,16 +185,38 @@ async def stream_complete( try: json_str = s[6:] except IndexError: - logger.error(f"ReqId={req_id} Baidu model stream error: string={s}") + logger.error(f"ReqId={req_id} | Baidu {model} | stream error: string={s}") continue try: json_data = json.loads(json_str) except json.JSONDecodeError as e: - logger.error(f"ReqId={req_id} Baidu model stream error: string={s}, error={e}") + logger.error(f"ReqId={req_id} | Baidu {model} | stream error: string={s}, error={e}") continue if json_data["is_end"]: - logger.info(f"ReqId={req_id} Baidu model usage: {json_data['usage']}") + logger.info(f"ReqId={req_id} | Baidu {model} | usage: {json_data['usage']}") break txt += json_data["result"] yield txt.encode("utf-8"), code + + async def batch_complete( + self, + messages: List[MessagesType], + model: str = None, + req_id: str = None, + ) -> List[Tuple[str, const.CodeEnum]]: + if model is None: + m = self.default_model + else: + m = _key2model[model].value + limiter = ratelimiter.RateLimiter(requests=m.RPM, period=60) + + tasks = [ + self._batch_complete( + limiters=[limiter], + messages=m, + model=model, + req_id=req_id, + ) for m in messages + ] + return await asyncio.gather(*tasks) diff --git a/src/retk/core/ai/llm/api/base.py b/src/retk/core/ai/llm/api/base.py index c5fa82d..2b3f188 100644 --- a/src/retk/core/ai/llm/api/base.py +++ b/src/retk/core/ai/llm/api/base.py @@ -1,13 +1,11 @@ -import asyncio -import datetime as dt from abc import ABC, abstractmethod -from dataclasses import dataclass -from functools import wraps +from dataclasses import dataclass, field from typing import List, Dict, Literal, AsyncIterable, Tuple, Optional, Union import httpx from retk import const +from retk.core.utils import ratelimiter from retk.logger import logger MessagesType = List[Dict[Literal["role", "content"], str]] @@ -17,8 +15,9 @@ class ModelConfig: key: str max_tokens: int - RPM: Optional[int] = None - TPM: Optional[int] = None + RPM: int = field(default=999999) + RPD: int = field(default=9999999999) + TPM: int = field(default=9999999999) class NoAPIKeyError(Exception): @@ -32,17 +31,15 @@ def __init__( self, endpoint: str, top_p: float = 1., - temperature: float = 1., + temperature: float = 0.4, timeout: float = None, default_model: Optional[ModelConfig] = None, - concurrency: int = -1, ): self.top_p = top_p self.temperature = temperature self.timeout = self.default_timeout if timeout is not None else timeout self.default_model: Optional[ModelConfig] = default_model self.endpoint = endpoint - self.concurrency = concurrency async def _complete( self, @@ -118,6 +115,28 @@ async def _stream_complete( yield chunk, const.CodeEnum.OK await client.aclose() + async def _batch_complete( + self, + limiters: List[Union[ratelimiter.RateLimiter, ratelimiter.ConcurrentLimiter]], + messages: MessagesType, + model: str = None, + req_id: str = None, + ) -> Tuple[str, const.CodeEnum]: + if len(limiters) == 4: + async with limiters[0], limiters[1], limiters[2], limiters[3]: + return await self.complete(messages=messages, model=model, req_id=req_id) + elif len(limiters) == 3: + async with limiters[0], limiters[1], limiters[2]: + return await self.complete(messages=messages, model=model, req_id=req_id) + elif len(limiters) == 2: + async with limiters[0], limiters[1]: + return await self.complete(messages=messages, model=model, req_id=req_id) + elif len(limiters) == 1: + async with limiters[0]: + return await self.complete(messages=messages, model=model, req_id=req_id) + else: + raise ValueError("Invalid number of limiters, should less than 4") + @abstractmethod async def stream_complete( self, @@ -127,55 +146,11 @@ async def stream_complete( ) -> AsyncIterable[Tuple[bytes, const.CodeEnum]]: ... - @staticmethod @abstractmethod - def get_concurrency() -> int: - ... - - -# unless you keep a strong reference to a running task, it can be dropped during execution -# https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task -_background_tasks = set() - - -class RateLimitedClient(httpx.AsyncClient): - """httpx.AsyncClient with a rate limit.""" - - def __init__( + async def batch_complete( self, - interval: Union[dt.timedelta, float], - count=1, - **kwargs - ): - """ - Parameters - ---------- - interval : Union[dt.timedelta, float] - Length of interval. - If a float is given, seconds are assumed. - numerator : int, optional - Number of requests which can be sent in any given interval (default 1). - """ - if isinstance(interval, dt.timedelta): - interval = interval.total_seconds() - - self.interval = interval - self.semaphore = asyncio.Semaphore(count) - super().__init__(**kwargs) - - def _schedule_semaphore_release(self): - wait = asyncio.create_task(asyncio.sleep(self.interval)) - _background_tasks.add(wait) - - def wait_cb(task): - self.semaphore.release() - _background_tasks.discard(task) - - wait.add_done_callback(wait_cb) - - @wraps(httpx.AsyncClient.send) - async def send(self, *args, **kwargs): - await self.semaphore.acquire() - send = asyncio.create_task(super().send(*args, **kwargs)) - self._schedule_semaphore_release() - return await send + messages: List[MessagesType], + model: str = None, + req_id: str = None, + ) -> List[Tuple[str, const.CodeEnum]]: + ... diff --git a/src/retk/core/ai/llm/api/moonshot.py b/src/retk/core/ai/llm/api/moonshot.py index 00d8ca9..4d4ab0c 100644 --- a/src/retk/core/ai/llm/api/moonshot.py +++ b/src/retk/core/ai/llm/api/moonshot.py @@ -1,7 +1,10 @@ +import asyncio from enum import Enum +from typing import List, Tuple -from retk import config -from .base import ModelConfig +from retk import config, const +from retk.core.utils import ratelimiter +from .base import ModelConfig, MessagesType from .openai import OpenaiLLMStyle @@ -43,3 +46,23 @@ def get_api_key(): @staticmethod def get_concurrency(): return config.get_settings().MOONSHOT_CONCURRENCY + + async def batch_complete( + self, + messages: List[MessagesType], + model: str = None, + req_id: str = None, + ) -> List[Tuple[str, const.CodeEnum]]: + settings = config.get_settings() + rate_limiter = ratelimiter.RateLimiter(requests=settings.MOONSHOT_RPM, period=60) + concurrent_limiter = ratelimiter.ConcurrentLimiter(n=settings.MOONSHOT_CONCURRENCY) + + tasks = [ + self._batch_complete( + limiters=[concurrent_limiter, rate_limiter], + messages=m, + model=model, + req_id=req_id, + ) for m in messages + ] + return await asyncio.gather(*tasks) diff --git a/src/retk/core/ai/llm/api/openai.py b/src/retk/core/ai/llm/api/openai.py index 7f13bef..8775842 100644 --- a/src/retk/core/ai/llm/api/openai.py +++ b/src/retk/core/ai/llm/api/openai.py @@ -1,9 +1,11 @@ +import asyncio import json from abc import ABC, abstractmethod from enum import Enum -from typing import Tuple, AsyncIterable, Optional +from typing import Tuple, AsyncIterable, Optional, List, Dict from retk import config, const +from retk.core.utils import ratelimiter from retk.logger import logger from .base import BaseLLMService, MessagesType, NoAPIKeyError, ModelConfig @@ -12,29 +14,38 @@ class OpenaiModelEnum(Enum): GPT4 = ModelConfig( key="gpt-4", - max_tokens=8192, + max_tokens=8_192, + RPM=500, # tier 1 ) GPT4_TURBO = ModelConfig( key="gpt-4-turbo", - max_tokens=128000, + max_tokens=128_000, + RPM=500, # tier 1 ) GPT4_32K = ModelConfig( key="gpt-4-32k", - max_tokens=32000, + max_tokens=32_000, + RPM=500, # tier 1 ) GPT35_TURBO = ModelConfig( key="gpt-3.5-turbo", - max_tokens=16385, + max_tokens=16_385, + RPM=3, # free, other tiers are different + # https://platform.openai.com/docs/guides/rate-limits/usage-tiers?context=tier-free + RPD=200, # free ) +_key2model: Dict[str, OpenaiModelEnum] = {m.value.key: m for m in OpenaiModelEnum} + + class OpenaiLLMStyle(BaseLLMService, ABC): def __init__( self, endpoint: str, default_model: ModelConfig, top_p: float = 0.9, - temperature: float = 0.7, + temperature: float = 0.4, timeout: float = 60., ): super().__init__( @@ -89,7 +100,7 @@ async def complete( return "", code if rj.get("error") is not None: return rj["error"]["message"], const.CodeEnum.LLM_SERVICE_ERROR - logger.info(f"ReqId={req_id} {self.__class__.__name__} model usage: {rj['usage']}") + logger.info(f"ReqId={req_id} | {self.__class__.__name__} {model} | usage: {rj['usage']}") return rj["choices"][0]["message"]["content"], code async def stream_complete( @@ -115,7 +126,7 @@ async def stream_complete( try: json_data = json.loads(json_str) except json.JSONDecodeError: - logger.error(f"ReqId={req_id} {self.__class__.__name__} model stream error: json={json_str}") + logger.error(f"ReqId={req_id} | {self.__class__.__name__} {model} | stream error: json={json_str}") continue choice = json_data["choices"][0] if choice["finish_reason"] is not None: @@ -123,7 +134,7 @@ async def stream_complete( usage = json_data["usage"] except KeyError: usage = choice["usage"] - logger.info(f"ReqId={req_id} {self.__class__.__name__} model usage: {usage}") + logger.info(f"ReqId={req_id} | {self.__class__.__name__} {model} | usage: {usage}") break txt += choice["delta"]["content"] yield txt.encode("utf-8"), code @@ -148,6 +159,24 @@ def __init__( def get_api_key(): return config.get_settings().OPENAI_API_KEY - @staticmethod - def get_concurrency(): - return 1 + async def batch_complete( + self, + messages: List[MessagesType], + model: str = None, + req_id: str = None, + ) -> List[Tuple[str, const.CodeEnum]]: + if model is None: + m = self.default_model + else: + m = _key2model[model].value + limiter = ratelimiter.RateLimiter(requests=m.RPM, period=60) + + tasks = [ + self._batch_complete( + limiters=[limiter], + messages=m, + model=model, + req_id=req_id, + ) for m in messages + ] + return await asyncio.gather(*tasks) diff --git a/src/retk/core/ai/llm/api/tencent.py b/src/retk/core/ai/llm/api/tencent.py index df532e6..917974c 100644 --- a/src/retk/core/ai/llm/api/tencent.py +++ b/src/retk/core/ai/llm/api/tencent.py @@ -1,12 +1,14 @@ +import asyncio import hashlib import hmac import json import time from datetime import datetime from enum import Enum -from typing import TypedDict, Tuple, Dict, AsyncIterable, Optional +from typing import TypedDict, Tuple, Dict, AsyncIterable, Optional, List from retk import config, const +from retk.core.utils import ratelimiter from retk.logger import logger from .base import BaseLLMService, MessagesType, NoAPIKeyError, ModelConfig @@ -50,11 +52,12 @@ class TencentService(BaseLLMService): service = "hunyuan" host = "hunyuan.tencentcloudapi.com" version = "2023-09-01" + concurrency = 5 def __init__( self, top_p: float = 0.9, - temperature: float = 0.7, + temperature: float = 0.4, timeout: float = 60., ): super().__init__( @@ -65,10 +68,6 @@ def __init__( default_model=TencentModelEnum.HUNYUAN_LITE.value, ) - @staticmethod - def get_concurrency() -> int: - return 5 - def get_auth(self, action: str, payload: bytes, timestamp: int, content_type: str) -> str: _s = config.get_settings() if _s.HUNYUAN_SECRET_KEY == "" or _s.HUNYUAN_SECRET_ID == "": @@ -139,9 +138,11 @@ def get_payload(self, model: Optional[str], messages: MessagesType, stream: bool def handle_err(req_id: str, error: Dict): msg = error.get("Message") code = error.get("Code") - logger.error(f"ReqId={req_id} Tencent model error code={code}, msg={msg}") + logger.error(f"ReqId={req_id} | Tencent | error code={code}, msg={msg}") if code == 4001: ccode = const.CodeEnum.LLM_TIMEOUT + elif code == "LimitExceeded": + ccode = const.CodeEnum.LLM_API_LIMIT_EXCEEDED else: ccode = const.CodeEnum.LLM_SERVICE_ERROR return msg, ccode @@ -153,7 +154,7 @@ def handle_normal_response(req_id: str, resp: Dict, stream: bool) -> Tuple[str, return "No response", const.CodeEnum.LLM_NO_CHOICE choice = choices[0] m = choice["Delta"] if stream else choice["Message"] - logger.info(f"ReqId={req_id} Tencent model usage: {resp['Usage']}") + logger.info(f"ReqId={req_id} | Tencent | usage: {resp['Usage']}") return m["Content"], const.CodeEnum.OK async def complete( @@ -210,17 +211,35 @@ async def stream_complete( try: json_str = s[6:] except IndexError: - logger.error(f"ReqId={req_id} Tencent model stream error: string={s}") + logger.error(f"ReqId={req_id} | Tencent {model} | stream error: string={s}") continue try: json_data = json.loads(json_str) except json.JSONDecodeError as e: - logger.error(f"ReqId={req_id} Tencent model stream error: string={s}, error={e}") + logger.error(f"ReqId={req_id} | Tencent {model} | stream error: string={s}, error={e}") continue choice = json_data["Choices"][0] if choice["FinishReason"] != "": - logger.info(f"ReqId={req_id} Tencent model usage: {json_data['Usage']}") + logger.info(f"ReqId={req_id} | Tencent {model} | usage: {json_data['Usage']}") break content = choice["Delta"]["Content"] txt += content yield txt.encode("utf-8"), code + + async def batch_complete( + self, + messages: List[MessagesType], + model: str = None, + req_id: str = None, + ) -> List[Tuple[str, const.CodeEnum]]: + limiter = ratelimiter.ConcurrentLimiter(n=self.concurrency) + + tasks = [ + self._batch_complete( + limiters=[limiter], + messages=m, + model=model, + req_id=req_id, + ) for m in messages + ] + return await asyncio.gather(*tasks) diff --git a/src/retk/core/ai/llm/api/xfyun.py b/src/retk/core/ai/llm/api/xfyun.py index 63bd75c..26456dd 100644 --- a/src/retk/core/ai/llm/api/xfyun.py +++ b/src/retk/core/ai/llm/api/xfyun.py @@ -1,8 +1,10 @@ +import asyncio import json from enum import Enum -from typing import Optional, Tuple, Dict, AsyncIterable +from typing import Optional, Tuple, Dict, AsyncIterable, List from retk import config, const +from retk.core.utils import ratelimiter from retk.logger import logger from .base import BaseLLMService, MessagesType, NoAPIKeyError, ModelConfig @@ -40,7 +42,7 @@ class XfYunService(BaseLLMService): def __init__( self, top_p: float = 0.9, - temperature: float = 0.7, + temperature: float = 0.4, timeout: float = 60., ): super().__init__( @@ -50,10 +52,7 @@ def __init__( timeout=timeout, default_model=XfYunModelEnum.SPARK_LITE.value, ) - - @staticmethod - def get_concurrency() -> int: - return 1 + self.concurrency = 1 @staticmethod def get_headers() -> Dict: @@ -94,7 +93,7 @@ async def complete( return "", code if rj["code"] != 0: return rj["message"], const.CodeEnum.LLM_SERVICE_ERROR - logger.info(f"ReqId={req_id} {self.__class__.__name__} model usage: {rj['usage']}") + logger.info(f"ReqId={req_id} | {self.__class__.__name__} {model} | usage: {rj['usage']}") return rj["choices"][0]["message"]["content"], code async def stream_complete( @@ -122,11 +121,11 @@ async def stream_complete( try: json_data = json.loads(json_str) except json.JSONDecodeError: - logger.error(f"ReqId={req_id} {self.__class__.__name__} model stream error: json={json_str}") + logger.error(f"ReqId={req_id} | {self.__class__.__name__} {model} | stream error: json={json_str}") continue if json_data["code"] != 0: logger.error( - f"ReqId={req_id} {self.__class__.__name__} model error:" + f"ReqId={req_id} | {self.__class__.__name__} {model} | error:" f" code={json_data['code']} {json_data['message']}" ) break @@ -136,7 +135,25 @@ async def stream_complete( except KeyError: pass else: - logger.info(f"ReqId={req_id} {self.__class__.__name__} model usage: {usage}") + logger.info(f"ReqId={req_id} | {self.__class__.__name__} {model} | usage: {usage}") break txt += choice["delta"]["content"] yield txt.encode("utf-8"), code + + async def batch_complete( + self, + messages: List[MessagesType], + model: str = None, + req_id: str = None, + ) -> List[Tuple[str, const.CodeEnum]]: + limiter = ratelimiter.ConcurrentLimiter(n=self.concurrency) + + tasks = [ + self._batch_complete( + limiters=[limiter], + messages=m, + model=model, + req_id=req_id, + ) for m in messages + ] + return await asyncio.gather(*tasks) diff --git a/src/retk/core/ai/llm/knowledge/__init__.py b/src/retk/core/ai/llm/knowledge/__init__.py index ffea231..cc89e3e 100644 --- a/src/retk/core/ai/llm/knowledge/__init__.py +++ b/src/retk/core/ai/llm/knowledge/__init__.py @@ -1,3 +1,3 @@ from . import extended -from .extending import extend_on_node_update, extend_on_node_post, LLM_SERVICES -from .ops import summary, extend +from .extending import extend_on_node_update, extend_on_node_post +from .ops import batch_summary, batch_extend, ExtendCase diff --git a/src/retk/core/ai/llm/knowledge/extending.py b/src/retk/core/ai/llm/knowledge/extending.py index 8f5012b..6ddb378 100644 --- a/src/retk/core/ai/llm/knowledge/extending.py +++ b/src/retk/core/ai/llm/knowledge/extending.py @@ -8,19 +8,6 @@ from retk.models.tps.node import Node from .. import api -TOP_P = 0.9 -TEMPERATURE = 0.5 -TIMEOUT = 60 - -LLM_SERVICES = { - "tencent": api.TencentService(top_p=TOP_P, temperature=TEMPERATURE, timeout=TIMEOUT), - "ali": api.AliyunService(top_p=TOP_P, temperature=TEMPERATURE, timeout=TIMEOUT), - "openai": api.OpenaiService(top_p=TOP_P, temperature=TEMPERATURE, timeout=TIMEOUT), - "moonshot": api.MoonshotService(top_p=TOP_P, temperature=TEMPERATURE, timeout=TIMEOUT), - "xf": api.XfYunService(top_p=TOP_P, temperature=TEMPERATURE, timeout=TIMEOUT), - "baidu": api.BaiduService(top_p=TOP_P, temperature=TEMPERATURE, timeout=TIMEOUT), -} - async def extend_on_node_post(data: Node): if data["md"].strip() == "": diff --git a/src/retk/core/ai/llm/knowledge/ops.py b/src/retk/core/ai/llm/knowledge/ops.py index 59bd7ae..387dadc 100644 --- a/src/retk/core/ai/llm/knowledge/ops.py +++ b/src/retk/core/ai/llm/knowledge/ops.py @@ -1,64 +1,126 @@ +from dataclasses import dataclass from pathlib import Path -from typing import Tuple +from typing import List + +from bson import ObjectId from retk import const from retk.logger import logger from .utils import parse_json_pattern, remove_links -from ..api.base import BaseLLMService, MessagesType +from .. import api +from ..api.base import MessagesType system_summary_prompt = (Path(__file__).parent / "system_summary.md").read_text(encoding="utf-8") system_extend_prompt = (Path(__file__).parent / "system_extend.md").read_text(encoding="utf-8") -async def _send( - llm_service: BaseLLMService, - model: str, +@dataclass +class ExtendCase: + _id: ObjectId + uid: str + nid: str + service: str + model: str + md: str + stripped_md: str = "" + summary: str = "" + summary_code: const.CodeEnum = const.CodeEnum.OK + extend: str = "" + extend_code: const.CodeEnum = const.CodeEnum.OK + + def __post_init__(self): + self.stripped_md = remove_links(self.md) + + +TOP_P = 0.9 +TEMPERATURE = 0.4 +TIMEOUT = 60 + +LLM_SERVICES = { + "tencent": api.TencentService(top_p=TOP_P, temperature=TEMPERATURE, timeout=TIMEOUT), + "ali": api.AliyunService(top_p=TOP_P, temperature=TEMPERATURE, timeout=TIMEOUT), + "openai": api.OpenaiService(top_p=TOP_P, temperature=TEMPERATURE, timeout=TIMEOUT), + "moonshot": api.MoonshotService(top_p=TOP_P, temperature=TEMPERATURE, timeout=TIMEOUT), + "xf": api.XfYunService(top_p=TOP_P, temperature=TEMPERATURE, timeout=TIMEOUT), + "baidu": api.BaiduService(top_p=TOP_P, temperature=TEMPERATURE, timeout=TIMEOUT), +} + + +async def _batch_send( + is_extend: bool, system_prompt: str, - md: str, + cases: List[ExtendCase], req_id: str, -) -> Tuple[str, const.CodeEnum]: - _msgs: MessagesType = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": md}, - ] - return await llm_service.complete(messages=_msgs, model=model, req_id=req_id) - - -async def summary( - llm_service: BaseLLMService, - model: str, - md: str, +) -> List[ExtendCase]: + svr_group = {} + for case in cases: + if case.service not in svr_group: + svr_group[case.service] = {} + if case.model not in svr_group[case.service]: + svr_group[case.service][case.model] = {"case": [], "msgs": []} + _m: MessagesType = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": case.summary if is_extend else case.stripped_md}, + ] + svr_group[case.service][case.model]["case"].append(case) + svr_group[case.service][case.model]["msgs"].append(_m) + for service, models in svr_group.items(): + for model, model_cases in models.items(): + llm_service = LLM_SERVICES[service] + results = await llm_service.batch_complete( + messages=model_cases["msgs"], + model=model, + req_id=req_id, + ) + for (_text, code), case in zip(results, model_cases["case"]): + if is_extend: + case.extend = _text + case.extend_code = code + else: + case.summary = _text + case.summary_code = code + + oneline_s = _text.replace('\n', '\\n') + phase = "extend" if is_extend else "summary" + logger.debug(f"reqId={req_id} | knowledge {phase}: {oneline_s}") + if code != const.CodeEnum.OK: + logger.error(f"reqId={req_id} | knowledge {phase} error: {code}") + return cases + + +async def batch_summary( + cases: List[ExtendCase], req_id: str = None, -) -> Tuple[str, const.CodeEnum]: - md_ = remove_links(md) - return await _send( - llm_service=llm_service, - model=model, +) -> List[ExtendCase]: + return await _batch_send( + is_extend=False, system_prompt=system_summary_prompt, - md=md_, + cases=cases, req_id=req_id, ) -async def extend( - llm_service: BaseLLMService, - model: str, - md: str, +async def batch_extend( + cases: List[ExtendCase], req_id: str = None, -) -> Tuple[str, const.CodeEnum]: - msg, code = await _send( - llm_service=llm_service, - model=model, +) -> List[ExtendCase]: + cases = await _batch_send( + is_extend=True, system_prompt=system_extend_prompt, - md=md, + cases=cases, req_id=req_id, ) - if code != const.CodeEnum.OK: - return msg, code - - try: - title, content = parse_json_pattern(msg) - except ValueError as e: - logger.error(f"parse_json_pattern error: {e}. msg: {msg}") - return str(e), const.CodeEnum.LLM_INVALID_RESPONSE_FORMAT - return f"{title}\n\n{content}", const.CodeEnum.OK + + for case in cases: + if case.extend_code != const.CodeEnum.OK: + continue + + try: + title, content = parse_json_pattern(case.extend) + except ValueError as e: + oneline = case.extend.replace('\n', '\\n') + logger.error(f"reqId={req_id} | parse_json_pattern error: {e}. msg: {oneline}") + case.extend_code = const.CodeEnum.LLM_INVALID_RESPONSE_FORMAT + else: + case.extend = f"{title}\n\n{content}" + return cases diff --git a/src/retk/core/ai/llm/knowledge/system_extend.md b/src/retk/core/ai/llm/knowledge/system_extend.md index 48a3ed9..a198685 100644 --- a/src/retk/core/ai/llm/knowledge/system_extend.md +++ b/src/retk/core/ai/llm/knowledge/system_extend.md @@ -20,12 +20,12 @@ # 你需要返回的结果 +```json { -"title": "儿童发展中的同理心培养", -"content": "- 富有同理心的小孩能理解和感受他人情感,有助于儿童建立良好的人际关系和社交技巧。\n- -儿童的同理心发展分为不同阶段,从2岁开始,他们能够感知到他人的情感,而4-5岁时,他们开始能够理解他人的观点和需求。\n- -家长和教育者可以通过共情、角色扮演、讲述故事、以及引导儿童关注他人的感受等方法,帮助儿童培养同理心。" + "title": "儿童发展中的同理心培养", + "content": "- 富有同理心的小孩能理解和感受他人情感,有助于儿童建立良好的人际关系和社交技巧。\n- 儿童的同理心发展分为不同阶段,从2岁开始,他们能够感知到他人的情感,而4-5岁时,他们开始能够理解他人的观点和需求。\n- 家长和教育者可以通过共情、角色扮演、讲述故事、以及引导儿童关注他人的感受等方法,帮助儿童培养同理心。" } +``` 案例 2: @@ -45,8 +45,10 @@ # 你需要返回的结果 +```json { -"title": "水的凝聚力和表面张力现象", -"content": " -凝聚力使水分子紧密相连,表面张力导致水成球状以减小表面积。这些现象在植物水分运输、清洁剂使用和雨伞设计等方面具有重要作用。通过探讨这些现象,可以更深入地理解水的特性及其在自然和生活中的应用。" + "title": "水的凝聚力和表面张力现象", + "content": "凝聚力使水分子紧密相连,表面张力导致水成球状以减小表面积。这些现象在植物水分运输、清洁剂使用和雨伞设计等方面具有重要作用。通过探讨这些现象,可以更深入地理解水的特性及其在自然和生活中的应用。" } +``` + diff --git a/src/retk/core/ai/llm/knowledge/system_summary.md b/src/retk/core/ai/llm/knowledge/system_summary.md index 24cfbe6..e8e7460 100644 --- a/src/retk/core/ai/llm/knowledge/system_summary.md +++ b/src/retk/core/ai/llm/knowledge/system_summary.md @@ -10,6 +10,7 @@ 小孩建立长线的因果关系 因为脑部发育阶段的特性,2-4岁的儿童没办法推理比较长线的因果关系,比如不吃饭,晚上会肚子饿。其中的一个原因是前额叶的发展不够,没办法模拟和推理未来发生的事情,也就没办法思考一段时间后的结果。 但是短期反馈还是有的,比如挥手要打人或者给脸色的时候能有直接的映射结果,这点他们理解 + """ # 你需要返回的总结格式: @@ -22,7 +23,8 @@ 1. **特点**:2-4岁的儿童由于脑部发育阶段的特性,无法推理长线的因果关系。 2. **原因**:这种现象的一个原因是前额叶的发展不足,无法模拟和推理未来发生的事情。 3. **长短期反馈**:他们无法理解一段时间后的结果,例如不吃饭会导致晚上肚子饿。尽管如此,他们可以理解短期反馈,例如挥手打人或者给脸色会有直接的结果。 - """ + +""" 案例 2: @@ -47,4 +49,5 @@ 3. 水的偏电性:氢正电荷,氧负电荷 4. 水作为良好溶剂:吸附其他分子,如盐 5. 生命过程中水的作用:输送养分和排除废物 - """ + +""" diff --git a/src/retk/core/ai/llm/knowledge/utils.py b/src/retk/core/ai/llm/knowledge/utils.py index 66f40f4..339a998 100644 --- a/src/retk/core/ai/llm/knowledge/utils.py +++ b/src/retk/core/ai/llm/knowledge/utils.py @@ -3,18 +3,26 @@ from typing import Tuple JSON_PTN = re.compile(r"^{\s*?\"title\":\s?\"(.+?)\",\s*?\"content\":\s?\"(.+?)\"\s*?}", re.DOTALL | re.MULTILINE) +JSON_PTN2 = re.compile(r"^{\s*?\"标题\":\s?\"(.+?)\",\s*?\"内容\":\s?\"(.+?)\"\s*?}", re.DOTALL | re.MULTILINE) IMG_PTN = re.compile(r"!\[.*?\]\(.+?\)") LINK_PTN = re.compile(r"\[(.*?)]\(.+?\)") def parse_json_pattern(text: str) -> Tuple[str, str]: - m = JSON_PTN.search(text) - if m: + def get_title_content(m): title, content = m.group(1), m.group(2) title = title.replace("\n", "\\n") content = content.replace("\n", "\\n") d = json.loads(f'{{"title": "{title}", "content": "{content}"}}') return d["title"], d["content"] + + m = JSON_PTN.search(text) + if m: + return get_title_content(m) + m = JSON_PTN2.search(text) + if m: + return get_title_content(m) + raise ValueError(f"Invalid JSON pattern: {text}") diff --git a/src/retk/core/scheduler/schedule.py b/src/retk/core/scheduler/schedule.py index 6beee1a..d008229 100644 --- a/src/retk/core/scheduler/schedule.py +++ b/src/retk/core/scheduler/schedule.py @@ -97,7 +97,7 @@ def init_tasks(): run_every_interval( job_id="deliver_unscheduled_node_extend", func=tasks.extend_node.deliver_unscheduled_extend_nodes, - hours=1, + minutes=40, ) return diff --git a/src/retk/core/scheduler/tasks/extend_node.py b/src/retk/core/scheduler/tasks/extend_node.py index 1d2a9c2..1d2132d 100644 --- a/src/retk/core/scheduler/tasks/extend_node.py +++ b/src/retk/core/scheduler/tasks/extend_node.py @@ -21,7 +21,7 @@ def deliver_unscheduled_extend_nodes(): async def async_deliver_unscheduled_extend_nodes() -> str: _, db = init_mongo(connection_timeout=5) - batch_size = 5 + batch_size = 40 total_success_count = 0 total_summary_time = 0 total_extend_time = 0 @@ -31,50 +31,54 @@ async def async_deliver_unscheduled_extend_nodes() -> str: batch_size).to_list(None) if len(batch) == 0: break + cases: List[knowledge.ExtendCase] = [] + req_id = "".join([str(random.randint(0, 9)) for _ in range(10)]) for item in batch: - req_id = "".join([str(random.randint(0, 9)) for _ in range(10)]) node = await db[CollNameEnum.nodes.value].find_one({"id": item["nid"]}) - s0 = time.perf_counter() - _summary, code = await knowledge.summary( - llm_service=knowledge.LLM_SERVICES[item["summaryService"]], - model=item["summaryModel"], - md=node["md"], - req_id=req_id, + cases.append( + knowledge.ExtendCase( + _id=item["_id"], + uid=item["uid"], + nid=item["nid"], + service=item["summaryService"], + model=item["summaryModel"], + md=node["md"], + ) ) - s1 = time.perf_counter() - if code != const.CodeEnum.OK: - logger.error(f"knowledge summary error: {code}") - continue - oneline_s = _summary.replace('\n', '\\n') - logger.debug(f"summary: {oneline_s}") - e0 = time.perf_counter() - _extended, code = await knowledge.extend( - llm_service=knowledge.LLM_SERVICES[item["extendService"]], - model=item["extendModel"], - md=_summary, - req_id=req_id, - ) - e1 = time.perf_counter() - if code != const.CodeEnum.OK: - logger.error(f"knowledge extend error: code={code}") + + s0 = time.perf_counter() + cases = await knowledge.batch_summary( + cases=cases, + req_id=req_id, + ) + s1 = time.perf_counter() + + e0 = time.perf_counter() + cases = await knowledge.batch_extend( + cases=cases, + req_id=req_id, + ) + e1 = time.perf_counter() + + for case in cases: + done_id_list.append(case._id) + if case.summary_code != const.CodeEnum.OK or case.extend_code != const.CodeEnum.OK: continue - oneline_e = _extended.replace('\n', '\\n') - logger.debug(f"extended: {oneline_e}") ext = ExtendedNode( - uid=item["uid"], - sourceNid=item["nid"], - sourceMd=node["md"], - extendMd=_extended, + uid=case.uid, + sourceNid=case.nid, + sourceMd=case.md, + extendMd=case.extend, ) await db[CollNameEnum.llm_extended_node.value].update_one( - {"uid": item["uid"], "sourceNid": item["nid"]}, + {"uid": case.uid, "sourceNid": case.nid}, {"$set": ext}, upsert=True ) - done_id_list.append(item["_id"]) - total_summary_time += s1 - s0 - total_extend_time += e1 - e0 + + total_summary_time += s1 - s0 + total_extend_time += e1 - e0 if len(done_id_list) > 0: res = await db[CollNameEnum.llm_extend_node_queue.value].delete_many({"_id": {"$in": done_id_list}}) diff --git a/src/retk/core/utils/__init__.py b/src/retk/core/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/retk/core/utils/ratelimiter.py b/src/retk/core/utils/ratelimiter.py new file mode 100644 index 0000000..3295c04 --- /dev/null +++ b/src/retk/core/utils/ratelimiter.py @@ -0,0 +1,42 @@ +import asyncio +import time +from datetime import timedelta +from typing import Union + + +class RateLimiter: + def __init__(self, requests: int, period: Union[int, float, timedelta]): + if isinstance(period, (int, float)): + period = timedelta(seconds=period) + self.max_tokens = requests + self.period = period + self.tokens = requests + self.last_refill = time.monotonic() + + async def __aenter__(self): + while self.tokens < 1: + await asyncio.sleep(0.1) + now = time.monotonic() + elapsed_time = now - self.last_refill + if elapsed_time > self.period.total_seconds(): + self.tokens = self.max_tokens + self.last_refill = now + break + self.tokens -= 1 + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + +class ConcurrentLimiter: + def __init__(self, n: int): + self.semaphore = asyncio.Semaphore(n) + + async def __aenter__(self): + await self.semaphore.acquire() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.semaphore.release() + return False diff --git a/tests/test_ai_llm_api.py b/tests/test_ai_llm_api.py index 79c0f8e..cc186c4 100644 --- a/tests/test_ai_llm_api.py +++ b/tests/test_ai_llm_api.py @@ -89,6 +89,28 @@ async def test_hunyuan_stream_complete(self): s = b.decode("utf-8") print(s) + @skip_no_api_key + async def test_hunyuan_batch_complete(self): + m = llm.api.TencentService() + res = await m.batch_complete( + [[{"role": "user", "content": "你是谁"}]] * 11, + ) + for text, code in res: + self.assertEqual(const.CodeEnum.OK, code, msg=text) + print(text) + + m = llm.api.TencentService() + m.concurrency = 6 + res = await m.batch_complete( + [[{"role": "user", "content": "你是谁"}]] * 11, + ) + reach_limit = False + for text, code in res: + if code == const.CodeEnum.LLM_API_LIMIT_EXCEEDED: + reach_limit = True + break + self.assertTrue(reach_limit) + @skip_no_api_key async def test_aliyun_complete(self): m = llm.api.AliyunService() @@ -103,6 +125,16 @@ async def test_aliyun_stream_complete(self): self.assertEqual(const.CodeEnum.OK, code) print(b.decode("utf-8")) + @skip_no_api_key + async def test_aliyun_batch_complete(self): + m = llm.api.AliyunService() + res = await m.batch_complete( + [[{"role": "user", "content": "你是谁"}]] * 11, + ) + for text, code in res: + self.assertEqual(const.CodeEnum.OK, code, msg=text) + print(text) + @skip_no_api_key async def test_baidu_complete(self): m = llm.api.BaiduService() @@ -117,6 +149,16 @@ async def test_baidu_stream_complete(self): self.assertEqual(const.CodeEnum.OK, code) print(b.decode("utf-8")) + @skip_no_api_key + async def test_baidu_batch_complete(self): + m = llm.api.BaiduService() + res = await m.batch_complete( + [[{"role": "user", "content": "你是谁"}]] * 11, + ) + for text, code in res: + self.assertEqual(const.CodeEnum.OK, code, msg=text) + print(text) + @skip_no_api_key async def test_openai_complete(self): m = llm.api.OpenaiService() @@ -131,6 +173,16 @@ async def test_openai_stream_complete(self): self.assertEqual(const.CodeEnum.OK, code) print(b.decode("utf-8")) + @skip_no_api_key + async def test_openai_batch_complete(self): + m = llm.api.OpenaiService() + res = await m.batch_complete( + [[{"role": "user", "content": "你是谁"}]] * 11, + ) + for text, code in res: + self.assertEqual(const.CodeEnum.OK, code, msg=text) + print(text) + @skip_no_api_key async def test_xfyun_complete(self): m = llm.api.XfYunService() @@ -145,6 +197,16 @@ async def test_xfyun_stream_complete(self): self.assertEqual(const.CodeEnum.OK, code) print(b.decode("utf-8")) + @skip_no_api_key + async def test_xfyun_batch_complete(self): + m = llm.api.XfYunService() + res = await m.batch_complete( + [[{"role": "user", "content": "你是谁"}]] * 11, + ) + for text, code in res: + self.assertEqual(const.CodeEnum.OK, code, msg=text) + print(text) + @skip_no_api_key async def test_moonshot_complete(self): m = llm.api.MoonshotService() diff --git a/tests/test_ai_llm_knowledge.py b/tests/test_ai_llm_knowledge.py index 915d917..f00cf0c 100644 --- a/tests/test_ai_llm_knowledge.py +++ b/tests/test_ai_llm_knowledge.py @@ -1,8 +1,11 @@ import unittest from textwrap import dedent +from bson import ObjectId + from retk import const from retk.core.ai import llm +from retk.core.ai.llm.knowledge.ops import ExtendCase from retk.core.ai.llm.knowledge.utils import parse_json_pattern from . import utils from .test_ai_llm_api import skip_no_api_key, clear_all_api_key @@ -73,41 +76,58 @@ def tearDown(self): @skip_no_api_key async def test_summary(self): for service, model in [ - (llm.api.TencentService(), llm.api.TencentModelEnum.HUNYUAN_LITE), - (llm.api.AliyunService(), llm.api.AliyunModelEnum.QWEN_2B), - (llm.api.BaiduService(), llm.api.BaiduModelEnum.ERNIE_SPEED_8K), - # (llm.api.OpenaiService(), llm.api.OpenaiModelEnum.GPT4), - (llm.api.XfYunService(), llm.api.XfYunModelEnum.SPARK_LITE), - (llm.api.MoonshotService(), llm.api.MoonshotModelEnum.V1_8K), # 这个总结比较好 + ("tencent", llm.api.TencentModelEnum.HUNYUAN_LITE), + ("ali", llm.api.AliyunModelEnum.QWEN_2B), + ("baidu", llm.api.BaiduModelEnum.ERNIE_SPEED_8K), + # ("openai", llm.api.OpenaiModelEnum.GPT4), + ("xf", llm.api.XfYunModelEnum.SPARK_LITE), + ("moonshot", llm.api.MoonshotModelEnum.V1_8K), # 这个总结比较好 ]: - for md in md_source: - text, code = await llm.knowledge.summary( - llm_service=service, + cases = [ + ExtendCase( + _id=ObjectId(), + uid="testuid", + nid="testnid", + service=service, model=model.value.key, md=md, - ) - self.assertEqual(const.CodeEnum.OK, code, msg=text) - print(f"{service.__class__.__name__} {model.name}\n{text}\n\n") + ) for md in md_source + ] + await llm.knowledge.batch_summary( + cases=cases, + ) + for case in cases: + self.assertEqual(const.CodeEnum.OK, case.extend_code, msg=case.md) + print(f"{service} {model.value.key}\n{case.summary}\n\n") @skip_no_api_key async def test_extend(self): for service, model in [ - # (llm.api.TencentService(), llm.api.TencentModelEnum.HUNYUAN_PRO), - (llm.api.TencentService(), llm.api.TencentModelEnum.HUNYUAN_STANDARD), - (llm.api.AliyunService(), llm.api.AliyunModelEnum.QWEN_PLUS), - (llm.api.BaiduService(), llm.api.BaiduModelEnum.ERNIE35_8K), - # (llm.api.OpenaiService(), llm.api.OpenaiModelEnum.GPT4), - (llm.api.XfYunService(), llm.api.XfYunModelEnum.SPARK_PRO), - (llm.api.MoonshotService(), llm.api.MoonshotModelEnum.V1_8K), # 这个延伸比较好 + # ("tencent", llm.api.TencentModelEnum.HUNYUAN_PRO), + ("tencent", llm.api.TencentModelEnum.HUNYUAN_STANDARD), + ("ali", llm.api.AliyunModelEnum.QWEN_PLUS), + ("baidu", llm.api.BaiduModelEnum.ERNIE35_8K), + # ("openai", llm.api.OpenaiModelEnum.GPT4), + ("xf", llm.api.XfYunModelEnum.SPARK_PRO), + ("moonshot", llm.api.MoonshotModelEnum.V1_8K), # 这个延伸比较好 ]: - for md in md_source: - text, code = await llm.knowledge.extend( - llm_service=service, + cases = [ + ExtendCase( + _id=ObjectId(), + uid="testuid", + nid="testnid", + service=service, model=model.value.key, md=md, - ) - self.assertEqual(const.CodeEnum.OK, code, msg=text) - print(f"{service.__class__.__name__} {model.name}\n{text}\n\n") + summary=md + ) for md in md_summary + ] + await llm.knowledge.batch_extend( + cases=cases + ) + for case in cases: + self.assertEqual(const.CodeEnum.OK, case.extend_code, msg=case.summary) + print(f"{service} {model.value.key}\n{case.extend}\n\n") def test_json_pattern(self): title, content = parse_json_pattern("""{"title": "tttt", "content": "cccc\n21\n2"}""") @@ -137,6 +157,17 @@ def test_json_pattern(self): } 23423saq1是当前 """, + """\ + 这是一个关于午睡对大脑健康益处的内容描述,以下是按照要求以json格式返回的结果: + + ```json + { + "标题": "tttt", + "内容": "cccc" + } + ```. msg: 这结果:```json + { "标题": "午睡对", "内容": "午睡进身心健康。"}``` + """ ] for case in cases: case = dedent(case) diff --git a/tests/test_core_local.py b/tests/test_core_local.py index 2cfac8d..9f777b4 100644 --- a/tests/test_core_local.py +++ b/tests/test_core_local.py @@ -27,7 +27,7 @@ from . import utils -@patch("retk.core.ai.llm.knowledge.ops._send", new_callable=AsyncMock, return_value=["", const.CodeEnum.OK]) +@patch("retk.core.ai.llm.knowledge.ops._batch_send", new_callable=AsyncMock, return_value=["", const.CodeEnum.OK]) class LocalModelsTest(unittest.IsolatedAsyncioTestCase): @classmethod def setUpClass(cls) -> None: @@ -56,7 +56,7 @@ async def asyncTearDown(self) -> None: shutil.rmtree(Path(__file__).parent / "temp" / const.settings.DOT_DATA / "files", ignore_errors=True) shutil.rmtree(Path(__file__).parent / "temp" / const.settings.DOT_DATA / "md", ignore_errors=True) - async def test_user(self, mock_send): + async def test_user(self, mock_batch_send): u, code = await core.user.get_by_email(email=const.DEFAULT_USER["email"]) self.assertEqual(const.CodeEnum.OK, code) self.assertEqual("rethink", u["nickname"]) @@ -115,7 +115,7 @@ async def test_user(self, mock_send): await core.account.manager.delete_by_uid(uid=_uid) - async def test_node(self, mock_send): + async def test_node(self, mock_batch_send): node, code = await core.node.post( au=self.au, md="a" * (const.settings.MD_MAX_LENGTH + 1), type_=const.NodeTypeEnum.MARKDOWN.value ) @@ -214,7 +214,7 @@ async def test_node(self, mock_send): self.assertEqual(2, len(nodes)) self.assertEqual(2, total) - async def test_node_in_ai_extend_queue(self, mock_send): + async def test_node_in_ai_extend_queue(self, mock_batch_send): node, code = await core.node.post( au=self.au, md="knowledge test\nthis is a knowledge test" ) @@ -249,7 +249,7 @@ async def test_node_in_ai_extend_queue(self, mock_send): self.assertEqual(1, len(q)) self.assertGreater(q_[0]["modifiedAt"], q_time) - async def test_parse_at(self, mock_send): + async def test_parse_at(self, mock_batch_send): nid1, _ = await core.node.post( au=self.au, md="c", type_=const.NodeTypeEnum.MARKDOWN.value, ) @@ -302,7 +302,7 @@ async def test_parse_at(self, mock_send): self.assertEqual(const.CodeEnum.OK, code) self.assertEqual(0, len(n["fromNodeIds"])) - async def test_add_set(self, mock_send): + async def test_add_set(self, mock_batch_send): node, code = await core.node.post( au=self.au, md="title\ntext", type_=const.NodeTypeEnum.MARKDOWN.value ) @@ -315,7 +315,7 @@ async def test_add_set(self, mock_send): self.assertEqual(const.CodeEnum.OK, code) self.assertEqual(1, len(node["toNodeIds"])) - async def test_cursor_text(self, mock_send): + async def test_cursor_text(self, mock_batch_send): n1, code = await core.node.post( au=self.au, md="title\ntext", type_=const.NodeTypeEnum.MARKDOWN.value ) @@ -359,7 +359,7 @@ async def test_cursor_text(self, mock_send): self.assertEqual(3, total) self.assertEqual("Welcome to Rethink", recom[2].title) - async def test_to_trash(self, mock_send): + async def test_to_trash(self, mock_batch_send): n1, code = await core.node.post( au=self.au, md="title\ntext", type_=const.NodeTypeEnum.MARKDOWN.value ) @@ -387,7 +387,7 @@ async def test_to_trash(self, mock_send): self.assertEqual(4, len(nodes)) self.assertEqual(4, total) - async def test_search(self, mock_send): + async def test_search(self, mock_batch_send): code = await core.recent.put_recent_search(au=self.au, query="a") self.assertEqual(const.CodeEnum.OK, code) await core.recent.put_recent_search(au=self.au, query="c") @@ -397,7 +397,7 @@ async def test_search(self, mock_send): self.assertIsNotNone(doc) self.assertEqual(["b", "c", "a"], doc["lastState"]["recentSearch"]) - async def test_batch(self, mock_send): + async def test_batch(self, mock_batch_send): ns = [] for i in range(10): n, code = await core.node.post( @@ -430,7 +430,7 @@ async def test_batch(self, mock_send): self.assertEqual(0, total) self.assertEqual(0, len(tns)) - async def test_files_upload_process(self, mock_send): + async def test_files_upload_process(self, mock_batch_send): now = datetime.datetime.now(tz=utc) doc = ImportData( _id=ObjectId(), @@ -460,7 +460,7 @@ async def test_files_upload_process(self, mock_send): await client.coll.import_data.delete_one({"uid": "xxx"}) - async def test_update_title_and_from_nodes_updates(self, mock_send): + async def test_update_title_and_from_nodes_updates(self, mock_batch_send): n1, code = await core.node.post( au=self.au, md="title1\ntext", type_=const.NodeTypeEnum.MARKDOWN.value ) @@ -476,7 +476,7 @@ async def test_update_title_and_from_nodes_updates(self, mock_send): self.assertEqual(const.CodeEnum.OK, code) self.assertEqual(f"title2\n[@title1Changed](/n/{n1['id']})", n2["md"]) - async def test_upload_image_vditor(self, mock_send): + async def test_upload_image_vditor(self, mock_batch_send): u, code = await core.user.get(self.au.u.id) self.assertEqual(const.CodeEnum.OK, code) used_space = u["usedSpace"] @@ -504,7 +504,7 @@ async def test_upload_image_vditor(self, mock_send): self.assertEqual(used_space + size, u["usedSpace"]) @patch("retk.core.files.upload.httpx.AsyncClient.get", ) - async def test_fetch_image_vditor(self, mock_get, mock_send): + async def test_fetch_image_vditor(self, mock_get, mock_batch_send): f = open(Path(__file__).parent / "temp" / "fake.png", "rb") mock_get.return_value = httpx.Response( 200, @@ -528,7 +528,7 @@ async def test_fetch_image_vditor(self, mock_get, mock_send): self.assertEqual(used_space + f.tell(), u["usedSpace"]) f.close() - async def test_update_used_space(self, mock_send): + async def test_update_used_space(self, mock_batch_send): u, code = await core.user.get(self.au.u.id) base_used_space = u["usedSpace"] for delta, value in [ @@ -548,7 +548,7 @@ async def test_update_used_space(self, mock_send): base_used_space = 0 self.assertAlmostEqual(value, now, msg=f"delta: {delta}, value: {value}") - async def test_node_version(self, mock_send): + async def test_node_version(self, mock_batch_send): node, code = await core.node.post( au=self.au, md="[title](/qqq)\nbody", type_=const.NodeTypeEnum.MARKDOWN.value ) @@ -569,7 +569,7 @@ async def test_node_version(self, mock_send): self.assertEqual(const.CodeEnum.OK, code) self.assertEqual(2, len(list(hist_dir.glob("*.md")))) - async def test_md_history(self, mock_send): + async def test_md_history(self, mock_batch_send): bi = config.get_settings().MD_BACKUP_INTERVAL config.get_settings().MD_BACKUP_INTERVAL = 0.0001 n1, code = await core.node.post( @@ -606,14 +606,14 @@ async def test_md_history(self, mock_send): config.get_settings().MD_BACKUP_INTERVAL = bi - async def test_get_version(self, mock_send): + async def test_get_version(self, mock_batch_send): v, code = await core.self_hosted.get_latest_pkg_version() self.assertEqual(const.CodeEnum.OK, code) self.assertEqual(3, len(v)) for num in v: self.assertTrue(isinstance(num, int)) - async def test_system_notice(self, mock_send): + async def test_system_notice(self, mock_batch_send): au = deepcopy(self.au) au.u.type = const.USER_TYPE.MANAGER.id publish_at = datetime.datetime.now() @@ -641,7 +641,7 @@ async def test_system_notice(self, mock_send): docs, total = await core.notice.get_system_notices(0, 10) self.assertTrue(docs[0]["scheduled"]) - async def test_notice(self, mock_send): + async def test_notice(self, mock_batch_send): au = deepcopy(self.au) doc, code = await core.notice.post_in_manager_delivery( au=au, @@ -696,7 +696,7 @@ async def test_notice(self, mock_send): self.assertFalse(sn[0]["read"]) self.assertIsNone(sn[0]["readTime"]) - async def test_mark_read(self, mock_send): + async def test_mark_read(self, mock_batch_send): au = deepcopy(self.au) au.u.type = const.USER_TYPE.MANAGER.id for i in range(3): diff --git a/tests/test_core_remote.py b/tests/test_core_remote.py index 7566d0c..4c7fd8c 100644 --- a/tests/test_core_remote.py +++ b/tests/test_core_remote.py @@ -21,7 +21,7 @@ from . import utils -@patch("retk.core.ai.llm.knowledge.ops._send", new_callable=AsyncMock, return_value=["", const.CodeEnum.OK]) +@patch("retk.core.ai.llm.knowledge.ops._batch_send", new_callable=AsyncMock, return_value=["", const.CodeEnum.OK]) class RemoteModelsTest(unittest.IsolatedAsyncioTestCase): default_pwd = "rethink123" @@ -88,7 +88,7 @@ async def asyncTearDown(self) -> None: utils.skip_no_connect.skip = True @utils.skip_no_connect - async def test_same_key(self, mock_send): + async def test_same_key(self, mock_batch_send): async def add(): oid = ObjectId() await client.coll.users.insert_one({ @@ -129,7 +129,7 @@ async def add(): await add() @utils.skip_no_connect - async def test_user(self, mock_send): + async def test_user(self, mock_batch_send): u, code = await core.user.get_by_email(email=const.DEFAULT_USER["email"]) self.assertEqual(const.CodeEnum.OK, code) self.assertEqual("rethink", u["nickname"]) @@ -190,7 +190,7 @@ async def test_user(self, mock_send): @patch("retk.core.node.backup.__save_md_to_cos") async def test_node( self, - mock_send, + mock_batch_send, mock_save_md_to_cos, mock_get_md_from_cos, mock_remove_md_from_cos, @@ -286,7 +286,7 @@ async def test_node( @patch("retk.core.node.backup.__save_md_to_cos") async def test_node_in_ai_extend_queue( self, - mock_send, + mock_batch_send, mock_save_md_to_cos, mock_get_md_from_cos, mock_remove_md_from_cos, @@ -338,7 +338,7 @@ async def test_node_in_ai_extend_queue( @patch("retk.core.node.backup.__save_md_to_cos") async def test_parse_at( self, - mock_send, + mock_batch_send, mock_save_md_to_cos, mock_get_md_from_cos, mock_remove_md_from_cos, @@ -404,7 +404,7 @@ async def test_parse_at( self.assertEqual(0, len(n["fromNodeIds"])) @utils.skip_no_connect - async def test_add_set(self, mock_send): + async def test_add_set(self, mock_batch_send): node, code = await core.node.post( au=self.au, md="title\ntext", type_=const.NodeTypeEnum.MARKDOWN.value ) @@ -418,7 +418,7 @@ async def test_add_set(self, mock_send): self.assertEqual(1, len(node["toNodeIds"])) @utils.skip_no_connect - async def test_to_trash(self, mock_send): + async def test_to_trash(self, mock_batch_send): n1, code = await core.node.post( au=self.au, md="title\ntext", type_=const.NodeTypeEnum.MARKDOWN.value ) @@ -456,7 +456,7 @@ async def test_to_trash(self, mock_send): @patch("retk.core.node.backup.__save_md_to_cos") async def test_batch( self, - mock_send, + mock_batch_send, mock_save_md_to_cos, mock_get_md_from_cos, mock_remove_md_from_cos, @@ -505,7 +505,7 @@ async def test_batch( self.assertEqual(0, len(tns)) @utils.skip_no_connect - async def test_update_used_space(self, mock_send): + async def test_update_used_space(self, mock_batch_send): u, code = await core.user.get(self.au.u.id) base_used_space = u["usedSpace"] for delta, value in [ @@ -532,7 +532,7 @@ async def test_update_used_space(self, mock_send): @patch("retk.core.node.backup.__save_md_to_cos") async def test_md_history( self, - mock_send, + mock_batch_send, mock_save_md_to_cos, mock_get_md_from_cos, mock_remove_md_from_cos, @@ -580,7 +580,7 @@ async def test_md_history( config.get_settings().MD_BACKUP_INTERVAL = bi @utils.skip_no_connect - async def test_system_notice(self, mock_send): + async def test_system_notice(self, mock_batch_send): au = deepcopy(self.au) au.u.type = const.USER_TYPE.MANAGER.id publish_at = datetime.datetime.now() @@ -609,7 +609,7 @@ async def test_system_notice(self, mock_send): self.assertTrue(docs[0]["scheduled"]) @utils.skip_no_connect - async def test_notice(self, mock_send): + async def test_notice(self, mock_batch_send): au = deepcopy(self.au) doc, code = await core.notice.post_in_manager_delivery( au=au, @@ -665,7 +665,7 @@ async def test_notice(self, mock_send): self.assertIsNone(sn[0]["readTime"]) @utils.skip_no_connect - async def test_mark_read(self, mock_send): + async def test_mark_read(self, mock_batch_send): au = deepcopy(self.au) au.u.type = const.USER_TYPE.MANAGER.id for i in range(3): diff --git a/tests/test_core_utils.py b/tests/test_core_utils.py new file mode 100644 index 0000000..204b59e --- /dev/null +++ b/tests/test_core_utils.py @@ -0,0 +1,97 @@ +import asyncio +import time +import unittest +from datetime import timedelta +from unittest.mock import patch + +import httpx + +from retk.core.utils import ratelimiter + + +class UtilsTest(unittest.IsolatedAsyncioTestCase): + + @patch("httpx.AsyncClient.get") + async def test_single_rate_limiter(self, mock_get): + mock_get.return_value = "mock" + rate_limiter = ratelimiter.RateLimiter(requests=5, period=timedelta(seconds=0.1)) + st = time.time() + count = 0 + + async def fetch(url: str): + nonlocal count + async with rate_limiter: + async with httpx.AsyncClient() as client: + await client.get(url) + count += 1 + + tasks = [fetch("https://xxx") for _ in range(11)] + await asyncio.gather(*tasks) + total_time = time.time() - st + self.assertGreaterEqual(total_time, 0.3) + self.assertLess(total_time, 0.4) + self.assertEqual(11, count) + + @patch("httpx.AsyncClient.get") + async def test_rate_limiter(self, mock_get): + mock_get.return_value = "mock" + + rate_limiter_1 = ratelimiter.RateLimiter(requests=15, period=timedelta(seconds=1)) + rate_limiter_2 = ratelimiter.RateLimiter(requests=5, period=timedelta(seconds=0.1)) + + st = time.time() + count = 0 + + async def fetch(url: str): + nonlocal count + async with rate_limiter_1, rate_limiter_2: + async with httpx.AsyncClient() as client: + await client.get(url) + count += 1 + + tasks = [fetch("https://xxx") for _ in range(16)] + await asyncio.gather(*tasks) + total_time = time.time() - st + self.assertGreaterEqual(total_time, 1) + self.assertLess(total_time, 2) + self.assertEqual(16, count) + + async def test_concurrent_limiter(self): + concurrent_limiter = ratelimiter.ConcurrentLimiter(n=2) + st = time.time() + count = 0 + + async def fetch(): + nonlocal count + async with concurrent_limiter: + async with httpx.AsyncClient() as _: + await asyncio.sleep(0.1) + count += 1 + + tasks = [fetch() for _ in range(5)] + await asyncio.gather(*tasks) + total_time = time.time() - st + self.assertGreaterEqual(total_time, 0.3) + self.assertLess(total_time, 0.4) + self.assertEqual(5, count) + + async def test_concurrent_with_rate_limiter(self): + concurrent_limiter = ratelimiter.ConcurrentLimiter(n=2) + rate_limiter = ratelimiter.RateLimiter(requests=3, period=timedelta(seconds=0.2)) + + st = time.time() + count = 0 + + async def fetch(): + nonlocal count + async with concurrent_limiter, rate_limiter: + async with httpx.AsyncClient() as _: + await asyncio.sleep(0.1) + count += 1 + + tasks = [fetch() for _ in range(4)] + await asyncio.gather(*tasks) + total_time = time.time() - st + self.assertGreaterEqual(total_time, 0.3) + self.assertLess(total_time, 0.4) + self.assertEqual(4, count)