Skip to content

Commit

Permalink
Merge pull request #346 from n3d1117/feature/support-functions
Browse files Browse the repository at this point in the history
Support functions (aka plugins)
n3d1117 authored Aug 4, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
2 parents 3d2231a + 3e4c228 commit 30d441d
Showing 23 changed files with 1,429 additions and 142 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -4,3 +4,4 @@ __pycache__
.DS_Store
/usage_logs
venv
/.cache
115 changes: 79 additions & 36 deletions README.md

Large diffs are not rendered by default.

21 changes: 16 additions & 5 deletions bot/main.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,8 @@

from dotenv import load_dotenv

from openai_helper import OpenAIHelper, default_max_tokens
from plugin_manager import PluginManager
from openai_helper import OpenAIHelper, default_max_tokens, are_functions_available
from telegram_bot import ChatGPTTelegramBot


@@ -27,6 +28,7 @@ def main():

# Setup configurations
model = os.environ.get('OPENAI_MODEL', 'gpt-3.5-turbo')
functions_available = are_functions_available(model=model)
max_tokens_default = default_max_tokens(model=model)
openai_config = {
'api_key': os.environ['OPENAI_API_KEY'],
@@ -41,14 +43,18 @@ def main():
'temperature': float(os.environ.get('TEMPERATURE', 1.0)),
'image_size': os.environ.get('IMAGE_SIZE', '512x512'),
'model': model,
'enable_functions': os.environ.get('ENABLE_FUNCTIONS', str(functions_available)).lower() == 'true',
'functions_max_consecutive_calls': int(os.environ.get('FUNCTIONS_MAX_CONSECUTIVE_CALLS', 10)),
'presence_penalty': float(os.environ.get('PRESENCE_PENALTY', 0.0)),
'frequency_penalty': float(os.environ.get('FREQUENCY_PENALTY', 0.0)),
'bot_language': os.environ.get('BOT_LANGUAGE', 'en'),
'show_plugins_used': os.environ.get('SHOW_PLUGINS_USED', 'false').lower() == 'true',
}

# log deprecation warning for old budget variable names
# old variables are caught in the telegram_config definition for now
# remove support for old budget names at some point in the future
if openai_config['enable_functions'] and not functions_available:
logging.error(f'ENABLE_FUNCTIONS is set to true, but the model {model} does not support it. '
f'Please set ENABLE_FUNCTIONS to false or use a model that supports it.')
exit(1)
if os.environ.get('MONTHLY_USER_BUDGETS') is not None:
logging.warning('The environment variable MONTHLY_USER_BUDGETS is deprecated. '
'Please use USER_BUDGETS with BUDGET_PERIOD instead.')
@@ -78,8 +84,13 @@ def main():
'bot_language': os.environ.get('BOT_LANGUAGE', 'en'),
}

plugin_config = {
'plugins': os.environ.get('PLUGINS', '').split(',')
}

# Setup and run ChatGPT and Telegram bot
openai_helper = OpenAIHelper(config=openai_config)
plugin_manager = PluginManager(config=plugin_config)
openai_helper = OpenAIHelper(config=openai_config, plugin_manager=plugin_manager)
telegram_bot = ChatGPTTelegramBot(config=telegram_config, openai=openai_helper)
telegram_bot.run()

134 changes: 122 additions & 12 deletions bot/openai_helper.py
Original file line number Diff line number Diff line change
@@ -14,6 +14,9 @@

from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_exception_type

from utils import is_direct_result
from plugin_manager import PluginManager

# Models can be found here: https://platform.openai.com/docs/models/overview
GPT_3_MODELS = ("gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613")
GPT_3_16K_MODELS = ("gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613")
@@ -39,6 +42,19 @@ def default_max_tokens(model: str) -> int:
return base * 8


