From c97904fdd63670c248247b4e4beaeea7abef7d7e Mon Sep 17 00:00:00 2001 From: morvanzhou Date: Wed, 29 May 2024 02:47:35 +0800 Subject: [PATCH] feat(llm): add hunyuan --- src/retk/config.py | 2 + src/retk/const/response_codes.py | 9 ++ src/retk/core/ai/__init__.py | 0 src/retk/core/ai/llm/__init__.py | 3 + src/retk/core/ai/llm/base.py | 22 ++++ src/retk/core/ai/llm/hunyuan.py | 214 +++++++++++++++++++++++++++++++ tests/test_llm.py | 103 +++++++++++++++ 7 files changed, 353 insertions(+) create mode 100644 src/retk/core/ai/__init__.py create mode 100644 src/retk/core/ai/llm/__init__.py create mode 100644 src/retk/core/ai/llm/base.py create mode 100644 src/retk/core/ai/llm/hunyuan.py create mode 100644 tests/test_llm.py diff --git a/src/retk/config.py b/src/retk/config.py index 5754bed..f54553d 100644 --- a/src/retk/config.py +++ b/src/retk/config.py @@ -24,6 +24,8 @@ class Settings(BaseSettings): DB_HOST: str = Field(env='DB_HOST', default="") DB_PORT: int = Field(env='DB_PORT', default=-1) DB_SALT: str = Field(env='BD_SALT', default="") + HUNYUAN_SECRET_ID: str = Field(env='HUNYUAN_SECRET_ID', default="") + HUNYUAN_SECRET_KEY: str = Field(env='HUNYUAN_SECRET_KEY', default="") ES_USER: str = Field(env='ES_USER', default="") ES_PASSWORD: str = Field(env='ES_PASSWORD', default="") ES_HOSTS: str = Field(env='ES_HOSTS', default="") diff --git a/src/retk/const/response_codes.py b/src/retk/const/response_codes.py index 3a4c882..339cd63 100644 --- a/src/retk/const/response_codes.py +++ b/src/retk/const/response_codes.py @@ -45,6 +45,9 @@ class CodeEnum(IntEnum): INVALID_PARAMS = 35 INVALID_SCHEDULE_JOB_ID = 36 NOTICE_NOT_FOUND = 37 + LLM_TIMEOUT = 38 + LLM_SERVICE_ERROR = 39 + LLM_NO_CHOICE = 40 @dataclass @@ -102,6 +105,9 @@ class CodeMessage: CodeEnum.INVALID_PARAMS: CodeMessage(zh="无效参数", en="Invalid parameter"), CodeEnum.INVALID_SCHEDULE_JOB_ID: CodeMessage(zh="无效的任务 ID", en="Invalid schedule job ID"), CodeEnum.NOTICE_NOT_FOUND: CodeMessage(zh="通知未找到", en="Notice not found"), + CodeEnum.LLM_TIMEOUT: CodeMessage(zh="模型超时", en="Model timeout"), + CodeEnum.LLM_SERVICE_ERROR: CodeMessage(zh="模型服务错误", en="Model service error"), + CodeEnum.LLM_NO_CHOICE: CodeMessage(zh="无回复", en="No response"), } CODE2STATUS_CODE: Dict[CodeEnum, int] = { @@ -143,6 +149,9 @@ class CodeMessage: CodeEnum.INVALID_PARAMS: 400, CodeEnum.INVALID_SCHEDULE_JOB_ID: 400, CodeEnum.NOTICE_NOT_FOUND: 404, + CodeEnum.LLM_TIMEOUT: 408, + CodeEnum.LLM_SERVICE_ERROR: 500, + CodeEnum.LLM_NO_CHOICE: 404, } diff --git a/src/retk/core/ai/__init__.py b/src/retk/core/ai/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/retk/core/ai/llm/__init__.py b/src/retk/core/ai/llm/__init__.py new file mode 100644 index 0000000..347e18d --- /dev/null +++ b/src/retk/core/ai/llm/__init__.py @@ -0,0 +1,3 @@ +from . import ( + hunyuan, +) diff --git a/src/retk/core/ai/llm/base.py b/src/retk/core/ai/llm/base.py new file mode 100644 index 0000000..a5aedf5 --- /dev/null +++ b/src/retk/core/ai/llm/base.py @@ -0,0 +1,22 @@ +from abc import ABC, abstractmethod +from typing import List, Dict, Literal, AsyncIterable, Tuple + +from retk import const + +MessagesType = List[Dict[Literal["Role", "Content"], str]] + + +class BaseLLM(ABC): + name: str = None + + def __init__(self): + if self.name is None: + raise ValueError("llm model name must be defined") + + @abstractmethod + async def complete(self, *args, **kwargs) -> Tuple[str, const.CodeEnum]: + ... + + @abstractmethod + async def stream_complete(self, *args, **kwargs) -> AsyncIterable[Tuple[bytes, const.CodeEnum]]: + ... diff --git a/src/retk/core/ai/llm/hunyuan.py b/src/retk/core/ai/llm/hunyuan.py new file mode 100644 index 0000000..88e8e2c --- /dev/null +++ b/src/retk/core/ai/llm/hunyuan.py @@ -0,0 +1,214 @@ +import hashlib +import hmac +import json +import time +from abc import ABC +from datetime import datetime +from typing import TypedDict, Tuple, Dict, AsyncIterable + +import httpx + +from retk import config, const +from retk.logger import logger +from .base import BaseLLM, MessagesType + +Headers = TypedDict("Headers", { + "Authorization": str, + "Content-Type": str, + "Host": str, + "X-TC-Action": str, + "X-TC-Timestamp": str, + "X-TC-Version": str, + "X-TC-Language": str, +}) + + +# 计算签名摘要函数 +def sign(key, msg): + return hmac.new(key, msg.encode("utf-8"), hashlib.sha256).digest() + + +class _Hunyuan(BaseLLM, ABC): + service = "hunyuan" + host = "hunyuan.tencentcloudapi.com" + version = "2023-09-01" + endpoint = f"https://{host}" + + def __init__( + self, + name: str, + top_p: float = 0.9, + temperature: float = 0.7, + timeout: float = 60., + ): + self.name = name + super().__init__() + self.top_p = top_p + self.temperature = temperature + self.timeout = timeout + + self.secret_id = config.get_settings().HUNYUAN_SECRET_ID + self.secret_key = config.get_settings().HUNYUAN_SECRET_KEY + + def get_auth(self, action: str, payload: bytes, timestamp: int, content_type: str) -> str: + algorithm = "TC3-HMAC-SHA256" + date = datetime.utcfromtimestamp(timestamp).strftime("%Y-%m-%d") + + # ************* 步骤 1:拼接规范请求串 ************* + http_request_method = "POST" + canonical_uri = "/" + canonical_querystring = "" + canonical_headers = f"content-type:{content_type}\nhost:{self.host}\nx-tc-action:{action.lower()}\n" + signed_headers = "content-type;host;x-tc-action" + hashed_request_payload = hashlib.sha256(payload).hexdigest() + canonical_request = f"{http_request_method}\n" \ + f"{canonical_uri}\n{canonical_querystring}\n{canonical_headers}\n" \ + f"{signed_headers}\n{hashed_request_payload}" + + # ************* 步骤 2:拼接待签名字符串 ************* + credential_scope = f"{date}/{self.service}/tc3_request" + hashed_canonical_request = hashlib.sha256(canonical_request.encode("utf-8")).hexdigest() + string_to_sign = f"{algorithm}\n{timestamp}\n{credential_scope}\n{hashed_canonical_request}" + + # ************* 步骤 3:计算签名 ************* + secret_date = sign(f"TC3{self.secret_key}".encode("utf-8"), date) + secret_service = sign(secret_date, self.service) + secret_signing = sign(secret_service, "tc3_request") + signature = hmac.new(secret_signing, string_to_sign.encode("utf-8"), hashlib.sha256).hexdigest() + + # ************* 步骤 4:拼接 Authorization ************* + authorization = f"{algorithm}" \ + f" Credential={self.secret_id}/{credential_scope}," \ + f" SignedHeaders={signed_headers}," \ + f" Signature={signature}" + return authorization + + def get_headers(self, action: str, payload: bytes) -> Headers: + ct = "application/json" + timestamp = int(time.time()) + authorization = self.get_auth(action=action, payload=payload, timestamp=timestamp, content_type=ct) + return { + "Authorization": authorization, + "Host": self.host, + "X-TC-Action": action, + "X-TC-Version": self.version, + "X-TC-Timestamp": str(timestamp), + "X-TC-Language": "zh-CN", + "Content-Type": ct, + } + + def get_payload(self, messages: MessagesType, stream: bool) -> bytes: + return json.dumps( + { + "Model": self.name, + "Messages": messages, + "Stream": stream, + "TopP": self.top_p, + "Temperature": self.temperature, + "EnableEnhancement": False, + }, ensure_ascii=False, separators=(",", ":") + ).encode("utf-8") + + @staticmethod + def handle_err(error: Dict): + msg = error.get("Message") + code = error.get("Code") + logger.error(f"Model error code={code}, msg={msg}") + if code == 4001: + ccode = const.CodeEnum.LLM_TIMEOUT + else: + ccode = const.CodeEnum.LLM_SERVICE_ERROR + return msg, ccode + + @staticmethod + def handle_normal_response(rj: Dict, stream: bool) -> Tuple[str, const.CodeEnum]: + choices = rj["Choices"] + if len(choices) == 0: + return "No response", const.CodeEnum.LLM_NO_CHOICE + choice = choices[0] + m = choice["Delta"] if stream else choice["Message"] + return m["Content"], const.CodeEnum.OK + + async def complete(self, messages: MessagesType) -> Tuple[str, const.CodeEnum]: + action = "ChatCompletions" + payload = self.get_payload(messages=messages, stream=False) + headers = self.get_headers(action=action, payload=payload) + + async with httpx.AsyncClient() as client: + try: + resp = await client.post( + url=self.endpoint, + headers=headers, + content=payload, + follow_redirects=False, + timeout=self.timeout, + ) + except ( + httpx.ConnectTimeout, + httpx.ConnectError, + httpx.ReadTimeout, + ) as e: + logger.error(f"Model error: {e}") + return "Model timeout, please try later", const.CodeEnum.LLM_TIMEOUT + except httpx.HTTPError as e: + logger.error(f"Model error: {e}") + return "Model error, please try later", const.CodeEnum.LLM_SERVICE_ERROR + if resp.status_code != 200: + logger.error(f"Model error: {resp.text}") + return "Model error, please try later", const.CodeEnum.LLM_SERVICE_ERROR + + rj = resp.json()["Response"] + error = rj.get("Error") + if error is not None: + return self.handle_err(error) + return self.handle_normal_response(rj=rj, stream=False) + + async def stream_complete(self, messages: MessagesType) -> AsyncIterable[Tuple[bytes, const.CodeEnum]]: + action = "ChatCompletions" + payload = self.get_payload(messages=messages, stream=True) + headers = self.get_headers(action=action, payload=payload) + + async with httpx.AsyncClient() as client: + async with client.stream( + method="POST", + url=self.endpoint, + headers=headers, + content=payload, + follow_redirects=False, + timeout=self.timeout, + ) as resp: + if resp.status_code != 200: + logger.error(f"Model error: {resp.text}") + yield "Model error, please try later", const.CodeEnum.LLM_SERVICE_ERROR + return + + async for chunk in resp.aiter_bytes(): + yield chunk, const.CodeEnum.OK + + +class HunyuanPro(_Hunyuan): + model_name = "hunyuan-pro" + + def __init__(self, top_p: float = 0.9, temperature: float = 0.7): + super().__init__(name=self.model_name, top_p=top_p, temperature=temperature) + + +class HunyuanStandard(_Hunyuan): + model_name = "hunyuan-standard" + + def __init__(self, top_p: float = 0.9, temperature: float = 0.7): + super().__init__(name=self.model_name, top_p=top_p, temperature=temperature) + + +class HunyuanStandard256K(_Hunyuan): + model_name = "hunyuan-standard-256K" + + def __init__(self, top_p: float = 0.9, temperature: float = 0.7): + super().__init__(name=self.model_name, top_p=top_p, temperature=temperature) + + +class HunyuanLite(_Hunyuan): + model_name = "hunyuan-lite" + + def __init__(self, top_p: float = 0.9, temperature: float = 0.7): + super().__init__(name=self.model_name, top_p=top_p, temperature=temperature) diff --git a/tests/test_llm.py b/tests/test_llm.py new file mode 100644 index 0000000..f44c213 --- /dev/null +++ b/tests/test_llm.py @@ -0,0 +1,103 @@ +import json +import os +import unittest +from unittest.mock import patch, AsyncMock + +from httpx import Response + +from retk import const +from retk.core.ai import llm +from . import utils + + +class HunyuanTest(unittest.IsolatedAsyncioTestCase): + @classmethod + def setUpClass(cls): + os.environ["HUNYUAN_SECRET_ID"] = "testid" + os.environ["HUNYUAN_SECRET_KEY"] = "testkey" + utils.set_env(".env.test.local") + + @classmethod + def tearDownClass(cls) -> None: + utils.drop_env(".env.test.local") + + def test_authorization(self): + payload = { + "Model": "hunyuan-lite", + "Messages": [{"Role": "user", "Content": "你是谁"}], + "Stream": False, + "TopP": 0.9, + "Temperature": 0.7, + "EnableEnhancement": False, + } + payload = json.dumps(payload, ensure_ascii=False, separators=(",", ":")).encode("utf-8") + m = llm.hunyuan.HunyuanLite() + auth = m.get_auth( + action="ChatCompletions", + payload=payload, + timestamp=1716913478, + content_type="application/json" + ) + sid = os.environ["HUNYUAN_SECRET_ID"] + self.assertEqual( + f"TC3-HMAC-SHA256 Credential={sid}/2024-05-28/hunyuan/tc3_request," + f" SignedHeaders=content-type;host;x-tc-action," + f" Signature=f628e271c4acdf72a4618fe59e3a31591f2ddedbd44e5befe6e02c05949b01b3", + auth) + + async def test_hunyuan_auth_failed(self): + m = llm.hunyuan.HunyuanLite() + self.assertEqual("hunyuan-lite", m.model_name) + text, code = await m.complete([{"Role": "user", "Content": "你是谁"}]) + self.assertEqual(const.CodeEnum.LLM_SERVICE_ERROR, code, msg=text) + self.assertEqual("SecretId不存在,请输入正确的密钥。", text) + + @patch("httpx.AsyncClient.post", new_callable=AsyncMock) + async def test_hunyuan_complete(self, mock_post): + mock_post.return_value = Response( + status_code=200, + json={ + "Response": { + "Choices": [ + { + "Message": { + "Role": "assistant", + "Content": "我是一个AI助手。" + } + } + ] + } + } + ) + m = llm.hunyuan.HunyuanLite() + self.assertEqual("hunyuan-lite", m.model_name) + text, code = await m.complete([{"Role": "user", "Content": "你是谁"}]) + self.assertEqual(const.CodeEnum.OK, code, msg=text) + self.assertEqual("我是一个AI助手。", text) + mock_post.assert_called_once() + + # async def test_hunyuan_stream_complete(self): + # m = llm.hunyuan.HunyuanLite() + # + # async for b, code in m.stream_complete([{"Role": "user", "Content": "你是谁"}]): + # self.assertEqual(const.CodeEnum.OK, code) + # s = b.decode("utf-8") + # lines = s.splitlines() + # for line in lines: + # if line.strip() == "": + # continue + # self.assertTrue(line.startswith("data: ")) + # json_str = line[6:] + # json_data = json.loads(json_str) + # self.assertIn("Choices", json_data) + # choices = json_data["Choices"] + # self.assertEqual(1, len(choices)) + # choice = choices[0] + # self.assertIn("Delta", choice) + # delta = choice["Delta"] + # self.assertIn("Content", delta) + # if choice["FinishReason"] == "": + # self.assertGreater(len(delta["Content"]), 0) + # else: + # self.assertEqual("", delta["Content"]) + # self.assertEqual("assistant", delta["Role"])