From ed4535bf066a785c47a07f8662e1c4be462e482e Mon Sep 17 00:00:00 2001 From: Arsenii Shatokhin Date: Tue, 23 Jan 2024 08:52:51 +0400 Subject: [PATCH] Async agency --- agency_swarm/agency/__init__.py | 2 +- agency_swarm/agency/agency.py | 82 ++++++++++-- agency_swarm/threads/thread_async.py | 53 ++++++++ notebooks/agency_async.ipynb | 190 +++++++++++++++++++++++++++ tests/test_agency.py | 62 +++++++-- 5 files changed, 367 insertions(+), 22 deletions(-) create mode 100644 agency_swarm/threads/thread_async.py create mode 100644 notebooks/agency_async.ipynb diff --git a/agency_swarm/agency/__init__.py b/agency_swarm/agency/__init__.py index 6ffeebd3..5720fc28 100644 --- a/agency_swarm/agency/__init__.py +++ b/agency_swarm/agency/__init__.py @@ -1 +1 @@ -from .agency import Agency \ No newline at end of file +from .agency import Agency diff --git a/agency_swarm/agency/agency.py b/agency_swarm/agency/agency.py index c59c9258..5d631337 100644 --- a/agency_swarm/agency/agency.py +++ b/agency_swarm/agency/agency.py @@ -3,7 +3,7 @@ import os import uuid from enum import Enum -from typing import List, TypedDict, Callable, Any, Dict +from typing import List, TypedDict, Callable, Any, Dict, Literal from pydantic import Field, field_validator from rich.console import Console @@ -20,13 +20,19 @@ class SettingsCallbacks(TypedDict): load: Callable[[], List[Dict]] save: Callable[[List[Dict]], Any] + class ThreadsCallbacks(TypedDict): load: Callable[[], Dict] save: Callable[[Dict], Any] class Agency: + ThreadType = Thread + send_message_tool_description = """Use this tool to facilitate direct, synchronous communication between specialized agents within your agency. When you send a message using this tool, you receive a response exclusively from the designated recipient agent. To continue the dialogue, invoke this tool again with the desired recipient and your follow-up message. Remember, communication here is synchronous; the recipient agent won't perform any tasks post-response. You are responsible for relaying the recipient agent's responses back to the user, as they do not have direct access to these replies. Keep engaging with the tool for continuous interaction until the task is fully resolved.""" + send_message_tool_description_async = """Use this tool to facilitate direct, asynchronous communication between specialized agents within your agency. When you send a message using this tool, you initiate the task with the recepient agent. To check the status of the task and recieve a response, please invoke the 'GetResponse' tool with the same recipient agent. You are responsible for relaying the recipient agent's responses back to the user, as they do not have direct access to these replies. Remember that you can't check the status yourself later, user needs to tell you when to do so. Keep engaging with this tool until the task is fully resolved.""" + def __init__(self, agency_chart: List, shared_instructions: str = "", shared_files: List = None, + async_mode: Literal['threading'] = None, settings_callbacks: SettingsCallbacks = None, threads_callbacks: ThreadsCallbacks = None): """ Initializes the Agency object, setting up agents, threads, and core functionalities. @@ -40,6 +46,11 @@ def __init__(self, agency_chart: List, shared_instructions: str = "", shared_fil This constructor initializes various components of the Agency, including CEO, agents, threads, and user interactions. It parses the agency chart to set up the organizational structure and initializes the messaging tools, agents, and threads necessary for the operation of the agency. Additionally, it prepares a main thread for user interactions. """ + self.async_mode = async_mode + if self.async_mode == "threading": + from agency_swarm.threads.thread_async import ThreadAsync + self.ThreadType = ThreadAsync + self.ceo = None self.agents = [] self.agents_and_threads = {} @@ -55,7 +66,7 @@ def __init__(self, agency_chart: List, shared_instructions: str = "", shared_fil self.shared_instructions = shared_instructions self._parse_agency_chart(agency_chart) - self._create_send_message_tools() + self._create_special_tools() self._init_agents() self._init_threads() @@ -115,9 +126,10 @@ def _init_threads(self): for agent_name, threads in self.agents_and_threads.items(): for other_agent, items in threads.items(): - self.agents_and_threads[agent_name][other_agent] = Thread(self.get_agent_by_name(items["agent"]), - self.get_agent_by_name( - items["recipient_agent"])) + self.agents_and_threads[agent_name][other_agent] = self.ThreadType( + self.get_agent_by_name(items["agent"]), + self.get_agent_by_name( + items["recipient_agent"])) if agent_name in loaded_thread_ids and other_agent in loaded_thread_ids[agent_name]: self.agents_and_threads[agent_name][other_agent].id = loaded_thread_ids[agent_name][other_agent] @@ -366,7 +378,7 @@ def _read_instructions(self, path): def plot_agency_chart(self): pass - def _create_send_message_tools(self): + def _create_special_tools(self): """ Creates and assigns 'SendMessage' tools to each agent based on the agency's structure. @@ -381,6 +393,8 @@ def _create_send_message_tools(self): recipient_agents = self.get_agents_by_names(recipient_names) agent = self.get_agent_by_name(agent_name) agent.add_tool(self._create_send_message_tool(agent, recipient_agents)) + if self.async_mode: + agent.add_tool(self._create_get_response_tool(agent, recipient_agents)) def _create_send_message_tool(self, agent: Agent, recipient_agents: List[Agent]): """ @@ -407,7 +421,6 @@ def _create_send_message_tool(self, agent: Agent, recipient_agents: List[Agent]) outer_self = self class SendMessage(BaseTool): - """Use this tool to facilitate direct, synchronous communication between specialized agents within your agency. When you send a message using this tool, you receive a response exclusively from the designated recipient agent. To continue the dialogue, invoke this tool again with the desired recipient and your follow-up message. Remember, communication here is synchronous; the recipient agent won't perform any tasks post-response. You are responsible for relaying the recipient agent's responses back to the user, as they do not have direct access to these replies. Keep engaging with the tool for continuous interaction until the task is fully resolved.""" instructions: str = Field(..., description="Please repeat your instructions step-by-step, including both completed " "and the following next steps that you need to perfrom. For multi-step complex tasks, first break them down " @@ -439,19 +452,62 @@ def check_caller_agent_name(cls, value): def run(self): thread = outer_self.agents_and_threads[self.caller_agent_name][self.recipient.value] - gen = thread.get_completion(message=self.message, message_files=self.message_files) - try: - while True: - yield next(gen) - except StopIteration as e: - message = e.value + if not outer_self.async_mode: + gen = thread.get_completion(message=self.message, message_files=self.message_files) + try: + while True: + yield next(gen) + except StopIteration as e: + message = e.value + else: + message = thread.get_completion_async(message=self.message, message_files=self.message_files) return message or "" SendMessage.caller_agent = agent + if self.async_mode: + SendMessage.__doc__ = self.send_message_tool_description_async + else: + SendMessage.__doc__ = self.send_message_tool_description return SendMessage + def _create_get_response_tool(self, agent: Agent, recipient_agents: List[Agent]): + """ + Creates a CheckStatus tool to enable an agent to check the status of a task with a specified recipient agent. + """ + recipient_names = [agent.name for agent in recipient_agents] + recipients = Enum("recipient", {name: name for name in recipient_names}) + + outer_self = self + + class GetResponse(BaseTool): + """This tool allows you to check the status of a task or get a response from a specified recipient agent, if the task has been completed. You must always use 'SendMessage' tool first.""" + recipient: recipients = Field(..., description=f"Recipient agent that you want to check the status of. Valid recipients are: {recipient_names}") + caller_agent_name: str = Field(default=agent.name, + description="The agent calling this tool. Defaults to your name. Do not change it.") + + @field_validator('recipient') + def check_recipient(cls, value): + if value.value not in recipient_names: + raise ValueError(f"Recipient {value} is not valid. Valid recipients are: {recipient_names}") + return value + + @field_validator('caller_agent_name') + def check_caller_agent_name(cls, value): + if value != agent.name: + raise ValueError(f"Caller agent name must be {agent.name}.") + return value + + def run(self): + thread = outer_self.agents_and_threads[self.caller_agent_name][self.recipient.value] + + return thread.check_status() + + GetResponse.caller_agent = agent + + return GetResponse + def get_recipient_names(self): """ Retrieves the names of all agents in the agency. diff --git a/agency_swarm/threads/thread_async.py b/agency_swarm/threads/thread_async.py new file mode 100644 index 00000000..ba046bcf --- /dev/null +++ b/agency_swarm/threads/thread_async.py @@ -0,0 +1,53 @@ +from agency_swarm.threads import Thread +import threading +from typing import Literal +from agency_swarm.agents import Agent +from agency_swarm.messages import MessageOutput +from agency_swarm.user import User +from agency_swarm.util.oai import get_openai_client + + +class ThreadAsync(Thread): + def __init__(self, agent: Literal[Agent, User], recipient_agent: Agent): + super().__init__(agent, recipient_agent) + self.pythread = None + self.response = None + + def worker(self, message: str, message_files=None): + gen = super().get_completion(message=message, message_files=message_files, + yield_messages=False) # yielding is not supported in async mode + while True: + try: + next(gen) + except StopIteration as e: + self.response = f"""{self.recipient_agent.name} Response: '{e.value}'""" + break + + return + + def get_completion_async(self, message: str, message_files=None): + if self.pythread and self.pythread.is_alive(): + return "System Notification: 'Agent is busy, so your message was not recived. Please always use 'GetResponse' tool to check for status first, before using 'SendMessage' tool again for the same agent.'" + elif self.pythread and not self.pythread.is_alive(): + self.pythread.join() + self.pythread = None + return self.response + + self.response = None + + self.pythread = threading.Thread(target=self.worker, + args=(message, message_files)) + + self.pythread.start() + + return "System Notification: 'Task has started. Please notify the user that they can tell you to check the status later. You can do this with the 'GetResponse' tool, but don't mention this tool to the user. " + + def check_status(self): + if self.pythread and self.pythread.is_alive(): + return "System Notification: 'Agent is busy. Please tell the user that they need to wait and ask you to check for status again later.'" + elif self.pythread and not self.pythread.is_alive(): + self.pythread.join() + self.pythread = None + return self.response + else: + return "System Notification: 'Agent is available. Please use 'SendMessage' tool to send a message.'" \ No newline at end of file diff --git a/notebooks/agency_async.ipynb b/notebooks/agency_async.ipynb new file mode 100644 index 00000000..531eeb17 --- /dev/null +++ b/notebooks/agency_async.ipynb @@ -0,0 +1,190 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "initial_id", + "metadata": { + "collapsed": true, + "ExecuteTime": { + "end_time": "2024-01-23T04:16:59.442026Z", + "start_time": "2024-01-23T04:16:59.427181Z" + } + }, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.append('../agency_swarm') \n", + "\n", + "%load_ext autoreload\n", + "\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "outputs": [], + "source": [ + "from agency_swarm import Agent, Agency\n", + "\n", + "ceo = Agent(name=\"CEO\",\n", + " description=\"Responsible for client communication, task planning and management.\",\n", + " instructions=\"You must converse with other agents to ensure complete task execution.\", # can be a file like ./instructions.md\n", + " tools=[])\n", + "\n", + "test = Agent(name=\"Test Agent\",\n", + " description=\"Test agent\",\n", + " instructions=\"Please always respond with 'test complete'\", # can be a file like ./instructions.md\n", + " tools=[])\n", + "\n", + "agency = Agency([ceo, [ceo, test]], async_mode='threading')" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-23T04:17:01.253161Z", + "start_time": "2024-01-23T04:16:59.679906Z" + } + }, + "id": "a16ee4220f5ab03a", + "execution_count": 2 + }, + { + "cell_type": "code", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "THREAD:[ user -> CEO ]: URL https://platform.openai.com/playground?assistant=asst_HQ3kpb9SzhEgo0ya4IvSFoQ8&mode=assistant&thread=thread_cn2VhmbuYcr0EZgcGVIEOSfw\n", + "THREAD:[ CEO -> Test Agent ]: URL https://platform.openai.com/playground?assistant=asst_cml8LF575HVYy7cWePbEQDgy&mode=assistant&thread=thread_YEsfiNS8gOyGOXXMLoZEBJZz\n" + ] + }, + { + "data": { + "text/plain": "\"I've sent a greeting to the Test Agent. You can ask me to check for a response later, and I'll be happy to do so!\"" + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agency.get_completion(\"Say hi to test agent\", yield_messages=False)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-23T04:17:09.976706Z", + "start_time": "2024-01-23T04:17:01.254030Z" + } + }, + "id": "f578f37d8b261559", + "execution_count": 3 + }, + { + "cell_type": "code", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "THREAD:[ user -> CEO ]: URL https://platform.openai.com/playground?assistant=asst_HQ3kpb9SzhEgo0ya4IvSFoQ8&mode=assistant&thread=thread_cn2VhmbuYcr0EZgcGVIEOSfw\n" + ] + }, + { + "data": { + "text/plain": "'The Test Agent has completed the task. If you have any more requests or tasks, feel free to let me know!'" + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agency.get_completion(\"Check status\", yield_messages=False)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-23T04:17:37.434522Z", + "start_time": "2024-01-23T04:17:30.450008Z" + } + }, + "id": "fc798bc6c58c9c16", + "execution_count": 4 + }, + { + "cell_type": "code", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running on local URL: http://127.0.0.1:7861\n", + "\n", + "To create a public link, set `share=True` in `launch()`.\n" + ] + }, + { + "data": { + "text/plain": "", + "text/html": "
" + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": "Gradio Blocks instance: 2 backend functions\n-------------------------------------------\nfn_index=0\n inputs:\n |-textbox\n |-chatbot\n outputs:\n |-textbox\n |-chatbot\nfn_index=1\n inputs:\n |-chatbot\n outputs:\n |-chatbot" + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agency.demo_gradio()" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-21T03:36:13.946078Z", + "start_time": "2024-01-21T03:36:13.843445Z" + } + }, + "id": "6ef9050ecc718655", + "execution_count": 7 + }, + { + "cell_type": "code", + "outputs": [], + "source": [], + "metadata": { + "collapsed": false + }, + "id": "94e0b7064cfb2c62" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tests/test_agency.py b/tests/test_agency.py index 18799898..79f55d45 100644 --- a/tests/test_agency.py +++ b/tests/test_agency.py @@ -3,6 +3,7 @@ import os import shutil import sys +import time import unittest sys.path.insert(0, '../agency-swarm') @@ -196,12 +197,50 @@ def test_5_load_from_db(self): self.assertTrue(agent.id in [settings['id'] for settings in self.__class__.loaded_agents_settings]) self.assertTrue(agent.id in [settings['id'] for settings in previous_loaded_agents_settings]) + def test_6_init_async_agency(self): + """it should initialize agency with agents""" + # reset loaded thread ids + self.__class__.loaded_thread_ids = {} + + self.__class__.agency = Agency([ + self.__class__.ceo, + [self.__class__.ceo, self.__class__.agent1], + [self.__class__.agent1, self.__class__.agent2]], + shared_instructions="This is a shared instruction", + settings_callbacks=self.__class__.settings_callbacks, + threads_callbacks=self.__class__.threads_callbacks, + async_mode='threading', + ) + + self.check_all_agents_settings(True) + + def test_7_async_agent_communication(self): + """it should communicate between agents asynchronously""" + print("TestAgent1 tools", self.__class__.agent1.tools) + self.__class__.agency.get_completion("Please tell TestAgent1 to say test to TestAgent2.", + yield_messages=False) + + time.sleep(10) + + message = self.__class__.agency.get_completion("Please check status. If the agent responds, say 'success', if the agent does not respond, or if you get a system update say 'error'.", + yield_messages=False) + + self.assertFalse('error' in message.lower()) + + for agent_name, threads in self.__class__.agency.agents_and_threads.items(): + for other_agent_name, thread in threads.items(): + self.assertTrue(thread.id in self.__class__.loaded_thread_ids[agent_name][other_agent_name]) + + for agent in self.__class__.agency.agents: + self.assertTrue(agent.id in [settings['id'] for settings in self.__class__.loaded_agents_settings]) + + # --- Helper methods --- def get_class_folder_path(self): return os.path.abspath(os.path.dirname(inspect.getfile(self.__class__))) - def check_agent_settings(self, agent): + def check_agent_settings(self, agent, async_mode=False): try: settings_path = agent.get_settings_path() self.assertTrue(os.path.exists(settings_path)) @@ -215,31 +254,38 @@ def check_agent_settings(self, agent): self.assertTrue(assistant) self.assertTrue(agent._check_parameters(assistant.model_dump())) if agent.name == "TestAgent1": + num_tools = 2 if not async_mode else 3 self.assertTrue(len(assistant.file_ids) == self.__class__.num_files) for file_id in assistant.file_ids: self.assertTrue(file_id in agent.file_ids) # check retrieval tools is there - self.assertTrue(len(assistant.tools) == 2) - self.assertTrue(len(agent.tools) == 2) + self.assertTrue(len(assistant.tools) == num_tools) + self.assertTrue(len(agent.tools) == num_tools) self.assertTrue(assistant.tools[0].type == "retrieval") self.assertTrue(assistant.tools[1].type == "function") self.assertTrue(assistant.tools[1].function.name == "SendMessage") + if async_mode: + self.assertTrue(assistant.tools[2].type == "function") + self.assertTrue(assistant.tools[2].function.name == "GetResponse") elif agent.name == "TestAgent2": self.assertTrue(len(assistant.tools) == self.__class__.num_schemas) for tool in assistant.tools: self.assertTrue(tool.type == "function") self.assertTrue(tool.function.name in [tool.__name__ for tool in agent.tools]) elif agent.name == "CEO": + num_tools = 1 if not async_mode else 2 self.assertTrue(len(assistant.file_ids) == 0) - self.assertTrue(len(assistant.tools) == 1) + self.assertTrue(len(assistant.tools) == num_tools) + else: + raise Exception("Unknown agent name") except Exception as e: print("Error checking agent settings ", agent.name) raise e - def check_all_agents_settings(self): - self.check_agent_settings(self.__class__.ceo) - self.check_agent_settings(self.__class__.agent1) - self.check_agent_settings(self.__class__.agent2) + def check_all_agents_settings(self, async_mode=False): + self.check_agent_settings(self.__class__.ceo, async_mode=async_mode) + self.check_agent_settings(self.__class__.agent1, async_mode=async_mode) + self.check_agent_settings(self.__class__.agent2, async_mode=async_mode) @classmethod def tearDownClass(cls):