From 7bb98d56a43baf8bc39bf235f63ba17d283fb391 Mon Sep 17 00:00:00 2001 From: Atomie CHEN Date: Sun, 30 Jul 2023 12:54:40 +0800 Subject: [PATCH 1/5] Feature: upgrade EndpointManager and add Endpoint also add test script --- src/handyllm/__init__.py | 2 +- src/handyllm/endpoint_manager.py | 143 ++++++++++++++++--------------- src/handyllm/openai_api.py | 127 ++++++++++++++------------- src/handyllm/utils.py | 7 ++ tests/test_endpoint.py | 64 ++++++++++++++ 5 files changed, 214 insertions(+), 129 deletions(-) create mode 100644 tests/test_endpoint.py diff --git a/src/handyllm/__init__.py b/src/handyllm/__init__.py index 74cfb78..7e9f65c 100644 --- a/src/handyllm/__init__.py +++ b/src/handyllm/__init__.py @@ -1,3 +1,3 @@ from .openai_api import OpenAIAPI -from .endpoint_manager import EndpointManager +from .endpoint_manager import Endpoint, EndpointManager from .prompt_converter import PromptConverter diff --git a/src/handyllm/endpoint_manager.py b/src/handyllm/endpoint_manager.py index a0813bb..eea11bf 100644 --- a/src/handyllm/endpoint_manager.py +++ b/src/handyllm/endpoint_manager.py @@ -1,81 +1,84 @@ from threading import Lock -from . import OpenAIAPI +from collections.abc import MutableSequence + + +class Endpoint: + def __init__( + self, + name=None, + api_key=None, + organization=None, + api_base=None, + api_type=None, + api_version=None, + ): + self.name = name if name else f"ep_{id(self)}" + self.api_key = api_key + self.organization = organization + self.api_base = api_base + self.api_type = api_type + self.api_version = api_version + + def __str__(self) -> str: + # do not print api_key + listed_attributes = [ + f'name={repr(self.name)}' if self.name else None, + f'api_key=*' if self.api_key else None, + f'organization={repr(self.organization)}' if self.organization else None, + f'api_base={repr(self.api_base)}' if self.api_base else None, + f'api_type={repr(self.api_type)}' if self.api_type else None, + f'api_version={repr(self.api_version)}' if self.api_version else None, + ] + # remove None in listed_attributes + listed_attributes = [item for item in listed_attributes if item] + return f"Endpoint({', '.join(listed_attributes)})" + + def get_api_info(self): + return ( + self.api_key, + self.organization, + self.api_base, + self.api_type, + self.api_version + ) + + +class EndpointManager(MutableSequence): -class EndpointManager: - def __init__(self): self._lock = Lock() + self._last_idx_endpoint = 0 + self._endpoints = [] + + def clear(self): + self._last_idx_endpoint = 0 + self._endpoints.clear() - self._base_urls = [] - self._last_idx_url = 0 - - self._keys = [] - self._last_idx_key = 0 - - self._organizations = [] - self._last_idx_organization = 0 + def __len__(self) -> int: + return len(self._endpoints) - def add_base_url(self, base_url: str): - if isinstance(base_url, str) and base_url.strip() != '': - self._base_urls.append(base_url) + def __getitem__(self, idx: int) -> Endpoint: + return self._endpoints[idx] - def add_key(self, key: str): - if isinstance(key, str) and key.strip() != '': - self._keys.append(key) - - def add_organization(self, organization: str): - if isinstance(organization, str) and organization.strip() != '': - self._organizations.append(organization) - - def set_base_urls(self, base_urls): - self._base_urls = [url for url in base_urls if isinstance(url, str) and url.strip() != ''] - - def set_keys(self, keys): - self._keys = [key for key in keys if isinstance(key, str) and key.strip() != ''] - - def set_organizations(self, organizations): - self._organizations = [organization for organization in organizations if isinstance(organization, str) and organization.strip() != ''] + def __setitem__(self, idx: int, endpoint: Endpoint): + self._endpoints[idx] = endpoint - def get_base_url(self): - if len(self._base_urls) == 0: - return OpenAIAPI.get_api_base() - else: - base_url = self._base_urls[self._last_idx_url] - if self._last_idx_url == len(self._base_urls) - 1: - self._last_idx_url = 0 - else: - self._last_idx_url += 1 - return base_url + def __delitem__(self, idx: int): + del self._endpoints[idx] - def get_key(self): - if len(self._keys) == 0: - return OpenAIAPI.get_api_key() - else: - key = self._keys[self._last_idx_key] - if self._last_idx_key == len(self._keys) - 1: - self._last_idx_key = 0 - else: - self._last_idx_key += 1 - return key - - def get_organization(self): - if len(self._organizations) == 0: - return OpenAIAPI.get_organization() - else: - organization = self._organizations[self._last_idx_organization] - if self._last_idx_organization == len(self._keys) - 1: - self._last_idx_organization = 0 - else: - self._last_idx_organization += 1 - return organization - - def get_endpoint(self): + def insert(self, idx: int, endpoint: Endpoint): + self._endpoints.insert(idx, endpoint) + + def add_endpoint_by_info(self, **kwargs): + endpoint = Endpoint(**kwargs) + self.append(endpoint) + + def get_next_endpoint(self) -> Endpoint: with self._lock: - # compose full url - base_url = self.get_base_url() - # get API key - api_key = self.get_key() - # get organization - organization = self.get_organization() - return base_url, api_key, organization + endpoint = self._endpoints[self._last_idx_endpoint] + if self._last_idx_endpoint == len(self._endpoints) - 1: + self._last_idx_endpoint = 0 + else: + self._last_idx_endpoint += 1 + return endpoint diff --git a/src/handyllm/openai_api.py b/src/handyllm/openai_api.py index a9b885f..a5444cb 100644 --- a/src/handyllm/openai_api.py +++ b/src/handyllm/openai_api.py @@ -6,6 +6,7 @@ import json import copy +from .endpoint_manager import Endpoint, EndpointManager from .prompt_converter import PromptConverter from . import utils @@ -185,30 +186,46 @@ def stream_completions(response): def api_request_endpoint( cls, request_url, - endpoint_manager=None, **kwargs ): api_key, organization, api_base, api_type, api_version, engine = cls.consume_kwargs(kwargs) - if endpoint_manager != None: - api_base, api_key, organization = endpoint_manager.get_endpoint() - url = api_base.rstrip('/') + '/' + request_url.lstrip('/') + url = utils.join_url(api_base, request_url) return cls._api_request(url, api_key, organization=organization, api_type=api_type, **kwargs) @classmethod def consume_kwargs(cls, kwargs): - api_key = cls.get_api_key(kwargs.pop('api_key', None)) - organization = cls.get_organization(kwargs.pop('organization', None)) - api_base = cls.get_api_base(kwargs.pop('api_base', None)) + api_key = organization = api_base = api_type = api_version = engine = None + + # read API info from endpoint_manager + endpoint_manager = kwargs.pop('endpoint_manager', None) + if endpoint_manager is not None: + if not isinstance(endpoint_manager, EndpointManager): + raise Exception("endpoint_manager must be an instance of EndpointManager") + # get_next_endpoint() will be called once for each request + api_key, organization, api_base, api_type, api_version = endpoint_manager.get_next_endpoint().get_api_info() + + # read API info from endpoint (override API info from endpoint_manager) + endpoint = kwargs.pop('endpoint', None) + if endpoint is not None: + if not isinstance(endpoint, Endpoint): + raise Exception("endpoint must be an instance of Endpoint") + api_key, organization, api_base, api_type, api_version = endpoint.get_api_info() + + # read API info from kwargs + api_key = cls.get_api_key(kwargs.pop('api_key', api_key)) + organization = cls.get_organization(kwargs.pop('organization', organization)) + api_base = cls.get_api_base(kwargs.pop('api_base', api_base)) api_type, api_version = cls.get_api_type_and_version( - kwargs.pop('api_type', None), - kwargs.pop('api_version', None) + kwargs.pop('api_type', api_type), + kwargs.pop('api_version', api_version) ) + deployment_id = kwargs.pop('deployment_id', None) engine = kwargs.pop('engine', deployment_id) return api_key, organization, api_base, api_type, api_version, engine @classmethod - def chat(cls, messages, timeout=None, endpoint_manager=None, logger=None, log_marks=[], **kwargs): + def chat(cls, messages, logger=None, log_marks=[], **kwargs): api_key, organization, api_base, api_type, api_version, engine = cls.consume_kwargs(kwargs) if api_type and api_type.lower() in _API_TYPES_AZURE: if engine is None: @@ -239,8 +256,6 @@ def chat(cls, messages, timeout=None, endpoint_manager=None, logger=None, log_ma request_url, messages=messages, method='post', - timeout=timeout, - endpoint_manager=endpoint_manager, api_key=api_key, organization=organization, api_base=api_base, @@ -289,7 +304,7 @@ def wrapper(response): return response @classmethod - def completions(cls, prompt, timeout=None, endpoint_manager=None, logger=None, log_marks=[], **kwargs): + def completions(cls, prompt, logger=None, log_marks=[], **kwargs): api_key, organization, api_base, api_type, api_version, engine = cls.consume_kwargs(kwargs) if api_type and api_type.lower() in _API_TYPES_AZURE: if engine is None: @@ -320,8 +335,6 @@ def completions(cls, prompt, timeout=None, endpoint_manager=None, logger=None, l request_url, prompt=prompt, method='post', - timeout=timeout, - endpoint_manager=endpoint_manager, api_key=api_key, organization=organization, api_base=api_base, @@ -365,12 +378,12 @@ def wrapper(response): return response @classmethod - def edits(cls, timeout=None, endpoint_manager=None, **kwargs): + def edits(cls, **kwargs): request_url = '/edits' - return cls.api_request_endpoint(request_url, method='post', timeout=timeout, endpoint_manager=endpoint_manager, **kwargs) + return cls.api_request_endpoint(request_url, method='post', **kwargs) @classmethod - def embeddings(cls, timeout=None, endpoint_manager=None, **kwargs): + def embeddings(cls, **kwargs): api_key, organization, api_base, api_type, api_version, engine = cls.consume_kwargs(kwargs) if api_type and api_type.lower() in _API_TYPES_AZURE: if engine is None: @@ -384,8 +397,6 @@ def embeddings(cls, timeout=None, endpoint_manager=None, **kwargs): return cls.api_request_endpoint( request_url, method='post', - timeout=timeout, - endpoint_manager=endpoint_manager, api_key=api_key, organization=organization, api_base=api_base, @@ -394,106 +405,106 @@ def embeddings(cls, timeout=None, endpoint_manager=None, **kwargs): ) @classmethod - def models_list(cls, timeout=None, endpoint_manager=None, **kwargs): + def models_list(cls, **kwargs): request_url = '/models' - return cls.api_request_endpoint(request_url, method='get', timeout=timeout, endpoint_manager=endpoint_manager, **kwargs) + return cls.api_request_endpoint(request_url, method='get', **kwargs) @classmethod - def models_retrieve(cls, model, timeout=None, endpoint_manager=None, **kwargs): + def models_retrieve(cls, model, **kwargs): request_url = f'/models/{model}' - return cls.api_request_endpoint(request_url, method='get', timeout=timeout, endpoint_manager=endpoint_manager, **kwargs) + return cls.api_request_endpoint(request_url, method='get', **kwargs) @classmethod - def moderations(cls, timeout=None, endpoint_manager=None, **kwargs): + def moderations(cls, **kwargs): request_url = '/moderations' - return cls.api_request_endpoint(request_url, method='post', timeout=timeout, endpoint_manager=endpoint_manager, **kwargs) + return cls.api_request_endpoint(request_url, method='post', **kwargs) @classmethod - def images_generations(cls, timeout=None, endpoint_manager=None, **kwargs): + def images_generations(cls, **kwargs): request_url = '/images/generations' - return cls.api_request_endpoint(request_url, method='post', timeout=timeout, endpoint_manager=endpoint_manager, **kwargs) + return cls.api_request_endpoint(request_url, method='post', **kwargs) @classmethod - def images_edits(cls, image, mask=None, timeout=None, endpoint_manager=None, **kwargs): + def images_edits(cls, image, mask=None, **kwargs): request_url = '/images/edits' files = { 'image': image } if mask: files['mask'] = mask - return cls.api_request_endpoint(request_url, method='post', files=files, timeout=timeout, endpoint_manager=endpoint_manager, **kwargs) + return cls.api_request_endpoint(request_url, method='post', files=files, **kwargs) @classmethod - def images_variations(cls, image, timeout=None, endpoint_manager=None, **kwargs): + def images_variations(cls, image, **kwargs): request_url = '/images/variations' files = { 'image': image } - return cls.api_request_endpoint(request_url, method='post', files=files, timeout=timeout, endpoint_manager=endpoint_manager, **kwargs) + return cls.api_request_endpoint(request_url, method='post', files=files, **kwargs) @classmethod - def audio_transcriptions(cls, file, timeout=None, endpoint_manager=None, **kwargs): + def audio_transcriptions(cls, file, **kwargs): request_url = '/audio/transcriptions' files = { 'file': file } - return cls.api_request_endpoint(request_url, method='post', files=files, timeout=timeout, endpoint_manager=endpoint_manager, **kwargs) + return cls.api_request_endpoint(request_url, method='post', files=files, **kwargs) @classmethod - def audio_translations(cls, file, timeout=None, endpoint_manager=None, **kwargs): + def audio_translations(cls, file, **kwargs): request_url = '/audio/translations' files = { 'file': file } - return cls.api_request_endpoint(request_url, method='post', files=files, timeout=timeout, endpoint_manager=endpoint_manager, **kwargs) + return cls.api_request_endpoint(request_url, method='post', files=files, **kwargs) @classmethod - def files_list(cls, timeout=None, endpoint_manager=None, **kwargs): + def files_list(cls, **kwargs): request_url = '/files' - return cls.api_request_endpoint(request_url, method='get', timeout=timeout, endpoint_manager=endpoint_manager, **kwargs) + return cls.api_request_endpoint(request_url, method='get', **kwargs) @classmethod - def files_upload(cls, file, timeout=None, endpoint_manager=None, **kwargs): + def files_upload(cls, file, **kwargs): request_url = '/files' files = { 'file': file } - return cls.api_request_endpoint(request_url, method='post', files=files, timeout=timeout, endpoint_manager=endpoint_manager, **kwargs) + return cls.api_request_endpoint(request_url, method='post', files=files, **kwargs) @classmethod - def files_delete(cls, file_id, timeout=None, endpoint_manager=None, **kwargs): + def files_delete(cls, file_id, **kwargs): request_url = f'/files/{file_id}' - return cls.api_request_endpoint(request_url, method='delete', timeout=timeout, endpoint_manager=endpoint_manager, **kwargs) + return cls.api_request_endpoint(request_url, method='delete', **kwargs) @classmethod - def files_retrieve(cls, file_id, timeout=None, endpoint_manager=None, **kwargs): + def files_retrieve(cls, file_id, **kwargs): request_url = f'/files/{file_id}' - return cls.api_request_endpoint(request_url, method='get', timeout=timeout, endpoint_manager=endpoint_manager, **kwargs) + return cls.api_request_endpoint(request_url, method='get', **kwargs) @classmethod - def files_retrieve_content(cls, file_id, timeout=None, endpoint_manager=None, **kwargs): + def files_retrieve_content(cls, file_id, **kwargs): request_url = f'/files/{file_id}/content' - return cls.api_request_endpoint(request_url, method='get', timeout=timeout, endpoint_manager=endpoint_manager, **kwargs) + return cls.api_request_endpoint(request_url, method='get', **kwargs) @classmethod - def finetunes_create(cls, timeout=None, endpoint_manager=None, **kwargs): + def finetunes_create(cls, **kwargs): request_url = '/fine-tunes' - return cls.api_request_endpoint(request_url, method='post', timeout=timeout, endpoint_manager=endpoint_manager, **kwargs) + return cls.api_request_endpoint(request_url, method='post', **kwargs) @classmethod - def finetunes_list(cls, timeout=None, endpoint_manager=None, **kwargs): + def finetunes_list(cls, **kwargs): request_url = '/fine-tunes' - return cls.api_request_endpoint(request_url, method='get', timeout=timeout, endpoint_manager=endpoint_manager, **kwargs) + return cls.api_request_endpoint(request_url, method='get', **kwargs) @classmethod - def finetunes_retrieve(cls, fine_tune_id, timeout=None, endpoint_manager=None, **kwargs): + def finetunes_retrieve(cls, fine_tune_id, **kwargs): request_url = f'/fine-tunes/{fine_tune_id}' - return cls.api_request_endpoint(request_url, method='get', timeout=timeout, endpoint_manager=endpoint_manager, **kwargs) + return cls.api_request_endpoint(request_url, method='get', **kwargs) @classmethod - def finetunes_cancel(cls, fine_tune_id, timeout=None, endpoint_manager=None, **kwargs): + def finetunes_cancel(cls, fine_tune_id, **kwargs): request_url = f'/fine-tunes/{fine_tune_id}/cancel' - return cls.api_request_endpoint(request_url, method='post', timeout=timeout, endpoint_manager=endpoint_manager, **kwargs) + return cls.api_request_endpoint(request_url, method='post', **kwargs) @classmethod - def finetunes_list_events(cls, fine_tune_id, timeout=None, endpoint_manager=None, **kwargs): + def finetunes_list_events(cls, fine_tune_id, **kwargs): request_url = f'/fine-tunes/{fine_tune_id}/events' - return cls.api_request_endpoint(request_url, method='get', timeout=timeout, endpoint_manager=endpoint_manager, **kwargs) + return cls.api_request_endpoint(request_url, method='get', **kwargs) @classmethod - def finetunes_delete_model(cls, model, timeout=None, endpoint_manager=None, **kwargs): + def finetunes_delete_model(cls, model, **kwargs): request_url = f'/models/{model}' - return cls.api_request_endpoint(request_url, method='delete', timeout=timeout, endpoint_manager=endpoint_manager, **kwargs) + return cls.api_request_endpoint(request_url, method='delete', **kwargs) if __name__ == '__main__': diff --git a/src/handyllm/utils.py b/src/handyllm/utils.py index c2f01d0..a2f464b 100644 --- a/src/handyllm/utils.py +++ b/src/handyllm/utils.py @@ -29,3 +29,10 @@ def isiterable(arg): isinstance(arg, collections.abc.Iterable) and not isinstance(arg, str) ) + +def join_url(base_url, *args): + url = base_url.rstrip('/') + for arg in args: + url += '/' + arg.lstrip('/') + return url + diff --git a/tests/test_endpoint.py b/tests/test_endpoint.py new file mode 100644 index 0000000..9dfcc9b --- /dev/null +++ b/tests/test_endpoint.py @@ -0,0 +1,64 @@ +from handyllm import OpenAIAPI, EndpointManager, Endpoint + +from dotenv import load_dotenv, find_dotenv +# load env parameters from file named .env +load_dotenv(find_dotenv()) + +import os + +## EndpointManager acts like a list +endpoint_manager = EndpointManager() + +endpoint_manager.add_endpoint_by_info( + api_key=os.environ.get('OPENAI_API_KEY'), +) +endpoint2 = Endpoint( + name='endpoint2', # name is not required + api_key=os.environ.get('OPENAI_API_KEY'), +) +endpoint_manager.append(endpoint2) + +assert isinstance(endpoint_manager[0], Endpoint) +assert endpoint2 == endpoint_manager[1] +print(f"total endpoints: {len(endpoint_manager)}") + +for endpoint in endpoint_manager: + print(endpoint) + # print(endpoint.get_api_info()) # WARNING: print endpoint info including api_key + + +# ----- EXAMPLE 1 ----- + +prompt = [{ + "role": "user", + "content": "please tell me a joke" + }] +response = OpenAIAPI.chat( + model="gpt-3.5-turbo", + messages=prompt, + temperature=0.2, + max_tokens=256, + top_p=1.0, + frequency_penalty=0.0, + presence_penalty=0.0, + timeout=10, + endpoint_manager=endpoint_manager + ) +print(response['choices'][0]['message']['content']) + + +print() +print("-----") + + +# ----- EXAMPLE 2 ----- + +response = OpenAIAPI.completions( + model="text-davinci-002", + prompt="count to 23 and stop: 1,2,3,", + timeout=10, + max_tokens=256, + echo=True, # Echo back the prompt in addition to the completion + endpoint=endpoint2 +) +print(response['choices'][0]['text']) From 355b0aae65157a9d5a56a01b0336ba45cbaf77d8 Mon Sep 17 00:00:00 2001 From: Atomie CHEN Date: Sun, 30 Jul 2023 14:50:39 +0800 Subject: [PATCH 2/5] Minor improvements --- src/handyllm/openai_api.py | 2 +- tests/test_azure.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/handyllm/openai_api.py b/src/handyllm/openai_api.py index a5444cb..8304c31 100644 --- a/src/handyllm/openai_api.py +++ b/src/handyllm/openai_api.py @@ -211,7 +211,7 @@ def consume_kwargs(cls, kwargs): raise Exception("endpoint must be an instance of Endpoint") api_key, organization, api_base, api_type, api_version = endpoint.get_api_info() - # read API info from kwargs + # read API info from kwargs, class variables, and environment variables api_key = cls.get_api_key(kwargs.pop('api_key', api_key)) organization = cls.get_organization(kwargs.pop('organization', organization)) api_base = cls.get_api_base(kwargs.pop('api_base', api_base)) diff --git a/tests/test_azure.py b/tests/test_azure.py index d83be08..0fbc42d 100644 --- a/tests/test_azure.py +++ b/tests/test_azure.py @@ -11,7 +11,7 @@ OpenAIAPI.api_type = 'azure' OpenAIAPI.api_base = os.getenv("AZURE_OPENAI_ENDPOINT") OpenAIAPI.api_key = os.getenv("AZURE_OPENAI_KEY") -OpenAIAPI.api_version = '2023-05-15' +OpenAIAPI.api_version = '2023-05-15' # can be None and default value will be used # ----- EXAMPLE 1 ----- From 4c01a6a1582196ab020f9f12aedc4a21a5fe43cd Mon Sep 17 00:00:00 2001 From: Atomie CHEN Date: Sun, 30 Jul 2023 14:50:46 +0800 Subject: [PATCH 3/5] Update README --- README.md | 37 ++++++++++++++++++------------------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index aad6345..268b3fc 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,24 @@ Example scripts are placed in [tests](./tests) folder. ## OpenAI API Request +### Endpoints + +Each API request will connect to an endpoint along with some API configurations, which include: `api_key`, `organization`, `api_base`, `api_type` and `api_version`. + +An `Endpoint` object contains these information. An `EndpointManager` acts like a list and can be used to rotate the next endpoint. See [test_endpoint.py](./tests/test_endpoint.py). + +There are 5 methods for specifying endpoint info: + +1. (each API call) Pass these fields as keyword parameters. +2. (each API call) Pass an `endpoint` keyword parameter to specify an `Endpoint`. +3. (each API call) Pass an `endpoint_manager` keyword parameter to specify an `EndpointManager`. +4. (global) Set class variables: `OpenAIAPI.api_base`, `OpenAIAPI.api_key`, `OpenAIAPI.organization`, `OpenAIAPI.api_type`, `OpenAIAPI.api_version`. +5. (global) Set environment variables: `OPENAI_API_KEY`, `OPENAI_ORGANIZATION`, `OPENAI_API_BASE`, `OPENAI_API_TYPE`, `OPENAI_API_VERSION`. + +**Note**: If a field is set to `None` in the previous method, it will be replaced by the non-`None` value in the subsequent method, until a default value is used (OpenAI's endpoint information). + +**Azure OpenAI APIs are supported:** Specify `api_type='azure'`, and set `api_base` and `api_key` accordingly. See [test_azure.py](./tests/test_azure.py). Please refer to [Azure OpenAI Service Documentation](https://learn.microsoft.com/en-us/azure/ai-services/openai/) for details. + ### Logger You can pass custom `logger` and `log_marks` (a string or a collection of strings) to `chat`/`completions` to get input and output logging. @@ -50,17 +68,6 @@ response = OpenAIAPI.chat( print(response['choices'][0]['message']['content']) ``` -### Authorization - -API key and organization will be loaded using the environment variable `OPENAI_API_KEY` and `OPENAI_ORGANIZATION`, or you can set manually: - -```python -OpenAIAPI.api_key = 'sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx' -OpenAIAPI.organization = '......' # default: None -``` - -Or, you can pass `api_key` and `organization` parameters in each API call. - ### Stream response Stream response of `chat`/`completions`/`finetunes_list_events` can be achieved using `steam` parameter: @@ -111,14 +118,6 @@ for text in OpenAIAPI.stream_chat(response): Please refer to [OpenAI official API reference](https://platform.openai.com/docs/api-reference) for details. -### Azure - -**Azure OpenAI APIs are supported!** - -Refer to [test_azure.py](./tests/test_azure.py) for example usage. - -Refer to [Azure OpenAI Service Documentation](https://learn.microsoft.com/en-us/azure/ai-services/openai/). - ## Prompt From e4f859da5613c199c621b0cd5ef8543e679aaf5e Mon Sep 17 00:00:00 2001 From: Atomie CHEN Date: Sun, 30 Jul 2023 14:56:48 +0800 Subject: [PATCH 4/5] Update endpoint test script --- tests/test_endpoint.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/tests/test_endpoint.py b/tests/test_endpoint.py index 9dfcc9b..6748bc3 100644 --- a/tests/test_endpoint.py +++ b/tests/test_endpoint.py @@ -3,6 +3,7 @@ from dotenv import load_dotenv, find_dotenv # load env parameters from file named .env load_dotenv(find_dotenv()) +load_dotenv(find_dotenv('azure.env')) import os @@ -13,8 +14,11 @@ api_key=os.environ.get('OPENAI_API_KEY'), ) endpoint2 = Endpoint( - name='endpoint2', # name is not required - api_key=os.environ.get('OPENAI_API_KEY'), + name='azure', # name is not required + api_type='azure', + api_base=os.getenv("AZURE_OPENAI_ENDPOINT"), + api_key=os.getenv("AZURE_OPENAI_KEY"), + api_version='2023-05-15' # can be None and default value will be used ) endpoint_manager.append(endpoint2) @@ -53,12 +57,11 @@ # ----- EXAMPLE 2 ----- -response = OpenAIAPI.completions( - model="text-davinci-002", - prompt="count to 23 and stop: 1,2,3,", +response = OpenAIAPI.chat( + deployment_id="initial_deployment", + messages=prompt, timeout=10, max_tokens=256, - echo=True, # Echo back the prompt in addition to the completion endpoint=endpoint2 ) -print(response['choices'][0]['text']) +print(response['choices'][0]['message']['content']) From 86f2d724658d13fd9d7d368fa74af0ef6f56b554 Mon Sep 17 00:00:00 2001 From: Atomie CHEN Date: Sun, 30 Jul 2023 15:09:55 +0800 Subject: [PATCH 5/5] Bump version to 0.5.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e82ab02..a6fdb4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "HandyLLM" -version = "0.4.0" +version = "0.5.0" authors = [ { name="Atomie CHEN", email="atomic_cwh@163.com" }, ]