Skip to content

Commit

Permalink
Merge pull request #1 from atomiechen/dev
Browse files Browse the repository at this point in the history
version 0.1.0
  • Loading branch information
atomiechen authored Jun 30, 2023
2 parents c7b2734 + 9bcf859 commit 027d1b0
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 11 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "HandyLLM"
version = "0.0.2"
version = "0.1.0"
authors = [
{ name="Atomie CHEN", email="atomic_cwh@163.com" },
]
Expand Down
105 changes: 99 additions & 6 deletions src/handyllm/openai_api.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,40 @@
import os
import time
import requests
import logging
import json
import copy

from .prompt_converter import PromptConverter

module_logger = logging.getLogger(__name__)

class OpenAIAPI:

base_url = "https://api.openai.com/v1"
api_key = os.environ.get('OPENAI_API_KEY')
organization = None

converter = PromptConverter()

@staticmethod
def api_request(url, api_key, organization=None, timeout=None, **kwargs):
if api_key is None:
raise Exception("OpenAI API key is not set")
if url is None:
raise Exception("OpenAI API url is not set")

## log request info
log_strs = []
# 避免直接打印api_key
plaintext_len = 8
print(f"API request: url={url} api_key={api_key[:plaintext_len]}{'*'*(len(api_key)-plaintext_len)}")
log_strs.append(f"API request {url}")
log_strs.append(f"api_key: {api_key[:plaintext_len]}{'*'*(len(api_key)-plaintext_len)}")
if organization is not None:
log_strs.append(f"organization: {organization[:plaintext_len]}{'*'*(len(organization)-plaintext_len)}")
log_strs.append(f"timeout: {timeout}")
module_logger.info('\n'.join(log_strs))

request_data = kwargs
headers = {
'Authorization': 'Bearer ' + api_key,
Expand All @@ -37,7 +56,9 @@ def api_request(url, api_key, organization=None, timeout=None, **kwargs):
message = response.json()['error']['message']
except:
message = response.text
raise Exception(f"OpenAI API error ({response.status_code} {response.reason}): {message}")
err_msg = f"OpenAI API error ({url} {response.status_code} {response.reason}): {message}"
module_logger.error(err_msg)
raise Exception(err_msg)
return response.json()

@staticmethod
Expand All @@ -53,14 +74,86 @@ def api_request_endpoint(request_url, endpoint_manager=None, **kwargs):
return OpenAIAPI.api_request(url, api_key, organization=organization, **kwargs)

@staticmethod
def chat(timeout=None, endpoint_manager=None, **kwargs):
def chat(timeout=None, endpoint_manager=None, logger=None, log_marks=[], **kwargs):
request_url = '/chat/completions'
return OpenAIAPI.api_request_endpoint(request_url, timeout=timeout, endpoint_manager=endpoint_manager, **kwargs)

if logger is not None and 'messages' in kwargs:
arguments = copy.deepcopy(kwargs)
arguments.pop('messages', None)
input_lines = [str(item) for item in log_marks]
input_lines.append(json.dumps(arguments, indent=2, ensure_ascii=False))
input_lines.append(" INPUT START ".center(50, '-'))
input_lines.append(OpenAIAPI.converter.chat2raw(kwargs['messages']))
input_lines.append(" INPUT END ".center(50, '-')+"\n")
input_str = "\n".join(input_lines)

start_time = time.time()
try:
response = OpenAIAPI.api_request_endpoint(request_url, timeout=timeout, endpoint_manager=endpoint_manager, **kwargs)

if logger is not None:
end_time = time.time()
## log this on result
log_strs = []
log_strs.append(f"Chat request result ({end_time-start_time:.2f}s)")
log_strs.append(input_str)

log_strs.append(" OUTPUT START ".center(50, '-'))
log_strs.append(response['choices'][0]['message']['content'])
log_strs.append(" OUTPUT END ".center(50, '-')+"\n")
logger.info('\n'.join(log_strs))
except Exception as e:
if logger is not None:
end_time = time.time()
log_strs = []
log_strs.append(f"Chat request error ({end_time-start_time:.2f}s)")
log_strs.append(input_str)
log_strs.append(str(e))
logger.error('\n'.join(log_strs))
raise e

return response

@staticmethod
def completions(timeout=None, endpoint_manager=None, **kwargs):
def completions(timeout=None, endpoint_manager=None, logger=None, log_marks=[], **kwargs):
request_url = '/completions'
return OpenAIAPI.api_request_endpoint(request_url, timeout=timeout, endpoint_manager=endpoint_manager, **kwargs)

if logger is not None and 'prompt' in kwargs:
arguments = copy.deepcopy(kwargs)
arguments.pop('prompt', None)
input_lines = [str(item) for item in log_marks]
input_lines.append(json.dumps(arguments, indent=2, ensure_ascii=False))
input_lines.append(" INPUT START ".center(50, '-'))
input_lines.append(kwargs['prompt'])
input_lines.append(" INPUT END ".center(50, '-')+"\n")
input_str = "\n".join(input_lines)

start_time = time.time()
try:
response = OpenAIAPI.api_request_endpoint(request_url, timeout=timeout, endpoint_manager=endpoint_manager, **kwargs)

if logger is not None:
end_time = time.time()
## log this on result
log_strs = []
log_strs.append(f"Completions request result ({end_time-start_time:.2f}s)")
log_strs.append(input_str)

log_strs.append(" OUTPUT START ".center(50, '-'))
log_strs.append(response['choices'][0]['text'])
log_strs.append(" OUTPUT END ".center(50, '-')+"\n")
logger.info('\n'.join(log_strs))
except Exception as e:
if logger is not None:
end_time = time.time()
log_strs = []
log_strs.append(f"Completions request error ({end_time-start_time:.2f}s)")
log_strs.append(input_str)
log_strs.append(str(e))
logger.error('\n'.join(log_strs))
raise e

return response

@staticmethod
def embeddings(timeout=None, endpoint_manager=None, **kwargs):
Expand Down
24 changes: 20 additions & 4 deletions src/handyllm/prompt_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@ def read_substitute_content(self, path: str):
value = blocks[idx+1]
self.substitute_map[key] = value.strip()

def raw2chat(self, raw_prompt_path: str):
with open(raw_prompt_path, 'r', encoding='utf-8') as fin:
raw_prompt = fin.read()

def raw2chat(self, raw_prompt: str):
# substitute pre-defined variables
for key, value in self.substitute_map.items():
raw_prompt = raw_prompt.replace(key, value)
Expand All @@ -35,6 +32,25 @@ def raw2chat(self, raw_prompt_path: str):

return chat

def rawfile2chat(self, raw_prompt_path: str):
with open(raw_prompt_path, 'r', encoding='utf-8') as fin:
raw_prompt = fin.read()

return self.raw2chat(raw_prompt)

def chat2raw(self, chat):
# convert chat format to plain text
messages = []
for message in chat:
messages.append(f"${message['role']}$\n{message['content']}")
raw_prompt = "\n\n".join(messages)
return raw_prompt

def chat2rawfile(self, chat, raw_prompt_path: str):
raw_prompt = self.chat2raw(chat)
with open(raw_prompt_path, 'w', encoding='utf-8') as fout:
fout.write(raw_prompt)

def chat_replace_variables(self, chat, variable_map: dict, inplace=False):
# replace every variable in chat content
if inplace:
Expand Down

0 comments on commit 027d1b0

Please sign in to comment.