diff --git a/agency_swarm/agency/agency.py b/agency_swarm/agency/agency.py index 4e5001d6..69c38880 100644 --- a/agency_swarm/agency/agency.py +++ b/agency_swarm/agency/agency.py @@ -6,7 +6,6 @@ import uuid from enum import Enum from typing import Any, Callable, Dict, List, Literal, Optional, Type, TypeVar, TypedDict, Union - from openai.lib._parsing._completions import type_to_response_format_param from openai.types.beta.threads import Message from openai.types.beta.threads.runs import RunStep @@ -24,7 +23,9 @@ from agency_swarm.messages import MessageOutput from agency_swarm.messages.message_output import MessageOutputLive from agency_swarm.threads import Thread +from agency_swarm.threads.thread_async import ThreadAsync from agency_swarm.tools import BaseTool, CodeInterpreter, FileSearch +from agency_swarm.tools.send_message import SendMessage, SendMessageBase from agency_swarm.user import User from agency_swarm.util.errors import RefusalError from agency_swarm.util.files import get_tools, get_file_purpose @@ -46,15 +47,12 @@ class ThreadsCallbacks(TypedDict): class Agency: - ThreadType = Thread - send_message_tool_description = """Use this tool to facilitate direct, synchronous communication between specialized agents within your agency. When you send a message using this tool, you receive a response exclusively from the designated recipient agent. To continue the dialogue, invoke this tool again with the desired recipient agent and your follow-up message. Remember, communication here is synchronous; the recipient agent won't perform any tasks post-response. You are responsible for relaying the recipient agent's responses back to the user, as the user does not have direct access to these replies. Keep engaging with the tool for continuous interaction until the task is fully resolved. Do not send more than 1 message at a time.""" - send_message_tool_description_async = """Use this tool for asynchronous communication with other agents within your agency. Initiate tasks by messaging, and check status and responses later with the 'GetResponse' tool. Relay responses to the user, who instructs on status checks. Continue until task completion.""" - def __init__(self, agency_chart: List, shared_instructions: str = "", shared_files: Union[str, List[str]] = None, async_mode: Literal['threading', "tools_threading"] = None, + send_message_tool_class: Type[SendMessageBase] = SendMessage, settings_path: str = "./settings.json", settings_callbacks: SettingsCallbacks = None, threads_callbacks: ThreadsCallbacks = None, @@ -72,6 +70,7 @@ def __init__(self, shared_instructions (str, optional): A path to a file containing shared instructions for all agents. Defaults to an empty string. shared_files (Union[str, List[str]], optional): A path to a folder or a list of folders containing shared files for all agents. Defaults to None. async_mode (str, optional): Specifies the mode for asynchronous processing. In "threading" mode, all sub-agents run in separate threads. In "tools_threading" mode, all tools run in separate threads, but agents do not. Defaults to None. + send_message_tool_class (Type[SendMessageBase], optional): The class to use for the send_message tool. For async communication, use `SendMessageAsyncThreading`. Defaults to SendMessage. settings_path (str, optional): The path to the settings file for the agency. Must be json. If file does not exist, it will be created. Defaults to None. settings_callbacks (SettingsCallbacks, optional): A dictionary containing functions to load and save settings for the agency. The keys must be "load" and "save". Both values must be defined. Defaults to None. threads_callbacks (ThreadsCallbacks, optional): A dictionary containing functions to load and save threads for the agency. The keys must be "load" and "save". Both values must be defined. Defaults to None. @@ -92,6 +91,7 @@ def __init__(self, self.recipient_agents = None # for autocomplete self.shared_files = shared_files if shared_files else [] self.async_mode = async_mode + self.send_message_tool_class = send_message_tool_class self.settings_path = settings_path self.settings_callbacks = settings_callbacks self.threads_callbacks = threads_callbacks @@ -101,11 +101,19 @@ def __init__(self, self.max_completion_tokens = max_completion_tokens self.truncation_strategy = truncation_strategy + # set thread type based send_message_tool_class async mode + if hasattr(send_message_tool_class.ToolConfig, "async_mode") and send_message_tool_class.ToolConfig.async_mode: + self._thread_type = ThreadAsync + else: + self._thread_type = Thread + if self.async_mode == "threading": - from agency_swarm.threads.thread_async import ThreadAsync - self.ThreadType = ThreadAsync + from agency_swarm.tools.send_message import SendMessageAsyncThreading + print("Warning: 'threading' mode is deprecated. Please use send_message_tool_class = SendMessageAsyncThreading to use async communication.") + self.send_message_tool_class = SendMessageAsyncThreading elif self.async_mode == "tools_threading": - Thread.async_mode = self.async_mode + Thread.async_mode = "tools_threading" + print("Warning: 'tools_threading' mode is deprecated. Use tool.ToolConfig.async_mode = 'threading' instead.") elif self.async_mode is None: pass else: @@ -121,9 +129,9 @@ def __init__(self, self.shared_state = SharedState() self._parse_agency_chart(agency_chart) + self._init_threads() self._create_special_tools() self._init_agents() - self._init_threads() def get_completion(self, message: str, message_files: List[str] = None, @@ -300,7 +308,7 @@ def demo_gradio(self, height=450, dark_mode=True, **kwargs): images = [] message_file_names = None uploading_files = False - recipient_agents = [agent.name for agent in self.main_recipients] + recipient_agent_names = [agent.name for agent in self.main_recipients] recipient_agent = self.main_recipients[0] with gr.Blocks(js=js) as demo: @@ -308,7 +316,7 @@ def demo_gradio(self, height=450, dark_mode=True, **kwargs): chatbot = gr.Chatbot(height=height) with gr.Row(): with gr.Column(scale=9): - dropdown = gr.Dropdown(label="Recipient Agent", choices=recipient_agents, + dropdown = gr.Dropdown(label="Recipient Agent", choices=recipient_agent_names, value=recipient_agent.name) msg = gr.Textbox(label="Your Message", lines=4) with gr.Column(scale=1): @@ -412,9 +420,14 @@ def check_and_add_tools_in_attachments(attachments, recipient_agent): class GradioEventHandler(AgencyEventHandler): message_output = None + @classmethod + def change_recipient_agent(cls, recipient_agent_name): + nonlocal chatbot_queue + chatbot_queue.put("[change_recipient_agent]") + chatbot_queue.put(recipient_agent_name) + @override def on_message_created(self, message: Message) -> None: - if message.role == "user": full_content = "" for content in message.content: @@ -530,19 +543,20 @@ def on_all_streams_end(cls): chatbot_queue.put("[end]") def bot(original_message, history): - if not original_message: - return "", history - nonlocal attachments nonlocal message_file_names nonlocal recipient_agent + nonlocal recipient_agent_names nonlocal images nonlocal uploading_files + if not original_message: + return "", history, gr.update(value=recipient_agent.name, choices=set([*recipient_agent_names, recipient_agent.name])) + if uploading_files: history.append([None, "Uploading files... Please wait."]) - yield "", history - return "", history + yield "", history, gr.update(value=recipient_agent.name, choices=set([*recipient_agent_names, recipient_agent.name])) + return "", history, gr.update(value=recipient_agent.name, choices=set([*recipient_agent_names, recipient_agent.name])) print("Message files: ", attachments) print("Images: ", images) @@ -579,13 +593,19 @@ def bot(original_message, history): new_message = True continue + if bot_message == "[change_recipient_agent]": + new_agent_name = chatbot_queue.get(block=True) + recipient_agent = self._get_agent_by_name(new_agent_name) + yield "", history, gr.update(value=new_agent_name, choices=set([*recipient_agent_names, recipient_agent.name])) + continue + if new_message: history.append([None, bot_message]) new_message = False else: history[-1][1] += bot_message - yield "", history + yield "", history, gr.update(value=recipient_agent.name, choices=set([*recipient_agent_names, recipient_agent.name])) except queue.Empty: break @@ -594,12 +614,12 @@ def bot(original_message, history): inputs=[msg, chatbot], outputs=[msg, chatbot] ).then( - bot, [msg, chatbot], [msg, chatbot] + bot, [msg, chatbot, dropdown], [msg, chatbot, dropdown] ) dropdown.change(handle_dropdown_change, dropdown) file_upload.change(handle_file_upload, file_upload) msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( - bot, [msg, chatbot], [msg, chatbot] + bot, [msg, chatbot], [msg, chatbot, dropdown] ) # Enable queuing for streaming intermediate outputs @@ -647,6 +667,7 @@ def run_demo(self): """ Executes agency in the terminal with autocomplete for recipient agent names. """ + outer_self = self from agency_swarm import AgencyEventHandler class TermEventHandler(AgencyEventHandler): message_output = None @@ -714,7 +735,7 @@ def on_tool_call_done(self, snapshot): if snapshot.type != "function": return - if snapshot.function.name == "SendMessage": + if snapshot.function.name == "SendMessage" and not (hasattr(outer_self.send_message_tool_class.ToolConfig, 'output_as_result') and outer_self.send_message_tool_class.ToolConfig.output_as_result): try: args = eval(snapshot.function.arguments) recipient = args["recipient"] @@ -864,15 +885,24 @@ def _init_threads(self): else: self.main_thread.init_thread() + # Save main_thread into agents_and_threads + self.agents_and_threads["main_thread"] = self.main_thread + + # initialize threads for agent_name, threads in self.agents_and_threads.items(): + if agent_name == "main_thread": + continue for other_agent, items in threads.items(): - self.agents_and_threads[agent_name][other_agent] = self.ThreadType( + # create thread class + self.agents_and_threads[agent_name][other_agent] = self._thread_type( self._get_agent_by_name(items["agent"]), self._get_agent_by_name( items["recipient_agent"])) + # load thread id if available if agent_name in loaded_thread_ids and other_agent in loaded_thread_ids[agent_name]: self.agents_and_threads[agent_name][other_agent].id = loaded_thread_ids[agent_name][other_agent] + # init threads if threre are threads callbacks so the ids are saved for later use elif self.threads_callbacks: self.agents_and_threads[agent_name][other_agent].init_thread() @@ -880,6 +910,8 @@ def _init_threads(self): if self.threads_callbacks: loaded_thread_ids = {} for agent_name, threads in self.agents_and_threads.items(): + if agent_name == "main_thread": + continue loaded_thread_ids[agent_name] = {} for other_agent, thread in threads.items(): loaded_thread_ids[agent_name][other_agent] = thread.id @@ -1001,13 +1033,15 @@ def _create_special_tools(self): No output parameters; this method modifies the agents' toolset internally. """ for agent_name, threads in self.agents_and_threads.items(): + if agent_name == "main_thread": + continue recipient_names = list(threads.keys()) recipient_agents = self._get_agents_by_names(recipient_names) if len(recipient_agents) == 0: continue agent = self._get_agent_by_name(agent_name) agent.add_tool(self._create_send_message_tool(agent, recipient_agents)) - if self.async_mode == 'threading': + if self._thread_type == ThreadAsync: agent.add_tool(self._create_get_response_tool(agent, recipient_agents)) def _create_send_message_tool(self, agent: Agent, recipient_agents: List[Agent]): @@ -1032,38 +1066,8 @@ def _create_send_message_tool(self, agent: Agent, recipient_agents: List[Agent]) agent_descriptions += recipient_agent.name + ": " agent_descriptions += recipient_agent.description + "\n" - outer_self = self - - class SendMessage(BaseTool): - my_primary_instructions: str = Field(..., - description="Please repeat your primary instructions step-by-step, including both completed " - "and the following next steps that you need to perform. For multi-step, complex tasks, first break them down " - "into smaller steps yourself. Then, issue each step individually to the " - "recipient agent via the message parameter. Each identified step should be " - "sent in separate message. Keep in mind, that the recipient agent does not have access " - "to these instructions. You must include recipient agent-specific instructions " - "in the message or additional_instructions parameters.") + class SendMessage(self.send_message_tool_class): recipient: recipients = Field(..., description=agent_descriptions) - message: str = Field(..., - description="Specify the task required for the recipient agent to complete. Focus on " - "clarifying what the task entails, rather than providing exact " - "instructions.") - message_files: Optional[List[str]] = Field(default=None, - description="A list of file ids to be sent as attachments to this message. Only use this if you have the file id that starts with 'file-'.", - examples=["file-1234", "file-5678"]) - additional_instructions: Optional[str] = Field(default=None, - description="Additional context or instructions for the recipient agent about the task. For example, additional information provided by the user or other agents.") - - class ToolConfig: - strict = False - one_call_at_a_time = outer_self.async_mode != 'threading' - - @model_validator(mode='after') - def validate_files(self): - if "file-" in self.message or ( - self.additional_instructions and "file-" in self.additional_instructions): - if not self.message_files: - raise ValueError("You must include file ids in message_files parameter.") @field_validator('recipient') @classmethod @@ -1071,36 +1075,9 @@ def check_recipient(cls, value): if value.value not in recipient_names: raise ValueError(f"Recipient {value} is not valid. Valid recipients are: {recipient_names}") return value - - @field_validator('additional_instructions', mode='before') - @classmethod - def validate_additional_instructions(cls, value): - if isinstance(value, list): - return "\n".join(value) - return value - - def run(self): - thread = outer_self.agents_and_threads[self._caller_agent.name][self.recipient.value] - - if not outer_self.async_mode == 'threading': - message = thread.get_completion(message=self.message, - message_files=self.message_files, - event_handler=self._event_handler, - yield_messages=not self._event_handler, - additional_instructions=self.additional_instructions, - ) - else: - message = thread.get_completion_async(message=self.message, - message_files=self.message_files, - additional_instructions=self.additional_instructions) - - return message or "" SendMessage._caller_agent = agent - if self.async_mode == 'threading': - SendMessage.__doc__ = self.send_message_tool_description_async - else: - SendMessage.__doc__ = self.send_message_tool_description + SendMessage._agents_and_threads = self.agents_and_threads return SendMessage diff --git a/agency_swarm/threads/thread.py b/agency_swarm/threads/thread.py index 8d681f3d..07eaa313 100644 --- a/agency_swarm/threads/thread.py +++ b/agency_swarm/threads/thread.py @@ -28,6 +28,16 @@ class Thread: @property def thread_url(self): return f'https://platform.openai.com/playground/assistants?assistant={self.recipient_agent.id}&mode=assistant&thread={self.id}' + + @property + def thread(self): + self.init_thread() + + if not self._thread: + print("retrieving thread", self.id) + self._thread = self.client.beta.threads.retrieve(self.id) + + return self._thread def __init__(self, agent: Union[Agent, User], recipient_agent: Agent): self.agent = agent @@ -36,28 +46,35 @@ def __init__(self, agent: Union[Agent, User], recipient_agent: Agent): self.client = get_openai_client() self.id = None - self.thread = None - self.run = None - self.stream = None + self._thread = None + self._run = None + self._stream = None + + self._num_run_retries = 0 + # names of recepient agents that were called in SendMessage tool + # needed to prevent agents calling the same recepient agent multiple times + self._called_recepients = [] - self.num_run_retries = 0 + self.terminal_states = ["cancelled", "completed", "failed", "expired", "incomplete"] def init_thread(self): - if self.id: - self.thread = self.client.beta.threads.retrieve(self.id) - else: - self.thread = self.client.beta.threads.create() - self.id = self.thread.id + self._called_recepients = [] + self._num_run_retries = 0 - if self.recipient_agent.examples: - for example in self.recipient_agent.examples: - self.client.beta.threads.messages.create( - thread_id=self.id, - **example, - ) + if self.id: + return + + self._thread = self.client.beta.threads.create() + self.id = self._thread.id + if self.recipient_agent.examples: + for example in self.recipient_agent.examples: + self.client.beta.threads.messages.create( + thread_id=self.id, + **example, + ) def get_completion_stream(self, - message: str, + message: Union[str, List[dict], None], event_handler: type(AgencyEventHandler), message_files: List[str] = None, attachments: Optional[List[Attachment]] = None, @@ -77,16 +94,18 @@ def get_completion_stream(self, response_format=response_format) def get_completion(self, - message: str | List[dict], + message: Union[str, List[dict], None], message_files: List[str] = None, attachments: Optional[List[dict]] = None, - recipient_agent: Agent = None, + recipient_agent: Union[Agent, None] = None, additional_instructions: str = None, event_handler: type(AgencyEventHandler) = None, tool_choice: AssistantToolChoice = None, yield_messages: bool = False, response_format: Optional[dict] = None ): + self.init_thread() + if not recipient_agent: recipient_agent = self.recipient_agent @@ -105,9 +124,6 @@ def get_completion(self, attachments.append({"file_id": file_id, "tools": recipient_tools or [{"type": "file_search"}]}) - if not self.thread: - self.init_thread() - if event_handler: event_handler.set_agent(self.agent) event_handler.set_recipient_agent(recipient_agent) @@ -117,14 +133,15 @@ def get_completion(self, print(f'THREAD:[ {sender_name} -> {recipient_agent.name} ]: URL {self.thread_url}') # send message - message_obj = self.create_message( - message=message, - role="user", - attachments=attachments - ) + if message: + message_obj = self.create_message( + message=message, + role="user", + attachments=attachments + ) - if yield_messages: - yield MessageOutput("text", self.agent.name, recipient_agent.name, message, message_obj) + if yield_messages: + yield MessageOutput("text", self.agent.name, recipient_agent.name, message, message_obj) self._create_run(recipient_agent, additional_instructions, event_handler, tool_choice, response_format=response_format) @@ -135,11 +152,11 @@ def get_completion(self, self._run_until_done() # function execution - if self.run.status == "requires_action": - tool_calls = self.run.required_action.submit_tool_outputs.tool_calls + if self._run.status == "requires_action": + self._called_recepients = [] + tool_calls = self._run.required_action.submit_tool_outputs.tool_calls tool_outputs_and_names = [] # list of tuples (name, tool_output) - sync_tool_calls = [tool_call for tool_call in tool_calls if tool_call.function.name == "SendMessage"] - async_tool_calls = [tool_call for tool_call in tool_calls if tool_call.function.name != "SendMessage"] + sync_tool_calls, async_tool_calls = self._get_sync_async_tool_calls(tool_calls, recipient_agent) def handle_output(tool_call, output): if inspect.isgenerator(output): @@ -157,6 +174,8 @@ def handle_output(tool_call, output): for tool_output in tool_outputs_and_names: if tool_output[1]["tool_call_id"] == tool_call.id: tool_output[1]["output"] = output + + return output if len(async_tool_calls) > 0 and self.async_mode == "tools_threading": max_workers = min(self.max_workers, os.cpu_count() or 1) # Use at most 4 workers or the number of CPUs available @@ -170,8 +189,11 @@ def handle_output(tool_call, output): for future in as_completed(futures): tool_call = futures[future] - output = future.result() - yield from handle_output(tool_call, output) + output, output_as_result = future.result() + output = yield from handle_output(tool_call, output) + if output_as_result: + self._cancel_run() + return output else: sync_tool_calls += async_tool_calls @@ -179,16 +201,19 @@ def handle_output(tool_call, output): for tool_call in sync_tool_calls: if yield_messages: yield MessageOutput("function", recipient_agent.name, self.agent.name, str(tool_call.function), tool_call) - output = self.execute_tool(tool_call, recipient_agent, event_handler, tool_outputs_and_names) + output, output_as_result = self.execute_tool(tool_call, recipient_agent, event_handler, tool_outputs_and_names) tool_outputs_and_names.append((tool_call.function.name, {"tool_call_id": tool_call.id, "output": output})) - yield from handle_output(tool_call, output) - + output = yield from handle_output(tool_call, output) + if output_as_result: + self._cancel_run() + return output + # split names and outputs tool_outputs = [tool_output for _, tool_output in tool_outputs_and_names] tool_names = [name for name, _ in tool_outputs_and_names] # await coroutines - tool_outputs = self._execute_async_tool_calls_outputs(tool_outputs) + tool_outputs = self._await_coroutines(tool_outputs) # convert all tool outputs to strings for tool_output in tool_outputs: @@ -213,11 +238,11 @@ def handle_output(tool_call, output): self._create_run(recipient_agent, additional_instructions, event_handler, 'required', temperature=0) self._run_until_done() - if self.run.status != "requires_action": - raise Exception("Run Failed. Error: ", self.run.last_error or self.run.incomplete_details) + if self._run.status != "requires_action": + raise Exception("Run Failed. Error: ", self._run.last_error or self._run.incomplete_details) # change tool call ids - tool_calls = self.run.required_action.submit_tool_outputs.tool_calls + tool_calls = self._run.required_action.submit_tool_outputs.tool_calls if len(tool_calls) != len(tool_outputs): tool_outputs = [] @@ -235,10 +260,10 @@ def handle_output(tool_call, output): else: raise e # error - elif self.run.status == "failed": + elif self._run.status == "failed": full_message += self._get_last_message_text() common_errors = ["something went wrong", "the server had an error processing your request", "rate limit reached"] - error_message = self.run.last_error.message.lower() + error_message = self._run.last_error.message.lower() if error_attempts < 3 and any(error in error_message for error in common_errors): if error_attempts < 2: @@ -250,9 +275,9 @@ def handle_output(tool_call, output): tool_choice, response_format=response_format) error_attempts += 1 else: - raise Exception("OpenAI Run Failed. Error: ", self.run.last_error.message) - elif self.run.status == "incomplete": - raise Exception("OpenAI Run Incomplete. Details: ", self.run.incomplete_details) + raise Exception("OpenAI Run Failed. Error: ", self._run.last_error.message) + elif self._run.status == "incomplete": + raise Exception("OpenAI Run Incomplete. Details: ", self._run.incomplete_details) # return assistant message else: message_obj = self._get_last_assistant_message() @@ -307,7 +332,7 @@ def _create_run(self, recipient_agent, additional_instructions, event_handler, t try: if event_handler: with self.client.beta.threads.runs.stream( - thread_id=self.thread.id, + thread_id=self.id, event_handler=event_handler(), assistant_id=recipient_agent.id, additional_instructions=additional_instructions, @@ -320,10 +345,10 @@ def _create_run(self, recipient_agent, additional_instructions, event_handler, t response_format=response_format ) as stream: stream.until_done() - self.run = stream.get_final_run() + self._run = stream.get_final_run() else: - self.run = self.client.beta.threads.runs.create( - thread_id=self.thread.id, + self._run = self.client.beta.threads.runs.create( + thread_id=self.id, assistant_id=recipient_agent.id, additional_instructions=additional_instructions, tool_choice=tool_choice, @@ -334,43 +359,72 @@ def _create_run(self, recipient_agent, additional_instructions, event_handler, t parallel_tool_calls=recipient_agent.parallel_tool_calls, response_format=response_format ) - self.run = self.client.beta.threads.runs.poll( - thread_id=self.thread.id, - run_id=self.run.id, + self._run = self.client.beta.threads.runs.poll( + thread_id=self.id, + run_id=self._run.id, # poll_interval_ms=500, ) except APIError as e: - if "The server had an error processing your request" in e.message and self.num_run_retries < 3: - time.sleep(1 + self.num_run_retries) + match = re.search(r"Thread (\w+) already has an active run (\w+)", e.message) + if match: + self._cancel_run(thread_id=match.groups()[0], run_id=match.groups()[1], check_status=False) + elif "The server had an error processing your request" in e.message and self._num_run_retries < 3: + time.sleep(1) self._create_run(recipient_agent, additional_instructions, event_handler, tool_choice, response_format=response_format) - self.num_run_retries += 1 + self._num_run_retries += 1 else: raise e def _run_until_done(self): - while self.run.status in ['queued', 'in_progress', "cancelling"]: + while self._run.status in ['queued', 'in_progress', "cancelling"]: time.sleep(0.5) - self.run = self.client.beta.threads.runs.retrieve( - thread_id=self.thread.id, - run_id=self.run.id + self._run = self.client.beta.threads.runs.retrieve( + thread_id=self.id, + run_id=self._run.id ) - def _submit_tool_outputs(self, tool_outputs, event_handler): - if not event_handler: - self.run = self.client.beta.threads.runs.submit_tool_outputs_and_poll( - thread_id=self.thread.id, - run_id=self.run.id, + def _submit_tool_outputs(self, tool_outputs, event_handler=None, poll=True): + if not poll: + self._run = self.client.beta.threads.runs.submit_tool_outputs( + thread_id=self.id, + run_id=self._run.id, tool_outputs=tool_outputs ) else: - with self.client.beta.threads.runs.submit_tool_outputs_stream( - thread_id=self.thread.id, - run_id=self.run.id, - tool_outputs=tool_outputs, - event_handler=event_handler() - ) as stream: - stream.until_done() - self.run = stream.get_final_run() + if not event_handler: + self._run = self.client.beta.threads.runs.submit_tool_outputs_and_poll( + thread_id=self.id, + run_id=self._run.id, + tool_outputs=tool_outputs + ) + else: + with self.client.beta.threads.runs.submit_tool_outputs_stream( + thread_id=self.id, + run_id=self._run.id, + tool_outputs=tool_outputs, + event_handler=event_handler() + ) as stream: + stream.until_done() + self._run = stream.get_final_run() + + def _cancel_run(self, thread_id=None, run_id=None, check_status=True): + if check_status and self._run.status in self.terminal_states and not run_id: + return + + try: + self._run = self.client.beta.threads.runs.cancel( + thread_id=self.id, + run_id=self._run.id + ) + except BadRequestError as e: + if "Cannot cancel run with status" in e.message: + self._run = self.client.beta.threads.runs.poll( + thread_id=thread_id or self.id, + run_id=run_id or self._run.id, + poll_interval_ms=500, + ) + else: + raise e def _get_last_message_text(self): messages = self.client.beta.threads.messages.list( @@ -417,15 +471,9 @@ def create_message(self, message: str, role: str = "user", attachments: List[dic thread_id, run_id = match.groups() thread_id = f"thread_{thread_id}" run_id = f"run_{run_id}" - self.client.beta.threads.runs.cancel( - thread_id=thread_id, - run_id=run_id - ) - self.run = self.client.beta.threads.runs.poll( - thread_id=thread_id, - run_id=run_id, - poll_interval_ms=500, - ) + + self._cancel_run(thread_id=thread_id, run_id=run_id) + return self.client.beta.threads.messages.create( thread_id=thread_id, role=role, @@ -439,33 +487,43 @@ def execute_tool(self, tool_call, recipient_agent=None, event_handler=None, tool if not recipient_agent: recipient_agent = self.recipient_agent + tool_name = tool_call.function.name funcs = recipient_agent.functions - tool = next((func for func in funcs if func.__name__ == tool_call.function.name), None) + tool = next((func for func in funcs if func.__name__ == tool_name), None) if not tool: - return f"Error: Function {tool_call.function.name} not found. Available functions: {[func.__name__ for func in funcs]}" + return f"Error: Function {tool_call.function.name} not found. Available functions: {[func.__name__ for func in funcs]}", False try: # init tool args = tool_call.function.arguments args = json.loads(args) if args else {} tool = tool(**args) + + # check if the tool is already called for tool_name in [name for name, _ in tool_outputs_and_names]: - if tool_name == tool_call.function.name and ( + if tool_name == tool_name and ( hasattr(tool, "ToolConfig") and hasattr(tool.ToolConfig, "one_call_at_a_time") and tool.ToolConfig.one_call_at_a_time): - return f"Error: Function {tool_call.function.name} is already called. You can only call this function once at a time. Please wait for the previous call to finish before calling it again." + return f"Error: Function {tool_name} is already called. You can only call this function once at a time. Please wait for the previous call to finish before calling it again.", False + # for send message tools, don't allow calling the same recepient agent multiple times + if tool_name.startswith("SendMessage"): + if tool.recipient.value in self._called_recepients: + return f"Error: Agent {tool.recipient.value} has already been called. You can only call each agent once at a time. Please wait for the previous call to finish before calling it again.", False + self._called_recepients.append(tool.recipient.value) + tool._caller_agent = recipient_agent tool._event_handler = event_handler + tool._tool_call = tool_call - return tool.run() + return tool.run(), tool.ToolConfig.output_as_result except Exception as e: error_message = f"Error: {e}" if "For further information visit" in error_message: error_message = error_message.split("For further information visit")[0] - return error_message + return error_message, False - def _execute_async_tool_calls_outputs(self, tool_outputs): + def _await_coroutines(self, tool_outputs): async_tool_calls = [] for tool_output in tool_outputs: if inspect.iscoroutine(tool_output["output"]): @@ -487,6 +545,39 @@ def _execute_async_tool_calls_outputs(self, tool_outputs): tool_output["output"] = str(result) return tool_outputs + + def _get_sync_async_tool_calls(self, tool_calls, recipient_agent): + async_tool_calls = [] + sync_tool_calls = [] + for tool_call in tool_calls: + if tool_call.function.name.startswith("SendMessage"): + sync_tool_calls.append(tool_call) + continue + + tool = next((func for func in recipient_agent.functions if func.__name__ == tool_call.function.name), None) + + if (hasattr(tool.ToolConfig, "async_mode") and tool.ToolConfig.async_mode) or self.async_mode == "tools_threading": + async_tool_calls.append(tool_call) + else: + sync_tool_calls.append(tool_call) + + return sync_tool_calls, async_tool_calls + + def get_messages(self, limit=None): + all_messages = [] + after = None + while True: + response = self.client.beta.threads.messages.list(thread_id=self.id, limit=100, after=after) + messages = response.data + if not messages: + break + all_messages.extend(messages) + after = messages[-1].id # Set the 'after' cursor to the ID of the last message + + if limit and len(all_messages) >= limit: + break + + return all_messages diff --git a/agency_swarm/threads/thread_async.py b/agency_swarm/threads/thread_async.py index a6ae8f72..6f1ebd70 100644 --- a/agency_swarm/threads/thread_async.py +++ b/agency_swarm/threads/thread_async.py @@ -90,11 +90,10 @@ def check_status(self, run=None): return f"""{self.recipient_agent.name}'s Response: '{messages.data[0].content[0].text.value}'""" def get_last_run(self): - if not self.thread: - self.init_thread() + self.init_thread() runs = self.client.beta.threads.runs.list( - thread_id=self.thread.id, + thread_id=self.id, order="desc", ) diff --git a/agency_swarm/tools/BaseTool.py b/agency_swarm/tools/BaseTool.py index fd6fd6a6..51cdaeb3 100644 --- a/agency_swarm/tools/BaseTool.py +++ b/agency_swarm/tools/BaseTool.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, ClassVar +from typing import Any, ClassVar, Literal, Union from docstring_parser import parse @@ -11,15 +11,31 @@ class BaseTool(BaseModel, ABC): _shared_state: ClassVar[SharedState] = None _caller_agent: Any = None _event_handler: Any = None + _tool_call: Any = None def __init__(self, **kwargs): if not self.__class__._shared_state: self.__class__._shared_state = SharedState() super().__init__(**kwargs) + + # Ensure all ToolConfig variables are initialized + config_defaults = { + 'strict': False, + 'one_call_at_a_time': False, + 'output_as_result': False, + 'async_mode': None + } + + for key, value in config_defaults.items(): + if not hasattr(self.ToolConfig, key): + setattr(self.ToolConfig, key, value) class ToolConfig: strict: bool = False one_call_at_a_time: bool = False + # return the tool output as assistant message + output_as_result: bool = False + async_mode: Union[Literal["threading"], None] = None @classmethod @property @@ -76,5 +92,5 @@ def openai_schema(cls): return schema @abstractmethod - def run(self, **kwargs): + def run(self): pass diff --git a/agency_swarm/tools/__init__.py b/agency_swarm/tools/__init__.py index 8bca7b06..d77d38dd 100644 --- a/agency_swarm/tools/__init__.py +++ b/agency_swarm/tools/__init__.py @@ -2,4 +2,4 @@ from .oai.CodeInterpreter import CodeInterpreter from .oai.FileSearch import FileSearch from .oai.Retrieval import Retrieval -from .ToolFactory import ToolFactory +from .ToolFactory import ToolFactory \ No newline at end of file diff --git a/agency_swarm/tools/send_message/SendMessage.py b/agency_swarm/tools/send_message/SendMessage.py new file mode 100644 index 00000000..d5c777ac --- /dev/null +++ b/agency_swarm/tools/send_message/SendMessage.py @@ -0,0 +1,44 @@ +from typing import Optional, List +from pydantic import Field, field_validator, model_validator +from .SendMessageBase import SendMessageBase + +class SendMessage(SendMessageBase): + """Use this tool to facilitate direct, synchronous communication between specialized agents within your agency. When you send a message using this tool, you receive a response exclusively from the designated recipient agent. To continue the dialogue, invoke this tool again with the desired recipient agent and your follow-up message. Remember, communication here is synchronous; the recipient agent won't perform any tasks post-response. You are responsible for relaying the recipient agent's responses back to the user, as the user does not have direct access to these replies. Keep engaging with the tool for continuous interaction until the task is fully resolved. Do not send more than 1 message to the same recipient agent at the same time.""" + my_primary_instructions: str = Field( + ..., + description=( + "Please repeat your primary instructions step-by-step, including both completed " + "and the following next steps that you need to perform. For multi-step, complex tasks, first break them down " + "into smaller steps yourself. Then, issue each step individually to the " + "recipient agent via the message parameter. Each identified step should be " + "sent in a separate message. Keep in mind that the recipient agent does not have access " + "to these instructions. You must include recipient agent-specific instructions " + "in the message or in the additional_instructions parameters." + ) + ) + message: str = Field( + ..., + description="Specify the task required for the recipient agent to complete. Focus on clarifying what the task entails, rather than providing exact instructions. Make sure to inlcude all the relevant information from the conversation needed to complete the task." + ) + message_files: Optional[List[str]] = Field( + default=None, + description="A list of file IDs to be sent as attachments to this message. Only use this if you have the file ID that starts with 'file-'.", + examples=["file-1234", "file-5678"] + ) + additional_instructions: Optional[str] = Field( + default=None, + description="Additional context or instructions from the conversation needed by the recipient agent to complete the task." + ) + + @model_validator(mode='after') + def validate_files(self): + # prevent hallucinations with agents sending file IDs into incorrect fields + if "file-" in self.message or (self.additional_instructions and "file-" in self.additional_instructions): + if not self.message_files: + raise ValueError("You must include file IDs in message_files parameter.") + return self + + def run(self): + return self._get_completion(message=self.message, + message_files=self.message_files, + additional_instructions=self.additional_instructions) \ No newline at end of file diff --git a/agency_swarm/tools/send_message/SendMessageAsyncThreading.py b/agency_swarm/tools/send_message/SendMessageAsyncThreading.py new file mode 100644 index 00000000..38912e09 --- /dev/null +++ b/agency_swarm/tools/send_message/SendMessageAsyncThreading.py @@ -0,0 +1,8 @@ +from typing import ClassVar, Type +from agency_swarm.threads.thread_async import ThreadAsync +from .SendMessage import SendMessage + +class SendMessageAsyncThreading(SendMessage): + """Use this tool for asynchronous communication with other agents within your agency. Initiate tasks by messaging, and check status and responses later with the 'GetResponse' tool. Relay responses to the user, who instructs on status checks. Continue until task completion.""" + class ToolConfig: + async_mode = "threading" \ No newline at end of file diff --git a/agency_swarm/tools/send_message/SendMessageBase.py b/agency_swarm/tools/send_message/SendMessageBase.py new file mode 100644 index 00000000..26b61128 --- /dev/null +++ b/agency_swarm/tools/send_message/SendMessageBase.py @@ -0,0 +1,41 @@ +from agency_swarm.agents.agent import Agent +from agency_swarm.threads.thread import Thread +from typing import ClassVar, Union +from pydantic import Field, field_validator +from agency_swarm.threads.thread_async import ThreadAsync +from agency_swarm.tools import BaseTool +from abc import ABC + +class SendMessageBase(BaseTool, ABC): + recipient: str = Field(..., description="Recipient agent that you want to send the message to. This field will be overriden inside the agency class.") + + _agents_and_threads: ClassVar = None + + @field_validator('additional_instructions', mode='before', check_fields=False) + @classmethod + def validate_additional_instructions(cls, value): + # previously the parameter was a list, now it's a string + # add compatibility for old code + if isinstance(value, list): + return "\n".join(value) + return value + + def _get_thread(self) -> Thread | ThreadAsync: + return self._agents_and_threads[self._caller_agent.name][self.recipient.value] + + def _get_main_thread(self) -> Thread | ThreadAsync: + return self._agents_and_threads["main_thread"] + + def _get_recipient_agent(self) -> Agent: + return self._agents_and_threads[self._caller_agent.name][self.recipient.value].recipient_agent + + def _get_completion(self, message: Union[str, None] = None, **kwargs): + thread = self._get_thread() + + if self.ToolConfig.async_mode == "threading": + return thread.get_completion_async(message=message, **kwargs) + else: + return thread.get_completion(message=message, + event_handler=self._event_handler, + yield_messages=not self._event_handler, + **kwargs) \ No newline at end of file diff --git a/agency_swarm/tools/send_message/SendMessageQuick.py b/agency_swarm/tools/send_message/SendMessageQuick.py new file mode 100644 index 00000000..f673f909 --- /dev/null +++ b/agency_swarm/tools/send_message/SendMessageQuick.py @@ -0,0 +1,14 @@ +from agency_swarm.threads.thread import Thread +from typing import ClassVar, Optional, List, Type +from pydantic import Field, field_validator, model_validator +from .SendMessageBase import SendMessageBase + +class SendMessageQuick(SendMessageBase): + """Use this tool to facilitate direct, synchronous communication between specialized agents within your agency. When you send a message using this tool, you receive a response exclusively from the designated recipient agent. To continue the dialogue, invoke this tool again with the desired recipient agent and your follow-up message. Remember, communication here is synchronous; the recipient agent won't perform any tasks post-response. You are responsible for relaying the recipient agent's responses back to the user, as the user does not have direct access to these replies. Keep engaging with the tool for continuous interaction until the task is fully resolved. Do not send more than 1 message to the same recipient agent at the same time.""" + message: str = Field( + ..., + description="Specify the task required for the recipient agent to complete. Focus on clarifying what the task entails, rather than providing exact instructions. Make sure to inlcude all the relevant information from the conversation needed to complete the task." + ) + + def run(self): + return self._get_completion(message=self.message) \ No newline at end of file diff --git a/agency_swarm/tools/send_message/SendMessageSwarm.py b/agency_swarm/tools/send_message/SendMessageSwarm.py new file mode 100644 index 00000000..868b9258 --- /dev/null +++ b/agency_swarm/tools/send_message/SendMessageSwarm.py @@ -0,0 +1,48 @@ +from openai import BadRequestError +from agency_swarm.threads.thread import Thread +from .SendMessage import SendMessageBase + +class SendMessageSwarm(SendMessageBase): + """Use this tool to route messages to other agents within your agency. After using this tool, you will be switched to the recipient agent. This tool can only be used once per message. Do not use any other tools together with this tool.""" + + class ToolConfig: + # set output as result because the communication will be finished after this tool is called + output_as_result: bool = True + one_call_at_a_time: bool = True + + def run(self): + # get main thread + thread = self._get_main_thread() + + # get recipient agent from thread + recipient_agent = self._get_recipient_agent() + + # submit tool output + try: + thread._submit_tool_outputs( + tool_outputs=[{"tool_call_id": self._tool_call.id, "output": "The request has been routed. You are now a " + recipient_agent.name + " agent. Please assist the user further with their request."}], + poll=False + ) + except BadRequestError as e: + raise Exception("You can only call this tool by itself. Do not use any other tools together with this tool.") + + try: + # cancel run + thread._cancel_run() + + # change recipient agent in thread + thread.recipient_agent = recipient_agent + + # change recipient agent in gradio dropdown + if self._event_handler: + if hasattr(self._event_handler, "change_recipient_agent"): + self._event_handler.change_recipient_agent(self.recipient.value) + + # continue conversation with the new recipient agent + message = thread.get_completion(message=None, recipient_agent=recipient_agent, yield_messages=not self._event_handler, event_handler=self._event_handler) + + return message or "" + except Exception as e: + # we need to catch errors beucase tool outputs are already submitted + print("Error in SendMessageSwarm: ", e) + return str(e) diff --git a/agency_swarm/tools/send_message/__init__.py b/agency_swarm/tools/send_message/__init__.py new file mode 100644 index 00000000..d044d285 --- /dev/null +++ b/agency_swarm/tools/send_message/__init__.py @@ -0,0 +1,5 @@ +from .SendMessageAsyncThreading import SendMessageAsyncThreading +from .SendMessageBase import SendMessageBase +from .SendMessage import SendMessage +from .SendMessageSwarm import SendMessageSwarm +from .SendMessageQuick import SendMessageQuick \ No newline at end of file diff --git a/docs/advanced-usage/communication_flows.md b/docs/advanced-usage/communication_flows.md new file mode 100644 index 00000000..c65d1a17 --- /dev/null +++ b/docs/advanced-usage/communication_flows.md @@ -0,0 +1,231 @@ +# Advanced Communication Flows + +Multi-agent communication is the core functionality of any Multi-Agent System. Unlike in all other frameworks, Agency Swarm not only allows you to define communication flows in any way you want (uniform communication flows), but to also configure the underlying logic for this feature. This means that you can create entirely new types of communication, or adjust it to your own needs. Below you will find a guide on how to do all this, along with some common examples. + +## Pre-Made SendMessage Classes + +Agency Swarm contains multiple commonly requested classes for communication flows. Currently, the following classes are available: + +| Class Name | Description | When to Use | Code Link | +| --------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------- | --------- | +| `SendMessage` (default) | This is the default class for sending messages to other agents. It uses synchronous communication with basic COT (Chain of Thought) prompting and allows agents to relay files and modify system instructions for each other. | Suitable for most use cases. Balances speed and functionality. | | +| `SendMessageQuick` | A variant of the SendMessage class without Chain of Thought prompting, files, and additional instructions. It allows for faster communication without the overhead of COT. | Use for simpler use cases or when you want to save tokens and increase speed. | | +| `SendMessageAsyncThreading` | Similar to `SendMessage` but with `async_mode='threading'`. Each agent will execute asynchronously in a separate thread. In the meantime, the caller agent can continue the conversation with the user and check the results later. | Use for asynchronous applications or when sub-agents take singificant amounts of time to complete their tasks. | | +| `SendMessageSwarm` | Instead of sending a message to another agent, it replaces the caller agent with the recipient agent, similar to [OpenAI's Swarm](https://github.com/openai/swarm). The recipient agent will then have access to the entire conversation. | When you need more granular control. It is not able to handle complex multi-step, multi-agent tasks. | | + +**To use any of the pre-made `SendMessage` classes**, simply put it in the `send_message_tool_class` parameter when initializing the `Agency` class: + +```python +from agency_swarm.tools.send_message import SendMessageQuick + +agency = Agency( + ... + send_message_tool_class=SendMessageQuick +) +``` + +That's it! Now, your agents will use your own custom `SendMessageQuick` class for communication. + +## Creating Your Own Unique Communication Flows + +To create you own communication flow, you will first need to extend the `SendMessageBase` class. This class extends the `BaseTool` class, like any other tools in Agency Swarm, and contains the most basic parameters required for communication, such as the `recipient_agent`. + +### Default `SendMessage` Class + +By defualt, Agency Swarm uses the following tool for communication: + +```python +from typing import Optional, List +from pydantic import Field, field_validator, model_validator +from .SendMessageBase import SendMessageBase + +class SendMessage(SendMessageBase): + """Use this tool to facilitate direct, synchronous communication between specialized agents within your agency. When you send a message using this tool, you receive a response exclusively from the designated recipient agent. To continue the dialogue, invoke this tool again with the desired recipient agent and your follow-up message. Remember, communication here is synchronous; the recipient agent won't perform any tasks post-response. You are responsible for relaying the recipient agent's responses back to the user, as the user does not have direct access to these replies. Keep engaging with the tool for continuous interaction until the task is fully resolved. Do not send more than 1 message to the same recipient agent at the same time.""" + my_primary_instructions: str = Field( + ..., + description=( + "Please repeat your primary instructions step-by-step, including both completed " + "and the following next steps that you need to perform. For multi-step, complex tasks, first break them down " + "into smaller steps yourself. Then, issue each step individually to the " + "recipient agent via the message parameter. Each identified step should be " + "sent in a separate message. Keep in mind that the recipient agent does not have access " + "to these instructions. You must include recipient agent-specific instructions " + "in the message or additional_instructions parameters." + ) + ) + message: str = Field( + ..., + description="Specify the task required for the recipient agent to complete. Focus on clarifying what the task entails, rather than providing exact instructions. Make sure to inlcude all the relevant information needed to complete the task." + ) + message_files: Optional[List[str]] = Field( + default=None, + description="A list of file IDs to be sent as attachments to this message. Only use this if you have the file ID that starts with 'file-'.", + examples=["file-1234", "file-5678"] + ) + additional_instructions: Optional[str] = Field( + default=None, + description="Additional context or instructions from the conversation needed by the recipient agent to complete the task." + ) + + @model_validator(mode='after') + def validate_files(self): + # prevent hallucinations with agents sending file IDs into incorrect fields + if "file-" in self.message or (self.additional_instructions and "file-" in self.additional_instructions): + if not self.message_files: + raise ValueError("You must include file IDs in message_files parameter.") + return self + + + def run(self): + return self._get_completion(message=self.message, + message_files=self.message_files, + additional_instructions=self.additional_instructions) +``` + +Let's break down the code. + +In general, all `SendMessage` tools have the following components: + +1. **The Docstring**: This is used to generate a description of the tool for the agent. This part should clearly describe how your multi-agent communication works, along with some additional guidelines on how to use it. +2. **Parameters**: Parameters like `message`, `message_files`, `additional_instructions` are used to provide the recipient agent with the necessary information. +3. **The `run` method**: This is where the communication logic is implemented. Most of the time, you just need to map your parameters to `self._get_completion()` the same way you would call it in the `agency.get_completion()` method. + +When creating your own `SendMessage` tools, you can use the above components as a template. + +### Common Use Cases + +In the following sections, we'll look at some common use cases for extending the `SendMessageBase` tool and how to implement them, so you can learn how to create your own SendMessage tools and use them in your own applications. + +#### 1. Adjusting parameters and descriptions + +The most basic use case is if you want to use your own parameter descriptions, such as if you want to change the docstring or the description of the `message` parameter. This can help you better customize how the agents communicate with each other and what information they relay. + +Let's say that instead of sending messages, I want my agents to send tasks to each other. In this case, I can change the docstring and the `message` parameter to a `task` parameter to better fit the nature of my application. + +```python +from pydantic import Field +from agency_swarm.tools.send_message import SendMessageBase + +class SendMessageTask(SendMessageBase): + """Use this tool to send tasks to other agents within your agency.""" + chain_of_thought: str = Field( + ..., + description="Please think step-by-step about how to solve your current task, provided by the user. Then, break down this task into smaller steps and issue each step individually to the recipient agent via the task parameter." + ) + task: str = Field( + ..., + description="Specify the task required for the recipient agent to complete. Focus on clarifying what the task entails, rather than providing exact instructions. Make sure to inlcude all the relevant information needed to complete the task." + ) + + def run(self): + return self._get_completion(message=self.task) +``` + +To remove the chain of thought, you can simply remove the `chain_of_thought` parameter. + +#### 2. Adding custom validation logic + +Now, let's say that I need to ensure that my message is sent to the correct recepient agent. (This is a very common hallucination in production.) In this case, I can add custom validator to the `recipient` parameter, which is defined in the `SendMessageBase` class. Since I don't want to change any other parameters or descriptions, I can inherit the default `SendMessage` class and only add this new validation logic. + +```python +from agency_swarm.tools.send_message import SendMessage +from pydantic import model_validator + +class SendMessageValidation(SendMessage): + @model_validator(mode='after') + def validate_recipient(self): + if "customer support" not in self.message.lower() and self.recipient == "CustomerSupportAgent": + raise ValueError("Messages not related to customer support cannot be sent to the customer support agent.") + return self +``` + +You can, of course, also use GPT for this: + +```python +from agency_swarm.tools.send_message import SendMessage +from agency_swarm.util.validators import llm_validator +from pydantic import model_validator + +class SendMessageLLMValidation(SendMessage): + @model_validator(mode='after') + def validate_recipient(self): + if self.recipient == "CustomerSupportAgent": + llm_validator( + statement="The message is related to customer support." + )(self.message) + return self +``` + +In this example, the `llm_validator` will throw an error if the message is not related to customer support. The caller agent will then have to fix the recipient or the message and send it again! This is extremely useful when you have a lot of agents. + +#### 3. Summurizing previous conversations with other agents and adding to context + +Sometimes, when using default `SendMessage`, the agents might not relay all the neceessary details to the recipient agent. Especially, when the previous conversation is too long. In this case, you can summarize the previous conversation with GPT and add it to the context, instead of the additional instructions. I will extend the `SendMessageQuick` class, which already contains the `message` parameter, as I don't need chain of thought or files in this case. + +```python +from agency_swarm.tools.send_message import SendMessageQuick +from agency_swarm.util.oai import get_openai_client + +class SendMessageSummary(SendMessageQuick): + def run(self): + client = get_openai_client() + thread = self._get_main_thread() # get the main thread (conversation with the user) + + # get the previous messages + previous_messages = thread.get_messages() + previous_messages_str = "\n".join([f"{m.role}: {m.content[0].text.value}" for m in previous_messages]) + + # summarize the previous conversation + summary = client.chat.completions.create( + model="gpt-4o-mini", + messages=[ + {"role": "system", "content": "You are a world-class summarizer. Please summarize the following conversation in a few sentences:"}, + {"role": "user", "content": previous_messages_str} + ] + ) + + # send the message with the summary + return self._get_completion(message=self.message, additional_instructions=f"\n\nPrevious conversation summary: '{summary.choices[0].message.content}'") +``` + +With this example, you can add your own custom logic to the `run` method. It does not have to be a summary; you can also use it to add any other information to the context. For example, you can even query a vector database or use an external API. + +#### 4. Running each agent in a separate API call + +If you are a PRO, and you have managed to deploy each agent in a separate API endpoint, instead of using `_get_completion()`, you can call your own API and let the agents communicate with each other over the internet. + +```python +import requests +from agency_swarm.tools.send_message import SendMessage + +class SendMessageAPI(SendMessage): + def run(self): + response = requests.post( + "https://your-api-endpoint.com/send-message", + json={"message": self.message, "recipient": self.recipient} + ) + return response.json()["message"] +``` + +This is very powerful, as you can even allow your agents to colloborate with agents outside your system. More on this is coming soon! + +!!! tip "Contributing" + + If you have any ideas for new communication flows, please either adjust this page in docs, or add your new send message tool in the `agency_swarm/tools/send_message` folder and open a PR! + +**After implementing your own `SendMessage` tool**, simply pass it into the `send_message_tool_class` parameter when initializing the `Agency` class: + +```python +agency = Agency( + ... + send_message_tool_class=SendMessageAPI +) +``` + +That's it! Now, your agents will use your own custom `SendMessageAPI` class for communication! + +## Conclusion + +Agency Swarm has been designed to give you, the developer, full control over your systems. It is the only framework that does not hard-code any prompts, parameters, or even worse, agents for you. With this new feature, the last part of the system that you couldn't fully customize to your own needs is now gone! + +So, I want to encourage you to keep experimenting and designing your own unique communication flows. While the examples above should serve as a good starting point, they do not even merely scratch the surface of what's possible here! I am looking forward to seeing what you will create. Please share it in our [Discord server](https://discord.gg/7HcABDpFPG) so we can all learn from each other. diff --git a/mkdocs.yml b/mkdocs.yml index 647e3e7b..58e118c4 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -14,7 +14,7 @@ theme: - navigation.instant.prefetch - navigation.instant.progress - navigation.prune -# - navigation.sections + # - navigation.sections - navigation.tabs # - navigation.tabs.sticky - navigation.top @@ -43,10 +43,11 @@ nav: - Introduction: "index.md" - Quick Start: "quick_start.md" - Advanced Usage: - - Advanced Tools: "advanced-usage/tools.md" - - Agents: "advanced-usage/agents.md" - - Agencies: "advanced-usage/agencies.md" - - Azure OpenAI: "advanced-usage/azure-openai.md" + - Tools: "advanced-usage/tools.md" + - Agents: "advanced-usage/agents.md" + - Agencies: "advanced-usage/agencies.md" + - Communication: "advanced-usage/communication_flows.md" + - Azure OpenAI: "advanced-usage/azure-openai.md" - Deployment to Production: "deployment.md" - Open Source Models: "advanced-usage/open-source-models.md" - API Reference: "api.md" diff --git a/pyproject.toml b/pyproject.toml index 17254123..4055e836 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta" [project] name = "agency-swarm" dynamic = ["version"] -authors = [{ name = "VRSEN", email = "arseny9795@gmail.com" }] +authors = [{ name = "VRSEN", email = "me@vrsen.ai" }] description = "An open source agent orchestration framework built on top of the latest OpenAI Assistants API." readme = "README.md" license = { file = "LICENSE" } @@ -15,7 +15,7 @@ classifiers = [ "License :: OSI Approved :: MIT License", ] dependencies = [ - "openai==1.54.3", + "openai==1.54.4", "docstring_parser==0.16", "pydantic==2.8.2", "datamodel-code-generator==0.26.1", diff --git a/requirements.txt b/requirements.txt index 6d7b06c5..c66fce38 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -openai==1.54.3 +openai==1.54.4 docstring_parser==0.16 pydantic==2.8.2 datamodel-code-generator==0.26.1 diff --git a/setup.py b/setup.py index 33f796a7..2dc7a02f 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name='agency-swarm', - version='0.3.1', + version='0.4.0', author='VRSEN', author_email='me@vrsen.ai', description='An opensource agent orchestration framework built on top of the latest OpenAI Assistants API.', diff --git a/tests/test_agency.py b/tests/test_agency.py index 7ed46b85..88177118 100644 --- a/tests/test_agency.py +++ b/tests/test_agency.py @@ -13,6 +13,7 @@ from agency_swarm.tools import CodeInterpreter, FileSearch sys.path.insert(0, '../agency-swarm') +from agency_swarm.tools.send_message import SendMessageAsyncThreading from agency_swarm.util import create_agent_template from agency_swarm import set_openai_key, Agent, Agency, AgencyEventHandler, get_openai_client @@ -219,11 +220,11 @@ def test_4_agent_communication(self): message = self.__class__.agency.get_completion("Please tell TestAgent1 to say test to TestAgent2.", tool_choice={"type": "function", "function": {"name": "SendMessage"}}) - self.assertFalse('error' in message.lower(), f"Error found in message: {message}") + self.assertFalse('error' in message.lower(), f"Error found in message: {message}. Thread url: {self.__class__.agency.main_thread.thread_url}") - for agent_name, threads in self.__class__.agency.agents_and_threads.items(): - for other_agent_name, thread in threads.items(): - self.assertTrue(thread.id in self.__class__.loaded_thread_ids[agent_name][other_agent_name]) + self.assertTrue(self.__class__.agency.agents_and_threads['main_thread'].id) + self.assertTrue(self.__class__.agency.agents_and_threads['CEO']['TestAgent1'].id) + self.assertTrue(self.__class__.agency.agents_and_threads['TestAgent1']['TestAgent2'].id) for agent in self.__class__.agency.agents: self.assertTrue(agent.id in [settings['id'] for settings in self.__class__.loaded_agents_settings]) @@ -238,7 +239,7 @@ def test_4_agent_communication(self): self.assertTrue(thread_messages.data[0].content[0].text.value == "Hi!") - run = main_thread.run + run = main_thread._run self.assertTrue(run.max_prompt_tokens == self.__class__.ceo.max_prompt_tokens) self.assertTrue(run.max_completion_tokens == self.__class__.ceo.max_completion_tokens) self.assertTrue(run.tool_choice.type == "function") @@ -251,7 +252,7 @@ def test_4_agent_communication(self): self.assertTrue(len(agent1_thread_messages.data) == 2) - agent1_run = agent1_thread.run + agent1_run = agent1_thread._run self.assertTrue(agent1_run.truncation_strategy.type == "last_messages") self.assertTrue(agent1_run.truncation_strategy.last_messages == 10) @@ -309,11 +310,11 @@ def on_all_streams_end(cls): self.assertTrue(self.__class__.TestTool._shared_state.get("test_tool_used")) agent1_thread = self.__class__.agency.agents_and_threads[self.__class__.ceo.name][self.__class__.agent1.name] - self.assertFalse(agent1_thread.run.parallel_tool_calls) + self.assertFalse(agent1_thread._run.parallel_tool_calls) - for agent_name, threads in self.__class__.agency.agents_and_threads.items(): - for other_agent_name, thread in threads.items(): - self.assertTrue(thread.id in self.__class__.loaded_thread_ids[agent_name][other_agent_name]) + self.assertTrue(self.__class__.agency.main_thread.id) + self.assertTrue(self.__class__.agency.agents_and_threads['CEO']['TestAgent1'].id) + self.assertTrue(self.__class__.agency.agents_and_threads['TestAgent1']['TestAgent2'].id) for agent in self.__class__.agency.agents: self.assertTrue(agent.id in [settings['id'] for settings in self.__class__.loaded_agents_settings]) @@ -322,8 +323,8 @@ def test_6_load_from_db(self): """it should load agents from db""" # os.rename("settings.json", "settings2.json") - previous_loaded_thread_ids = self.__class__.loaded_thread_ids - previous_loaded_agents_settings = self.__class__.loaded_agents_settings + previous_loaded_thread_ids = self.__class__.loaded_thread_ids.copy() + previous_loaded_agents_settings = self.__class__.loaded_agents_settings.copy() from test_agents.CEO import CEO from test_agents.TestAgent1 import TestAgent1 @@ -368,10 +369,18 @@ def test_6_load_from_db(self): self.check_all_agents_settings() # check that threads are the same - for agent_name, threads in agency.agents_and_threads.items(): - for other_agent_name, thread in threads.items(): - self.assertTrue(thread.id in self.__class__.loaded_thread_ids[agent_name][other_agent_name]) - self.assertTrue(thread.id in previous_loaded_thread_ids[agent_name][other_agent_name]) + print("previous_loaded_thread_ids", previous_loaded_thread_ids) + print("self.__class__.loaded_thread_ids", self.__class__.loaded_thread_ids) + # Start of Selection + for agent, threads in self.__class__.agency.agents_and_threads.items(): + if agent == "main_thread": + print("main_thread", threads) + continue + for other_agent, thread in threads.items(): + print(f"Thread ID between {agent} and {other_agent}: {thread.id}") + self.assertTrue(self.__class__.agency.agents_and_threads['main_thread'].id == previous_loaded_thread_ids['main_thread'] == self.__class__.loaded_thread_ids['main_thread']) + self.assertTrue(self.__class__.agency.agents_and_threads['CEO']['TestAgent1'].id == previous_loaded_thread_ids['CEO']['TestAgent1'] == self.__class__.loaded_thread_ids['CEO']['TestAgent1']) + self.assertTrue(self.__class__.agency.agents_and_threads['TestAgent1']['TestAgent2'].id == previous_loaded_thread_ids['TestAgent1']['TestAgent2'] == self.__class__.loaded_thread_ids['TestAgent1']['TestAgent2']) # check that agents are the same for agent in agency.agents: @@ -379,7 +388,7 @@ def test_6_load_from_db(self): self.assertTrue(agent.id in [settings['id'] for settings in previous_loaded_agents_settings]) def test_7_init_async_agency(self): - """it should initialize agency with agents""" + """it should initialize async agency with agents""" # reset loaded thread ids self.__class__.loaded_thread_ids = {} @@ -397,7 +406,7 @@ def test_7_init_async_agency(self): shared_instructions="", settings_callbacks=self.__class__.settings_callbacks, threads_callbacks=self.__class__.threads_callbacks, - async_mode='threading', + send_message_tool_class=SendMessageAsyncThreading, temperature=0, ) @@ -405,7 +414,6 @@ def test_7_init_async_agency(self): def test_8_async_agent_communication(self): """it should communicate between agents asynchronously""" - print("TestAgent1 tools", self.__class__.agent1.tools) self.__class__.agency.get_completion("Please tell TestAgent2 hello.", tool_choice={"type": "function", "function": {"name": "SendMessage"}}, recipient_agent=self.__class__.agent1) @@ -449,9 +457,8 @@ def on_all_streams_end(cls): if 'error' in message.lower(): self.assertFalse('error' in message.lower(), self.__class__.agency.main_thread.thread_url) - for agent_name, threads in self.__class__.agency.agents_and_threads.items(): - for other_agent_name, thread in threads.items(): - self.assertTrue(thread.id in self.__class__.loaded_thread_ids[agent_name][other_agent_name]) + self.assertTrue(self.__class__.agency.main_thread.id) + self.assertTrue(self.__class__.agency.agents_and_threads['TestAgent1']['TestAgent2'].id) for agent in self.__class__.agency.agents: self.assertTrue(agent.id in [settings['id'] for settings in self.__class__.loaded_agents_settings]) @@ -459,11 +466,16 @@ def on_all_streams_end(cls): def test_9_async_tool_calls(self): """it should execute tools asynchronously""" class PrintTool(BaseTool): + class ToolConfig: + async_mode = "threading" def run(self, **kwargs): time.sleep(2) # Simulate a delay return "Printed successfully." class AnotherPrintTool(BaseTool): + class ToolConfig: + async_mode = "threading" + def run(self, **kwargs): time.sleep(2) # Simulate a delay return "Another print successful." @@ -472,12 +484,9 @@ def run(self, **kwargs): agency = Agency( [ceo], - async_mode='tools_threading', temperature=0 ) - self.assertTrue(agency.main_thread.async_mode == 'tools_threading') - result = agency.get_completion("Use 2 print tools together at the same time and output the results exectly as they are. ", yield_messages=False) self.assertIn("success", result.lower(), agency.main_thread.thread_url) diff --git a/tests/test_communication.py b/tests/test_communication.py new file mode 100644 index 00000000..216efd29 --- /dev/null +++ b/tests/test_communication.py @@ -0,0 +1,64 @@ +import unittest +from agency_swarm import Agent, Agency +from agency_swarm.tools.send_message import SendMessageSwarm +from agency_swarm.tools import BaseTool +from pydantic import Field + +class TestSendMessage(unittest.TestCase): + def setUp(self): + class PrintTool(BaseTool): + """ + A simple tool that prints a message. + """ + message: str = Field(..., description="The message to print.") + + def run(self): + print(self.message) + return f"Printed: {self.message}" + + self.ceo = Agent( + name="CEO", + description="Responsible for client communication, task planning and management.", + instructions="Your role is to route messages to other agents within your agency.", + tools=[PrintTool] + ) + + self.customer_support = Agent( + name="Customer Support", + description="Responsible for customer support.", + instructions="You are a Customer Support agent. Answer customer questions and help with issues.", + tools=[] + ) + + self.agency = Agency([self.ceo, [self.ceo, self.customer_support], [self.customer_support, self.ceo]], + temperature=0, send_message_tool_class=SendMessageSwarm) + + def test_send_message_swarm(self): + response = self.agency.get_completion("Hello, can you send me to customer support? If tool responds says that you have NOT been rerouted, or if there is another error, please say 'error'") + self.assertFalse("error" in response.lower(), self.agency.main_thread.thread_url) + response = self.agency.get_completion("Who are you?") + self.assertTrue("customer support" in response.lower(), self.agency.main_thread.thread_url) + + main_thread = self.agency.main_thread + + # check if recipient agent is correct + self.assertEqual(main_thread.recipient_agent, self.customer_support) + + #check if all messages in the same thread (this is how Swarm works) + self.assertTrue(len(main_thread.get_messages()) >= 4) # sometimes run does not cancel immediately, so there might be 5 messages + + def test_send_message_double_recepient_error(self): + ceo = Agent(name="CEO", + description="Responsible for client communication, task planning and management.", + instructions="You are an agent for testing. Route request AT THE SAME TIME as instructed. If there is an error in a single request, please say 'error'. If there are errors in both requests, please say 'fatal'. do not output anything else.", + ) + test_agent = Agent(name="Test Agent1", + description="Responsible for testing.", + instructions="Test agent for testing.") + agency = Agency([ceo, [ceo, test_agent]], temperature=0) + response = agency.get_completion("Please route me to customer support TWICE at the same time. I am testing something.") + self.assertTrue("error" in response.lower(), agency.main_thread.thread_url) + self.assertTrue("fatal" not in response.lower(), agency.main_thread.thread_url) + +if __name__ == '__main__': + unittest.main()