def are_functions_available(model: str) -> bool:
"""
Whether the given model supports functions
"""
# Deprecated models
if model in ("gpt-3.5-turbo-0301", "gpt-4-0314", "gpt-4-32k-0314"):
return False
# Stable models will be updated to support functions on June 27, 2023
if model in ("gpt-3.5-turbo", "gpt-4", "gpt-4-32k"):
return datetime.date.today() > datetime.date(2023, 6, 27)
return True


# Load translations
parent_dir_path = os.path.join(os.path.dirname(__file__), os.pardir)
translations_file_path = os.path.join(parent_dir_path, 'translations.json')
@@ -69,14 +85,16 @@ class OpenAIHelper:
ChatGPT helper class.
"""

def __init__(self, config: dict):
def __init__(self, config: dict, plugin_manager: PluginManager):
"""
Initializes the OpenAI helper class with the given configuration.
:param config: A dictionary containing the GPT configuration
:param plugin_manager: The plugin manager
"""
openai.api_key = config['api_key']
openai.proxy = config['proxy']
self.config = config
self.plugin_manager = plugin_manager
self.conversations: dict[int: list] = {} # {chat_id: history}
self.last_updated: dict[int: datetime] = {} # {chat_id: last_update_timestamp}

@@ -97,7 +115,13 @@ async def get_chat_response(self, chat_id: int, query: str) -> tuple[str, str]:
:param query: The query to send to the model
:return: The answer from the model and the number of tokens used
"""
plugins_used = ()
response = await self.__common_get_chat_response(chat_id, query)
if self.config['enable_functions']:
response, plugins_used = await self.__handle_function_call(chat_id, response)
if is_direct_result(response):
return response, '0'

answer = ''

if len(response.choices) > 1 and self.config['n_choices'] > 1:
@@ -113,11 +137,17 @@ async def get_chat_response(self, chat_id: int, query: str) -> tuple[str, str]:
self.__add_to_history(chat_id, role="assistant", content=answer)

bot_language = self.config['bot_language']
show_plugins_used = len(plugins_used) > 0 and self.config['show_plugins_used']
plugin_names = tuple(self.plugin_manager.get_plugin_source_name(plugin) for plugin in plugins_used)
if self.config['show_usage']:
answer += "\n\n---\n" \
f"💰 {str(response.usage['total_tokens'])} {localized_text('stats_tokens', bot_language)}" \
f" ({str(response.usage['prompt_tokens'])} {localized_text('prompt', bot_language)}," \
f" {str(response.usage['completion_tokens'])} {localized_text('completion', bot_language)})"
if show_plugins_used:
answer += f"\n🔌 {', '.join(plugin_names)}"
elif show_plugins_used:
answer += f"\n\n---\n🔌 {', '.join(plugin_names)}"

return answer, response.usage['total_tokens']

