From e2c66d0f7a8eb83aab5bfdcee21f3fd1e19b8566 Mon Sep 17 00:00:00 2001 From: Atomie CHEN Date: Fri, 30 Jun 2023 00:25:09 +0800 Subject: [PATCH 1/4] Add logger feature; convert chat to raw format --- src/handyllm/openai_api.py | 101 +++++++++++++++++++++++++++++-- src/handyllm/prompt_converter.py | 13 ++++ 2 files changed, 109 insertions(+), 5 deletions(-) diff --git a/src/handyllm/openai_api.py b/src/handyllm/openai_api.py index c81f8f3..254ad57 100644 --- a/src/handyllm/openai_api.py +++ b/src/handyllm/openai_api.py @@ -1,11 +1,21 @@ import os +import time import requests +import logging +import json +import copy + +from .prompt_converter import PromptConverter + +pkg_logger = logging.getLogger(__name__) class OpenAIAPI: base_url = "https://api.openai.com/v1" api_key = os.environ.get('OPENAI_API_KEY') organization = None + + converter = PromptConverter() @staticmethod def api_request(url, api_key, organization=None, timeout=None, **kwargs): @@ -13,9 +23,18 @@ def api_request(url, api_key, organization=None, timeout=None, **kwargs): raise Exception("OpenAI API key is not set") if url is None: raise Exception("OpenAI API url is not set") + + ## log request info + log_strs = [] # 避免直接打印api_key plaintext_len = 8 - print(f"API request: url={url} api_key={api_key[:plaintext_len]}{'*'*(len(api_key)-plaintext_len)}") + log_strs.append(f"API request {url}") + log_strs.append(f"api_key: {api_key[:plaintext_len]}{'*'*(len(api_key)-plaintext_len)}") + if organization is not None: + log_strs.append(f"organization={organization[:plaintext_len]}{'*'*(len(organization)-plaintext_len)}") + log_strs.append(f"timeout: {timeout}") + pkg_logger.info('\n'.join(log_strs)) + request_data = kwargs headers = { 'Authorization': 'Bearer ' + api_key, @@ -53,14 +72,86 @@ def api_request_endpoint(request_url, endpoint_manager=None, **kwargs): return OpenAIAPI.api_request(url, api_key, organization=organization, **kwargs) @staticmethod - def chat(timeout=None, endpoint_manager=None, **kwargs): + def chat(timeout=None, endpoint_manager=None, logger=None, log_marks=[], **kwargs): request_url = '/chat/completions' - return OpenAIAPI.api_request_endpoint(request_url, timeout=timeout, endpoint_manager=endpoint_manager, **kwargs) + + if logger is not None and 'messages' in kwargs: + arguments = copy.deepcopy(kwargs) + arguments.pop('messages', None) + input_lines = [str(item) for item in log_marks] + input_lines.append(json.dumps(arguments, indent=2, ensure_ascii=False)) + input_lines.append(" INPUT START ".center(50, '-')) + input_lines.append(OpenAIAPI.converter.chat2raw(kwargs['messages'])) + input_lines.append(" INPUT END ".center(50, '-')+"\n") + input_str = "\n".join(input_lines) + + start_time = time.time() + try: + response = OpenAIAPI.api_request_endpoint(request_url, timeout=timeout, endpoint_manager=endpoint_manager, **kwargs) + + if logger is not None: + end_time = time.time() + ## log this on result + log_strs = [] + log_strs.append(f"Chat request result ({end_time-start_time:.2f}s)") + log_strs.append(input_str) + + log_strs.append(" OUTPUT START ".center(50, '-')) + log_strs.append(response['choices'][0]['message']['content']) + log_strs.append(" OUTPUT END ".center(50, '-')+"\n") + logger.info('\n'.join(log_strs)) + except Exception as e: + if logger is not None: + end_time = time.time() + log_strs = [] + log_strs.append(f"Chat request error ({end_time-start_time:.2f}s)") + log_strs.append(input_str) + log_strs.append(str(e)) + logger.error('\n'.join(log_strs)) + raise e + + return response @staticmethod - def completions(timeout=None, endpoint_manager=None, **kwargs): + def completions(timeout=None, endpoint_manager=None, logger=None, log_marks=[], **kwargs): request_url = '/completions' - return OpenAIAPI.api_request_endpoint(request_url, timeout=timeout, endpoint_manager=endpoint_manager, **kwargs) + + if logger is not None and 'prompt' in kwargs: + arguments = copy.deepcopy(kwargs) + arguments.pop('prompt', None) + input_lines = [str(item) for item in log_marks] + input_lines.append(json.dumps(arguments, indent=2, ensure_ascii=False)) + input_lines.append(" INPUT START ".center(50, '-')) + input_lines.append(kwargs['prompt']) + input_lines.append(" INPUT END ".center(50, '-')+"\n") + input_str = "\n".join(input_lines) + + start_time = time.time() + try: + response = OpenAIAPI.api_request_endpoint(request_url, timeout=timeout, endpoint_manager=endpoint_manager, **kwargs) + + if logger is not None: + end_time = time.time() + ## log this on result + log_strs = [] + log_strs.append(f"Completions request result ({end_time-start_time:.2f}s)") + log_strs.append(input_str) + + log_strs.append(" OUTPUT START ".center(50, '-')) + log_strs.append(response['choices'][0]['text']) + log_strs.append(" OUTPUT END ".center(50, '-')+"\n") + logger.info('\n'.join(log_strs)) + except Exception as e: + if logger is not None: + end_time = time.time() + log_strs = [] + log_strs.append(f"Completions request error ({end_time-start_time:.2f}s)") + log_strs.append(input_str) + log_strs.append(str(e)) + logger.error('\n'.join(log_strs)) + raise e + + return response @staticmethod def embeddings(timeout=None, endpoint_manager=None, **kwargs): diff --git a/src/handyllm/prompt_converter.py b/src/handyllm/prompt_converter.py index 51478cd..b9f7b58 100644 --- a/src/handyllm/prompt_converter.py +++ b/src/handyllm/prompt_converter.py @@ -35,6 +35,19 @@ def raw2chat(self, raw_prompt_path: str): return chat + def chat2raw(self, chat, raw_prompt_path: str=None): + # convert chat format to plain text + messages = [] + for message in chat: + messages.append(f"${message['role']}$\n{message['content']}") + raw_prompt = "\n\n".join(messages) + + if raw_prompt_path is not None: + with open(raw_prompt_path, 'w', encoding='utf-8') as fout: + fout.write(raw_prompt) + + return raw_prompt + def chat_replace_variables(self, chat, variable_map: dict, inplace=False): # replace every variable in chat content if inplace: From cd0748e8893ed39518a1a1f9a5c71b632a358b0d Mon Sep 17 00:00:00 2001 From: Atomie CHEN Date: Fri, 30 Jun 2023 00:53:43 +0800 Subject: [PATCH 2/4] improve logging --- src/handyllm/openai_api.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/handyllm/openai_api.py b/src/handyllm/openai_api.py index 254ad57..01c9f3a 100644 --- a/src/handyllm/openai_api.py +++ b/src/handyllm/openai_api.py @@ -7,7 +7,7 @@ from .prompt_converter import PromptConverter -pkg_logger = logging.getLogger(__name__) +module_logger = logging.getLogger(__name__) class OpenAIAPI: @@ -31,9 +31,9 @@ def api_request(url, api_key, organization=None, timeout=None, **kwargs): log_strs.append(f"API request {url}") log_strs.append(f"api_key: {api_key[:plaintext_len]}{'*'*(len(api_key)-plaintext_len)}") if organization is not None: - log_strs.append(f"organization={organization[:plaintext_len]}{'*'*(len(organization)-plaintext_len)}") + log_strs.append(f"organization: {organization[:plaintext_len]}{'*'*(len(organization)-plaintext_len)}") log_strs.append(f"timeout: {timeout}") - pkg_logger.info('\n'.join(log_strs)) + module_logger.info('\n'.join(log_strs)) request_data = kwargs headers = { @@ -56,7 +56,9 @@ def api_request(url, api_key, organization=None, timeout=None, **kwargs): message = response.json()['error']['message'] except: message = response.text - raise Exception(f"OpenAI API error ({response.status_code} {response.reason}): {message}") + err_msg = f"OpenAI API error ({url} {response.status_code} {response.reason}): {message}" + module_logger.error(err_msg) + raise Exception(err_msg) return response.json() @staticmethod From f2e2525f8aaf7d3f388054c498427757084adafc Mon Sep 17 00:00:00 2001 From: Atomie CHEN Date: Fri, 30 Jun 2023 11:05:30 +0800 Subject: [PATCH 3/4] rename chat/raw conversion API --- src/handyllm/prompt_converter.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/handyllm/prompt_converter.py b/src/handyllm/prompt_converter.py index b9f7b58..8dd1b55 100644 --- a/src/handyllm/prompt_converter.py +++ b/src/handyllm/prompt_converter.py @@ -17,10 +17,7 @@ def read_substitute_content(self, path: str): value = blocks[idx+1] self.substitute_map[key] = value.strip() - def raw2chat(self, raw_prompt_path: str): - with open(raw_prompt_path, 'r', encoding='utf-8') as fin: - raw_prompt = fin.read() - + def raw2chat(self, raw_prompt: str): # substitute pre-defined variables for key, value in self.substitute_map.items(): raw_prompt = raw_prompt.replace(key, value) @@ -35,19 +32,25 @@ def raw2chat(self, raw_prompt_path: str): return chat - def chat2raw(self, chat, raw_prompt_path: str=None): + def rawfile2chat(self, raw_prompt_path: str): + with open(raw_prompt_path, 'r', encoding='utf-8') as fin: + raw_prompt = fin.read() + + return self.raw2chat(raw_prompt) + + def chat2raw(self, chat): # convert chat format to plain text messages = [] for message in chat: messages.append(f"${message['role']}$\n{message['content']}") raw_prompt = "\n\n".join(messages) - - if raw_prompt_path is not None: - with open(raw_prompt_path, 'w', encoding='utf-8') as fout: - fout.write(raw_prompt) - return raw_prompt + def chat2rawfile(self, chat, raw_prompt_path: str): + raw_prompt = self.chat2raw(chat) + with open(raw_prompt_path, 'w', encoding='utf-8') as fout: + fout.write(raw_prompt) + def chat_replace_variables(self, chat, variable_map: dict, inplace=False): # replace every variable in chat content if inplace: From 9bcf859afe1b0d2dcb04aabc727e93b5be612910 Mon Sep 17 00:00:00 2001 From: Atomie CHEN Date: Fri, 30 Jun 2023 13:15:46 +0800 Subject: [PATCH 4/4] bump version to 0.1.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b035c89..f0a73eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "HandyLLM" -version = "0.0.2" +version = "0.1.0" authors = [ { name="Atomie CHEN", email="atomic_cwh@163.com" }, ]