diff --git a/agency_swarm/agency/agency.py b/agency_swarm/agency/agency.py index 4e5001d6..50840c40 100644 --- a/agency_swarm/agency/agency.py +++ b/agency_swarm/agency/agency.py @@ -6,7 +6,6 @@ import uuid from enum import Enum from typing import Any, Callable, Dict, List, Literal, Optional, Type, TypeVar, TypedDict, Union - from openai.lib._parsing._completions import type_to_response_format_param from openai.types.beta.threads import Message from openai.types.beta.threads.runs import RunStep @@ -24,7 +23,9 @@ from agency_swarm.messages import MessageOutput from agency_swarm.messages.message_output import MessageOutputLive from agency_swarm.threads import Thread +from agency_swarm.threads.thread_async import ThreadAsync from agency_swarm.tools import BaseTool, CodeInterpreter, FileSearch +from agency_swarm.tools.send_message import SendMessage, SendMessageBase from agency_swarm.user import User from agency_swarm.util.errors import RefusalError from agency_swarm.util.files import get_tools, get_file_purpose @@ -46,15 +47,12 @@ class ThreadsCallbacks(TypedDict): class Agency: - ThreadType = Thread - send_message_tool_description = """Use this tool to facilitate direct, synchronous communication between specialized agents within your agency. When you send a message using this tool, you receive a response exclusively from the designated recipient agent. To continue the dialogue, invoke this tool again with the desired recipient agent and your follow-up message. Remember, communication here is synchronous; the recipient agent won't perform any tasks post-response. You are responsible for relaying the recipient agent's responses back to the user, as the user does not have direct access to these replies. Keep engaging with the tool for continuous interaction until the task is fully resolved. Do not send more than 1 message at a time.""" - send_message_tool_description_async = """Use this tool for asynchronous communication with other agents within your agency. Initiate tasks by messaging, and check status and responses later with the 'GetResponse' tool. Relay responses to the user, who instructs on status checks. Continue until task completion.""" - def __init__(self, agency_chart: List, shared_instructions: str = "", shared_files: Union[str, List[str]] = None, async_mode: Literal['threading', "tools_threading"] = None, + send_message_tool_class: Type[SendMessageBase] = SendMessage, settings_path: str = "./settings.json", settings_callbacks: SettingsCallbacks = None, threads_callbacks: ThreadsCallbacks = None, @@ -72,6 +70,7 @@ def __init__(self, shared_instructions (str, optional): A path to a file containing shared instructions for all agents. Defaults to an empty string. shared_files (Union[str, List[str]], optional): A path to a folder or a list of folders containing shared files for all agents. Defaults to None. async_mode (str, optional): Specifies the mode for asynchronous processing. In "threading" mode, all sub-agents run in separate threads. In "tools_threading" mode, all tools run in separate threads, but agents do not. Defaults to None. + send_message_tool_class (Type[SendMessageBase], optional): The class to use for the send_message tool. For async communication, use `SendMessageAsyncThreading`. Defaults to SendMessage. settings_path (str, optional): The path to the settings file for the agency. Must be json. If file does not exist, it will be created. Defaults to None. settings_callbacks (SettingsCallbacks, optional): A dictionary containing functions to load and save settings for the agency. The keys must be "load" and "save". Both values must be defined. Defaults to None. threads_callbacks (ThreadsCallbacks, optional): A dictionary containing functions to load and save threads for the agency. The keys must be "load" and "save". Both values must be defined. Defaults to None. @@ -92,6 +91,7 @@ def __init__(self, self.recipient_agents = None # for autocomplete self.shared_files = shared_files if shared_files else [] self.async_mode = async_mode + self.send_message_tool_class = send_message_tool_class self.settings_path = settings_path self.settings_callbacks = settings_callbacks self.threads_callbacks = threads_callbacks @@ -102,8 +102,9 @@ def __init__(self, self.truncation_strategy = truncation_strategy if self.async_mode == "threading": - from agency_swarm.threads.thread_async import ThreadAsync - self.ThreadType = ThreadAsync + from agency_swarm.tools.send_message import SendMessageAsyncThreading + print("Warning: 'threading' mode is deprecated. Please use send_message_tool_class = SendMessageAsyncThreading to use async communication.") + self.send_message_tool_class = SendMessageAsyncThreading elif self.async_mode == "tools_threading": Thread.async_mode = self.async_mode elif self.async_mode is None: @@ -121,9 +122,9 @@ def __init__(self, self.shared_state = SharedState() self._parse_agency_chart(agency_chart) + self._init_threads() self._create_special_tools() self._init_agents() - self._init_threads() def get_completion(self, message: str, message_files: List[str] = None, @@ -300,7 +301,7 @@ def demo_gradio(self, height=450, dark_mode=True, **kwargs): images = [] message_file_names = None uploading_files = False - recipient_agents = [agent.name for agent in self.main_recipients] + recipient_agent_names = [agent.name for agent in self.main_recipients] recipient_agent = self.main_recipients[0] with gr.Blocks(js=js) as demo: @@ -308,7 +309,7 @@ def demo_gradio(self, height=450, dark_mode=True, **kwargs): chatbot = gr.Chatbot(height=height) with gr.Row(): with gr.Column(scale=9): - dropdown = gr.Dropdown(label="Recipient Agent", choices=recipient_agents, + dropdown = gr.Dropdown(label="Recipient Agent", choices=recipient_agent_names, value=recipient_agent.name) msg = gr.Textbox(label="Your Message", lines=4) with gr.Column(scale=1): @@ -412,9 +413,14 @@ def check_and_add_tools_in_attachments(attachments, recipient_agent): class GradioEventHandler(AgencyEventHandler): message_output = None + @classmethod + def change_recipient_agent(cls, recipient_agent_name): + nonlocal chatbot_queue + chatbot_queue.put("[change_recipient_agent]") + chatbot_queue.put(recipient_agent_name) + @override def on_message_created(self, message: Message) -> None: - if message.role == "user": full_content = "" for content in message.content: @@ -530,19 +536,20 @@ def on_all_streams_end(cls): chatbot_queue.put("[end]") def bot(original_message, history): - if not original_message: - return "", history - nonlocal attachments nonlocal message_file_names nonlocal recipient_agent + nonlocal recipient_agent_names nonlocal images nonlocal uploading_files + if not original_message: + return "", history, gr.update(value=recipient_agent.name, choices=set([*recipient_agent_names, recipient_agent.name])) + if uploading_files: history.append([None, "Uploading files... Please wait."]) - yield "", history - return "", history + yield "", history, gr.update(value=recipient_agent.name, choices=set([*recipient_agent_names, recipient_agent.name])) + return "", history, gr.update(value=recipient_agent.name, choices=set([*recipient_agent_names, recipient_agent.name])) print("Message files: ", attachments) print("Images: ", images) @@ -579,13 +586,19 @@ def bot(original_message, history): new_message = True continue + if bot_message == "[change_recipient_agent]": + new_agent_name = chatbot_queue.get(block=True) + recipient_agent = self._get_agent_by_name(new_agent_name) + yield "", history, gr.update(value=new_agent_name, choices=set([*recipient_agent_names, recipient_agent.name])) + continue + if new_message: history.append([None, bot_message]) new_message = False else: history[-1][1] += bot_message - yield "", history + yield "", history, gr.update(value=recipient_agent.name, choices=set([*recipient_agent_names, recipient_agent.name])) except queue.Empty: break @@ -594,12 +607,12 @@ def bot(original_message, history): inputs=[msg, chatbot], outputs=[msg, chatbot] ).then( - bot, [msg, chatbot], [msg, chatbot] + bot, [msg, chatbot, dropdown], [msg, chatbot, dropdown] ) dropdown.change(handle_dropdown_change, dropdown) file_upload.change(handle_file_upload, file_upload) msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( - bot, [msg, chatbot], [msg, chatbot] + bot, [msg, chatbot], [msg, chatbot, dropdown] ) # Enable queuing for streaming intermediate outputs @@ -647,6 +660,7 @@ def run_demo(self): """ Executes agency in the terminal with autocomplete for recipient agent names. """ + outer_self = self from agency_swarm import AgencyEventHandler class TermEventHandler(AgencyEventHandler): message_output = None @@ -714,7 +728,7 @@ def on_tool_call_done(self, snapshot): if snapshot.type != "function": return - if snapshot.function.name == "SendMessage": + if snapshot.function.name == "SendMessage" and not (hasattr(outer_self.send_message_tool_class.ToolConfig, 'output_as_result') and outer_self.send_message_tool_class.ToolConfig.output_as_result): try: args = eval(snapshot.function.arguments) recipient = args["recipient"] @@ -864,9 +878,14 @@ def _init_threads(self): else: self.main_thread.init_thread() + # Save main_thread into agents_and_threads + self.agents_and_threads["main_thread"] = self.main_thread + for agent_name, threads in self.agents_and_threads.items(): + if agent_name == "main_thread": + continue for other_agent, items in threads.items(): - self.agents_and_threads[agent_name][other_agent] = self.ThreadType( + self.agents_and_threads[agent_name][other_agent] = self.send_message_tool_class._thread_type( self._get_agent_by_name(items["agent"]), self._get_agent_by_name( items["recipient_agent"])) @@ -880,6 +899,8 @@ def _init_threads(self): if self.threads_callbacks: loaded_thread_ids = {} for agent_name, threads in self.agents_and_threads.items(): + if agent_name == "main_thread": + continue loaded_thread_ids[agent_name] = {} for other_agent, thread in threads.items(): loaded_thread_ids[agent_name][other_agent] = thread.id @@ -1001,13 +1022,15 @@ def _create_special_tools(self): No output parameters; this method modifies the agents' toolset internally. """ for agent_name, threads in self.agents_and_threads.items(): + if agent_name == "main_thread": + continue recipient_names = list(threads.keys()) recipient_agents = self._get_agents_by_names(recipient_names) if len(recipient_agents) == 0: continue agent = self._get_agent_by_name(agent_name) agent.add_tool(self._create_send_message_tool(agent, recipient_agents)) - if self.async_mode == 'threading': + if self.send_message_tool_class._thread_type == ThreadAsync: agent.add_tool(self._create_get_response_tool(agent, recipient_agents)) def _create_send_message_tool(self, agent: Agent, recipient_agents: List[Agent]): @@ -1032,38 +1055,8 @@ def _create_send_message_tool(self, agent: Agent, recipient_agents: List[Agent]) agent_descriptions += recipient_agent.name + ": " agent_descriptions += recipient_agent.description + "\n" - outer_self = self - - class SendMessage(BaseTool): - my_primary_instructions: str = Field(..., - description="Please repeat your primary instructions step-by-step, including both completed " - "and the following next steps that you need to perform. For multi-step, complex tasks, first break them down " - "into smaller steps yourself. Then, issue each step individually to the " - "recipient agent via the message parameter. Each identified step should be " - "sent in separate message. Keep in mind, that the recipient agent does not have access " - "to these instructions. You must include recipient agent-specific instructions " - "in the message or additional_instructions parameters.") + class SendMessage(self.send_message_tool_class): recipient: recipients = Field(..., description=agent_descriptions) - message: str = Field(..., - description="Specify the task required for the recipient agent to complete. Focus on " - "clarifying what the task entails, rather than providing exact " - "instructions.") - message_files: Optional[List[str]] = Field(default=None, - description="A list of file ids to be sent as attachments to this message. Only use this if you have the file id that starts with 'file-'.", - examples=["file-1234", "file-5678"]) - additional_instructions: Optional[str] = Field(default=None, - description="Additional context or instructions for the recipient agent about the task. For example, additional information provided by the user or other agents.") - - class ToolConfig: - strict = False - one_call_at_a_time = outer_self.async_mode != 'threading' - - @model_validator(mode='after') - def validate_files(self): - if "file-" in self.message or ( - self.additional_instructions and "file-" in self.additional_instructions): - if not self.message_files: - raise ValueError("You must include file ids in message_files parameter.") @field_validator('recipient') @classmethod @@ -1071,36 +1064,10 @@ def check_recipient(cls, value): if value.value not in recipient_names: raise ValueError(f"Recipient {value} is not valid. Valid recipients are: {recipient_names}") return value - - @field_validator('additional_instructions', mode='before') - @classmethod - def validate_additional_instructions(cls, value): - if isinstance(value, list): - return "\n".join(value) - return value - - def run(self): - thread = outer_self.agents_and_threads[self._caller_agent.name][self.recipient.value] - - if not outer_self.async_mode == 'threading': - message = thread.get_completion(message=self.message, - message_files=self.message_files, - event_handler=self._event_handler, - yield_messages=not self._event_handler, - additional_instructions=self.additional_instructions, - ) - else: - message = thread.get_completion_async(message=self.message, - message_files=self.message_files, - additional_instructions=self.additional_instructions) - - return message or "" SendMessage._caller_agent = agent - if self.async_mode == 'threading': - SendMessage.__doc__ = self.send_message_tool_description_async - else: - SendMessage.__doc__ = self.send_message_tool_description + SendMessage._agents_and_threads = self.agents_and_threads + SendMessage._agency = self return SendMessage diff --git a/agency_swarm/threads/thread.py b/agency_swarm/threads/thread.py index 8d681f3d..8a9e49bb 100644 --- a/agency_swarm/threads/thread.py +++ b/agency_swarm/threads/thread.py @@ -42,6 +42,8 @@ def __init__(self, agent: Union[Agent, User], recipient_agent: Agent): self.num_run_retries = 0 + self.terminal_states = ["cancelled", "completed", "failed", "expired", "incomplete"] + def init_thread(self): if self.id: self.thread = self.client.beta.threads.retrieve(self.id) @@ -57,7 +59,7 @@ def init_thread(self): ) def get_completion_stream(self, - message: str, + message: Union[str, List[dict], None], event_handler: type(AgencyEventHandler), message_files: List[str] = None, attachments: Optional[List[Attachment]] = None, @@ -77,10 +79,10 @@ def get_completion_stream(self, response_format=response_format) def get_completion(self, - message: str | List[dict], + message: Union[str, List[dict], None], message_files: List[str] = None, attachments: Optional[List[dict]] = None, - recipient_agent: Agent = None, + recipient_agent: Union[Agent, None] = None, additional_instructions: str = None, event_handler: type(AgencyEventHandler) = None, tool_choice: AssistantToolChoice = None, @@ -117,14 +119,15 @@ def get_completion(self, print(f'THREAD:[ {sender_name} -> {recipient_agent.name} ]: URL {self.thread_url}') # send message - message_obj = self.create_message( - message=message, - role="user", - attachments=attachments - ) + if message: + message_obj = self.create_message( + message=message, + role="user", + attachments=attachments + ) - if yield_messages: - yield MessageOutput("text", self.agent.name, recipient_agent.name, message, message_obj) + if yield_messages: + yield MessageOutput("text", self.agent.name, recipient_agent.name, message, message_obj) self._create_run(recipient_agent, additional_instructions, event_handler, tool_choice, response_format=response_format) @@ -138,8 +141,8 @@ def get_completion(self, if self.run.status == "requires_action": tool_calls = self.run.required_action.submit_tool_outputs.tool_calls tool_outputs_and_names = [] # list of tuples (name, tool_output) - sync_tool_calls = [tool_call for tool_call in tool_calls if tool_call.function.name == "SendMessage"] - async_tool_calls = [tool_call for tool_call in tool_calls if tool_call.function.name != "SendMessage"] + sync_tool_calls = [tool_call for tool_call in tool_calls if tool_call.function.name.startswith("SendMessage")] + async_tool_calls = [tool_call for tool_call in tool_calls if not tool_call.function.name.startswith("SendMessage")] def handle_output(tool_call, output): if inspect.isgenerator(output): @@ -157,6 +160,8 @@ def handle_output(tool_call, output): for tool_output in tool_outputs_and_names: if tool_output[1]["tool_call_id"] == tool_call.id: tool_output[1]["output"] = output + + return output if len(async_tool_calls) > 0 and self.async_mode == "tools_threading": max_workers = min(self.max_workers, os.cpu_count() or 1) # Use at most 4 workers or the number of CPUs available @@ -170,8 +175,11 @@ def handle_output(tool_call, output): for future in as_completed(futures): tool_call = futures[future] - output = future.result() - yield from handle_output(tool_call, output) + output, output_as_result = future.result() + output = yield from handle_output(tool_call, output) + if output_as_result: + self._cancel_run() + return output else: sync_tool_calls += async_tool_calls @@ -179,10 +187,13 @@ def handle_output(tool_call, output): for tool_call in sync_tool_calls: if yield_messages: yield MessageOutput("function", recipient_agent.name, self.agent.name, str(tool_call.function), tool_call) - output = self.execute_tool(tool_call, recipient_agent, event_handler, tool_outputs_and_names) + output, output_as_result = self.execute_tool(tool_call, recipient_agent, event_handler, tool_outputs_and_names) tool_outputs_and_names.append((tool_call.function.name, {"tool_call_id": tool_call.id, "output": output})) - yield from handle_output(tool_call, output) - + output = yield from handle_output(tool_call, output) + if output_as_result: + self._cancel_run() + return output + # split names and outputs tool_outputs = [tool_output for _, tool_output in tool_outputs_and_names] tool_names = [name for name, _ in tool_outputs_and_names] @@ -340,8 +351,11 @@ def _create_run(self, recipient_agent, additional_instructions, event_handler, t # poll_interval_ms=500, ) except APIError as e: - if "The server had an error processing your request" in e.message and self.num_run_retries < 3: - time.sleep(1 + self.num_run_retries) + match = re.search(r"Thread (\w+) already has an active run (\w+)", e.message) + if match: + self._cancel_run(thread_id=match.groups()[0], run_id=match.groups()[1], check_status=False) + elif "The server had an error processing your request" in e.message and self.num_run_retries < 3: + time.sleep(1) self._create_run(recipient_agent, additional_instructions, event_handler, tool_choice, response_format=response_format) self.num_run_retries += 1 else: @@ -355,22 +369,48 @@ def _run_until_done(self): run_id=self.run.id ) - def _submit_tool_outputs(self, tool_outputs, event_handler): - if not event_handler: - self.run = self.client.beta.threads.runs.submit_tool_outputs_and_poll( + def _submit_tool_outputs(self, tool_outputs, event_handler=None, poll=True): + if not poll: + self.run = self.client.beta.threads.runs.submit_tool_outputs( thread_id=self.thread.id, run_id=self.run.id, tool_outputs=tool_outputs ) else: - with self.client.beta.threads.runs.submit_tool_outputs_stream( + if not event_handler: + self.run = self.client.beta.threads.runs.submit_tool_outputs_and_poll( thread_id=self.thread.id, run_id=self.run.id, - tool_outputs=tool_outputs, - event_handler=event_handler() - ) as stream: - stream.until_done() - self.run = stream.get_final_run() + tool_outputs=tool_outputs + ) + else: + with self.client.beta.threads.runs.submit_tool_outputs_stream( + thread_id=self.thread.id, + run_id=self.run.id, + tool_outputs=tool_outputs, + event_handler=event_handler() + ) as stream: + stream.until_done() + self.run = stream.get_final_run() + + def _cancel_run(self, thread_id=None, run_id=None, check_status=True): + if check_status and self.run.status in self.terminal_states and not run_id: + return + + try: + self.run = self.client.beta.threads.runs.cancel( + thread_id=self.thread.id, + run_id=self.run.id + ) + except BadRequestError as e: + if "Cannot cancel run with status" in e.message: + self.run = self.client.beta.threads.runs.poll( + thread_id=thread_id or self.thread.id, + run_id=run_id or self.run.id, + poll_interval_ms=500, + ) + else: + raise e def _get_last_message_text(self): messages = self.client.beta.threads.messages.list( @@ -417,15 +457,9 @@ def create_message(self, message: str, role: str = "user", attachments: List[dic thread_id, run_id = match.groups() thread_id = f"thread_{thread_id}" run_id = f"run_{run_id}" - self.client.beta.threads.runs.cancel( - thread_id=thread_id, - run_id=run_id - ) - self.run = self.client.beta.threads.runs.poll( - thread_id=thread_id, - run_id=run_id, - poll_interval_ms=500, - ) + + self._cancel_run(thread_id=thread_id, run_id=run_id) + return self.client.beta.threads.messages.create( thread_id=thread_id, role=role, @@ -443,7 +477,7 @@ def execute_tool(self, tool_call, recipient_agent=None, event_handler=None, tool tool = next((func for func in funcs if func.__name__ == tool_call.function.name), None) if not tool: - return f"Error: Function {tool_call.function.name} not found. Available functions: {[func.__name__ for func in funcs]}" + return f"Error: Function {tool_call.function.name} not found. Available functions: {[func.__name__ for func in funcs]}", False try: # init tool @@ -453,17 +487,18 @@ def execute_tool(self, tool_call, recipient_agent=None, event_handler=None, tool for tool_name in [name for name, _ in tool_outputs_and_names]: if tool_name == tool_call.function.name and ( hasattr(tool, "ToolConfig") and hasattr(tool.ToolConfig, "one_call_at_a_time") and tool.ToolConfig.one_call_at_a_time): - return f"Error: Function {tool_call.function.name} is already called. You can only call this function once at a time. Please wait for the previous call to finish before calling it again." + return f"Error: Function {tool_call.function.name} is already called. You can only call this function once at a time. Please wait for the previous call to finish before calling it again.", False tool._caller_agent = recipient_agent tool._event_handler = event_handler + tool._tool_call = tool_call - return tool.run() + return tool.run(), tool.ToolConfig.output_as_result except Exception as e: error_message = f"Error: {e}" if "For further information visit" in error_message: error_message = error_message.split("For further information visit")[0] - return error_message + return error_message, False def _execute_async_tool_calls_outputs(self, tool_outputs): async_tool_calls = [] diff --git a/agency_swarm/tools/BaseTool.py b/agency_swarm/tools/BaseTool.py index fd6fd6a6..7d9ffacc 100644 --- a/agency_swarm/tools/BaseTool.py +++ b/agency_swarm/tools/BaseTool.py @@ -11,15 +11,29 @@ class BaseTool(BaseModel, ABC): _shared_state: ClassVar[SharedState] = None _caller_agent: Any = None _event_handler: Any = None + _tool_call: Any = None def __init__(self, **kwargs): if not self.__class__._shared_state: self.__class__._shared_state = SharedState() super().__init__(**kwargs) + + # Ensure all ToolConfig variables are initialized + config_defaults = { + 'strict': False, + 'one_call_at_a_time': False, + 'output_as_result': False + } + + for key, value in config_defaults.items(): + if not hasattr(self.ToolConfig, key): + setattr(self.ToolConfig, key, value) class ToolConfig: strict: bool = False one_call_at_a_time: bool = False + # return the tool output as assistant message + output_as_result: bool = False @classmethod @property @@ -76,5 +90,5 @@ def openai_schema(cls): return schema @abstractmethod - def run(self, **kwargs): + def run(self): pass diff --git a/agency_swarm/tools/__init__.py b/agency_swarm/tools/__init__.py index 8bca7b06..d77d38dd 100644 --- a/agency_swarm/tools/__init__.py +++ b/agency_swarm/tools/__init__.py @@ -2,4 +2,4 @@ from .oai.CodeInterpreter import CodeInterpreter from .oai.FileSearch import FileSearch from .oai.Retrieval import Retrieval -from .ToolFactory import ToolFactory +from .ToolFactory import ToolFactory \ No newline at end of file diff --git a/agency_swarm/tools/send_message/SendMessage.py b/agency_swarm/tools/send_message/SendMessage.py new file mode 100644 index 00000000..758b38ce --- /dev/null +++ b/agency_swarm/tools/send_message/SendMessage.py @@ -0,0 +1,60 @@ +from agency_swarm.threads.thread import Thread +from typing import ClassVar, Optional, List, Type +from pydantic import Field, field_validator, model_validator +from .SendMessageBase import SendMessageBase + +class SendMessage(SendMessageBase): + """Use this tool to facilitate direct, synchronous communication between specialized agents within your agency. When you send a message using this tool, you receive a response exclusively from the designated recipient agent. To continue the dialogue, invoke this tool again with the desired recipient agent and your follow-up message. Remember, communication here is synchronous; the recipient agent won't perform any tasks post-response. You are responsible for relaying the recipient agent's responses back to the user, as the user does not have direct access to these replies. Keep engaging with the tool for continuous interaction until the task is fully resolved. Do not send more than 1 message at a time.""" + message: str = Field( + ..., + description="Specify the task required for the recipient agent to complete. Focus on clarifying what the task entails, rather than providing exact instructions." + ) + my_primary_instructions: str = Field( + ..., + description=( + "Please repeat your primary instructions step-by-step, including both completed " + "and the following next steps that you need to perform. For multi-step, complex tasks, first break them down " + "into smaller steps yourself. Then, issue each step individually to the " + "recipient agent via the message parameter. Each identified step should be " + "sent in a separate message. Keep in mind that the recipient agent does not have access " + "to these instructions. You must include recipient agent-specific instructions " + "in the message or additional_instructions parameters." + ) + ) + message_files: Optional[List[str]] = Field( + default=None, + description="A list of file IDs to be sent as attachments to this message. Only use this if you have the file ID that starts with 'file-'.", + examples=["file-1234", "file-5678"] + ) + additional_instructions: Optional[str] = Field( + default=None, + description="Additional context or instructions from the conversation needed by the recipient agent to complete the task." + ) + + + @model_validator(mode='after') + def validate_files(self): + if hasattr(self, 'message') and "file-" in self.message or (self.additional_instructions and "file-" in self.additional_instructions): + if not self.message_files: + raise ValueError("You must include file IDs in message_files parameter.") + return self + + @field_validator('additional_instructions', mode='before') + @classmethod + def validate_additional_instructions(cls, value): + if isinstance(value, list): + return "\n".join(value) + return value + + + def run(self): + thread: Thread = self._agents_and_threads[self._caller_agent.name][self.recipient.value] + + message = thread.get_completion(message=self.message, + message_files=self.message_files, + event_handler=self._event_handler, + yield_messages=not self._event_handler, + additional_instructions=self.additional_instructions, + ) + + return message or "" \ No newline at end of file diff --git a/agency_swarm/tools/send_message/SendMessageAsyncThreading.py b/agency_swarm/tools/send_message/SendMessageAsyncThreading.py new file mode 100644 index 00000000..ffdfee49 --- /dev/null +++ b/agency_swarm/tools/send_message/SendMessageAsyncThreading.py @@ -0,0 +1,16 @@ +from typing import ClassVar, Type +from agency_swarm.threads.thread_async import ThreadAsync +from .SendMessage import SendMessage + +class SendMessageAsyncThreading(SendMessage): + """Use this tool for asynchronous communication with other agents within your agency. Initiate tasks by messaging, and check status and responses later with the 'GetResponse' tool. Relay responses to the user, who instructs on status checks. Continue until task completion.""" + _thread_type: ClassVar[Type[ThreadAsync]] = ThreadAsync + + def run(self): + thread: ThreadAsync = self._agents_and_threads[self._caller_agent.name][self.recipient.value] + + message = thread.get_completion_async(message=self.message, + message_files=self.message_files, + additional_instructions=self.additional_instructions) + + return message or "" \ No newline at end of file diff --git a/agency_swarm/tools/send_message/SendMessageBase.py b/agency_swarm/tools/send_message/SendMessageBase.py new file mode 100644 index 00000000..37ffe3db --- /dev/null +++ b/agency_swarm/tools/send_message/SendMessageBase.py @@ -0,0 +1,19 @@ +from agency_swarm.threads.thread import Thread +from typing import ClassVar, Optional, List, Type +from pydantic import Field, field_validator, model_validator +from agency_swarm.tools import BaseTool +from abc import ABC + +class SendMessageBase(BaseTool, ABC): + """Use this tool to facilitate direct, synchronous communication between specialized agents within your agency. When you send a message using this tool, you receive a response exclusively from the designated recipient agent. To continue the dialogue, invoke this tool again with the desired recipient agent and your follow-up message. Remember, communication here is synchronous; the recipient agent won't perform any tasks post-response. You are responsible for relaying the recipient agent's responses back to the user, as the user does not have direct access to these replies. Keep engaging with the tool for continuous interaction until the task is fully resolved. Do not send more than 1 message at a time.""" + + recipient: str = Field(..., description="Recipient agent that you want to send the message to. This field will be overriden inside the agency class.") + + _agents_and_threads: ClassVar = None + _thread_type: ClassVar[Type[Thread]] = Thread # thread type assigned by the agency class + + @classmethod + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + if not cls.__name__.startswith("SendMessage"): + raise TypeError(f"Class name '{cls.__name__}' must start with 'SendMessage'.") \ No newline at end of file diff --git a/agency_swarm/tools/send_message/SendMessageSwarm.py b/agency_swarm/tools/send_message/SendMessageSwarm.py new file mode 100644 index 00000000..0557e681 --- /dev/null +++ b/agency_swarm/tools/send_message/SendMessageSwarm.py @@ -0,0 +1,47 @@ +from openai import BadRequestError +from agency_swarm.threads.thread import Thread +from .SendMessage import SendMessageBase + +class SendMessageSwarm(SendMessageBase): + """Use this tool to route messages to other agents within your agency. After using this tool, you will be switched to the recipient agent. This tool can only be used once per message. Do not use any other tools together with this tool.""" + + class ToolConfig: + output_as_result: bool = True + one_call_at_a_time: bool = True + + def run(self): + # get main thread + thread: Thread = self._agents_and_threads["main_thread"] + + # get recipient agent from thread + recipient_agent = self._agents_and_threads[self._caller_agent.name][self.recipient.value].recipient_agent + + # submit tool output + try: + thread._submit_tool_outputs( + tool_outputs=[{"tool_call_id": self._tool_call.id, "output": "The request has been routed. You are now a " + recipient_agent.name + " agent. Please assist the user further with their request."}], + poll=False + ) + except BadRequestError as e: + raise BadRequestError("You can only call this tool by itself. Do not use any other tools together with this tool.") + + try: + # cancel run + thread._cancel_run() + + # change recipient agent in thread + thread.recipient_agent = recipient_agent + + # change recipient agent in gradio dropdown + if self._event_handler: + if hasattr(self._event_handler, "change_recipient_agent"): + self._event_handler.change_recipient_agent(self.recipient.value) + + # continue conversation with the new recipient agent + message = thread.get_completion(message=None, recipient_agent=recipient_agent, yield_messages=not self._event_handler, event_handler=self._event_handler) + + return message or "" + except Exception as e: + # we need to catch errors beucase tool outputs are already submitted + print("Error in SendMessageSwarm: ", e) + return str(e) diff --git a/agency_swarm/tools/send_message/__init__.py b/agency_swarm/tools/send_message/__init__.py new file mode 100644 index 00000000..180265b6 --- /dev/null +++ b/agency_swarm/tools/send_message/__init__.py @@ -0,0 +1,4 @@ +from .SendMessageAsyncThreading import SendMessageAsyncThreading +from .SendMessageBase import SendMessageBase +from .SendMessage import SendMessage +from .SendMessageSwarm import SendMessageSwarm \ No newline at end of file diff --git a/tests/test_agency.py b/tests/test_agency.py index 7ed46b85..0c258527 100644 --- a/tests/test_agency.py +++ b/tests/test_agency.py @@ -219,9 +219,11 @@ def test_4_agent_communication(self): message = self.__class__.agency.get_completion("Please tell TestAgent1 to say test to TestAgent2.", tool_choice={"type": "function", "function": {"name": "SendMessage"}}) - self.assertFalse('error' in message.lower(), f"Error found in message: {message}") + self.assertFalse('error' in message.lower(), f"Error found in message: {message}. Thread url: {self.__class__.agency.main_thread.thread_url}") for agent_name, threads in self.__class__.agency.agents_and_threads.items(): + if agent_name == "main_thread": + continue for other_agent_name, thread in threads.items(): self.assertTrue(thread.id in self.__class__.loaded_thread_ids[agent_name][other_agent_name]) @@ -312,6 +314,8 @@ def on_all_streams_end(cls): self.assertFalse(agent1_thread.run.parallel_tool_calls) for agent_name, threads in self.__class__.agency.agents_and_threads.items(): + if agent_name == "main_thread": + continue for other_agent_name, thread in threads.items(): self.assertTrue(thread.id in self.__class__.loaded_thread_ids[agent_name][other_agent_name]) @@ -369,6 +373,8 @@ def test_6_load_from_db(self): # check that threads are the same for agent_name, threads in agency.agents_and_threads.items(): + if agent_name == "main_thread": + continue for other_agent_name, thread in threads.items(): self.assertTrue(thread.id in self.__class__.loaded_thread_ids[agent_name][other_agent_name]) self.assertTrue(thread.id in previous_loaded_thread_ids[agent_name][other_agent_name]) @@ -450,6 +456,8 @@ def on_all_streams_end(cls): self.assertFalse('error' in message.lower(), self.__class__.agency.main_thread.thread_url) for agent_name, threads in self.__class__.agency.agents_and_threads.items(): + if agent_name == "main_thread": + continue for other_agent_name, thread in threads.items(): self.assertTrue(thread.id in self.__class__.loaded_thread_ids[agent_name][other_agent_name]) diff --git a/tests/test_send_message.py b/tests/test_send_message.py new file mode 100644 index 00000000..9123d938 --- /dev/null +++ b/tests/test_send_message.py @@ -0,0 +1,42 @@ +import unittest +from agency_swarm import Agent, Agency +from agency_swarm.tools.send_message import SendMessageSwarm +from agency_swarm.tools import BaseTool +from pydantic import Field + +class TestSendMessage(unittest.TestCase): + def setUp(self): + class PrintTool(BaseTool): + """ + A simple tool that prints a message. + """ + message: str = Field(..., description="The message to print.") + + def run(self): + print(self.message) + return f"Printed: {self.message}" + + self.ceo = Agent( + name="CEO", + description="Responsible for client communication, task planning and management.", + instructions="Your role is to route messages to other agents within your agency.", + tools=[PrintTool] + ) + + self.customer_support = Agent( + name="Customer Support", + description="Responsible for customer support.", + instructions="You are a Customer Support agent. Answer customer questions and help with issues.", + tools=[] + ) + + self.agency = Agency([self.ceo, [self.ceo, self.customer_support], [self.customer_support, self.ceo]], send_message_tool_class=SendMessageSwarm) + + def test_send_message_swarm(self): + response = self.agency.get_completion("Hello, can you send me to customer support? If there are any issues, please say 'error'") + self.assertFalse("error" in response.lower()) + response = self.agency.get_completion("Who are you?") + self.assertTrue("customer support" in response.lower()) + +if __name__ == '__main__': + unittest.main()