@@ -128,22 +158,34 @@ async def get_chat_response_stream(self, chat_id: int, query: str):
:param query: The query to send to the model
:return: The answer from the model and the number of tokens used, or 'not_finished'
"""
plugins_used = ()
response = await self.__common_get_chat_response(chat_id, query, stream=True)
if self.config['enable_functions']:
response, plugins_used = await self.__handle_function_call(chat_id, response, stream=True)
if is_direct_result(response):
yield response, '0'
return

answer = ''
async for item in response:
if 'choices' not in item or len(item.choices) == 0:
continue
delta = item.choices[0].delta
if 'content' in delta:
if 'content' in delta and delta.content is not None:
answer += delta.content
yield answer, 'not_finished'
answer = answer.strip()
self.__add_to_history(chat_id, role="assistant", content=answer)
tokens_used = str(self.__count_tokens(self.conversations[chat_id]))

show_plugins_used = len(plugins_used) > 0 and self.config['show_plugins_used']
plugin_names = tuple(self.plugin_manager.get_plugin_source_name(plugin) for plugin in plugins_used)
if self.config['show_usage']:
answer += f"\n\n---\n💰 {tokens_used} {localized_text('stats_tokens', self.config['bot_language'])}"
if show_plugins_used:
answer += f"\n🔌 {', '.join(plugin_names)}"
elif show_plugins_used:
answer += f"\n\n---\n🔌 {', '.join(plugin_names)}"

yield answer, tokens_used

@@ -186,16 +228,24 @@ async def __common_get_chat_response(self, chat_id: int, query: str, stream=Fals
logging.warning(f'Error while summarising chat history: {str(e)}. Popping elements instead...')
self.conversations[chat_id] = self.conversations[chat_id][-self.config['max_history_size']:]

return await openai.ChatCompletion.acreate(
model=self.config['model'],
messages=self.conversations[chat_id],
temperature=self.config['temperature'],
n=self.config['n_choices'],
max_tokens=self.config['max_tokens'],
presence_penalty=self.config['presence_penalty'],
frequency_penalty=self.config['frequency_penalty'],
stream=stream
)
common_args = {
'model': self.config['model'],
'messages': self.conversations[chat_id],
'temperature': self.config['temperature'],
'n': self.config['n_choices'],
'max_tokens': self.config['max_tokens'],
'presence_penalty': self.config['presence_penalty'],
'frequency_penalty': self.config['frequency_penalty'],
'stream': stream
}

if self.config['enable_functions']:
functions = self.plugin_manager.get_functions_specs()
if len(functions) > 0:
common_args['functions'] = self.plugin_manager.get_functions_specs()
common_args['function_call'] = 'auto'

return await openai.ChatCompletion.acreate(**common_args)

except openai.error.RateLimitError as e:
raise e
@@ -206,6 +256,60 @@ async def __common_get_chat_response(self, chat_id: int, query: str, stream=Fals
except Exception as e:
raise Exception(f"⚠️ _{localized_text('error', bot_language)}._ ⚠️\n{str(e)}") from e

async def __handle_function_call(self, chat_id, response, stream=False, times=0, plugins_used=()):
function_name = ''
arguments = ''
if stream:
async for item in response:
if 'choices' in item and len(item.choices) > 0:
first_choice = item.choices[0]
if 'delta' in first_choice \
and 'function_call' in first_choice.delta:
if 'name' in first_choice.delta.function_call:
function_name += first_choice.delta.function_call.name
if 'arguments' in first_choice.delta.function_call:
arguments += str(first_choice.delta.function_call.arguments)
elif 'finish_reason' in first_choice and first_choice.finish_reason == 'function_call':
break
else:
return response, plugins_used
else:
return response, plugins_used
else:
if 'choices' in response and len(response.choices) > 0:
first_choice = response.choices[0]
if 'function_call' in first_choice.message:
if 'name' in first_choice.message.function_call:
function_name += first_choice.message.function_call.name
if 'arguments' in first_choice.message.function_call:
arguments += str(first_choice.message.function_call.arguments)
else:
return response, plugins_used
else:
return response, plugins_used

logging.info(f'Calling function {function_name} with arguments {arguments}')
function_response = await self.plugin_manager.call_function(function_name, arguments)

if function_name not in plugins_used:
plugins_used += (function_name,)

if is_direct_result(function_response):
self.__add_function_call_to_history(chat_id=chat_id, function_name=function_name,
content=json.dumps({'result': 'Done, the content has been sent'
'to the user.'}))
return function_response, plugins_used

self.__add_function_call_to_history(chat_id=chat_id, function_name=function_name, content=function_response)
response = await openai.ChatCompletion.acreate(
model=self.config['model'],
messages=self.conversations[chat_id],
functions=self.plugin_manager.get_functions_specs(),
function_call='auto' if times < self.config['functions_max_consecutive_calls'] else 'none',
stream=stream
)
return await self.__handle_function_call(chat_id, response, stream, times + 1, plugins_used)

async def generate_image(self, prompt: str) -> tuple[str, str]:
"""
Generates an image from the given prompt using DALL·E model.
@@ -264,6 +368,12 @@ def __max_age_reached(self, chat_id) -> bool:
max_age_minutes = self.config['max_conversation_age_minutes']
return last_updated < now - datetime.timedelta(minutes=max_age_minutes)

