Skip to content

Commit

Permalink
feat(llm): add hunyuan
Browse files Browse the repository at this point in the history
  • Loading branch information
MorvanZhou committed May 28, 2024
1 parent 670d2c0 commit c97904f
Show file tree
Hide file tree
Showing 7 changed files with 353 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/retk/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class Settings(BaseSettings):
DB_HOST: str = Field(env='DB_HOST', default="")
DB_PORT: int = Field(env='DB_PORT', default=-1)
DB_SALT: str = Field(env='BD_SALT', default="")
HUNYUAN_SECRET_ID: str = Field(env='HUNYUAN_SECRET_ID', default="")
HUNYUAN_SECRET_KEY: str = Field(env='HUNYUAN_SECRET_KEY', default="")
ES_USER: str = Field(env='ES_USER', default="")
ES_PASSWORD: str = Field(env='ES_PASSWORD', default="")
ES_HOSTS: str = Field(env='ES_HOSTS', default="")
Expand Down
9 changes: 9 additions & 0 deletions src/retk/const/response_codes.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ class CodeEnum(IntEnum):
INVALID_PARAMS = 35
INVALID_SCHEDULE_JOB_ID = 36
NOTICE_NOT_FOUND = 37
LLM_TIMEOUT = 38
LLM_SERVICE_ERROR = 39
LLM_NO_CHOICE = 40


@dataclass
Expand Down Expand Up @@ -102,6 +105,9 @@ class CodeMessage:
CodeEnum.INVALID_PARAMS: CodeMessage(zh="无效参数", en="Invalid parameter"),
CodeEnum.INVALID_SCHEDULE_JOB_ID: CodeMessage(zh="无效的任务 ID", en="Invalid schedule job ID"),
CodeEnum.NOTICE_NOT_FOUND: CodeMessage(zh="通知未找到", en="Notice not found"),
CodeEnum.LLM_TIMEOUT: CodeMessage(zh="模型超时", en="Model timeout"),
CodeEnum.LLM_SERVICE_ERROR: CodeMessage(zh="模型服务错误", en="Model service error"),
CodeEnum.LLM_NO_CHOICE: CodeMessage(zh="无回复", en="No response"),
}

CODE2STATUS_CODE: Dict[CodeEnum, int] = {
Expand Down Expand Up @@ -143,6 +149,9 @@ class CodeMessage:
CodeEnum.INVALID_PARAMS: 400,
CodeEnum.INVALID_SCHEDULE_JOB_ID: 400,
CodeEnum.NOTICE_NOT_FOUND: 404,
CodeEnum.LLM_TIMEOUT: 408,
CodeEnum.LLM_SERVICE_ERROR: 500,
CodeEnum.LLM_NO_CHOICE: 404,
}


Expand Down
Empty file added src/retk/core/ai/__init__.py
Empty file.
3 changes: 3 additions & 0 deletions src/retk/core/ai/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from . import (
hunyuan,
)
22 changes: 22 additions & 0 deletions src/retk/core/ai/llm/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from abc import ABC, abstractmethod
from typing import List, Dict, Literal, AsyncIterable, Tuple

from retk import const

MessagesType = List[Dict[Literal["Role", "Content"], str]]


class BaseLLM(ABC):
name: str = None

def __init__(self):
if self.name is None:
raise ValueError("llm model name must be defined")

@abstractmethod
async def complete(self, *args, **kwargs) -> Tuple[str, const.CodeEnum]:
...

@abstractmethod
async def stream_complete(self, *args, **kwargs) -> AsyncIterable[Tuple[bytes, const.CodeEnum]]:
...
214 changes: 214 additions & 0 deletions src/retk/core/ai/llm/hunyuan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
import hashlib
import hmac
import json
import time
from abc import ABC
from datetime import datetime
from typing import TypedDict, Tuple, Dict, AsyncIterable

import httpx

from retk import config, const
from retk.logger import logger
from .base import BaseLLM, MessagesType

Headers = TypedDict("Headers", {
"Authorization": str,
"Content-Type": str,
"Host": str,
"X-TC-Action": str,
"X-TC-Timestamp": str,
"X-TC-Version": str,
"X-TC-Language": str,
})


# 计算签名摘要函数
def sign(key, msg):
return hmac.new(key, msg.encode("utf-8"), hashlib.sha256).digest()


class _Hunyuan(BaseLLM, ABC):
service = "hunyuan"
host = "hunyuan.tencentcloudapi.com"
version = "2023-09-01"
endpoint = f"https://{host}"

def __init__(
self,
name: str,
top_p: float = 0.9,
temperature: float = 0.7,
timeout: float = 60.,
):
self.name = name
super().__init__()
self.top_p = top_p
self.temperature = temperature
self.timeout = timeout

self.secret_id = config.get_settings().HUNYUAN_SECRET_ID
self.secret_key = config.get_settings().HUNYUAN_SECRET_KEY

