From f9664621d7131ed753e791d668ccd5674f9483c4 Mon Sep 17 00:00:00 2001 From: Arsenii Shatokhin Date: Thu, 14 Nov 2024 11:56:45 +0400 Subject: [PATCH 1/9] Allow extending and modifying SendMessage tool in Agency class --- agency_swarm/agency/agency.py | 129 +++++++----------- agency_swarm/threads/thread.py | 117 ++++++++++------ agency_swarm/tools/BaseTool.py | 16 ++- agency_swarm/tools/__init__.py | 2 +- .../tools/send_message/SendMessage.py | 60 ++++++++ .../send_message/SendMessageAsyncThreading.py | 16 +++ .../tools/send_message/SendMessageBase.py | 19 +++ .../tools/send_message/SendMessageSwarm.py | 47 +++++++ agency_swarm/tools/send_message/__init__.py | 4 + tests/test_agency.py | 10 +- tests/test_send_message.py | 42 ++++++ 11 files changed, 337 insertions(+), 125 deletions(-) create mode 100644 agency_swarm/tools/send_message/SendMessage.py create mode 100644 agency_swarm/tools/send_message/SendMessageAsyncThreading.py create mode 100644 agency_swarm/tools/send_message/SendMessageBase.py create mode 100644 agency_swarm/tools/send_message/SendMessageSwarm.py create mode 100644 agency_swarm/tools/send_message/__init__.py create mode 100644 tests/test_send_message.py diff --git a/agency_swarm/agency/agency.py b/agency_swarm/agency/agency.py index 4e5001d6..50840c40 100644 --- a/agency_swarm/agency/agency.py +++ b/agency_swarm/agency/agency.py @@ -6,7 +6,6 @@ import uuid from enum import Enum from typing import Any, Callable, Dict, List, Literal, Optional, Type, TypeVar, TypedDict, Union - from openai.lib._parsing._completions import type_to_response_format_param from openai.types.beta.threads import Message from openai.types.beta.threads.runs import RunStep @@ -24,7 +23,9 @@ from agency_swarm.messages import MessageOutput from agency_swarm.messages.message_output import MessageOutputLive from agency_swarm.threads import Thread +from agency_swarm.threads.thread_async import ThreadAsync from agency_swarm.tools import BaseTool, CodeInterpreter, FileSearch +from agency_swarm.tools.send_message import SendMessage, SendMessageBase from agency_swarm.user import User from agency_swarm.util.errors import RefusalError from agency_swarm.util.files import get_tools, get_file_purpose @@ -46,15 +47,12 @@ class ThreadsCallbacks(TypedDict): class Agency: - ThreadType = Thread - send_message_tool_description = """Use this tool to facilitate direct, synchronous communication between specialized agents within your agency. When you send a message using this tool, you receive a response exclusively from the designated recipient agent. To continue the dialogue, invoke this tool again with the desired recipient agent and your follow-up message. Remember, communication here is synchronous; the recipient agent won't perform any tasks post-response. You are responsible for relaying the recipient agent's responses back to the user, as the user does not have direct access to these replies. Keep engaging with the tool for continuous interaction until the task is fully resolved. Do not send more than 1 message at a time.""" - send_message_tool_description_async = """Use this tool for asynchronous communication with other agents within your agency. Initiate tasks by messaging, and check status and responses later with the 'GetResponse' tool. Relay responses to the user, who instructs on status checks. Continue until task completion.""" - def __init__(self, agency_chart: List, shared_instructions: str = "", shared_files: Union[str, List[str]] = None, async_mode: Literal['threading', "tools_threading"] = None, + send_message_tool_class: Type[SendMessageBase] = SendMessage, settings_path: str = "./settings.json", settings_callbacks: SettingsCallbacks = None, threads_callbacks: ThreadsCallbacks = None, @@ -72,6 +70,7 @@ def __init__(self, shared_instructions (str, optional): A path to a file containing shared instructions for all agents. Defaults to an empty string. shared_files (Union[str, List[str]], optional): A path to a folder or a list of folders containing shared files for all agents. Defaults to None. async_mode (str, optional): Specifies the mode for asynchronous processing. In "threading" mode, all sub-agents run in separate threads. In "tools_threading" mode, all tools run in separate threads, but agents do not. Defaults to None. + send_message_tool_class (Type[SendMessageBase], optional): The class to use for the send_message tool. For async communication, use `SendMessageAsyncThreading`. Defaults to SendMessage. settings_path (str, optional): The path to the settings file for the agency. Must be json. If file does not exist, it will be created. Defaults to None. settings_callbacks (SettingsCallbacks, optional): A dictionary containing functions to load and save settings for the agency. The keys must be "load" and "save". Both values must be defined. Defaults to None. threads_callbacks (ThreadsCallbacks, optional): A dictionary containing functions to load and save threads for the agency. The keys must be "load" and "save". Both values must be defined. Defaults to None. @@ -92,6 +91,7 @@ def __init__(self, self.recipient_agents = None # for autocomplete self.shared_files = shared_files if shared_files else [] self.async_mode = async_mode + self.send_message_tool_class = send_message_tool_class self.settings_path = settings_path self.settings_callbacks = settings_callbacks self.threads_callbacks = threads_callbacks @@ -102,8 +102,9 @@ def __init__(self, self.truncation_strategy = truncation_strategy if self.async_mode == "threading": - from agency_swarm.threads.thread_async import ThreadAsync - self.ThreadType = ThreadAsync + from agency_swarm.tools.send_message import SendMessageAsyncThreading + print("Warning: 'threading' mode is deprecated. Please use send_message_tool_class = SendMessageAsyncThreading to use async communication.") + self.send_message_tool_class = SendMessageAsyncThreading elif self.async_mode == "tools_threading": Thread.async_mode = self.async_mode elif self.async_mode is None: @@ -121,9 +122,9 @@ def __init__(self, self.shared_state = SharedState() self._parse_agency_chart(agency_chart) + self._init_threads() self._create_special_tools() self._init_agents() - self._init_threads() def get_completion(self, message: str, message_files: List[str] = None, @@ -300,7 +301,7 @@ def demo_gradio(self, height=450, dark_mode=True, **kwargs): images = [] message_file_names = None uploading_files = False - recipient_agents = [agent.name for agent in self.main_recipients] + recipient_agent_names = [agent.name for agent in self.main_recipients] recipient_agent = self.main_recipients[0] with gr.Blocks(js=js) as demo: @@ -308,7 +309,7 @@ def demo_gradio(self, height=450, dark_mode=True, **kwargs): chatbot = gr.Chatbot(height=height) with gr.Row(): with gr.Column(scale=9): - dropdown = gr.Dropdown(label="Recipient Agent", choices=recipient_agents, + dropdown = gr.Dropdown(label="Recipient Agent", choices=recipient_agent_names, value=recipient_agent.name) msg = gr.Textbox(label="Your Message", lines=4) with gr.Column(scale=1): @@ -412,9 +413,14 @@ def check_and_add_tools_in_attachments(attachments, recipient_agent): class GradioEventHandler(AgencyEventHandler): message_output = None + @classmethod + def change_recipient_agent(cls, recipient_agent_name): + nonlocal chatbot_queue + chatbot_queue.put("[change_recipient_agent]") + chatbot_queue.put(recipient_agent_name) + @override def on_message_created(self, message: Message) -> None: - if message.role == "user": full_content = "" for content in message.content: @@ -530,19 +536,20 @@ def on_all_streams_end(cls): chatbot_queue.put("[end]") def bot(original_message, history): - if not original_message: - return "", history - nonlocal attachments nonlocal message_file_names nonlocal recipient_agent + nonlocal recipient_agent_names nonlocal images nonlocal uploading_files + if not original_message: + return "", history, gr.update(value=recipient_agent.name, choices=set([*recipient_agent_names, recipient_agent.name])) + if uploading_files: history.append([None, "Uploading files... Please wait."]) - yield "", history - return "", history + yield "", history, gr.update(value=recipient_agent.name, choices=set([*recipient_agent_names, recipient_agent.name])) + return "", history, gr.update(value=recipient_agent.name, choices=set([*recipient_agent_names, recipient_agent.name])) print("Message files: ", attachments) print("Images: ", images) @@ -579,13 +586,19 @@ def bot(original_message, history): new_message = True continue + if bot_message == "[change_recipient_agent]": + new_agent_name = chatbot_queue.get(block=True) + recipient_agent = self._get_agent_by_name(new_agent_name) + yield "", history, gr.update(value=new_agent_name, choices=set([*recipient_agent_names, recipient_agent.name])) + continue + if new_message: history.append([None, bot_message]) new_message = False else: history[-1][1] += bot_message - yield "", history + yield "", history, gr.update(value=recipient_agent.name, choices=set([*recipient_agent_names, recipient_agent.name])) except queue.Empty: break @@ -594,12 +607,12 @@ def bot(original_message, history): inputs=[msg, chatbot], outputs=[msg, chatbot] ).then( - bot, [msg, chatbot], [msg, chatbot] + bot, [msg, chatbot, dropdown], [msg, chatbot, dropdown] ) dropdown.change(handle_dropdown_change, dropdown) file_upload.change(handle_file_upload, file_upload) msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( - bot, [msg, chatbot], [msg, chatbot] + bot, [msg, chatbot], [msg, chatbot, dropdown] ) # Enable queuing for streaming intermediate outputs @@ -647,6 +660,7 @@ def run_demo(self): """ Executes agency in the terminal with autocomplete for recipient agent names. """ + outer_self = self from agency_swarm import AgencyEventHandler class TermEventHandler(AgencyEventHandler): message_output = None @@ -714,7 +728,7 @@ def on_tool_call_done(self, snapshot): if snapshot.type != "function": return - if snapshot.function.name == "SendMessage": + if snapshot.function.name == "SendMessage" and not (hasattr(outer_self.send_message_tool_class.ToolConfig, 'output_as_result') and outer_self.send_message_tool_class.ToolConfig.output_as_result): try: args = eval(snapshot.function.arguments) recipient = args["recipient"] @@ -864,9 +878,14 @@ def _init_threads(self): else: self.main_thread.init_thread() + # Save main_thread into agents_and_threads + self.agents_and_threads["main_thread"] = self.main_thread + for agent_name, threads in self.agents_and_threads.items(): + if agent_name == "main_thread": + continue for other_agent, items in threads.items(): - self.agents_and_threads[agent_name][other_agent] = self.ThreadType( + self.agents_and_threads[agent_name][other_agent] = self.send_message_tool_class._thread_type( self._get_agent_by_name(items["agent"]), self._get_agent_by_name( items["recipient_agent"])) @@ -880,6 +899,8 @@ def _init_threads(self): if self.threads_callbacks: loaded_thread_ids = {} for agent_name, threads in self.agents_and_threads.items(): + if agent_name == "main_thread": + continue loaded_thread_ids[agent_name] = {} for other_agent, thread in threads.items(): loaded_thread_ids[agent_name][other_agent] = thread.id @@ -1001,13 +1022,15 @@ def _create_special_tools(self): No output parameters; this method modifies the agents' toolset internally. """ for agent_name, threads in self.agents_and_threads.items(): + if agent_name == "main_thread": + continue recipient_names = list(threads.keys()) recipient_agents = self._get_agents_by_names(recipient_names) if len(recipient_agents) == 0: continue agent = self._get_agent_by_name(agent_name) agent.add_tool(self._create_send_message_tool(agent, recipient_agents)) - if self.async_mode == 'threading': + if self.send_message_tool_class._thread_type == ThreadAsync: agent.add_tool(self._create_get_response_tool(agent, recipient_agents)) def _create_send_message_tool(self, agent: Agent, recipient_agents: List[Agent]): @@ -1032,38 +1055,8 @@ def _create_send_message_tool(self, agent: Agent, recipient_agents: List[Agent]) agent_descriptions += recipient_agent.name + ": " agent_descriptions += recipient_agent.description + "\n" - outer_self = self - - class SendMessage(BaseTool): - my_primary_instructions: str = Field(..., - description="Please repeat your primary instructions step-by-step, including both completed " - "and the following next steps that you need to perform. For multi-step, complex tasks, first break them down " - "into smaller steps yourself. Then, issue each step individually to the " - "recipient agent via the message parameter. Each identified step should be " - "sent in separate message. Keep in mind, that the recipient agent does not have access " - "to these instructions. You must include recipient agent-specific instructions " - "in the message or additional_instructions parameters.") + class SendMessage(self.send_message_tool_class): recipient: recipients = Field(..., description=agent_descriptions) - message: str = Field(..., - description="Specify the task required for the recipient agent to complete. Focus on " - "clarifying what the task entails, rather than providing exact " - "instructions.") - message_files: Optional[List[str]] = Field(default=None, - description="A list of file ids to be sent as attachments to this message. Only use this if you have the file id that starts with 'file-'.", - examples=["file-1234", "file-5678"]) - additional_instructions: Optional[str] = Field(default=None, - description="Additional context or instructions for the recipient agent about the task. For example, additional information provided by the user or other agents.") - - class ToolConfig: - strict = False - one_call_at_a_time = outer_self.async_mode != 'threading' - - @model_validator(mode='after') - def validate_files(self): - if "file-" in self.message or ( - self.additional_instructions and "file-" in self.additional_instructions): - if not self.message_files: - raise ValueError("You must include file ids in message_files parameter.") @field_validator('recipient') @classmethod @@ -1071,36 +1064,10 @@ def check_recipient(cls, value): if value.value not in recipient_names: raise ValueError(f"Recipient {value} is not valid. Valid recipients are: {recipient_names}") return value - - @field_validator('additional_instructions', mode='before') - @classmethod - def validate_additional_instructions(cls, value): - if isinstance(value, list): - return "\n".join(value) - return value - - def run(self): - thread = outer_self.agents_and_threads[self._caller_agent.name][self.recipient.value] - - if not outer_self.async_mode == 'threading': - message = thread.get_completion(message=self.message, - message_files=self.message_files, - event_handler=self._event_handler, - yield_messages=not self._event_handler, - additional_instructions=self.additional_instructions, - ) - else: - message = thread.get_completion_async(message=self.message, - message_files=self.message_files, - additional_instructions=self.additional_instructions) - - return message or "" SendMessage._caller_agent = agent - if self.async_mode == 'threading': - SendMessage.__doc__ = self.send_message_tool_description_async - else: - SendMessage.__doc__ = self.send_message_tool_description + SendMessage._agents_and_threads = self.agents_and_threads + SendMessage._agency = self return SendMessage diff --git a/agency_swarm/threads/thread.py b/agency_swarm/threads/thread.py index 8d681f3d..8a9e49bb 100644 --- a/agency_swarm/threads/thread.py +++ b/agency_swarm/threads/thread.py @@ -42,6 +42,8 @@ def __init__(self, agent: Union[Agent, User], recipient_agent: Agent): self.num_run_retries = 0 + self.terminal_states = ["cancelled", "completed", "failed", "expired", "incomplete"] + def init_thread(self): if self.id: self.thread = self.client.beta.threads.retrieve(self.id) @@ -57,7 +59,7 @@ def init_thread(self): ) def get_completion_stream(self, - message: str, + message: Union[str, List[dict], None], event_handler: type(AgencyEventHandler), message_files: List[str] = None, attachments: Optional[List[Attachment]] = None, @@ -77,10 +79,10 @@ def get_completion_stream(self, response_format=response_format) def get_completion(self, - message: str | List[dict], + message: Union[str, List[dict], None], message_files: List[str] = None, attachments: Optional[List[dict]] = None, - recipient_agent: Agent = None, + recipient_agent: Union[Agent, None] = None, additional_instructions: str = None, event_handler: type(AgencyEventHandler) = None, tool_choice: AssistantToolChoice = None, @@ -117,14 +119,15 @@ def get_completion(self, print(f'THREAD:[ {sender_name} -> {recipient_agent.name} ]: URL {self.thread_url}') # send message - message_obj = self.create_message( - message=message, - role="user", - attachments=attachments - ) + if message: + message_obj = self.create_message( + message=message, + role="user", + attachments=attachments + ) - if yield_messages: - yield MessageOutput("text", self.agent.name, recipient_agent.name, message, message_obj) + if yield_messages: + yield MessageOutput("text", self.agent.name, recipient_agent.name, message, message_obj) self._create_run(recipient_agent, additional_instructions, event_handler, tool_choice, response_format=response_format) @@ -138,8 +141,8 @@ def get_completion(self, if self.run.status == "requires_action": tool_calls = self.run.required_action.submit_tool_outputs.tool_calls tool_outputs_and_names = [] # list of tuples (name, tool_output) - sync_tool_calls = [tool_call for tool_call in tool_calls if tool_call.function.name == "SendMessage"] - async_tool_calls = [tool_call for tool_call in tool_calls if tool_call.function.name != "SendMessage"] + sync_tool_calls = [tool_call for tool_call in tool_calls if tool_call.function.name.startswith("SendMessage")] + async_tool_calls = [tool_call for tool_call in tool_calls if not tool_call.function.name.startswith("SendMessage")] def handle_output(tool_call, output): if inspect.isgenerator(output): @@ -157,6 +160,8 @@ def handle_output(tool_call, output): for tool_output in tool_outputs_and_names: if tool_output[1]["tool_call_id"] == tool_call.id: tool_output[1]["output"] = output + + return output if len(async_tool_calls) > 0 and self.async_mode == "tools_threading": max_workers = min(self.max_workers, os.cpu_count() or 1) # Use at most 4 workers or the number of CPUs available @@ -170,8 +175,11 @@ def handle_output(tool_call, output): for future in as_completed(futures): tool_call = futures[future] - output = future.result() - yield from handle_output(tool_call, output) + output, output_as_result = future.result() + output = yield from handle_output(tool_call, output) + if output_as_result: + self._cancel_run() + return output else: sync_tool_calls += async_tool_calls @@ -179,10 +187,13 @@ def handle_output(tool_call, output): for tool_call in sync_tool_calls: if yield_messages: yield MessageOutput("function", recipient_agent.name, self.agent.name, str(tool_call.function), tool_call) - output = self.execute_tool(tool_call, recipient_agent, event_handler, tool_outputs_and_names) + output, output_as_result = self.execute_tool(tool_call, recipient_agent, event_handler, tool_outputs_and_names) tool_outputs_and_names.append((tool_call.function.name, {"tool_call_id": tool_call.id, "output": output})) - yield from handle_output(tool_call, output) - + output = yield from handle_output(tool_call, output) + if output_as_result: + self._cancel_run() + return output + # split names and outputs tool_outputs = [tool_output for _, tool_output in tool_outputs_and_names] tool_names = [name for name, _ in tool_outputs_and_names] @@ -340,8 +351,11 @@ def _create_run(self, recipient_agent, additional_instructions, event_handler, t # poll_interval_ms=500, ) except APIError as e: - if "The server had an error processing your request" in e.message and self.num_run_retries < 3: - time.sleep(1 + self.num_run_retries) + match = re.search(r"Thread (\w+) already has an active run (\w+)", e.message) + if match: + self._cancel_run(thread_id=match.groups()[0], run_id=match.groups()[1], check_status=False) + elif "The server had an error processing your request" in e.message and self.num_run_retries < 3: + time.sleep(1) self._create_run(recipient_agent, additional_instructions, event_handler, tool_choice, response_format=response_format) self.num_run_retries += 1 else: @@ -355,22 +369,48 @@ def _run_until_done(self): run_id=self.run.id ) - def _submit_tool_outputs(self, tool_outputs, event_handler): - if not event_handler: - self.run = self.client.beta.threads.runs.submit_tool_outputs_and_poll( + def _submit_tool_outputs(self, tool_outputs, event_handler=None, poll=True): + if not poll: + self.run = self.client.beta.threads.runs.submit_tool_outputs( thread_id=self.thread.id, run_id=self.run.id, tool_outputs=tool_outputs ) else: - with self.client.beta.threads.runs.submit_tool_outputs_stream( + if not event_handler: + self.run = self.client.beta.threads.runs.submit_tool_outputs_and_poll( thread_id=self.thread.id, run_id=self.run.id, - tool_outputs=tool_outputs, - event_handler=event_handler() - ) as stream: - stream.until_done() - self.run = stream.get_final_run() + tool_outputs=tool_outputs + ) + else: + with self.client.beta.threads.runs.submit_tool_outputs_stream( + thread_id=self.thread.id, + run_id=self.run.id, + tool_outputs=tool_outputs, + event_handler=event_handler() + ) as stream: + stream.until_done() + self.run = stream.get_final_run() + + def _cancel_run(self, thread_id=None, run_id=None, check_status=True): + if check_status and self.run.status in self.terminal_states and not run_id: + return + + try: + self.run = self.client.beta.threads.runs.cancel( + thread_id=self.thread.id, + run_id=self.run.id + ) + except BadRequestError as e: + if "Cannot cancel run with status" in e.message: + self.run = self.client.beta.threads.runs.poll( + thread_id=thread_id or self.thread.id, + run_id=run_id or self.run.id, + poll_interval_ms=500, + ) + else: + raise e def _get_last_message_text(self): messages = self.client.beta.threads.messages.list( @@ -417,15 +457,9 @@ def create_message(self, message: str, role: str = "user", attachments: List[dic thread_id, run_id = match.groups() thread_id = f"thread_{thread_id}" run_id = f"run_{run_id}" - self.client.beta.threads.runs.cancel( - thread_id=thread_id, - run_id=run_id - ) - self.run = self.client.beta.threads.runs.poll( - thread_id=thread_id, - run_id=run_id, - poll_interval_ms=500, - ) + + self._cancel_run(thread_id=thread_id, run_id=run_id) + return self.client.beta.threads.messages.create( thread_id=thread_id, role=role, @@ -443,7 +477,7 @@ def execute_tool(self, tool_call, recipient_agent=None, event_handler=None, tool tool = next((func for func in funcs if func.__name__ == tool_call.function.name), None) if not tool: - return f"Error: Function {tool_call.function.name} not found. Available functions: {[func.__name__ for func in funcs]}" + return f"Error: Function {tool_call.function.name} not found. Available functions: {[func.__name__ for func in funcs]}", False try: # init tool @@ -453,17 +487,18 @@ def execute_tool(self, tool_call, recipient_agent=None, event_handler=None, tool for tool_name in [name for name, _ in tool_outputs_and_names]: if tool_name == tool_call.function.name and ( hasattr(tool, "ToolConfig") and hasattr(tool.ToolConfig, "one_call_at_a_time") and tool.ToolConfig.one_call_at_a_time): - return f"Error: Function {tool_call.function.name} is already called. You can only call this function once at a time. Please wait for the previous call to finish before calling it again." + return f"Error: Function {tool_call.function.name} is already called. You can only call this function once at a time. Please wait for the previous call to finish before calling it again.", False tool._caller_agent = recipient_agent tool._event_handler = event_handler + tool._tool_call = tool_call - return tool.run() + return tool.run(), tool.ToolConfig.output_as_result except Exception as e: error_message = f"Error: {e}" if "For further information visit" in error_message: error_message = error_message.split("For further information visit")[0] - return error_message + return error_message, False def _execute_async_tool_calls_outputs(self, tool_outputs): async_tool_calls = [] diff --git a/agency_swarm/tools/BaseTool.py b/agency_swarm/tools/BaseTool.py index fd6fd6a6..7d9ffacc 100644 --- a/agency_swarm/tools/BaseTool.py +++ b/agency_swarm/tools/BaseTool.py @@ -11,15 +11,29 @@ class BaseTool(BaseModel, ABC): _shared_state: ClassVar[SharedState] = None _caller_agent: Any = None _event_handler: Any = None + _tool_call: Any = None def __init__(self, **kwargs): if not self.__class__._shared_state: self.__class__._shared_state = SharedState() super().__init__(**kwargs) + + # Ensure all ToolConfig variables are initialized + config_defaults = { + 'strict': False, + 'one_call_at_a_time': False, + 'output_as_result': False + } + + for key, value in config_defaults.items(): + if not hasattr(self.ToolConfig, key): + setattr(self.ToolConfig, key, value) class ToolConfig: strict: bool = False one_call_at_a_time: bool = False + # return the tool output as assistant message + output_as_result: bool = False @classmethod @property @@ -76,5 +90,5 @@ def openai_schema(cls): return schema @abstractmethod - def run(self, **kwargs): + def run(self): pass diff --git a/agency_swarm/tools/__init__.py b/agency_swarm/tools/__init__.py index 8bca7b06..d77d38dd 100644 --- a/agency_swarm/tools/__init__.py +++ b/agency_swarm/tools/__init__.py @@ -2,4 +2,4 @@ from .oai.CodeInterpreter import CodeInterpreter from .oai.FileSearch import FileSearch from .oai.Retrieval import Retrieval -from .ToolFactory import ToolFactory +from .ToolFactory import ToolFactory \ No newline at end of file diff --git a/agency_swarm/tools/send_message/SendMessage.py b/agency_swarm/tools/send_message/SendMessage.py new file mode 100644 index 00000000..758b38ce --- /dev/null +++ b/agency_swarm/tools/send_message/SendMessage.py @@ -0,0 +1,60 @@ +from agency_swarm.threads.thread import Thread +from typing import ClassVar, Optional, List, Type +from pydantic import Field, field_validator, model_validator +from .SendMessageBase import SendMessageBase + +class SendMessage(SendMessageBase): + """Use this tool to facilitate direct, synchronous communication between specialized agents within your agency. When you send a message using this tool, you receive a response exclusively from the designated recipient agent. To continue the dialogue, invoke this tool again with the desired recipient agent and your follow-up message. Remember, communication here is synchronous; the recipient agent won't perform any tasks post-response. You are responsible for relaying the recipient agent's responses back to the user, as the user does not have direct access to these replies. Keep engaging with the tool for continuous interaction until the task is fully resolved. Do not send more than 1 message at a time.""" + message: str = Field( + ..., + description="Specify the task required for the recipient agent to complete. Focus on clarifying what the task entails, rather than providing exact instructions." + ) + my_primary_instructions: str = Field( + ..., + description=( + "Please repeat your primary instructions step-by-step, including both completed " + "and the following next steps that you need to perform. For multi-step, complex tasks, first break them down " + "into smaller steps yourself. Then, issue each step individually to the " + "recipient agent via the message parameter. Each identified step should be " + "sent in a separate message. Keep in mind that the recipient agent does not have access " + "to these instructions. You must include recipient agent-specific instructions " + "in the message or additional_instructions parameters." + ) + ) + message_files: Optional[List[str]] = Field( + default=None, + description="A list of file IDs to be sent as attachments to this message. Only use this if you have the file ID that starts with 'file-'.", + examples=["file-1234", "file-5678"] + ) + additional_instructions: Optional[str] = Field( + default=None, + description="Additional context or instructions from the conversation needed by the recipient agent to complete the task." + ) + + + @model_validator(mode='after') + def validate_files(self): + if hasattr(self, 'message') and "file-" in self.message or (self.additional_instructions and "file-" in self.additional_instructions): + if not self.message_files: + raise ValueError("You must include file IDs in message_files parameter.") + return self + + @field_validator('additional_instructions', mode='before') + @classmethod + def validate_additional_instructions(cls, value): + if isinstance(value, list): + return "\n".join(value) + return value + + + def run(self): + thread: Thread = self._agents_and_threads[self._caller_agent.name][self.recipient.value] + + message = thread.get_completion(message=self.message, + message_files=self.message_files, + event_handler=self._event_handler, + yield_messages=not self._event_handler, + additional_instructions=self.additional_instructions, + ) + + return message or "" \ No newline at end of file diff --git a/agency_swarm/tools/send_message/SendMessageAsyncThreading.py b/agency_swarm/tools/send_message/SendMessageAsyncThreading.py new file mode 100644 index 00000000..ffdfee49 --- /dev/null +++ b/agency_swarm/tools/send_message/SendMessageAsyncThreading.py @@ -0,0 +1,16 @@ +from typing import ClassVar, Type +from agency_swarm.threads.thread_async import ThreadAsync +from .SendMessage import SendMessage + +class SendMessageAsyncThreading(SendMessage): + """Use this tool for asynchronous communication with other agents within your agency. Initiate tasks by messaging, and check status and responses later with the 'GetResponse' tool. Relay responses to the user, who instructs on status checks. Continue until task completion.""" + _thread_type: ClassVar[Type[ThreadAsync]] = ThreadAsync + + def run(self): + thread: ThreadAsync = self._agents_and_threads[self._caller_agent.name][self.recipient.value] + + message = thread.get_completion_async(message=self.message, + message_files=self.message_files, + additional_instructions=self.additional_instructions) + + return message or "" \ No newline at end of file diff --git a/agency_swarm/tools/send_message/SendMessageBase.py b/agency_swarm/tools/send_message/SendMessageBase.py new file mode 100644 index 00000000..37ffe3db --- /dev/null +++ b/agency_swarm/tools/send_message/SendMessageBase.py @@ -0,0 +1,19 @@ +from agency_swarm.threads.thread import Thread +from typing import ClassVar, Optional, List, Type +from pydantic import Field, field_validator, model_validator +from agency_swarm.tools import BaseTool +from abc import ABC + +class SendMessageBase(BaseTool, ABC): + """Use this tool to facilitate direct, synchronous communication between specialized agents within your agency. When you send a message using this tool, you receive a response exclusively from the designated recipient agent. To continue the dialogue, invoke this tool again with the desired recipient agent and your follow-up message. Remember, communication here is synchronous; the recipient agent won't perform any tasks post-response. You are responsible for relaying the recipient agent's responses back to the user, as the user does not have direct access to these replies. Keep engaging with the tool for continuous interaction until the task is fully resolved. Do not send more than 1 message at a time.""" + + recipient: str = Field(..., description="Recipient agent that you want to send the message to. This field will be overriden inside the agency class.") + + _agents_and_threads: ClassVar = None + _thread_type: ClassVar[Type[Thread]] = Thread # thread type assigned by the agency class + + @classmethod + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + if not cls.__name__.startswith("SendMessage"): + raise TypeError(f"Class name '{cls.__name__}' must start with 'SendMessage'.") \ No newline at end of file diff --git a/agency_swarm/tools/send_message/SendMessageSwarm.py b/agency_swarm/tools/send_message/SendMessageSwarm.py new file mode 100644 index 00000000..0557e681 --- /dev/null +++ b/agency_swarm/tools/send_message/SendMessageSwarm.py @@ -0,0 +1,47 @@ +from openai import BadRequestError +from agency_swarm.threads.thread import Thread +from .SendMessage import SendMessageBase + +class SendMessageSwarm(SendMessageBase): + """Use this tool to route messages to other agents within your agency. After using this tool, you will be switched to the recipient agent. This tool can only be used once per message. Do not use any other tools together with this tool.""" + + class ToolConfig: + output_as_result: bool = True + one_call_at_a_time: bool = True + + def run(self): + # get main thread + thread: Thread = self._agents_and_threads["main_thread"] + + # get recipient agent from thread + recipient_agent = self._agents_and_threads[self._caller_agent.name][self.recipient.value].recipient_agent + + # submit tool output + try: + thread._submit_tool_outputs( + tool_outputs=[{"tool_call_id": self._tool_call.id, "output": "The request has been routed. You are now a " + recipient_agent.name + " agent. Please assist the user further with their request."}], + poll=False + ) + except BadRequestError as e: + raise BadRequestError("You can only call this tool by itself. Do not use any other tools together with this tool.") + + try: + # cancel run + thread._cancel_run() + + # change recipient agent in thread + thread.recipient_agent = recipient_agent + + # change recipient agent in gradio dropdown + if self._event_handler: + if hasattr(self._event_handler, "change_recipient_agent"): + self._event_handler.change_recipient_agent(self.recipient.value) + + # continue conversation with the new recipient agent + message = thread.get_completion(message=None, recipient_agent=recipient_agent, yield_messages=not self._event_handler, event_handler=self._event_handler) + + return message or "" + except Exception as e: + # we need to catch errors beucase tool outputs are already submitted + print("Error in SendMessageSwarm: ", e) + return str(e) diff --git a/agency_swarm/tools/send_message/__init__.py b/agency_swarm/tools/send_message/__init__.py new file mode 100644 index 00000000..180265b6 --- /dev/null +++ b/agency_swarm/tools/send_message/__init__.py @@ -0,0 +1,4 @@ +from .SendMessageAsyncThreading import SendMessageAsyncThreading +from .SendMessageBase import SendMessageBase +from .SendMessage import SendMessage +from .SendMessageSwarm import SendMessageSwarm \ No newline at end of file diff --git a/tests/test_agency.py b/tests/test_agency.py index 7ed46b85..0c258527 100644 --- a/tests/test_agency.py +++ b/tests/test_agency.py @@ -219,9 +219,11 @@ def test_4_agent_communication(self): message = self.__class__.agency.get_completion("Please tell TestAgent1 to say test to TestAgent2.", tool_choice={"type": "function", "function": {"name": "SendMessage"}}) - self.assertFalse('error' in message.lower(), f"Error found in message: {message}") + self.assertFalse('error' in message.lower(), f"Error found in message: {message}. Thread url: {self.__class__.agency.main_thread.thread_url}") for agent_name, threads in self.__class__.agency.agents_and_threads.items(): + if agent_name == "main_thread": + continue for other_agent_name, thread in threads.items(): self.assertTrue(thread.id in self.__class__.loaded_thread_ids[agent_name][other_agent_name]) @@ -312,6 +314,8 @@ def on_all_streams_end(cls): self.assertFalse(agent1_thread.run.parallel_tool_calls) for agent_name, threads in self.__class__.agency.agents_and_threads.items(): + if agent_name == "main_thread": + continue for other_agent_name, thread in threads.items(): self.assertTrue(thread.id in self.__class__.loaded_thread_ids[agent_name][other_agent_name]) @@ -369,6 +373,8 @@ def test_6_load_from_db(self): # check that threads are the same for agent_name, threads in agency.agents_and_threads.items(): + if agent_name == "main_thread": + continue for other_agent_name, thread in threads.items(): self.assertTrue(thread.id in self.__class__.loaded_thread_ids[agent_name][other_agent_name]) self.assertTrue(thread.id in previous_loaded_thread_ids[agent_name][other_agent_name]) @@ -450,6 +456,8 @@ def on_all_streams_end(cls): self.assertFalse('error' in message.lower(), self.__class__.agency.main_thread.thread_url) for agent_name, threads in self.__class__.agency.agents_and_threads.items(): + if agent_name == "main_thread": + continue for other_agent_name, thread in threads.items(): self.assertTrue(thread.id in self.__class__.loaded_thread_ids[agent_name][other_agent_name]) diff --git a/tests/test_send_message.py b/tests/test_send_message.py new file mode 100644 index 00000000..9123d938 --- /dev/null +++ b/tests/test_send_message.py @@ -0,0 +1,42 @@ +import unittest +from agency_swarm import Agent, Agency +from agency_swarm.tools.send_message import SendMessageSwarm +from agency_swarm.tools import BaseTool +from pydantic import Field + +class TestSendMessage(unittest.TestCase): + def setUp(self): + class PrintTool(BaseTool): + """ + A simple tool that prints a message. + """ + message: str = Field(..., description="The message to print.") + + def run(self): + print(self.message) + return f"Printed: {self.message}" + + self.ceo = Agent( + name="CEO", + description="Responsible for client communication, task planning and management.", + instructions="Your role is to route messages to other agents within your agency.", + tools=[PrintTool] + ) + + self.customer_support = Agent( + name="Customer Support", + description="Responsible for customer support.", + instructions="You are a Customer Support agent. Answer customer questions and help with issues.", + tools=[] + ) + + self.agency = Agency([self.ceo, [self.ceo, self.customer_support], [self.customer_support, self.ceo]], send_message_tool_class=SendMessageSwarm) + + def test_send_message_swarm(self): + response = self.agency.get_completion("Hello, can you send me to customer support? If there are any issues, please say 'error'") + self.assertFalse("error" in response.lower()) + response = self.agency.get_completion("Who are you?") + self.assertTrue("customer support" in response.lower()) + +if __name__ == '__main__': + unittest.main() From 2646b8adb2164f94d3a7ea03b562647074308421 Mon Sep 17 00:00:00 2001 From: Arsenii Shatokhin Date: Fri, 15 Nov 2024 09:49:59 +0400 Subject: [PATCH 2/9] Refactor thread and optimize thread initializetion --- agency_swarm/agency/agency.py | 4 + agency_swarm/threads/thread.py | 124 ++++++++++++++------------- agency_swarm/threads/thread_async.py | 5 +- 3 files changed, 72 insertions(+), 61 deletions(-) diff --git a/agency_swarm/agency/agency.py b/agency_swarm/agency/agency.py index 50840c40..3770d4aa 100644 --- a/agency_swarm/agency/agency.py +++ b/agency_swarm/agency/agency.py @@ -881,17 +881,21 @@ def _init_threads(self): # Save main_thread into agents_and_threads self.agents_and_threads["main_thread"] = self.main_thread + # initialize threads for agent_name, threads in self.agents_and_threads.items(): if agent_name == "main_thread": continue for other_agent, items in threads.items(): + # create thread class self.agents_and_threads[agent_name][other_agent] = self.send_message_tool_class._thread_type( self._get_agent_by_name(items["agent"]), self._get_agent_by_name( items["recipient_agent"])) + # load thread id if available if agent_name in loaded_thread_ids and other_agent in loaded_thread_ids[agent_name]: self.agents_and_threads[agent_name][other_agent].id = loaded_thread_ids[agent_name][other_agent] + # init threads if threre are threads callbacks so the ids are saved for later use elif self.threads_callbacks: self.agents_and_threads[agent_name][other_agent].init_thread() diff --git a/agency_swarm/threads/thread.py b/agency_swarm/threads/thread.py index 8a9e49bb..6fb9c9dd 100644 --- a/agency_swarm/threads/thread.py +++ b/agency_swarm/threads/thread.py @@ -28,6 +28,16 @@ class Thread: @property def thread_url(self): return f'https://platform.openai.com/playground/assistants?assistant={self.recipient_agent.id}&mode=assistant&thread={self.id}' + + @property + def thread(self): + self.init_thread() + + if not self._thread: + print("retrieving thread", self.id) + self._thread = self.client.beta.threads.retrieve(self.id) + + return self._thread def __init__(self, agent: Union[Agent, User], recipient_agent: Agent): self.agent = agent @@ -36,28 +46,27 @@ def __init__(self, agent: Union[Agent, User], recipient_agent: Agent): self.client = get_openai_client() self.id = None - self.thread = None - self.run = None - self.stream = None + self._thread = None + self._run = None + self._stream = None - self.num_run_retries = 0 + self._num_run_retries = 0 self.terminal_states = ["cancelled", "completed", "failed", "expired", "incomplete"] def init_thread(self): if self.id: - self.thread = self.client.beta.threads.retrieve(self.id) - else: - self.thread = self.client.beta.threads.create() - self.id = self.thread.id - - if self.recipient_agent.examples: - for example in self.recipient_agent.examples: - self.client.beta.threads.messages.create( - thread_id=self.id, - **example, - ) - + return + print("creating thread") + self._thread = self.client.beta.threads.create() + self.id = self._thread.id + if self.recipient_agent.examples: + for example in self.recipient_agent.examples: + self.client.beta.threads.messages.create( + thread_id=self.id, + **example, + ) + print("thread created", self.id) def get_completion_stream(self, message: Union[str, List[dict], None], event_handler: type(AgencyEventHandler), @@ -89,6 +98,8 @@ def get_completion(self, yield_messages: bool = False, response_format: Optional[dict] = None ): + self.init_thread() + if not recipient_agent: recipient_agent = self.recipient_agent @@ -107,9 +118,6 @@ def get_completion(self, attachments.append({"file_id": file_id, "tools": recipient_tools or [{"type": "file_search"}]}) - if not self.thread: - self.init_thread() - if event_handler: event_handler.set_agent(self.agent) event_handler.set_recipient_agent(recipient_agent) @@ -138,8 +146,8 @@ def get_completion(self, self._run_until_done() # function execution - if self.run.status == "requires_action": - tool_calls = self.run.required_action.submit_tool_outputs.tool_calls + if self._run.status == "requires_action": + tool_calls = self._run.required_action.submit_tool_outputs.tool_calls tool_outputs_and_names = [] # list of tuples (name, tool_output) sync_tool_calls = [tool_call for tool_call in tool_calls if tool_call.function.name.startswith("SendMessage")] async_tool_calls = [tool_call for tool_call in tool_calls if not tool_call.function.name.startswith("SendMessage")] @@ -224,11 +232,11 @@ def handle_output(tool_call, output): self._create_run(recipient_agent, additional_instructions, event_handler, 'required', temperature=0) self._run_until_done() - if self.run.status != "requires_action": - raise Exception("Run Failed. Error: ", self.run.last_error or self.run.incomplete_details) + if self._run.status != "requires_action": + raise Exception("Run Failed. Error: ", self._run.last_error or self._run.incomplete_details) # change tool call ids - tool_calls = self.run.required_action.submit_tool_outputs.tool_calls + tool_calls = self._run.required_action.submit_tool_outputs.tool_calls if len(tool_calls) != len(tool_outputs): tool_outputs = [] @@ -246,10 +254,10 @@ def handle_output(tool_call, output): else: raise e # error - elif self.run.status == "failed": + elif self._run.status == "failed": full_message += self._get_last_message_text() common_errors = ["something went wrong", "the server had an error processing your request", "rate limit reached"] - error_message = self.run.last_error.message.lower() + error_message = self._run.last_error.message.lower() if error_attempts < 3 and any(error in error_message for error in common_errors): if error_attempts < 2: @@ -261,9 +269,9 @@ def handle_output(tool_call, output): tool_choice, response_format=response_format) error_attempts += 1 else: - raise Exception("OpenAI Run Failed. Error: ", self.run.last_error.message) - elif self.run.status == "incomplete": - raise Exception("OpenAI Run Incomplete. Details: ", self.run.incomplete_details) + raise Exception("OpenAI Run Failed. Error: ", self._run.last_error.message) + elif self._run.status == "incomplete": + raise Exception("OpenAI Run Incomplete. Details: ", self._run.incomplete_details) # return assistant message else: message_obj = self._get_last_assistant_message() @@ -318,7 +326,7 @@ def _create_run(self, recipient_agent, additional_instructions, event_handler, t try: if event_handler: with self.client.beta.threads.runs.stream( - thread_id=self.thread.id, + thread_id=self.id, event_handler=event_handler(), assistant_id=recipient_agent.id, additional_instructions=additional_instructions, @@ -331,10 +339,10 @@ def _create_run(self, recipient_agent, additional_instructions, event_handler, t response_format=response_format ) as stream: stream.until_done() - self.run = stream.get_final_run() + self._run = stream.get_final_run() else: - self.run = self.client.beta.threads.runs.create( - thread_id=self.thread.id, + self._run = self.client.beta.threads.runs.create( + thread_id=self.id, assistant_id=recipient_agent.id, additional_instructions=additional_instructions, tool_choice=tool_choice, @@ -345,68 +353,68 @@ def _create_run(self, recipient_agent, additional_instructions, event_handler, t parallel_tool_calls=recipient_agent.parallel_tool_calls, response_format=response_format ) - self.run = self.client.beta.threads.runs.poll( - thread_id=self.thread.id, - run_id=self.run.id, + self._run = self.client.beta.threads.runs.poll( + thread_id=self.id, + run_id=self._run.id, # poll_interval_ms=500, ) except APIError as e: match = re.search(r"Thread (\w+) already has an active run (\w+)", e.message) if match: self._cancel_run(thread_id=match.groups()[0], run_id=match.groups()[1], check_status=False) - elif "The server had an error processing your request" in e.message and self.num_run_retries < 3: + elif "The server had an error processing your request" in e.message and self._num_run_retries < 3: time.sleep(1) self._create_run(recipient_agent, additional_instructions, event_handler, tool_choice, response_format=response_format) - self.num_run_retries += 1 + self._num_run_retries += 1 else: raise e def _run_until_done(self): - while self.run.status in ['queued', 'in_progress', "cancelling"]: + while self._run.status in ['queued', 'in_progress', "cancelling"]: time.sleep(0.5) - self.run = self.client.beta.threads.runs.retrieve( - thread_id=self.thread.id, - run_id=self.run.id + self._run = self.client.beta.threads.runs.retrieve( + thread_id=self.id, + run_id=self._run.id ) def _submit_tool_outputs(self, tool_outputs, event_handler=None, poll=True): if not poll: - self.run = self.client.beta.threads.runs.submit_tool_outputs( - thread_id=self.thread.id, - run_id=self.run.id, + self._run = self.client.beta.threads.runs.submit_tool_outputs( + thread_id=self.id, + run_id=self._run.id, tool_outputs=tool_outputs ) else: if not event_handler: - self.run = self.client.beta.threads.runs.submit_tool_outputs_and_poll( - thread_id=self.thread.id, - run_id=self.run.id, + self._run = self.client.beta.threads.runs.submit_tool_outputs_and_poll( + thread_id=self.id, + run_id=self._run.id, tool_outputs=tool_outputs ) else: with self.client.beta.threads.runs.submit_tool_outputs_stream( - thread_id=self.thread.id, - run_id=self.run.id, + thread_id=self.id, + run_id=self._run.id, tool_outputs=tool_outputs, event_handler=event_handler() ) as stream: stream.until_done() - self.run = stream.get_final_run() + self._run = stream.get_final_run() def _cancel_run(self, thread_id=None, run_id=None, check_status=True): - if check_status and self.run.status in self.terminal_states and not run_id: + if check_status and self._run.status in self.terminal_states and not run_id: return try: - self.run = self.client.beta.threads.runs.cancel( - thread_id=self.thread.id, - run_id=self.run.id + self._run = self.client.beta.threads.runs.cancel( + thread_id=self.id, + run_id=self._run.id ) except BadRequestError as e: if "Cannot cancel run with status" in e.message: - self.run = self.client.beta.threads.runs.poll( - thread_id=thread_id or self.thread.id, - run_id=run_id or self.run.id, + self._run = self.client.beta.threads.runs.poll( + thread_id=thread_id or self.id, + run_id=run_id or self._run.id, poll_interval_ms=500, ) else: diff --git a/agency_swarm/threads/thread_async.py b/agency_swarm/threads/thread_async.py index a6ae8f72..6f1ebd70 100644 --- a/agency_swarm/threads/thread_async.py +++ b/agency_swarm/threads/thread_async.py @@ -90,11 +90,10 @@ def check_status(self, run=None): return f"""{self.recipient_agent.name}'s Response: '{messages.data[0].content[0].text.value}'""" def get_last_run(self): - if not self.thread: - self.init_thread() + self.init_thread() runs = self.client.beta.threads.runs.list( - thread_id=self.thread.id, + thread_id=self.id, order="desc", ) From bac5cf8757221fcc53b2d3649c3801bed7345ec3 Mon Sep 17 00:00:00 2001 From: Arsenii Shatokhin Date: Fri, 15 Nov 2024 11:03:44 +0400 Subject: [PATCH 3/9] Moved async_mode into ToolConfig --- agency_swarm/agency/agency.py | 13 +++- agency_swarm/threads/thread.py | 44 +++++++++++-- agency_swarm/tools/BaseTool.py | 6 +- .../tools/send_message/SendMessage.py | 12 ++-- .../send_message/SendMessageAsyncThreading.py | 7 ++- .../tools/send_message/SendMessageBase.py | 19 ++++-- .../tools/send_message/SendMessageSwarm.py | 4 +- tests/test_agency.py | 63 ++++++++++--------- tests/test_send_message.py | 8 +++ 9 files changed, 118 insertions(+), 58 deletions(-) diff --git a/agency_swarm/agency/agency.py b/agency_swarm/agency/agency.py index 3770d4aa..5da38f03 100644 --- a/agency_swarm/agency/agency.py +++ b/agency_swarm/agency/agency.py @@ -101,12 +101,19 @@ def __init__(self, self.max_completion_tokens = max_completion_tokens self.truncation_strategy = truncation_strategy + # set thread type based send_message_tool_class async mode + if hasattr(send_message_tool_class.ToolConfig, "async_mode") and send_message_tool_class.ToolConfig.async_mode: + self._thread_type = ThreadAsync + else: + self._thread_type = Thread + if self.async_mode == "threading": from agency_swarm.tools.send_message import SendMessageAsyncThreading print("Warning: 'threading' mode is deprecated. Please use send_message_tool_class = SendMessageAsyncThreading to use async communication.") self.send_message_tool_class = SendMessageAsyncThreading elif self.async_mode == "tools_threading": - Thread.async_mode = self.async_mode + Thread.async_mode = "tools_threading" + print("Warning: 'tools_threading' mode is deprecated. Use tool.ToolConfig.async_mode = 'threading' instead.") elif self.async_mode is None: pass else: @@ -887,7 +894,7 @@ def _init_threads(self): continue for other_agent, items in threads.items(): # create thread class - self.agents_and_threads[agent_name][other_agent] = self.send_message_tool_class._thread_type( + self.agents_and_threads[agent_name][other_agent] = self._thread_type( self._get_agent_by_name(items["agent"]), self._get_agent_by_name( items["recipient_agent"])) @@ -1034,7 +1041,7 @@ def _create_special_tools(self): continue agent = self._get_agent_by_name(agent_name) agent.add_tool(self._create_send_message_tool(agent, recipient_agents)) - if self.send_message_tool_class._thread_type == ThreadAsync: + if self._thread_type == ThreadAsync: agent.add_tool(self._create_get_response_tool(agent, recipient_agents)) def _create_send_message_tool(self, agent: Agent, recipient_agents: List[Agent]): diff --git a/agency_swarm/threads/thread.py b/agency_swarm/threads/thread.py index 6fb9c9dd..482a7cdf 100644 --- a/agency_swarm/threads/thread.py +++ b/agency_swarm/threads/thread.py @@ -57,7 +57,7 @@ def __init__(self, agent: Union[Agent, User], recipient_agent: Agent): def init_thread(self): if self.id: return - print("creating thread") + self._thread = self.client.beta.threads.create() self.id = self._thread.id if self.recipient_agent.examples: @@ -66,7 +66,7 @@ def init_thread(self): thread_id=self.id, **example, ) - print("thread created", self.id) + def get_completion_stream(self, message: Union[str, List[dict], None], event_handler: type(AgencyEventHandler), @@ -149,8 +149,7 @@ def get_completion(self, if self._run.status == "requires_action": tool_calls = self._run.required_action.submit_tool_outputs.tool_calls tool_outputs_and_names = [] # list of tuples (name, tool_output) - sync_tool_calls = [tool_call for tool_call in tool_calls if tool_call.function.name.startswith("SendMessage")] - async_tool_calls = [tool_call for tool_call in tool_calls if not tool_call.function.name.startswith("SendMessage")] + sync_tool_calls, async_tool_calls = self._get_sync_async_tool_calls(tool_calls, recipient_agent) def handle_output(tool_call, output): if inspect.isgenerator(output): @@ -207,7 +206,7 @@ def handle_output(tool_call, output): tool_names = [name for name, _ in tool_outputs_and_names] # await coroutines - tool_outputs = self._execute_async_tool_calls_outputs(tool_outputs) + tool_outputs = self._await_coroutines(tool_outputs) # convert all tool outputs to strings for tool_output in tool_outputs: @@ -508,7 +507,7 @@ def execute_tool(self, tool_call, recipient_agent=None, event_handler=None, tool error_message = error_message.split("For further information visit")[0] return error_message, False - def _execute_async_tool_calls_outputs(self, tool_outputs): + def _await_coroutines(self, tool_outputs): async_tool_calls = [] for tool_output in tool_outputs: if inspect.iscoroutine(tool_output["output"]): @@ -530,6 +529,39 @@ def _execute_async_tool_calls_outputs(self, tool_outputs): tool_output["output"] = str(result) return tool_outputs + + def _get_sync_async_tool_calls(self, tool_calls, recipient_agent): + async_tool_calls = [] + sync_tool_calls = [] + for tool_call in tool_calls: + if tool_call.function.name.startswith("SendMessage"): + sync_tool_calls.append(tool_call) + continue + + tool = next((func for func in recipient_agent.functions if func.__name__ == tool_call.function.name), None) + + if (hasattr(tool.ToolConfig, "async_mode") and tool.ToolConfig.async_mode) or self.async_mode == "tools_threading": + async_tool_calls.append(tool_call) + else: + sync_tool_calls.append(tool_call) + + return sync_tool_calls, async_tool_calls + + def get_messages(self, limit=None): + all_messages = [] + after = None + while True: + response = self.client.beta.threads.messages.list(thread_id=self.id, limit=100, after=after) + messages = response.data + if not messages: + break + all_messages.extend(messages) + after = messages[-1].id # Set the 'after' cursor to the ID of the last message + + if limit and len(all_messages) >= limit: + break + + return all_messages diff --git a/agency_swarm/tools/BaseTool.py b/agency_swarm/tools/BaseTool.py index 7d9ffacc..51cdaeb3 100644 --- a/agency_swarm/tools/BaseTool.py +++ b/agency_swarm/tools/BaseTool.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, ClassVar +from typing import Any, ClassVar, Literal, Union from docstring_parser import parse @@ -22,7 +22,8 @@ def __init__(self, **kwargs): config_defaults = { 'strict': False, 'one_call_at_a_time': False, - 'output_as_result': False + 'output_as_result': False, + 'async_mode': None } for key, value in config_defaults.items(): @@ -34,6 +35,7 @@ class ToolConfig: one_call_at_a_time: bool = False # return the tool output as assistant message output_as_result: bool = False + async_mode: Union[Literal["threading"], None] = None @classmethod @property diff --git a/agency_swarm/tools/send_message/SendMessage.py b/agency_swarm/tools/send_message/SendMessage.py index 758b38ce..417422c3 100644 --- a/agency_swarm/tools/send_message/SendMessage.py +++ b/agency_swarm/tools/send_message/SendMessage.py @@ -4,11 +4,7 @@ from .SendMessageBase import SendMessageBase class SendMessage(SendMessageBase): - """Use this tool to facilitate direct, synchronous communication between specialized agents within your agency. When you send a message using this tool, you receive a response exclusively from the designated recipient agent. To continue the dialogue, invoke this tool again with the desired recipient agent and your follow-up message. Remember, communication here is synchronous; the recipient agent won't perform any tasks post-response. You are responsible for relaying the recipient agent's responses back to the user, as the user does not have direct access to these replies. Keep engaging with the tool for continuous interaction until the task is fully resolved. Do not send more than 1 message at a time.""" - message: str = Field( - ..., - description="Specify the task required for the recipient agent to complete. Focus on clarifying what the task entails, rather than providing exact instructions." - ) + """Use this tool to facilitate direct, synchronous communication between specialized agents within your agency. When you send a message using this tool, you receive a response exclusively from the designated recipient agent. To continue the dialogue, invoke this tool again with the desired recipient agent and your follow-up message. Remember, communication here is synchronous; the recipient agent won't perform any tasks post-response. You are responsible for relaying the recipient agent's responses back to the user, as the user does not have direct access to these replies. Keep engaging with the tool for continuous interaction until the task is fully resolved. Do not send more than 1 message to the same recipient agent at the same time.""" my_primary_instructions: str = Field( ..., description=( @@ -21,6 +17,10 @@ class SendMessage(SendMessageBase): "in the message or additional_instructions parameters." ) ) + message: str = Field( + ..., + description="Specify the task required for the recipient agent to complete. Focus on clarifying what the task entails, rather than providing exact instructions. Make sure to inlcude all the relevant information needed to complete the task." + ) message_files: Optional[List[str]] = Field( default=None, description="A list of file IDs to be sent as attachments to this message. Only use this if you have the file ID that starts with 'file-'.", @@ -48,7 +48,7 @@ def validate_additional_instructions(cls, value): def run(self): - thread: Thread = self._agents_and_threads[self._caller_agent.name][self.recipient.value] + thread = self._get_thread() message = thread.get_completion(message=self.message, message_files=self.message_files, diff --git a/agency_swarm/tools/send_message/SendMessageAsyncThreading.py b/agency_swarm/tools/send_message/SendMessageAsyncThreading.py index ffdfee49..1ef9b3a7 100644 --- a/agency_swarm/tools/send_message/SendMessageAsyncThreading.py +++ b/agency_swarm/tools/send_message/SendMessageAsyncThreading.py @@ -4,10 +4,11 @@ class SendMessageAsyncThreading(SendMessage): """Use this tool for asynchronous communication with other agents within your agency. Initiate tasks by messaging, and check status and responses later with the 'GetResponse' tool. Relay responses to the user, who instructs on status checks. Continue until task completion.""" - _thread_type: ClassVar[Type[ThreadAsync]] = ThreadAsync - + class ToolConfig: + async_mode = "threading" + def run(self): - thread: ThreadAsync = self._agents_and_threads[self._caller_agent.name][self.recipient.value] + thread = self._get_thread() message = thread.get_completion_async(message=self.message, message_files=self.message_files, diff --git a/agency_swarm/tools/send_message/SendMessageBase.py b/agency_swarm/tools/send_message/SendMessageBase.py index 37ffe3db..f9b9279c 100644 --- a/agency_swarm/tools/send_message/SendMessageBase.py +++ b/agency_swarm/tools/send_message/SendMessageBase.py @@ -1,6 +1,8 @@ +from agency_swarm.agents.agent import Agent from agency_swarm.threads.thread import Thread -from typing import ClassVar, Optional, List, Type -from pydantic import Field, field_validator, model_validator +from typing import ClassVar +from pydantic import Field +from agency_swarm.threads.thread_async import ThreadAsync from agency_swarm.tools import BaseTool from abc import ABC @@ -8,12 +10,19 @@ class SendMessageBase(BaseTool, ABC): """Use this tool to facilitate direct, synchronous communication between specialized agents within your agency. When you send a message using this tool, you receive a response exclusively from the designated recipient agent. To continue the dialogue, invoke this tool again with the desired recipient agent and your follow-up message. Remember, communication here is synchronous; the recipient agent won't perform any tasks post-response. You are responsible for relaying the recipient agent's responses back to the user, as the user does not have direct access to these replies. Keep engaging with the tool for continuous interaction until the task is fully resolved. Do not send more than 1 message at a time.""" recipient: str = Field(..., description="Recipient agent that you want to send the message to. This field will be overriden inside the agency class.") - _agents_and_threads: ClassVar = None - _thread_type: ClassVar[Type[Thread]] = Thread # thread type assigned by the agency class @classmethod def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) if not cls.__name__.startswith("SendMessage"): - raise TypeError(f"Class name '{cls.__name__}' must start with 'SendMessage'.") \ No newline at end of file + raise TypeError(f"Class name '{cls.__name__}' must start with 'SendMessage'.") + + def _get_thread(self) -> Thread | ThreadAsync: + return self._agents_and_threads[self._caller_agent.name][self.recipient.value] + + def _get_main_thread(self) -> Thread | ThreadAsync: + return self._agents_and_threads["main_thread"] + + def _get_recipient_agent(self) -> Agent: + return self._agents_and_threads[self._caller_agent.name][self.recipient.value].recipient_agent \ No newline at end of file diff --git a/agency_swarm/tools/send_message/SendMessageSwarm.py b/agency_swarm/tools/send_message/SendMessageSwarm.py index 0557e681..143e5bfe 100644 --- a/agency_swarm/tools/send_message/SendMessageSwarm.py +++ b/agency_swarm/tools/send_message/SendMessageSwarm.py @@ -11,10 +11,10 @@ class ToolConfig: def run(self): # get main thread - thread: Thread = self._agents_and_threads["main_thread"] + thread = self._get_main_thread() # get recipient agent from thread - recipient_agent = self._agents_and_threads[self._caller_agent.name][self.recipient.value].recipient_agent + recipient_agent = self._get_recipient_agent() # submit tool output try: diff --git a/tests/test_agency.py b/tests/test_agency.py index 0c258527..88177118 100644 --- a/tests/test_agency.py +++ b/tests/test_agency.py @@ -13,6 +13,7 @@ from agency_swarm.tools import CodeInterpreter, FileSearch sys.path.insert(0, '../agency-swarm') +from agency_swarm.tools.send_message import SendMessageAsyncThreading from agency_swarm.util import create_agent_template from agency_swarm import set_openai_key, Agent, Agency, AgencyEventHandler, get_openai_client @@ -221,11 +222,9 @@ def test_4_agent_communication(self): self.assertFalse('error' in message.lower(), f"Error found in message: {message}. Thread url: {self.__class__.agency.main_thread.thread_url}") - for agent_name, threads in self.__class__.agency.agents_and_threads.items(): - if agent_name == "main_thread": - continue - for other_agent_name, thread in threads.items(): - self.assertTrue(thread.id in self.__class__.loaded_thread_ids[agent_name][other_agent_name]) + self.assertTrue(self.__class__.agency.agents_and_threads['main_thread'].id) + self.assertTrue(self.__class__.agency.agents_and_threads['CEO']['TestAgent1'].id) + self.assertTrue(self.__class__.agency.agents_and_threads['TestAgent1']['TestAgent2'].id) for agent in self.__class__.agency.agents: self.assertTrue(agent.id in [settings['id'] for settings in self.__class__.loaded_agents_settings]) @@ -240,7 +239,7 @@ def test_4_agent_communication(self): self.assertTrue(thread_messages.data[0].content[0].text.value == "Hi!") - run = main_thread.run + run = main_thread._run self.assertTrue(run.max_prompt_tokens == self.__class__.ceo.max_prompt_tokens) self.assertTrue(run.max_completion_tokens == self.__class__.ceo.max_completion_tokens) self.assertTrue(run.tool_choice.type == "function") @@ -253,7 +252,7 @@ def test_4_agent_communication(self): self.assertTrue(len(agent1_thread_messages.data) == 2) - agent1_run = agent1_thread.run + agent1_run = agent1_thread._run self.assertTrue(agent1_run.truncation_strategy.type == "last_messages") self.assertTrue(agent1_run.truncation_strategy.last_messages == 10) @@ -311,13 +310,11 @@ def on_all_streams_end(cls): self.assertTrue(self.__class__.TestTool._shared_state.get("test_tool_used")) agent1_thread = self.__class__.agency.agents_and_threads[self.__class__.ceo.name][self.__class__.agent1.name] - self.assertFalse(agent1_thread.run.parallel_tool_calls) + self.assertFalse(agent1_thread._run.parallel_tool_calls) - for agent_name, threads in self.__class__.agency.agents_and_threads.items(): - if agent_name == "main_thread": - continue - for other_agent_name, thread in threads.items(): - self.assertTrue(thread.id in self.__class__.loaded_thread_ids[agent_name][other_agent_name]) + self.assertTrue(self.__class__.agency.main_thread.id) + self.assertTrue(self.__class__.agency.agents_and_threads['CEO']['TestAgent1'].id) + self.assertTrue(self.__class__.agency.agents_and_threads['TestAgent1']['TestAgent2'].id) for agent in self.__class__.agency.agents: self.assertTrue(agent.id in [settings['id'] for settings in self.__class__.loaded_agents_settings]) @@ -326,8 +323,8 @@ def test_6_load_from_db(self): """it should load agents from db""" # os.rename("settings.json", "settings2.json") - previous_loaded_thread_ids = self.__class__.loaded_thread_ids - previous_loaded_agents_settings = self.__class__.loaded_agents_settings + previous_loaded_thread_ids = self.__class__.loaded_thread_ids.copy() + previous_loaded_agents_settings = self.__class__.loaded_agents_settings.copy() from test_agents.CEO import CEO from test_agents.TestAgent1 import TestAgent1 @@ -372,12 +369,18 @@ def test_6_load_from_db(self): self.check_all_agents_settings() # check that threads are the same - for agent_name, threads in agency.agents_and_threads.items(): - if agent_name == "main_thread": + print("previous_loaded_thread_ids", previous_loaded_thread_ids) + print("self.__class__.loaded_thread_ids", self.__class__.loaded_thread_ids) + # Start of Selection + for agent, threads in self.__class__.agency.agents_and_threads.items(): + if agent == "main_thread": + print("main_thread", threads) continue - for other_agent_name, thread in threads.items(): - self.assertTrue(thread.id in self.__class__.loaded_thread_ids[agent_name][other_agent_name]) - self.assertTrue(thread.id in previous_loaded_thread_ids[agent_name][other_agent_name]) + for other_agent, thread in threads.items(): + print(f"Thread ID between {agent} and {other_agent}: {thread.id}") + self.assertTrue(self.__class__.agency.agents_and_threads['main_thread'].id == previous_loaded_thread_ids['main_thread'] == self.__class__.loaded_thread_ids['main_thread']) + self.assertTrue(self.__class__.agency.agents_and_threads['CEO']['TestAgent1'].id == previous_loaded_thread_ids['CEO']['TestAgent1'] == self.__class__.loaded_thread_ids['CEO']['TestAgent1']) + self.assertTrue(self.__class__.agency.agents_and_threads['TestAgent1']['TestAgent2'].id == previous_loaded_thread_ids['TestAgent1']['TestAgent2'] == self.__class__.loaded_thread_ids['TestAgent1']['TestAgent2']) # check that agents are the same for agent in agency.agents: @@ -385,7 +388,7 @@ def test_6_load_from_db(self): self.assertTrue(agent.id in [settings['id'] for settings in previous_loaded_agents_settings]) def test_7_init_async_agency(self): - """it should initialize agency with agents""" + """it should initialize async agency with agents""" # reset loaded thread ids self.__class__.loaded_thread_ids = {} @@ -403,7 +406,7 @@ def test_7_init_async_agency(self): shared_instructions="", settings_callbacks=self.__class__.settings_callbacks, threads_callbacks=self.__class__.threads_callbacks, - async_mode='threading', + send_message_tool_class=SendMessageAsyncThreading, temperature=0, ) @@ -411,7 +414,6 @@ def test_7_init_async_agency(self): def test_8_async_agent_communication(self): """it should communicate between agents asynchronously""" - print("TestAgent1 tools", self.__class__.agent1.tools) self.__class__.agency.get_completion("Please tell TestAgent2 hello.", tool_choice={"type": "function", "function": {"name": "SendMessage"}}, recipient_agent=self.__class__.agent1) @@ -455,11 +457,8 @@ def on_all_streams_end(cls): if 'error' in message.lower(): self.assertFalse('error' in message.lower(), self.__class__.agency.main_thread.thread_url) - for agent_name, threads in self.__class__.agency.agents_and_threads.items(): - if agent_name == "main_thread": - continue - for other_agent_name, thread in threads.items(): - self.assertTrue(thread.id in self.__class__.loaded_thread_ids[agent_name][other_agent_name]) + self.assertTrue(self.__class__.agency.main_thread.id) + self.assertTrue(self.__class__.agency.agents_and_threads['TestAgent1']['TestAgent2'].id) for agent in self.__class__.agency.agents: self.assertTrue(agent.id in [settings['id'] for settings in self.__class__.loaded_agents_settings]) @@ -467,11 +466,16 @@ def on_all_streams_end(cls): def test_9_async_tool_calls(self): """it should execute tools asynchronously""" class PrintTool(BaseTool): + class ToolConfig: + async_mode = "threading" def run(self, **kwargs): time.sleep(2) # Simulate a delay return "Printed successfully." class AnotherPrintTool(BaseTool): + class ToolConfig: + async_mode = "threading" + def run(self, **kwargs): time.sleep(2) # Simulate a delay return "Another print successful." @@ -480,12 +484,9 @@ def run(self, **kwargs): agency = Agency( [ceo], - async_mode='tools_threading', temperature=0 ) - self.assertTrue(agency.main_thread.async_mode == 'tools_threading') - result = agency.get_completion("Use 2 print tools together at the same time and output the results exectly as they are. ", yield_messages=False) self.assertIn("success", result.lower(), agency.main_thread.thread_url) diff --git a/tests/test_send_message.py b/tests/test_send_message.py index 9123d938..4ebc3a2c 100644 --- a/tests/test_send_message.py +++ b/tests/test_send_message.py @@ -38,5 +38,13 @@ def test_send_message_swarm(self): response = self.agency.get_completion("Who are you?") self.assertTrue("customer support" in response.lower()) + main_thread = self.agency.main_thread + + # check if recipient agent is correct + self.assertEqual(main_thread.recipient_agent, self.customer_support) + + #check if all messages in the same thread (this is how Swarm works) + self.assertTrue(len(main_thread.get_messages()) >= 4) # sometimes run does not cancel immediately, so there might be 5 messages + if __name__ == '__main__': unittest.main() From bcf1e1ca54af06d25afcea24b867debb84573371 Mon Sep 17 00:00:00 2001 From: Arsenii Shatokhin Date: Mon, 18 Nov 2024 10:50:33 +0400 Subject: [PATCH 4/9] Added docs for communication flows --- agency_swarm/agency/agency.py | 1 - .../tools/send_message/SendMessage.py | 30 +-- .../send_message/SendMessageAsyncThreading.py | 11 +- .../tools/send_message/SendMessageBase.py | 31 ++- .../tools/send_message/SendMessageQuick.py | 14 ++ .../tools/send_message/SendMessageSwarm.py | 1 + agency_swarm/tools/send_message/__init__.py | 3 +- docs/advanced-usage/communication_flows.md | 216 ++++++++++++++++++ mkdocs.yml | 11 +- 9 files changed, 269 insertions(+), 49 deletions(-) create mode 100644 agency_swarm/tools/send_message/SendMessageQuick.py create mode 100644 docs/advanced-usage/communication_flows.md diff --git a/agency_swarm/agency/agency.py b/agency_swarm/agency/agency.py index 5da38f03..69c38880 100644 --- a/agency_swarm/agency/agency.py +++ b/agency_swarm/agency/agency.py @@ -1078,7 +1078,6 @@ def check_recipient(cls, value): SendMessage._caller_agent = agent SendMessage._agents_and_threads = self.agents_and_threads - SendMessage._agency = self return SendMessage diff --git a/agency_swarm/tools/send_message/SendMessage.py b/agency_swarm/tools/send_message/SendMessage.py index 417422c3..7a1af529 100644 --- a/agency_swarm/tools/send_message/SendMessage.py +++ b/agency_swarm/tools/send_message/SendMessage.py @@ -1,5 +1,4 @@ -from agency_swarm.threads.thread import Thread -from typing import ClassVar, Optional, List, Type +from typing import Optional, List from pydantic import Field, field_validator, model_validator from .SendMessageBase import SendMessageBase @@ -14,12 +13,12 @@ class SendMessage(SendMessageBase): "recipient agent via the message parameter. Each identified step should be " "sent in a separate message. Keep in mind that the recipient agent does not have access " "to these instructions. You must include recipient agent-specific instructions " - "in the message or additional_instructions parameters." + "in the message or in the additional_instructions parameters." ) ) message: str = Field( ..., - description="Specify the task required for the recipient agent to complete. Focus on clarifying what the task entails, rather than providing exact instructions. Make sure to inlcude all the relevant information needed to complete the task." + description="Specify the task required for the recipient agent to complete. Focus on clarifying what the task entails, rather than providing exact instructions. Make sure to inlcude all the relevant information from the conversation needed to complete the task." ) message_files: Optional[List[str]] = Field( default=None, @@ -31,30 +30,15 @@ class SendMessage(SendMessageBase): description="Additional context or instructions from the conversation needed by the recipient agent to complete the task." ) - @model_validator(mode='after') def validate_files(self): + # prevent hallucinations with file IDs if the necessary parameters are provided if hasattr(self, 'message') and "file-" in self.message or (self.additional_instructions and "file-" in self.additional_instructions): if not self.message_files: raise ValueError("You must include file IDs in message_files parameter.") return self - - @field_validator('additional_instructions', mode='before') - @classmethod - def validate_additional_instructions(cls, value): - if isinstance(value, list): - return "\n".join(value) - return value - def run(self): - thread = self._get_thread() - - message = thread.get_completion(message=self.message, - message_files=self.message_files, - event_handler=self._event_handler, - yield_messages=not self._event_handler, - additional_instructions=self.additional_instructions, - ) - - return message or "" \ No newline at end of file + return self._get_completion(message=self.message, + message_files=self.message_files, + additional_instructions=self.additional_instructions) \ No newline at end of file diff --git a/agency_swarm/tools/send_message/SendMessageAsyncThreading.py b/agency_swarm/tools/send_message/SendMessageAsyncThreading.py index 1ef9b3a7..38912e09 100644 --- a/agency_swarm/tools/send_message/SendMessageAsyncThreading.py +++ b/agency_swarm/tools/send_message/SendMessageAsyncThreading.py @@ -5,13 +5,4 @@ class SendMessageAsyncThreading(SendMessage): """Use this tool for asynchronous communication with other agents within your agency. Initiate tasks by messaging, and check status and responses later with the 'GetResponse' tool. Relay responses to the user, who instructs on status checks. Continue until task completion.""" class ToolConfig: - async_mode = "threading" - - def run(self): - thread = self._get_thread() - - message = thread.get_completion_async(message=self.message, - message_files=self.message_files, - additional_instructions=self.additional_instructions) - - return message or "" \ No newline at end of file + async_mode = "threading" \ No newline at end of file diff --git a/agency_swarm/tools/send_message/SendMessageBase.py b/agency_swarm/tools/send_message/SendMessageBase.py index f9b9279c..26b61128 100644 --- a/agency_swarm/tools/send_message/SendMessageBase.py +++ b/agency_swarm/tools/send_message/SendMessageBase.py @@ -1,22 +1,24 @@ from agency_swarm.agents.agent import Agent from agency_swarm.threads.thread import Thread -from typing import ClassVar -from pydantic import Field +from typing import ClassVar, Union +from pydantic import Field, field_validator from agency_swarm.threads.thread_async import ThreadAsync from agency_swarm.tools import BaseTool from abc import ABC class SendMessageBase(BaseTool, ABC): - """Use this tool to facilitate direct, synchronous communication between specialized agents within your agency. When you send a message using this tool, you receive a response exclusively from the designated recipient agent. To continue the dialogue, invoke this tool again with the desired recipient agent and your follow-up message. Remember, communication here is synchronous; the recipient agent won't perform any tasks post-response. You are responsible for relaying the recipient agent's responses back to the user, as the user does not have direct access to these replies. Keep engaging with the tool for continuous interaction until the task is fully resolved. Do not send more than 1 message at a time.""" - recipient: str = Field(..., description="Recipient agent that you want to send the message to. This field will be overriden inside the agency class.") + _agents_and_threads: ClassVar = None + @field_validator('additional_instructions', mode='before', check_fields=False) @classmethod - def __init_subclass__(cls, **kwargs): - super().__init_subclass__(**kwargs) - if not cls.__name__.startswith("SendMessage"): - raise TypeError(f"Class name '{cls.__name__}' must start with 'SendMessage'.") + def validate_additional_instructions(cls, value): + # previously the parameter was a list, now it's a string + # add compatibility for old code + if isinstance(value, list): + return "\n".join(value) + return value def _get_thread(self) -> Thread | ThreadAsync: return self._agents_and_threads[self._caller_agent.name][self.recipient.value] @@ -25,4 +27,15 @@ def _get_main_thread(self) -> Thread | ThreadAsync: return self._agents_and_threads["main_thread"] def _get_recipient_agent(self) -> Agent: - return self._agents_and_threads[self._caller_agent.name][self.recipient.value].recipient_agent \ No newline at end of file + return self._agents_and_threads[self._caller_agent.name][self.recipient.value].recipient_agent + + def _get_completion(self, message: Union[str, None] = None, **kwargs): + thread = self._get_thread() + + if self.ToolConfig.async_mode == "threading": + return thread.get_completion_async(message=message, **kwargs) + else: + return thread.get_completion(message=message, + event_handler=self._event_handler, + yield_messages=not self._event_handler, + **kwargs) \ No newline at end of file diff --git a/agency_swarm/tools/send_message/SendMessageQuick.py b/agency_swarm/tools/send_message/SendMessageQuick.py new file mode 100644 index 00000000..f673f909 --- /dev/null +++ b/agency_swarm/tools/send_message/SendMessageQuick.py @@ -0,0 +1,14 @@ +from agency_swarm.threads.thread import Thread +from typing import ClassVar, Optional, List, Type +from pydantic import Field, field_validator, model_validator +from .SendMessageBase import SendMessageBase + +class SendMessageQuick(SendMessageBase): + """Use this tool to facilitate direct, synchronous communication between specialized agents within your agency. When you send a message using this tool, you receive a response exclusively from the designated recipient agent. To continue the dialogue, invoke this tool again with the desired recipient agent and your follow-up message. Remember, communication here is synchronous; the recipient agent won't perform any tasks post-response. You are responsible for relaying the recipient agent's responses back to the user, as the user does not have direct access to these replies. Keep engaging with the tool for continuous interaction until the task is fully resolved. Do not send more than 1 message to the same recipient agent at the same time.""" + message: str = Field( + ..., + description="Specify the task required for the recipient agent to complete. Focus on clarifying what the task entails, rather than providing exact instructions. Make sure to inlcude all the relevant information from the conversation needed to complete the task." + ) + + def run(self): + return self._get_completion(message=self.message) \ No newline at end of file diff --git a/agency_swarm/tools/send_message/SendMessageSwarm.py b/agency_swarm/tools/send_message/SendMessageSwarm.py index 143e5bfe..ec30288a 100644 --- a/agency_swarm/tools/send_message/SendMessageSwarm.py +++ b/agency_swarm/tools/send_message/SendMessageSwarm.py @@ -6,6 +6,7 @@ class SendMessageSwarm(SendMessageBase): """Use this tool to route messages to other agents within your agency. After using this tool, you will be switched to the recipient agent. This tool can only be used once per message. Do not use any other tools together with this tool.""" class ToolConfig: + # set output as result because the communication will be finished after this tool is called output_as_result: bool = True one_call_at_a_time: bool = True diff --git a/agency_swarm/tools/send_message/__init__.py b/agency_swarm/tools/send_message/__init__.py index 180265b6..d044d285 100644 --- a/agency_swarm/tools/send_message/__init__.py +++ b/agency_swarm/tools/send_message/__init__.py @@ -1,4 +1,5 @@ from .SendMessageAsyncThreading import SendMessageAsyncThreading from .SendMessageBase import SendMessageBase from .SendMessage import SendMessage -from .SendMessageSwarm import SendMessageSwarm \ No newline at end of file +from .SendMessageSwarm import SendMessageSwarm +from .SendMessageQuick import SendMessageQuick \ No newline at end of file diff --git a/docs/advanced-usage/communication_flows.md b/docs/advanced-usage/communication_flows.md new file mode 100644 index 00000000..02f179f7 --- /dev/null +++ b/docs/advanced-usage/communication_flows.md @@ -0,0 +1,216 @@ +# Advanced Communication Flows + +Multi-agent communication is the core functionality of any Multi-Agent System. Unlike in all other frameworks, Agency Swarm not only allows you to define communication flows in any way you want (uniform communication flows), but to also configure the underlying logic for this feature. This means that you can create entirely new types of communication, or adjust it to your own needs. Below you will find a guide on how to do all this, along with some common examples. + +**To use your own `SendMessage` calss**, simply put it in the `send_message_tool_class` parameter when initializing the `Agency` class: + +```python +from agency_swarm.tools.send_message import SendMessageQuick + +agency = Agency( + ... + send_message_tool_class=SendMessageQuick +) +``` + +That's it! Now, your agents will use your own custom `SendMessageQuick` class for communication. + +## Pre-Made SendMessage Classes + +Agency Swarm contains multiple commonly requested classes for communication flows. Currently, the following classes are available: + +| Class Name | Description | When to Use | Code Link | +| --------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------- | --------- | +| `SendMessage` (default) | This is the default class for sending messages to other agents. It uses synchronous communication with basic COT (Chain of Thought) prompting and allows agents to relay files and modify system instructions for each other. | Suitable for most use cases. Balances speed and functionality. | | +| `SendMessageQuick` | A variant of the SendMessage class without Chain of Thought prompting, files, and additional instructions. It allows for faster communication without the overhead of COT. | Use for simpler use cases or when you want to save tokens and increase speed. | | +| `SendMessageAsyncThreading` | Similar to `SendMessage` but with `async_mode='threading'`. Each agent will execute asynchronously in a separate thread. In the meantime, the caller agent can continue the conversation with the user and check the results later. | Use for asynchronous applications or when sub-agents take singificant amounts of time to complete their tasks. | | +| `SendMessageSwarm` | Instead of sending a message to another agent, it replaces the caller agent with the recipient agent, similar to [OpenAI's Swarm](https://github.com/openai/swarm). The recipient agent will then have access to the entire conversation. | When you need more granular control. It is not able to handle complex multi-step, multi-agent tasks. | | + +## Creating Your Own Unique Communication Flows + +To create you own communication flow, you will first need to extend the `SendMessageBase` class. This class extends the `BaseTool` class, like any other tools in Agency Swarm, and contains the most basic parameters required for communication, such as the `recipient_agent`. + +### Default `SendMessage` Class + +By defualt, Agency Swarm uses the following tool for communication: + +```python +from typing import Optional, List +from pydantic import Field, field_validator, model_validator +from .SendMessageBase import SendMessageBase + +class SendMessage(SendMessageBase): + """Use this tool to facilitate direct, synchronous communication between specialized agents within your agency. When you send a message using this tool, you receive a response exclusively from the designated recipient agent. To continue the dialogue, invoke this tool again with the desired recipient agent and your follow-up message. Remember, communication here is synchronous; the recipient agent won't perform any tasks post-response. You are responsible for relaying the recipient agent's responses back to the user, as the user does not have direct access to these replies. Keep engaging with the tool for continuous interaction until the task is fully resolved. Do not send more than 1 message to the same recipient agent at the same time.""" + my_primary_instructions: str = Field( + ..., + description=( + "Please repeat your primary instructions step-by-step, including both completed " + "and the following next steps that you need to perform. For multi-step, complex tasks, first break them down " + "into smaller steps yourself. Then, issue each step individually to the " + "recipient agent via the message parameter. Each identified step should be " + "sent in a separate message. Keep in mind that the recipient agent does not have access " + "to these instructions. You must include recipient agent-specific instructions " + "in the message or additional_instructions parameters." + ) + ) + message: str = Field( + ..., + description="Specify the task required for the recipient agent to complete. Focus on clarifying what the task entails, rather than providing exact instructions. Make sure to inlcude all the relevant information needed to complete the task." + ) + message_files: Optional[List[str]] = Field( + default=None, + description="A list of file IDs to be sent as attachments to this message. Only use this if you have the file ID that starts with 'file-'.", + examples=["file-1234", "file-5678"] + ) + additional_instructions: Optional[str] = Field( + default=None, + description="Additional context or instructions from the conversation needed by the recipient agent to complete the task." + ) + + @field_validator('additional_instructions', mode='before') + @classmethod + def validate_additional_instructions(cls, value): + # previously the parameter was a list, now it's a string + # add compatibility for old code + if isinstance(value, list): + return "\n".join(value) + return value + + + def run(self): + return self._get_completion(message=self.message, + message_files=self.message_files, + additional_instructions=self.additional_instructions) +``` + +Let's break down the code. + +In general, all `SendMessage` tools have the following components: + +1. **The Docstring**: This is used to generate a description of the tool for the agent. This part should clearly describe how your multi-agent communication works, along with some additional guidelines on how to use it. +2. **Parameters**: Parameters like `message`, `message_files`, `additional_instructions` are used to provide the recipient agent with the necessary information. +3. **The `run` method**: This is where the communication logic is implemented. Most of the time, you just need to map your parameters to `self._get_completion()` the same way you would call it in the `agency.get_completion()` method. + +When creating your own `SendMessage` tools, you can use the above components as a template. + +### Common Use Cases + +In the following sections, we'll look at some common use cases for extending the `SendMessageBase` tool and how to implement them, so you can learn how to create your own SendMessage tools and use them in your own applications. + +#### 1. Adjusting parameters and descriptions + +The most basic use case is if you want to have your own parameter descriptions, such as you want to change the docstring or the description of the `message` parameter. This can help you better customize how the agents communicate with each other and what information they relay. + +Let's say that instead of sending messages, I want my agents to send tasks to each other. In this case, I can change the docstring and the `message` parameter description to better fit the task-oriented nature of my application. + +```python +from pydantic import Field +from agency_swarm.tools.send_message import SendMessageBase + +class SendMessageTask(SendMessageBase): + """Use this tool to send tasks to other agents within your agency.""" + chain_of_thought: str = Field( + ..., + description="Please think step-by-step about how to solve your current task, provided by the user. Then, break down this task into smaller steps and issue each step individually to the recipient agent via the task parameter." + ) + task: str = Field( + ..., + description="Specify the task required for the recipient agent to complete. Focus on clarifying what the task entails, rather than providing exact instructions. Make sure to inlcude all the relevant information needed to complete the task." + ) + + def run(self): + return self._get_completion(message=self.task) +``` + +To remove the chain of thought, you can simply remove the `chain_of_thought` parameter. + +#### 2. Adding custom validation logic + +Now, let's say that I need to ensure that my message is sent to the correct recepient agent. (This is a very common hallucination in production.) In this case, I can add custom validator to the `recipient` parameter, which is defined in the `SendMessageBase` class. Since I don't want to change any other logic, I can inherit the `SendMessage` class and only add this new validation logic. + +```python +from agency_swarm.tools.send_message import SendMessage +from pydantic import model_validator + +class SendMessageValidation(SendMessage): + @model_validator(mode='after') + def validate_recipient(self): + if "customer support" not in self.message.lower() and self.recipient == "CustomerSupportAgent": + raise ValueError("Messages not related to customer support cannot be sent to the customer support agent.") + return self +``` + +You can, of course, also use GPT for this: + +```python +from agency_swarm.tools.send_message import SendMessage +from agency_swarm.util.validators import llm_validator + +class SendMessageLLMValidation(SendMessage): + @model_validator(mode='after') + def validate_recipient(self): + if self.recipient == "CustomerSupportAgent": + llm_validator( + statement="The message is related to customer support." + )(self.message) + return self +``` + +In this example, the `llm_validator` will throw an error if the message is not related to customer support. The caller agent will then have to fix the recipient or the message and send it again! + +#### 3. Summurizing previous conversations with other agents and adding to context + +Sometimes, when using default `SendMessage`, the agents might not relay all the neceessary details to the recipient agent. Especially, when the previous conversation is long. In this case, you can summarize the previous conversation with GPT and add it to the context, instead of the additional instructions. I will extend the `SendMessageQuick` class, which already contains the `message` parameter. + +```python +from agency_swarm.tools.send_message import SendMessageQuick +from agency_swarm.util.oai import get_openai_client + +class SendMessageSummary(SendMessageQuick): + def run(self): + client = get_openai_client() + thread = self._get_main_thread() # get the main thread (conversation with the user) + + # get the previous messages + previous_messages = thread.get_messages() + previous_messages_str = "\n".join([f"{m.role}: {m.content[0].text.value}" for m in previous_messages]) + + # summarize the previous conversation + summary = client.chat.completions.create( + model="gpt-4o-mini", + messages=[ + {"role": "system", "content": "You are a world-class summarizer. Please summarize the following conversation in a few sentences:"}, + {"role": "user", "content": previous_messages_str} + ] + ) + + # send the message with the summary + return self._get_completion(message=self.message, additional_instructions=f"\n\nPrevious conversation summary: '{summary.choices[0].message.content}'") +``` + +#### 4. Running each agent in a separate API call + +If you are a PRO, and you have managed to deploy each agent in a separate API endpoint, instead of using `_get_completion()`, you can call your own API and let the agents communicate with each other over the internet. + +```python +import requests +from agency_swarm.tools.send_message import SendMessage + +class SendMessageAPI(SendMessage): + def run(self): + response = requests.post( + "https://your-api-endpoint.com/send-message", + json={"message": self.message, "recipient": self.recipient} + ) + return response.json()["message"] +``` + +This is very powerful, as you can even allow your agents to colloborate with agents outside your system. More on this is coming soon! + +!!! tip "Contributing" + + If you have any ideas for new communication flows, please either adjust this page in docs, or add your new send message tool in the `agency_swarm/tools/send_message` folder and open a PR! + +## Conclusion + +Agency Swarm is the only framework that gives you full control over your systems. With this new feature, now there is **not a single prompt, parameter or part of the system** that you cannot adjust or customize to your own needs! diff --git a/mkdocs.yml b/mkdocs.yml index 647e3e7b..58e118c4 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -14,7 +14,7 @@ theme: - navigation.instant.prefetch - navigation.instant.progress - navigation.prune -# - navigation.sections + # - navigation.sections - navigation.tabs # - navigation.tabs.sticky - navigation.top @@ -43,10 +43,11 @@ nav: - Introduction: "index.md" - Quick Start: "quick_start.md" - Advanced Usage: - - Advanced Tools: "advanced-usage/tools.md" - - Agents: "advanced-usage/agents.md" - - Agencies: "advanced-usage/agencies.md" - - Azure OpenAI: "advanced-usage/azure-openai.md" + - Tools: "advanced-usage/tools.md" + - Agents: "advanced-usage/agents.md" + - Agencies: "advanced-usage/agencies.md" + - Communication: "advanced-usage/communication_flows.md" + - Azure OpenAI: "advanced-usage/azure-openai.md" - Deployment to Production: "deployment.md" - Open Source Models: "advanced-usage/open-source-models.md" - API Reference: "api.md" From 53f70504a587b9dbfef9f75531c31b17e9c182a6 Mon Sep 17 00:00:00 2001 From: Arsenii Shatokhin Date: Mon, 18 Nov 2024 12:15:05 +0400 Subject: [PATCH 5/9] Prevent agents from calling another agent twice at the same time --- agency_swarm/threads/thread.py | 22 ++++++++++++++++--- .../tools/send_message/SendMessageSwarm.py | 2 +- docs/advanced-usage/communication_flows.md | 3 ++- ..._send_message.py => test_communication.py} | 20 ++++++++++++++--- 4 files changed, 39 insertions(+), 8 deletions(-) rename tests/{test_send_message.py => test_communication.py} (57%) diff --git a/agency_swarm/threads/thread.py b/agency_swarm/threads/thread.py index 482a7cdf..07eaa313 100644 --- a/agency_swarm/threads/thread.py +++ b/agency_swarm/threads/thread.py @@ -51,10 +51,16 @@ def __init__(self, agent: Union[Agent, User], recipient_agent: Agent): self._stream = None self._num_run_retries = 0 + # names of recepient agents that were called in SendMessage tool + # needed to prevent agents calling the same recepient agent multiple times + self._called_recepients = [] self.terminal_states = ["cancelled", "completed", "failed", "expired", "incomplete"] def init_thread(self): + self._called_recepients = [] + self._num_run_retries = 0 + if self.id: return @@ -147,6 +153,7 @@ def get_completion(self, # function execution if self._run.status == "requires_action": + self._called_recepients = [] tool_calls = self._run.required_action.submit_tool_outputs.tool_calls tool_outputs_and_names = [] # list of tuples (name, tool_output) sync_tool_calls, async_tool_calls = self._get_sync_async_tool_calls(tool_calls, recipient_agent) @@ -480,8 +487,9 @@ def execute_tool(self, tool_call, recipient_agent=None, event_handler=None, tool if not recipient_agent: recipient_agent = self.recipient_agent + tool_name = tool_call.function.name funcs = recipient_agent.functions - tool = next((func for func in funcs if func.__name__ == tool_call.function.name), None) + tool = next((func for func in funcs if func.__name__ == tool_name), None) if not tool: return f"Error: Function {tool_call.function.name} not found. Available functions: {[func.__name__ for func in funcs]}", False @@ -491,11 +499,19 @@ def execute_tool(self, tool_call, recipient_agent=None, event_handler=None, tool args = tool_call.function.arguments args = json.loads(args) if args else {} tool = tool(**args) + + # check if the tool is already called for tool_name in [name for name, _ in tool_outputs_and_names]: - if tool_name == tool_call.function.name and ( + if tool_name == tool_name and ( hasattr(tool, "ToolConfig") and hasattr(tool.ToolConfig, "one_call_at_a_time") and tool.ToolConfig.one_call_at_a_time): - return f"Error: Function {tool_call.function.name} is already called. You can only call this function once at a time. Please wait for the previous call to finish before calling it again.", False + return f"Error: Function {tool_name} is already called. You can only call this function once at a time. Please wait for the previous call to finish before calling it again.", False + # for send message tools, don't allow calling the same recepient agent multiple times + if tool_name.startswith("SendMessage"): + if tool.recipient.value in self._called_recepients: + return f"Error: Agent {tool.recipient.value} has already been called. You can only call each agent once at a time. Please wait for the previous call to finish before calling it again.", False + self._called_recepients.append(tool.recipient.value) + tool._caller_agent = recipient_agent tool._event_handler = event_handler tool._tool_call = tool_call diff --git a/agency_swarm/tools/send_message/SendMessageSwarm.py b/agency_swarm/tools/send_message/SendMessageSwarm.py index ec30288a..868b9258 100644 --- a/agency_swarm/tools/send_message/SendMessageSwarm.py +++ b/agency_swarm/tools/send_message/SendMessageSwarm.py @@ -24,7 +24,7 @@ def run(self): poll=False ) except BadRequestError as e: - raise BadRequestError("You can only call this tool by itself. Do not use any other tools together with this tool.") + raise Exception("You can only call this tool by itself. Do not use any other tools together with this tool.") try: # cancel run diff --git a/docs/advanced-usage/communication_flows.md b/docs/advanced-usage/communication_flows.md index 02f179f7..f86ae6ae 100644 --- a/docs/advanced-usage/communication_flows.md +++ b/docs/advanced-usage/communication_flows.md @@ -101,7 +101,7 @@ In the following sections, we'll look at some common use cases for extending the The most basic use case is if you want to have your own parameter descriptions, such as you want to change the docstring or the description of the `message` parameter. This can help you better customize how the agents communicate with each other and what information they relay. -Let's say that instead of sending messages, I want my agents to send tasks to each other. In this case, I can change the docstring and the `message` parameter description to better fit the task-oriented nature of my application. +Let's say that instead of sending messages, I want my agents to send tasks to each other. In this case, I can change the docstring and the `message` parameter to a `task` parameter to better fit the task-oriented nature of my application. ```python from pydantic import Field @@ -145,6 +145,7 @@ You can, of course, also use GPT for this: ```python from agency_swarm.tools.send_message import SendMessage from agency_swarm.util.validators import llm_validator +from pydantic import model_validator class SendMessageLLMValidation(SendMessage): @model_validator(mode='after') diff --git a/tests/test_send_message.py b/tests/test_communication.py similarity index 57% rename from tests/test_send_message.py rename to tests/test_communication.py index 4ebc3a2c..b068c97c 100644 --- a/tests/test_send_message.py +++ b/tests/test_communication.py @@ -30,13 +30,14 @@ def run(self): tools=[] ) - self.agency = Agency([self.ceo, [self.ceo, self.customer_support], [self.customer_support, self.ceo]], send_message_tool_class=SendMessageSwarm) + self.agency = Agency([self.ceo, [self.ceo, self.customer_support], [self.customer_support, self.ceo]], + temperature=0, send_message_tool_class=SendMessageSwarm) def test_send_message_swarm(self): response = self.agency.get_completion("Hello, can you send me to customer support? If there are any issues, please say 'error'") - self.assertFalse("error" in response.lower()) + self.assertFalse("error" in response.lower(), self.agency.main_thread.thread_url) response = self.agency.get_completion("Who are you?") - self.assertTrue("customer support" in response.lower()) + self.assertTrue("customer support" in response.lower(), self.agency.main_thread.thread_url) main_thread = self.agency.main_thread @@ -46,5 +47,18 @@ def test_send_message_swarm(self): #check if all messages in the same thread (this is how Swarm works) self.assertTrue(len(main_thread.get_messages()) >= 4) # sometimes run does not cancel immediately, so there might be 5 messages + def test_send_message_double_recepient_error(self): + ceo = Agent(name="CEO", + description="Responsible for client communication, task planning and management.", + instructions="You are an agent for testing. Route request AT THE SAME TIME as instructed. If there is an error in a single request, please say 'error'. If there are errors in both requests, please say 'fatal'. do not output anything else.", # can be a file like ./instructions.md + ) + test_agent = Agent(name="Test Agent1", + description="Responsible for testing.", + instructions="Test agent for testing.") + agency = Agency([ceo, [ceo, test_agent]], temperature=0) + response = agency.get_completion("Please route me to customer support TWICE at the same time. I am testing something.") + self.assertTrue("error" in response.lower(), agency.main_thread.thread_url) + self.assertTrue("fatal" not in response.lower(), agency.main_thread.thread_url) + if __name__ == '__main__': unittest.main() From c25bdb434c6e0b9d98f906c45e47288a32b6e75b Mon Sep 17 00:00:00 2001 From: Arsenii Shatokhin Date: Tue, 19 Nov 2024 08:22:35 +0400 Subject: [PATCH 6/9] Minor docs adjustments --- .../tools/send_message/SendMessage.py | 4 +- docs/advanced-usage/communication_flows.md | 42 ++++++++++++------- tests/test_communication.py | 4 +- 3 files changed, 32 insertions(+), 18 deletions(-) diff --git a/agency_swarm/tools/send_message/SendMessage.py b/agency_swarm/tools/send_message/SendMessage.py index 7a1af529..d5c777ac 100644 --- a/agency_swarm/tools/send_message/SendMessage.py +++ b/agency_swarm/tools/send_message/SendMessage.py @@ -32,8 +32,8 @@ class SendMessage(SendMessageBase): @model_validator(mode='after') def validate_files(self): - # prevent hallucinations with file IDs if the necessary parameters are provided - if hasattr(self, 'message') and "file-" in self.message or (self.additional_instructions and "file-" in self.additional_instructions): + # prevent hallucinations with agents sending file IDs into incorrect fields + if "file-" in self.message or (self.additional_instructions and "file-" in self.additional_instructions): if not self.message_files: raise ValueError("You must include file IDs in message_files parameter.") return self diff --git a/docs/advanced-usage/communication_flows.md b/docs/advanced-usage/communication_flows.md index f86ae6ae..4f8c6a34 100644 --- a/docs/advanced-usage/communication_flows.md +++ b/docs/advanced-usage/communication_flows.md @@ -67,14 +67,13 @@ class SendMessage(SendMessageBase): description="Additional context or instructions from the conversation needed by the recipient agent to complete the task." ) - @field_validator('additional_instructions', mode='before') - @classmethod - def validate_additional_instructions(cls, value): - # previously the parameter was a list, now it's a string - # add compatibility for old code - if isinstance(value, list): - return "\n".join(value) - return value + @model_validator(mode='after') + def validate_files(self): + # prevent hallucinations with agents sending file IDs into incorrect fields + if "file-" in self.message or (self.additional_instructions and "file-" in self.additional_instructions): + if not self.message_files: + raise ValueError("You must include file IDs in message_files parameter.") + return self def run(self): @@ -99,9 +98,9 @@ In the following sections, we'll look at some common use cases for extending the #### 1. Adjusting parameters and descriptions -The most basic use case is if you want to have your own parameter descriptions, such as you want to change the docstring or the description of the `message` parameter. This can help you better customize how the agents communicate with each other and what information they relay. +The most basic use case is if you want to use your own parameter descriptions, such as if you want to change the docstring or the description of the `message` parameter. This can help you better customize how the agents communicate with each other and what information they relay. -Let's say that instead of sending messages, I want my agents to send tasks to each other. In this case, I can change the docstring and the `message` parameter to a `task` parameter to better fit the task-oriented nature of my application. +Let's say that instead of sending messages, I want my agents to send tasks to each other. In this case, I can change the docstring and the `message` parameter to a `task` parameter to better fit the nature of my application. ```python from pydantic import Field @@ -126,7 +125,7 @@ To remove the chain of thought, you can simply remove the `chain_of_thought` par #### 2. Adding custom validation logic -Now, let's say that I need to ensure that my message is sent to the correct recepient agent. (This is a very common hallucination in production.) In this case, I can add custom validator to the `recipient` parameter, which is defined in the `SendMessageBase` class. Since I don't want to change any other logic, I can inherit the `SendMessage` class and only add this new validation logic. +Now, let's say that I need to ensure that my message is sent to the correct recepient agent. (This is a very common hallucination in production.) In this case, I can add custom validator to the `recipient` parameter, which is defined in the `SendMessageBase` class. Since I don't want to change any other parameters or descriptions, I can inherit the default `SendMessage` class and only add this new validation logic. ```python from agency_swarm.tools.send_message import SendMessage @@ -157,11 +156,11 @@ class SendMessageLLMValidation(SendMessage): return self ``` -In this example, the `llm_validator` will throw an error if the message is not related to customer support. The caller agent will then have to fix the recipient or the message and send it again! +In this example, the `llm_validator` will throw an error if the message is not related to customer support. The caller agent will then have to fix the recipient or the message and send it again! This is extremely useful when you have a lot of agents. #### 3. Summurizing previous conversations with other agents and adding to context -Sometimes, when using default `SendMessage`, the agents might not relay all the neceessary details to the recipient agent. Especially, when the previous conversation is long. In this case, you can summarize the previous conversation with GPT and add it to the context, instead of the additional instructions. I will extend the `SendMessageQuick` class, which already contains the `message` parameter. +Sometimes, when using default `SendMessage`, the agents might not relay all the neceessary details to the recipient agent. Especially, when the previous conversation is too long. In this case, you can summarize the previous conversation with GPT and add it to the context, instead of the additional instructions. I will extend the `SendMessageQuick` class, which already contains the `message` parameter, as I don't need chain of thought or files in this case. ```python from agency_swarm.tools.send_message import SendMessageQuick @@ -189,6 +188,8 @@ class SendMessageSummary(SendMessageQuick): return self._get_completion(message=self.message, additional_instructions=f"\n\nPrevious conversation summary: '{summary.choices[0].message.content}'") ``` +With this example, you can add your own custom logic to the `run` method. It does not have to be a summary; you can also use it to add any other information to the context. For example, you can even query a vector database or use an external API. + #### 4. Running each agent in a separate API call If you are a PRO, and you have managed to deploy each agent in a separate API endpoint, instead of using `_get_completion()`, you can call your own API and let the agents communicate with each other over the internet. @@ -212,6 +213,19 @@ This is very powerful, as you can even allow your agents to colloborate with age If you have any ideas for new communication flows, please either adjust this page in docs, or add your new send message tool in the `agency_swarm/tools/send_message` folder and open a PR! +**After implementing your own `SendMessage` tool**, simply pass it into the `send_message_tool_class` parameter when initializing the `Agency` class: + +```python +agency = Agency( + ... + send_message_tool_class=SendMessageAPI +) +``` + +Now, your agents will use your own custom `SendMessageAPI` class for communication! + ## Conclusion -Agency Swarm is the only framework that gives you full control over your systems. With this new feature, now there is **not a single prompt, parameter or part of the system** that you cannot adjust or customize to your own needs! +Agency Swarm has been designed to give you, the developer, full control over your systems. It is the only framework that does not hard-code any prompts, parameters, or even worse, agents for you. With this new feature, the last part of the system that you couldn't fully customize to your own needs is now gone! + +So, I want to encourage you to keep experimenting and designing your own unique communication flows. While the examples above should serve as a good starting point, they do not even merely scratch the surface of what's possible here! I am looking forward to seeing what you will create. Please share it in our [Discord server](https://discord.gg/7HcABDpFPG) so we can all learn from each other. diff --git a/tests/test_communication.py b/tests/test_communication.py index b068c97c..216efd29 100644 --- a/tests/test_communication.py +++ b/tests/test_communication.py @@ -34,7 +34,7 @@ def run(self): temperature=0, send_message_tool_class=SendMessageSwarm) def test_send_message_swarm(self): - response = self.agency.get_completion("Hello, can you send me to customer support? If there are any issues, please say 'error'") + response = self.agency.get_completion("Hello, can you send me to customer support? If tool responds says that you have NOT been rerouted, or if there is another error, please say 'error'") self.assertFalse("error" in response.lower(), self.agency.main_thread.thread_url) response = self.agency.get_completion("Who are you?") self.assertTrue("customer support" in response.lower(), self.agency.main_thread.thread_url) @@ -50,7 +50,7 @@ def test_send_message_swarm(self): def test_send_message_double_recepient_error(self): ceo = Agent(name="CEO", description="Responsible for client communication, task planning and management.", - instructions="You are an agent for testing. Route request AT THE SAME TIME as instructed. If there is an error in a single request, please say 'error'. If there are errors in both requests, please say 'fatal'. do not output anything else.", # can be a file like ./instructions.md + instructions="You are an agent for testing. Route request AT THE SAME TIME as instructed. If there is an error in a single request, please say 'error'. If there are errors in both requests, please say 'fatal'. do not output anything else.", ) test_agent = Agent(name="Test Agent1", description="Responsible for testing.", From e750bcf7a23ddfb2d83ce588b075664431809f00 Mon Sep 17 00:00:00 2001 From: Arsenii Shatokhin Date: Tue, 19 Nov 2024 08:25:49 +0400 Subject: [PATCH 7/9] doc minor adjustments --- docs/advanced-usage/communication_flows.md | 26 +++++++++++----------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/docs/advanced-usage/communication_flows.md b/docs/advanced-usage/communication_flows.md index 4f8c6a34..c65d1a17 100644 --- a/docs/advanced-usage/communication_flows.md +++ b/docs/advanced-usage/communication_flows.md @@ -2,7 +2,18 @@ Multi-agent communication is the core functionality of any Multi-Agent System. Unlike in all other frameworks, Agency Swarm not only allows you to define communication flows in any way you want (uniform communication flows), but to also configure the underlying logic for this feature. This means that you can create entirely new types of communication, or adjust it to your own needs. Below you will find a guide on how to do all this, along with some common examples. -**To use your own `SendMessage` calss**, simply put it in the `send_message_tool_class` parameter when initializing the `Agency` class: +## Pre-Made SendMessage Classes + +Agency Swarm contains multiple commonly requested classes for communication flows. Currently, the following classes are available: + +| Class Name | Description | When to Use | Code Link | +| --------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------- | --------- | +| `SendMessage` (default) | This is the default class for sending messages to other agents. It uses synchronous communication with basic COT (Chain of Thought) prompting and allows agents to relay files and modify system instructions for each other. | Suitable for most use cases. Balances speed and functionality. | | +| `SendMessageQuick` | A variant of the SendMessage class without Chain of Thought prompting, files, and additional instructions. It allows for faster communication without the overhead of COT. | Use for simpler use cases or when you want to save tokens and increase speed. | | +| `SendMessageAsyncThreading` | Similar to `SendMessage` but with `async_mode='threading'`. Each agent will execute asynchronously in a separate thread. In the meantime, the caller agent can continue the conversation with the user and check the results later. | Use for asynchronous applications or when sub-agents take singificant amounts of time to complete their tasks. | | +| `SendMessageSwarm` | Instead of sending a message to another agent, it replaces the caller agent with the recipient agent, similar to [OpenAI's Swarm](https://github.com/openai/swarm). The recipient agent will then have access to the entire conversation. | When you need more granular control. It is not able to handle complex multi-step, multi-agent tasks. | | + +**To use any of the pre-made `SendMessage` classes**, simply put it in the `send_message_tool_class` parameter when initializing the `Agency` class: ```python from agency_swarm.tools.send_message import SendMessageQuick @@ -15,17 +26,6 @@ agency = Agency( That's it! Now, your agents will use your own custom `SendMessageQuick` class for communication. -## Pre-Made SendMessage Classes - -Agency Swarm contains multiple commonly requested classes for communication flows. Currently, the following classes are available: - -| Class Name | Description | When to Use | Code Link | -| --------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------- | --------- | -| `SendMessage` (default) | This is the default class for sending messages to other agents. It uses synchronous communication with basic COT (Chain of Thought) prompting and allows agents to relay files and modify system instructions for each other. | Suitable for most use cases. Balances speed and functionality. | | -| `SendMessageQuick` | A variant of the SendMessage class without Chain of Thought prompting, files, and additional instructions. It allows for faster communication without the overhead of COT. | Use for simpler use cases or when you want to save tokens and increase speed. | | -| `SendMessageAsyncThreading` | Similar to `SendMessage` but with `async_mode='threading'`. Each agent will execute asynchronously in a separate thread. In the meantime, the caller agent can continue the conversation with the user and check the results later. | Use for asynchronous applications or when sub-agents take singificant amounts of time to complete their tasks. | | -| `SendMessageSwarm` | Instead of sending a message to another agent, it replaces the caller agent with the recipient agent, similar to [OpenAI's Swarm](https://github.com/openai/swarm). The recipient agent will then have access to the entire conversation. | When you need more granular control. It is not able to handle complex multi-step, multi-agent tasks. | | - ## Creating Your Own Unique Communication Flows To create you own communication flow, you will first need to extend the `SendMessageBase` class. This class extends the `BaseTool` class, like any other tools in Agency Swarm, and contains the most basic parameters required for communication, such as the `recipient_agent`. @@ -222,7 +222,7 @@ agency = Agency( ) ``` -Now, your agents will use your own custom `SendMessageAPI` class for communication! +That's it! Now, your agents will use your own custom `SendMessageAPI` class for communication! ## Conclusion From a303a6d5e76579a37a3f9c574a4574aff9fa96fb Mon Sep 17 00:00:00 2001 From: Arsenii Shatokhin Date: Tue, 19 Nov 2024 12:40:53 +0400 Subject: [PATCH 8/9] Bump versions --- pyproject.toml | 2 +- requirements.txt | 2 +- setup.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 17254123..39241d38 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ classifiers = [ "License :: OSI Approved :: MIT License", ] dependencies = [ - "openai==1.54.3", + "openai==1.54.4", "docstring_parser==0.16", "pydantic==2.8.2", "datamodel-code-generator==0.26.1", diff --git a/requirements.txt b/requirements.txt index 6d7b06c5..c66fce38 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -openai==1.54.3 +openai==1.54.4 docstring_parser==0.16 pydantic==2.8.2 datamodel-code-generator==0.26.1 diff --git a/setup.py b/setup.py index 33f796a7..2dc7a02f 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name='agency-swarm', - version='0.3.1', + version='0.4.0', author='VRSEN', author_email='me@vrsen.ai', description='An opensource agent orchestration framework built on top of the latest OpenAI Assistants API.', From 23cd1fc18697e805b3fa1ab195ce2ce3e83087d6 Mon Sep 17 00:00:00 2001 From: Arsenii Shatokhin Date: Tue, 19 Nov 2024 12:41:39 +0400 Subject: [PATCH 9/9] Replace email --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 39241d38..4055e836 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta" [project] name = "agency-swarm" dynamic = ["version"] -authors = [{ name = "VRSEN", email = "arseny9795@gmail.com" }] +authors = [{ name = "VRSEN", email = "me@vrsen.ai" }] description = "An open source agent orchestration framework built on top of the latest OpenAI Assistants API." readme = "README.md" license = { file = "LICENSE" }