def __add_function_call_to_history(self, chat_id, function_name, content):
"""
Adds a function call to the conversation history
"""
self.conversations[chat_id].append({"role": "function", "name": function_name, "content": content})

def __add_to_history(self, chat_id, role, content):
"""
Adds a message to the conversation history.
68 changes: 68 additions & 0 deletions bot/plugin_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import json

from plugins.gtts_text_to_speech import GTTSTextToSpeech
from plugins.dice import DicePlugin
from plugins.youtube_audio_extractor import YouTubeAudioExtractorPlugin
from plugins.ddg_image_search import DDGImageSearchPlugin
from plugins.ddg_translate import DDGTranslatePlugin
from plugins.spotify import SpotifyPlugin
from plugins.crypto import CryptoPlugin
from plugins.weather import WeatherPlugin
from plugins.ddg_web_search import DDGWebSearchPlugin
from plugins.wolfram_alpha import WolframAlphaPlugin
from plugins.deepl import DeeplTranslatePlugin
from plugins.worldtimeapi import WorldTimeApiPlugin
from plugins.whois_ import WhoisPlugin


class PluginManager:
"""
A class to manage the plugins and call the correct functions
"""

def __init__(self, config):
enabled_plugins = config.get('plugins', [])
plugin_mapping = {
'wolfram': WolframAlphaPlugin,
'weather': WeatherPlugin,
'crypto': CryptoPlugin,
'ddg_web_search': DDGWebSearchPlugin,
'ddg_translate': DDGTranslatePlugin,
'ddg_image_search': DDGImageSearchPlugin,
'spotify': SpotifyPlugin,
'worldtimeapi': WorldTimeApiPlugin,
'youtube_audio_extractor': YouTubeAudioExtractorPlugin,
'dice': DicePlugin,
'deepl_translate': DeeplTranslatePlugin,
'gtts_text_to_speech': GTTSTextToSpeech,
'whois': WhoisPlugin,
}
self.plugins = [plugin_mapping[plugin]() for plugin in enabled_plugins if plugin in plugin_mapping]

def get_functions_specs(self):
"""
Return the list of function specs that can be called by the model
"""
return [spec for specs in map(lambda plugin: plugin.get_spec(), self.plugins) for spec in specs]

async def call_function(self, function_name, arguments):
"""
Call a function based on the name and parameters provided
"""
plugin = self.__get_plugin_by_function_name(function_name)
if not plugin:
return json.dumps({'error': f'Function {function_name} not found'})
return json.dumps(await plugin.execute(function_name, **json.loads(arguments)), default=str)

def get_plugin_source_name(self, function_name) -> str:
"""
Return the source name of the plugin
"""
plugin = self.__get_plugin_by_function_name(function_name)
if not plugin:
return ''
return plugin.get_source_name()

def __get_plugin_by_function_name(self, function_name):
return next((plugin for plugin in self.plugins
if function_name in map(lambda spec: spec.get('name'), plugin.get_spec())), None)
30 changes: 30 additions & 0 deletions bot/plugins/crypto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Dict

import requests

from .plugin import Plugin


# Author: https://github.com/stumpyfr
class CryptoPlugin(Plugin):
"""
A plugin to fetch the current rate of various cryptocurrencies
"""
def get_source_name(self) -> str:
return "CoinCap"

def get_spec(self) -> [Dict]:
return [{
"name": "get_crypto_rate",
"description": "Get the current rate of various crypto currencies",
"parameters": {
"type": "object",
"properties": {
"asset": {"type": "string", "description": "Asset of the crypto"}
},
"required": ["asset"],
},
}]

async def execute(self, function_name, **kwargs) -> Dict:
return requests.get(f"https://api.coincap.io/v2/rates/{kwargs['asset']}").json()
Loading

0 comments on commit 30d441d

Please sign in to comment.