def get_auth(self, action: str, payload: bytes, timestamp: int, content_type: str) -> str:
algorithm = "TC3-HMAC-SHA256"
date = datetime.utcfromtimestamp(timestamp).strftime("%Y-%m-%d")

# ************* 步骤 1:拼接规范请求串 *************
http_request_method = "POST"
canonical_uri = "/"
canonical_querystring = ""
canonical_headers = f"content-type:{content_type}\nhost:{self.host}\nx-tc-action:{action.lower()}\n"
signed_headers = "content-type;host;x-tc-action"
hashed_request_payload = hashlib.sha256(payload).hexdigest()
canonical_request = f"{http_request_method}\n" \
f"{canonical_uri}\n{canonical_querystring}\n{canonical_headers}\n" \
f"{signed_headers}\n{hashed_request_payload}"

# ************* 步骤 2:拼接待签名字符串 *************
credential_scope = f"{date}/{self.service}/tc3_request"
hashed_canonical_request = hashlib.sha256(canonical_request.encode("utf-8")).hexdigest()
string_to_sign = f"{algorithm}\n{timestamp}\n{credential_scope}\n{hashed_canonical_request}"

# ************* 步骤 3:计算签名 *************
secret_date = sign(f"TC3{self.secret_key}".encode("utf-8"), date)
secret_service = sign(secret_date, self.service)
secret_signing = sign(secret_service, "tc3_request")
signature = hmac.new(secret_signing, string_to_sign.encode("utf-8"), hashlib.sha256).hexdigest()

# ************* 步骤 4:拼接 Authorization *************
authorization = f"{algorithm}" \
f" Credential={self.secret_id}/{credential_scope}," \
f" SignedHeaders={signed_headers}," \
f" Signature={signature}"
return authorization

def get_headers(self, action: str, payload: bytes) -> Headers:
ct = "application/json"
timestamp = int(time.time())
authorization = self.get_auth(action=action, payload=payload, timestamp=timestamp, content_type=ct)
return {
"Authorization": authorization,
"Host": self.host,
"X-TC-Action": action,
"X-TC-Version": self.version,
"X-TC-Timestamp": str(timestamp),
"X-TC-Language": "zh-CN",
"Content-Type": ct,
}

def get_payload(self, messages: MessagesType, stream: bool) -> bytes:
return json.dumps(
{
"Model": self.name,
"Messages": messages,
"Stream": stream,
"TopP": self.top_p,
"Temperature": self.temperature,
"EnableEnhancement": False,
}, ensure_ascii=False, separators=(",", ":")
).encode("utf-8")

@staticmethod
def handle_err(error: Dict):
msg = error.get("Message")
code = error.get("Code")
logger.error(f"Model error code={code}, msg={msg}")
if code == 4001:
ccode = const.CodeEnum.LLM_TIMEOUT
else:
ccode = const.CodeEnum.LLM_SERVICE_ERROR
return msg, ccode

@staticmethod
def handle_normal_response(rj: Dict, stream: bool) -> Tuple[str, const.CodeEnum]:
choices = rj["Choices"]
if len(choices) == 0:
return "No response", const.CodeEnum.LLM_NO_CHOICE
choice = choices[0]
m = choice["Delta"] if stream else choice["Message"]
return m["Content"], const.CodeEnum.OK

async def complete(self, messages: MessagesType) -> Tuple[str, const.CodeEnum]:
action = "ChatCompletions"
payload = self.get_payload(messages=messages, stream=False)
headers = self.get_headers(action=action, payload=payload)

async with httpx.AsyncClient() as client:
try:
resp = await client.post(
url=self.endpoint,
headers=headers,
content=payload,
follow_redirects=False,
timeout=self.timeout,
)
except (
httpx.ConnectTimeout,
httpx.ConnectError,
httpx.ReadTimeout,
) as e:
logger.error(f"Model error: {e}")
return "Model timeout, please try later", const.CodeEnum.LLM_TIMEOUT
except httpx.HTTPError as e:
logger.error(f"Model error: {e}")
return "Model error, please try later", const.CodeEnum.LLM_SERVICE_ERROR
if resp.status_code != 200:
logger.error(f"Model error: {resp.text}")
return "Model error, please try later", const.CodeEnum.LLM_SERVICE_ERROR

rj = resp.json()["Response"]
error = rj.get("Error")
if error is not None:
return self.handle_err(error)
return self.handle_normal_response(rj=rj, stream=False)

async def stream_complete(self, messages: MessagesType) -> AsyncIterable[Tuple[bytes, const.CodeEnum]]:
action = "ChatCompletions"
payload = self.get_payload(messages=messages, stream=True)
headers = self.get_headers(action=action, payload=payload)

async with httpx.AsyncClient() as client:
async with client.stream(
method="POST",
url=self.endpoint,
headers=headers,
content=payload,
follow_redirects=False,
timeout=self.timeout,
) as resp:
if resp.status_code != 200:
logger.error(f"Model error: {resp.text}")
yield "Model error, please try later", const.CodeEnum.LLM_SERVICE_ERROR
return

