Skip to content

Commit

Permalink
fix: support o1 models
Browse files Browse the repository at this point in the history
  • Loading branch information
binary-husky committed Sep 14, 2024
1 parent 0d0575a commit 18290fd
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 13 deletions.
4 changes: 4 additions & 0 deletions request_llms/bridge_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ def decode(self, *args, **kwargs):
"max_token": 128000,
"tokenizer": tokenizer_gpt4,
"token_cnt": get_token_num_gpt4,
"openai_disable_system_prompt": True,
"openai_disable_stream": True,
},
"o1-mini": {
"fn_with_ui": chatgpt_ui,
Expand All @@ -263,6 +265,8 @@ def decode(self, *args, **kwargs):
"max_token": 128000,
"tokenizer": tokenizer_gpt4,
"token_cnt": get_token_num_gpt4,
"openai_disable_system_prompt": True,
"openai_disable_stream": True,
},

"gpt-4-turbo": {
Expand Down
64 changes: 51 additions & 13 deletions request_llms/bridge_chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,22 +133,33 @@ def predict_no_ui_long_connection(inputs:str, llm_kwargs:dict, history:list=[],
observe_window = None:
用于负责跨越线程传递已经输出的部分,大部分时候仅仅为了fancy的视觉效果,留空即可。observe_window[0]:观测窗。observe_window[1]:看门狗
"""
from request_llms.bridge_all import model_info

watch_dog_patience = 5 # 看门狗的耐心, 设置5秒即可
headers, payload = generate_payload(inputs, llm_kwargs, history, system_prompt=sys_prompt, stream=True)

if model_info[llm_kwargs['llm_model']].get('openai_disable_stream', False): stream = False
else: stream = True

headers, payload = generate_payload(inputs, llm_kwargs, history, system_prompt=sys_prompt, stream=stream)
retry = 0
while True:
try:
# make a POST request to the API endpoint, stream=False
from .bridge_all import model_info
endpoint = verify_endpoint(model_info[llm_kwargs['llm_model']]['endpoint'])
response = requests.post(endpoint, headers=headers, proxies=proxies,
json=payload, stream=True, timeout=TIMEOUT_SECONDS); break
json=payload, stream=stream, timeout=TIMEOUT_SECONDS); break
except requests.exceptions.ReadTimeout as e:
retry += 1
traceback.print_exc()
if retry > MAX_RETRY: raise TimeoutError
if MAX_RETRY!=0: print(f'请求超时,正在重试 ({retry}/{MAX_RETRY}) ……')

if not stream:
# 该分支仅适用于不支持stream的o1模型,其他情形一律不适用
chunkjson = json.loads(response.content.decode())
gpt_replying_buffer = chunkjson['choices'][0]["message"]["content"]
return gpt_replying_buffer

stream_response = response.iter_lines()
result = ''
json_data = None
Expand Down Expand Up @@ -208,7 +219,7 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
chatbot 为WebUI中显示的对话列表,修改它,然后yeild出去,可以直接修改对话界面内容
additional_fn代表点击的哪个按钮,按钮见functional.py
"""
from .bridge_all import model_info
from request_llms.bridge_all import model_info
if is_any_api_key(inputs):
chatbot._cookies['api_key'] = inputs
chatbot.append(("输入已识别为openai的api_key", what_keys(inputs)))
Expand Down Expand Up @@ -237,6 +248,10 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
chatbot.append((_inputs, ""))
yield from update_ui(chatbot=chatbot, history=history, msg="等待响应") # 刷新界面

# 禁用stream的特殊模型处理
if model_info[llm_kwargs['llm_model']].get('openai_disable_stream', False): stream = False
else: stream = True

# check mis-behavior
if is_the_upload_folder(user_input):
chatbot[-1] = (inputs, f"[Local Message] 检测到操作错误!当您上传文档之后,需点击“**函数插件区**”按钮进行处理,请勿点击“提交”按钮或者“基础功能区”按钮。")
Expand Down Expand Up @@ -270,18 +285,23 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
try:
# make a POST request to the API endpoint, stream=True
response = requests.post(endpoint, headers=headers, proxies=proxies,
json=payload, stream=True, timeout=TIMEOUT_SECONDS);break
json=payload, stream=stream, timeout=TIMEOUT_SECONDS);break
except:
retry += 1
chatbot[-1] = ((chatbot[-1][0], timeout_bot_msg))
retry_msg = f",正在重试 ({retry}/{MAX_RETRY}) ……" if MAX_RETRY > 0 else ""
yield from update_ui(chatbot=chatbot, history=history, msg="请求超时"+retry_msg) # 刷新界面
if retry > MAX_RETRY: raise TimeoutError

gpt_replying_buffer = ""

is_head_of_the_stream = True
if not stream:
# 该分支仅适用于不支持stream的o1模型,其他情形一律不适用
yield from handle_o1_model_special(response, inputs, llm_kwargs, chatbot, history)
return

if stream:
gpt_replying_buffer = ""
is_head_of_the_stream = True
stream_response = response.iter_lines()
while True:
try:
Expand Down Expand Up @@ -343,12 +363,24 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
chunk_decoded = chunk.decode()
error_msg = chunk_decoded
chatbot, history = handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg)
yield from update_ui(chatbot=chatbot, history=history, msg="Json异常" + error_msg) # 刷新界面
yield from update_ui(chatbot=chatbot, history=history, msg="Json解析异常" + error_msg) # 刷新界面
print(error_msg)
return
return # return from stream-branch

