-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
670d2c0
commit c97904f
Showing
7 changed files
with
353 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from . import ( | ||
hunyuan, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import List, Dict, Literal, AsyncIterable, Tuple | ||
|
||
from retk import const | ||
|
||
MessagesType = List[Dict[Literal["Role", "Content"], str]] | ||
|
||
|
||
class BaseLLM(ABC): | ||
name: str = None | ||
|
||
def __init__(self): | ||
if self.name is None: | ||
raise ValueError("llm model name must be defined") | ||
|
||
@abstractmethod | ||
async def complete(self, *args, **kwargs) -> Tuple[str, const.CodeEnum]: | ||
... | ||
|
||
@abstractmethod | ||
async def stream_complete(self, *args, **kwargs) -> AsyncIterable[Tuple[bytes, const.CodeEnum]]: | ||
... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,214 @@ | ||
import hashlib | ||
import hmac | ||
import json | ||
import time | ||
from abc import ABC | ||
from datetime import datetime | ||
from typing import TypedDict, Tuple, Dict, AsyncIterable | ||
|
||
import httpx | ||
|
||
from retk import config, const | ||
from retk.logger import logger | ||
from .base import BaseLLM, MessagesType | ||
|
||
Headers = TypedDict("Headers", { | ||
"Authorization": str, | ||
"Content-Type": str, | ||
"Host": str, | ||
"X-TC-Action": str, | ||
"X-TC-Timestamp": str, | ||
"X-TC-Version": str, | ||
"X-TC-Language": str, | ||
}) | ||
|
||
|
||
# 计算签名摘要函数 | ||
def sign(key, msg): | ||
return hmac.new(key, msg.encode("utf-8"), hashlib.sha256).digest() | ||
|
||
|
||
class _Hunyuan(BaseLLM, ABC): | ||
service = "hunyuan" | ||
host = "hunyuan.tencentcloudapi.com" | ||
version = "2023-09-01" | ||
endpoint = f"https://{host}" | ||
|
||
def __init__( | ||
self, | ||
name: str, | ||
top_p: float = 0.9, | ||
temperature: float = 0.7, | ||
timeout: float = 60., | ||
): | ||
self.name = name | ||
super().__init__() | ||
self.top_p = top_p | ||
self.temperature = temperature | ||
self.timeout = timeout | ||
|
||
self.secret_id = config.get_settings().HUNYUAN_SECRET_ID | ||
self.secret_key = config.get_settings().HUNYUAN_SECRET_KEY | ||
|
||
def get_auth(self, action: str, payload: bytes, timestamp: int, content_type: str) -> str: | ||
algorithm = "TC3-HMAC-SHA256" | ||
date = datetime.utcfromtimestamp(timestamp).strftime("%Y-%m-%d") | ||
|
||
# ************* 步骤 1:拼接规范请求串 ************* | ||
http_request_method = "POST" | ||
canonical_uri = "/" | ||
canonical_querystring = "" | ||
canonical_headers = f"content-type:{content_type}\nhost:{self.host}\nx-tc-action:{action.lower()}\n" | ||
signed_headers = "content-type;host;x-tc-action" | ||
hashed_request_payload = hashlib.sha256(payload).hexdigest() | ||
canonical_request = f"{http_request_method}\n" \ | ||
f"{canonical_uri}\n{canonical_querystring}\n{canonical_headers}\n" \ | ||
f"{signed_headers}\n{hashed_request_payload}" | ||
|
||
# ************* 步骤 2:拼接待签名字符串 ************* | ||
credential_scope = f"{date}/{self.service}/tc3_request" | ||
hashed_canonical_request = hashlib.sha256(canonical_request.encode("utf-8")).hexdigest() | ||
string_to_sign = f"{algorithm}\n{timestamp}\n{credential_scope}\n{hashed_canonical_request}" | ||
|
||
# ************* 步骤 3:计算签名 ************* | ||
secret_date = sign(f"TC3{self.secret_key}".encode("utf-8"), date) | ||
secret_service = sign(secret_date, self.service) | ||
secret_signing = sign(secret_service, "tc3_request") | ||
signature = hmac.new(secret_signing, string_to_sign.encode("utf-8"), hashlib.sha256).hexdigest() | ||
|
||
# ************* 步骤 4:拼接 Authorization ************* | ||
authorization = f"{algorithm}" \ | ||
f" Credential={self.secret_id}/{credential_scope}," \ | ||
f" SignedHeaders={signed_headers}," \ | ||
f" Signature={signature}" | ||
return authorization | ||
|
||
def get_headers(self, action: str, payload: bytes) -> Headers: | ||
ct = "application/json" | ||
timestamp = int(time.time()) | ||
authorization = self.get_auth(action=action, payload=payload, timestamp=timestamp, content_type=ct) | ||
return { | ||
"Authorization": authorization, | ||
"Host": self.host, | ||
"X-TC-Action": action, | ||
"X-TC-Version": self.version, | ||
"X-TC-Timestamp": str(timestamp), | ||
"X-TC-Language": "zh-CN", | ||
"Content-Type": ct, | ||
} | ||
|
||
def get_payload(self, messages: MessagesType, stream: bool) -> bytes: | ||
return json.dumps( | ||
{ | ||
"Model": self.name, | ||
"Messages": messages, | ||
"Stream": stream, | ||
"TopP": self.top_p, | ||
"Temperature": self.temperature, | ||
"EnableEnhancement": False, | ||
}, ensure_ascii=False, separators=(",", ":") | ||
).encode("utf-8") | ||
|
||
@staticmethod | ||
def handle_err(error: Dict): | ||
msg = error.get("Message") | ||
code = error.get("Code") | ||
logger.error(f"Model error code={code}, msg={msg}") | ||
if code == 4001: | ||
ccode = const.CodeEnum.LLM_TIMEOUT | ||
else: | ||
ccode = const.CodeEnum.LLM_SERVICE_ERROR | ||
return msg, ccode | ||
|
||
@staticmethod | ||
def handle_normal_response(rj: Dict, stream: bool) -> Tuple[str, const.CodeEnum]: | ||
choices = rj["Choices"] | ||
if len(choices) == 0: | ||
return "No response", const.CodeEnum.LLM_NO_CHOICE | ||
choice = choices[0] | ||
m = choice["Delta"] if stream else choice["Message"] | ||
return m["Content"], const.CodeEnum.OK | ||
|
||
async def complete(self, messages: MessagesType) -> Tuple[str, const.CodeEnum]: | ||
action = "ChatCompletions" | ||
payload = self.get_payload(messages=messages, stream=False) | ||
headers = self.get_headers(action=action, payload=payload) | ||
|
||
async with httpx.AsyncClient() as client: | ||
try: | ||
resp = await client.post( | ||
url=self.endpoint, | ||
headers=headers, | ||
content=payload, | ||
follow_redirects=False, | ||
timeout=self.timeout, | ||
) | ||
except ( | ||
httpx.ConnectTimeout, | ||
httpx.ConnectError, | ||
httpx.ReadTimeout, | ||
) as e: | ||
logger.error(f"Model error: {e}") | ||
return "Model timeout, please try later", const.CodeEnum.LLM_TIMEOUT | ||
except httpx.HTTPError as e: | ||
logger.error(f"Model error: {e}") | ||
return "Model error, please try later", const.CodeEnum.LLM_SERVICE_ERROR | ||
if resp.status_code != 200: | ||
logger.error(f"Model error: {resp.text}") | ||
return "Model error, please try later", const.CodeEnum.LLM_SERVICE_ERROR | ||
|
||
rj = resp.json()["Response"] | ||
error = rj.get("Error") | ||
if error is not None: | ||
return self.handle_err(error) | ||
return self.handle_normal_response(rj=rj, stream=False) | ||
|
||
async def stream_complete(self, messages: MessagesType) -> AsyncIterable[Tuple[bytes, const.CodeEnum]]: | ||
action = "ChatCompletions" | ||
payload = self.get_payload(messages=messages, stream=True) | ||
headers = self.get_headers(action=action, payload=payload) | ||
|
||
async with httpx.AsyncClient() as client: | ||
async with client.stream( | ||
method="POST", | ||
url=self.endpoint, | ||
headers=headers, | ||
content=payload, | ||
follow_redirects=False, | ||
timeout=self.timeout, | ||
) as resp: | ||
if resp.status_code != 200: | ||
logger.error(f"Model error: {resp.text}") | ||
yield "Model error, please try later", const.CodeEnum.LLM_SERVICE_ERROR | ||
return | ||
|
||
async for chunk in resp.aiter_bytes(): | ||
yield chunk, const.CodeEnum.OK | ||
|
||
|
||
class HunyuanPro(_Hunyuan): | ||
model_name = "hunyuan-pro" | ||
|
||
def __init__(self, top_p: float = 0.9, temperature: float = 0.7): | ||
super().__init__(name=self.model_name, top_p=top_p, temperature=temperature) | ||
|
||
|
||
class HunyuanStandard(_Hunyuan): | ||
model_name = "hunyuan-standard" | ||
|
||
def __init__(self, top_p: float = 0.9, temperature: float = 0.7): | ||
super().__init__(name=self.model_name, top_p=top_p, temperature=temperature) | ||
|
||
|
||
class HunyuanStandard256K(_Hunyuan): | ||
model_name = "hunyuan-standard-256K" | ||
|
||
def __init__(self, top_p: float = 0.9, temperature: float = 0.7): | ||
super().__init__(name=self.model_name, top_p=top_p, temperature=temperature) | ||
|
||
|
||
class HunyuanLite(_Hunyuan): | ||
model_name = "hunyuan-lite" | ||
|
||
def __init__(self, top_p: float = 0.9, temperature: float = 0.7): | ||
super().__init__(name=self.model_name, top_p=top_p, temperature=temperature) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import json | ||
import os | ||
import unittest | ||
from unittest.mock import patch, AsyncMock | ||
|
||
from httpx import Response | ||
|
||
from retk import const | ||
from retk.core.ai import llm | ||
from . import utils | ||
|
||
|
||
class HunyuanTest(unittest.IsolatedAsyncioTestCase): | ||
@classmethod | ||
def setUpClass(cls): | ||
os.environ["HUNYUAN_SECRET_ID"] = "testid" | ||
os.environ["HUNYUAN_SECRET_KEY"] = "testkey" | ||
utils.set_env(".env.test.local") | ||
|
||
@classmethod | ||
def tearDownClass(cls) -> None: | ||
utils.drop_env(".env.test.local") | ||
|
||
def test_authorization(self): | ||
payload = { | ||
"Model": "hunyuan-lite", | ||
"Messages": [{"Role": "user", "Content": "你是谁"}], | ||
"Stream": False, | ||
"TopP": 0.9, | ||
"Temperature": 0.7, | ||
"EnableEnhancement": False, | ||
} | ||
payload = json.dumps(payload, ensure_ascii=False, separators=(",", ":")).encode("utf-8") | ||
m = llm.hunyuan.HunyuanLite() | ||
auth = m.get_auth( | ||
action="ChatCompletions", | ||
payload=payload, | ||
timestamp=1716913478, | ||
content_type="application/json" | ||
) | ||
sid = os.environ["HUNYUAN_SECRET_ID"] | ||
self.assertEqual( | ||
f"TC3-HMAC-SHA256 Credential={sid}/2024-05-28/hunyuan/tc3_request," | ||
f" SignedHeaders=content-type;host;x-tc-action," | ||
f" Signature=f628e271c4acdf72a4618fe59e3a31591f2ddedbd44e5befe6e02c05949b01b3", | ||
auth) | ||
|
||
async def test_hunyuan_auth_failed(self): | ||
m = llm.hunyuan.HunyuanLite() | ||
self.assertEqual("hunyuan-lite", m.model_name) | ||
text, code = await m.complete([{"Role": "user", "Content": "你是谁"}]) | ||
self.assertEqual(const.CodeEnum.LLM_SERVICE_ERROR, code, msg=text) | ||
self.assertEqual("SecretId不存在,请输入正确的密钥。", text) | ||
|
||
@patch("httpx.AsyncClient.post", new_callable=AsyncMock) | ||
async def test_hunyuan_complete(self, mock_post): | ||
mock_post.return_value = Response( | ||
status_code=200, | ||
json={ | ||
"Response": { | ||
"Choices": [ | ||
{ | ||
"Message": { | ||
"Role": "assistant", | ||
"Content": "我是一个AI助手。" | ||
} | ||
} | ||
] | ||
} | ||
} | ||
) | ||
m = llm.hunyuan.HunyuanLite() | ||
self.assertEqual("hunyuan-lite", m.model_name) | ||
text, code = await m.complete([{"Role": "user", "Content": "你是谁"}]) | ||
self.assertEqual(const.CodeEnum.OK, code, msg=text) | ||
self.assertEqual("我是一个AI助手。", text) | ||
mock_post.assert_called_once() | ||
|
||
# async def test_hunyuan_stream_complete(self): | ||
# m = llm.hunyuan.HunyuanLite() | ||
# | ||
# async for b, code in m.stream_complete([{"Role": "user", "Content": "你是谁"}]): | ||
# self.assertEqual(const.CodeEnum.OK, code) | ||
# s = b.decode("utf-8") | ||
# lines = s.splitlines() | ||
# for line in lines: | ||
# if line.strip() == "": | ||
# continue | ||
# self.assertTrue(line.startswith("data: ")) | ||
# json_str = line[6:] | ||
# json_data = json.loads(json_str) | ||
# self.assertIn("Choices", json_data) | ||
# choices = json_data["Choices"] | ||
# self.assertEqual(1, len(choices)) | ||
# choice = choices[0] | ||
# self.assertIn("Delta", choice) | ||
# delta = choice["Delta"] | ||
# self.assertIn("Content", delta) | ||
# if choice["FinishReason"] == "": | ||
# self.assertGreater(len(delta["Content"]), 0) | ||
# else: | ||
# self.assertEqual("", delta["Content"]) | ||
# self.assertEqual("assistant", delta["Role"]) |