Skip to content

Commit

Permalink
v0.0.4
Browse files Browse the repository at this point in the history
  • Loading branch information
JiauZhang committed May 21, 2024
1 parent a9399ab commit d6908ed
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 25 deletions.
29 changes: 21 additions & 8 deletions chatchat/baidu.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,22 @@
import chatchat.utils as utils
from chatchat.base import Base
import httpx, json, time

class Completion():
def __init__(self, jfile):
class Completion(Base):
def __init__(self, jfile, name='ERNIE-Speed-8K'):
# https://console.bce.baidu.com/qianfan/ais/console/onlineService
self.api_list = {
'ERNIE-Speed-8K': 'ernie_speed',
'ERNIE-Speed-128K': 'ernie-speed-128k',
'ERNIE Speed-AppBuilder': 'ai_apaas',
'ERNIE-Lite-8K': 'ernie-lite-8k',
'ERNIE-Tiny-8K': 'ernie-tiny-8k',
'Yi-34B-Chat': 'yi_34b_chat',
}

if name not in self.api_list:
raise RuntimeError(f'supported chat type: {self.api_list.keys()}')
self.api = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/' + self.api_list[name]

# jfile: https://console.bce.baidu.com/qianfan/ais/console/applicationConsole/application
# {
# "baidu": {
Expand All @@ -13,7 +27,7 @@ def __init__(self, jfile):
# }
# }
self.jfile = jfile
self.jdata = utils.load_json(jfile)['baidu']
self.jdata = self.load_json(jfile)['baidu']
self.update_interval = 3600
self.headers = {
'Content-Type': 'application/json',
Expand Down Expand Up @@ -46,17 +60,16 @@ def update_access_token(self):
r = httpx.post(url, headers=self.headers, params=params).json()
self.jdata['access_token'] = r['access_token']
self.jdata['expires_in'] = cur_time + float(r['expires_in'])
jdata = utils.load_json(self.jfile)
jdata = self.load_json(self.jfile)
jdata.update({'baidu': self.jdata})
utils.write_json(self.jfile, jdata)
self.write_json(self.jfile, jdata)

def get_access_token(self):
self.update_access_token()
return self.jdata['access_token']

def create(self, json):
url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token=" \
+ self.get_access_token()
url = f'{self.api}?access_token={self.get_access_token()}'
r = httpx.request("POST", url, headers=self.headers, json=json)
return r.json()

Expand Down
14 changes: 14 additions & 0 deletions chatchat/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import json

class Base():
def __init__(self):
...

def load_json(self, jfile):
with open(jfile) as jf:
data = json.load(jf)
return data

def write_json(self, jfile, jdata):
with open(jfile, 'w+') as jd:
json.dump(jdata, jd, indent=4)
10 changes: 0 additions & 10 deletions chatchat/utils.py

This file was deleted.

11 changes: 5 additions & 6 deletions chatchat/xunfei.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from wsgiref.handlers import format_date_time
from urllib.parse import urlencode
import chatchat.utils as utils
from chatchat.base import Base
import base64, hashlib, hmac
from datetime import datetime
from time import mktime
import time, websocket, json, ssl
import _thread as thread

class Completion():
class Completion(Base):
def __init__(self, jfile, version='2.0'):
# jfile: https://console.xfyun.cn/services/bm2
# "xunfei": {
Expand All @@ -16,7 +15,7 @@ def __init__(self, jfile, version='2.0'):
# "api_key": "z"
# }
self.jfile = jfile
self.jdata = utils.load_json(jfile)['xunfei']
self.jdata = self.load_json(jfile)['xunfei']
self.update_interval = 150
self.headers = {
'Content-Type': 'application/json',
Expand Down Expand Up @@ -67,9 +66,9 @@ def update_url(self):
url = self.create_url()
self.jdata['url'] = url
self.jdata['expires_in'] = cur_time + 300
jdata = utils.load_json(self.jfile)
jdata = self.load_json(self.jfile)
jdata.update({'xunfei': self.jdata})
utils.write_json(self.jfile, jdata)
self.write_json(self.jfile, jdata)

def get_url(self):
self.update_url()
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'chatchat',
packages = find_packages(exclude=['examples']),
version = '0.0.3',
version = '0.0.4',
license = 'GPL-2.0',
description = 'large language model api',
author = 'JiauZhang',
Expand Down

0 comments on commit d6908ed

Please sign in to comment.