def handle_o1_model_special(response, inputs, llm_kwargs, chatbot, history):
try:
chunkjson = json.loads(response.content.decode())
gpt_replying_buffer = chunkjson['choices'][0]["message"]["content"]
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=gpt_replying_buffer)
history[-1] = gpt_replying_buffer
chatbot[-1] = (history[-2], history[-1])
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
except Exception as e:
yield from update_ui(chatbot=chatbot, history=history, msg="Json解析异常" + response.text) # 刷新界面

def handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg):
from .bridge_all import model_info
from request_llms.bridge_all import model_info
openai_website = ' 请登录OpenAI查看详情 https://platform.openai.com/signup'
if "reduce the length" in error_msg:
if len(history) >= 2: history[-1] = ""; history[-2] = "" # 清除当前溢出的输入:history[-2] 是本次输入, history[-1] 是本次输出
Expand Down Expand Up @@ -381,6 +413,8 @@ def generate_payload(inputs:str, llm_kwargs:dict, history:list, system_prompt:st
"""
整合所有信息,选择LLM模型,生成http请求,为发送请求做准备
"""
from request_llms.bridge_all import model_info

if not is_any_api_key(llm_kwargs['api_key']):
raise AssertionError("你提供了错误的API_KEY。\n\n1. 临时解决方案:直接在输入区键入api_key,然后回车提交。\n\n2. 长效解决方案:在config.py中配置。")

Expand Down Expand Up @@ -409,10 +443,16 @@ def generate_payload(inputs:str, llm_kwargs:dict, history:list, system_prompt:st
else:
enable_multimodal_capacity = False

conversation_cnt = len(history) // 2
openai_disable_system_prompt = model_info[llm_kwargs['llm_model']].get('openai_disable_system_prompt', False)

if openai_disable_system_prompt:
messages = []
else:
messages = [{"role": "system", "content": system_prompt}]

if not enable_multimodal_capacity:
# 不使用多模态能力
conversation_cnt = len(history) // 2
messages = [{"role": "system", "content": system_prompt}]
if conversation_cnt:
for index in range(0, 2*conversation_cnt, 2):
what_i_have_asked = {}
Expand All @@ -434,8 +474,6 @@ def generate_payload(inputs:str, llm_kwargs:dict, history:list, system_prompt:st
messages.append(what_i_ask_now)
else:
# 多模态能力
conversation_cnt = len(history) // 2
messages = [{"role": "system", "content": system_prompt}]
if conversation_cnt:
for index in range(0, 2*conversation_cnt, 2):
what_i_have_asked = {}
Expand Down

0 comments on commit 18290fd

Please sign in to comment.