async for chunk in resp.aiter_bytes():
yield chunk, const.CodeEnum.OK


class HunyuanPro(_Hunyuan):
model_name = "hunyuan-pro"

def __init__(self, top_p: float = 0.9, temperature: float = 0.7):
super().__init__(name=self.model_name, top_p=top_p, temperature=temperature)


class HunyuanStandard(_Hunyuan):
model_name = "hunyuan-standard"

def __init__(self, top_p: float = 0.9, temperature: float = 0.7):
super().__init__(name=self.model_name, top_p=top_p, temperature=temperature)


class HunyuanStandard256K(_Hunyuan):
model_name = "hunyuan-standard-256K"

def __init__(self, top_p: float = 0.9, temperature: float = 0.7):
super().__init__(name=self.model_name, top_p=top_p, temperature=temperature)


class HunyuanLite(_Hunyuan):
model_name = "hunyuan-lite"

def __init__(self, top_p: float = 0.9, temperature: float = 0.7):
super().__init__(name=self.model_name, top_p=top_p, temperature=temperature)
103 changes: 103 additions & 0 deletions tests/test_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import json
import os
import unittest
from unittest.mock import patch, AsyncMock

from httpx import Response

from retk import const
from retk.core.ai import llm
from . import utils


class HunyuanTest(unittest.IsolatedAsyncioTestCase):
@classmethod
def setUpClass(cls):
os.environ["HUNYUAN_SECRET_ID"] = "testid"
os.environ["HUNYUAN_SECRET_KEY"] = "testkey"
utils.set_env(".env.test.local")

@classmethod
def tearDownClass(cls) -> None:
utils.drop_env(".env.test.local")

def test_authorization(self):
payload = {
"Model": "hunyuan-lite",
"Messages": [{"Role": "user", "Content": "你是谁"}],
"Stream": False,
"TopP": 0.9,
"Temperature": 0.7,
"EnableEnhancement": False,
}
payload = json.dumps(payload, ensure_ascii=False, separators=(",", ":")).encode("utf-8")
m = llm.hunyuan.HunyuanLite()
auth = m.get_auth(
action="ChatCompletions",
payload=payload,
timestamp=1716913478,
content_type="application/json"
)
sid = os.environ["HUNYUAN_SECRET_ID"]
self.assertEqual(
f"TC3-HMAC-SHA256 Credential={sid}/2024-05-28/hunyuan/tc3_request,"
f" SignedHeaders=content-type;host;x-tc-action,"
f" Signature=f628e271c4acdf72a4618fe59e3a31591f2ddedbd44e5befe6e02c05949b01b3",
auth)

async def test_hunyuan_auth_failed(self):
m = llm.hunyuan.HunyuanLite()
self.assertEqual("hunyuan-lite", m.model_name)
text, code = await m.complete([{"Role": "user", "Content": "你是谁"}])
self.assertEqual(const.CodeEnum.LLM_SERVICE_ERROR, code, msg=text)
self.assertEqual("SecretId不存在,请输入正确的密钥。", text)

@patch("httpx.AsyncClient.post", new_callable=AsyncMock)
async def test_hunyuan_complete(self, mock_post):
mock_post.return_value = Response(
status_code=200,
json={
"Response": {
"Choices": [
{
"Message": {
"Role": "assistant",
"Content": "我是一个AI助手。"
}
}
]
}
}
)
m = llm.hunyuan.HunyuanLite()
self.assertEqual("hunyuan-lite", m.model_name)
text, code = await m.complete([{"Role": "user", "Content": "你是谁"}])
self.assertEqual(const.CodeEnum.OK, code, msg=text)
self.assertEqual("我是一个AI助手。", text)
mock_post.assert_called_once()

# async def test_hunyuan_stream_complete(self):
# m = llm.hunyuan.HunyuanLite()
#
# async for b, code in m.stream_complete([{"Role": "user", "Content": "你是谁"}]):
# self.assertEqual(const.CodeEnum.OK, code)
# s = b.decode("utf-8")
# lines = s.splitlines()
# for line in lines:
# if line.strip() == "":
# continue
# self.assertTrue(line.startswith("data: "))
# json_str = line[6:]
# json_data = json.loads(json_str)
# self.assertIn("Choices", json_data)
# choices = json_data["Choices"]
# self.assertEqual(1, len(choices))
# choice = choices[0]
# self.assertIn("Delta", choice)
# delta = choice["Delta"]
# self.assertIn("Content", delta)
# if choice["FinishReason"] == "":
# self.assertGreater(len(delta["Content"]), 0)
# else:
# self.assertEqual("", delta["Content"])
# self.assertEqual("assistant", delta["Role"])

0 comments on commit c97904f

Please sign in to comment.