Skip to content

Commit

Permalink
Allow extending and modifying SendMessage tool in Agency class
Browse files Browse the repository at this point in the history
  • Loading branch information
VRSEN committed Nov 14, 2024
1 parent 3de9f27 commit f966462
Show file tree
Hide file tree
Showing 11 changed files with 337 additions and 125 deletions.
129 changes: 48 additions & 81 deletions agency_swarm/agency/agency.py

Large diffs are not rendered by default.

117 changes: 76 additions & 41 deletions agency_swarm/threads/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def __init__(self, agent: Union[Agent, User], recipient_agent: Agent):

self.num_run_retries = 0

self.terminal_states = ["cancelled", "completed", "failed", "expired", "incomplete"]

def init_thread(self):
if self.id:
self.thread = self.client.beta.threads.retrieve(self.id)
Expand All @@ -57,7 +59,7 @@ def init_thread(self):
)

def get_completion_stream(self,
message: str,
message: Union[str, List[dict], None],
event_handler: type(AgencyEventHandler),
message_files: List[str] = None,
attachments: Optional[List[Attachment]] = None,
Expand All @@ -77,10 +79,10 @@ def get_completion_stream(self,
response_format=response_format)

def get_completion(self,
message: str | List[dict],
message: Union[str, List[dict], None],
message_files: List[str] = None,
attachments: Optional[List[dict]] = None,
recipient_agent: Agent = None,
recipient_agent: Union[Agent, None] = None,
additional_instructions: str = None,
event_handler: type(AgencyEventHandler) = None,
tool_choice: AssistantToolChoice = None,
Expand Down Expand Up @@ -117,14 +119,15 @@ def get_completion(self,
print(f'THREAD:[ {sender_name} -> {recipient_agent.name} ]: URL {self.thread_url}')

# send message
message_obj = self.create_message(
message=message,
role="user",
attachments=attachments
)
if message:
message_obj = self.create_message(
message=message,
role="user",
attachments=attachments
)

if yield_messages:
yield MessageOutput("text", self.agent.name, recipient_agent.name, message, message_obj)
if yield_messages:
yield MessageOutput("text", self.agent.name, recipient_agent.name, message, message_obj)

self._create_run(recipient_agent, additional_instructions, event_handler, tool_choice, response_format=response_format)

Expand All @@ -138,8 +141,8 @@ def get_completion(self,
if self.run.status == "requires_action":
tool_calls = self.run.required_action.submit_tool_outputs.tool_calls
tool_outputs_and_names = [] # list of tuples (name, tool_output)
sync_tool_calls = [tool_call for tool_call in tool_calls if tool_call.function.name == "SendMessage"]
async_tool_calls = [tool_call for tool_call in tool_calls if tool_call.function.name != "SendMessage"]
sync_tool_calls = [tool_call for tool_call in tool_calls if tool_call.function.name.startswith("SendMessage")]
async_tool_calls = [tool_call for tool_call in tool_calls if not tool_call.function.name.startswith("SendMessage")]

def handle_output(tool_call, output):
if inspect.isgenerator(output):
Expand All @@ -157,6 +160,8 @@ def handle_output(tool_call, output):
for tool_output in tool_outputs_and_names:
if tool_output[1]["tool_call_id"] == tool_call.id:
tool_output[1]["output"] = output

return output

if len(async_tool_calls) > 0 and self.async_mode == "tools_threading":
max_workers = min(self.max_workers, os.cpu_count() or 1) # Use at most 4 workers or the number of CPUs available
Expand All @@ -170,19 +175,25 @@ def handle_output(tool_call, output):

for future in as_completed(futures):
tool_call = futures[future]
output = future.result()
yield from handle_output(tool_call, output)
output, output_as_result = future.result()
output = yield from handle_output(tool_call, output)
if output_as_result:
self._cancel_run()
return output
else:
sync_tool_calls += async_tool_calls

# execute sync tool calls
for tool_call in sync_tool_calls:
if yield_messages:
yield MessageOutput("function", recipient_agent.name, self.agent.name, str(tool_call.function), tool_call)
output = self.execute_tool(tool_call, recipient_agent, event_handler, tool_outputs_and_names)
output, output_as_result = self.execute_tool(tool_call, recipient_agent, event_handler, tool_outputs_and_names)
tool_outputs_and_names.append((tool_call.function.name, {"tool_call_id": tool_call.id, "output": output}))
yield from handle_output(tool_call, output)

output = yield from handle_output(tool_call, output)
if output_as_result:
self._cancel_run()
return output

# split names and outputs
tool_outputs = [tool_output for _, tool_output in tool_outputs_and_names]
tool_names = [name for name, _ in tool_outputs_and_names]
Expand Down Expand Up @@ -340,8 +351,11 @@ def _create_run(self, recipient_agent, additional_instructions, event_handler, t
# poll_interval_ms=500,
)
except APIError as e:
if "The server had an error processing your request" in e.message and self.num_run_retries < 3:
time.sleep(1 + self.num_run_retries)
match = re.search(r"Thread (\w+) already has an active run (\w+)", e.message)
if match:
self._cancel_run(thread_id=match.groups()[0], run_id=match.groups()[1], check_status=False)
elif "The server had an error processing your request" in e.message and self.num_run_retries < 3:
time.sleep(1)
self._create_run(recipient_agent, additional_instructions, event_handler, tool_choice, response_format=response_format)
self.num_run_retries += 1
else:
Expand All @@ -355,22 +369,48 @@ def _run_until_done(self):
run_id=self.run.id
)

def _submit_tool_outputs(self, tool_outputs, event_handler):
if not event_handler:
self.run = self.client.beta.threads.runs.submit_tool_outputs_and_poll(
def _submit_tool_outputs(self, tool_outputs, event_handler=None, poll=True):
if not poll:
self.run = self.client.beta.threads.runs.submit_tool_outputs(
thread_id=self.thread.id,
run_id=self.run.id,
tool_outputs=tool_outputs
)
else:
with self.client.beta.threads.runs.submit_tool_outputs_stream(
if not event_handler:
self.run = self.client.beta.threads.runs.submit_tool_outputs_and_poll(
thread_id=self.thread.id,
run_id=self.run.id,
tool_outputs=tool_outputs,
event_handler=event_handler()
) as stream:
stream.until_done()
self.run = stream.get_final_run()
tool_outputs=tool_outputs
)
else:
with self.client.beta.threads.runs.submit_tool_outputs_stream(
thread_id=self.thread.id,
run_id=self.run.id,
tool_outputs=tool_outputs,
event_handler=event_handler()
) as stream:
stream.until_done()
self.run = stream.get_final_run()

def _cancel_run(self, thread_id=None, run_id=None, check_status=True):
if check_status and self.run.status in self.terminal_states and not run_id:
return

try:
self.run = self.client.beta.threads.runs.cancel(
thread_id=self.thread.id,
run_id=self.run.id
)
except BadRequestError as e:
if "Cannot cancel run with status" in e.message:
self.run = self.client.beta.threads.runs.poll(
thread_id=thread_id or self.thread.id,
run_id=run_id or self.run.id,
poll_interval_ms=500,
)
else:
raise e

def _get_last_message_text(self):
messages = self.client.beta.threads.messages.list(
Expand Down Expand Up @@ -417,15 +457,9 @@ def create_message(self, message: str, role: str = "user", attachments: List[dic
thread_id, run_id = match.groups()
thread_id = f"thread_{thread_id}"
run_id = f"run_{run_id}"
self.client.beta.threads.runs.cancel(
thread_id=thread_id,
run_id=run_id
)
self.run = self.client.beta.threads.runs.poll(
thread_id=thread_id,
run_id=run_id,
poll_interval_ms=500,
)

self._cancel_run(thread_id=thread_id, run_id=run_id)

return self.client.beta.threads.messages.create(
thread_id=thread_id,
role=role,
Expand All @@ -443,7 +477,7 @@ def execute_tool(self, tool_call, recipient_agent=None, event_handler=None, tool
tool = next((func for func in funcs if func.__name__ == tool_call.function.name), None)

if not tool:
return f"Error: Function {tool_call.function.name} not found. Available functions: {[func.__name__ for func in funcs]}"
return f"Error: Function {tool_call.function.name} not found. Available functions: {[func.__name__ for func in funcs]}", False

try:
# init tool
Expand All @@ -453,17 +487,18 @@ def execute_tool(self, tool_call, recipient_agent=None, event_handler=None, tool
for tool_name in [name for name, _ in tool_outputs_and_names]:
if tool_name == tool_call.function.name and (
hasattr(tool, "ToolConfig") and hasattr(tool.ToolConfig, "one_call_at_a_time") and tool.ToolConfig.one_call_at_a_time):
return f"Error: Function {tool_call.function.name} is already called. You can only call this function once at a time. Please wait for the previous call to finish before calling it again."
return f"Error: Function {tool_call.function.name} is already called. You can only call this function once at a time. Please wait for the previous call to finish before calling it again.", False

tool._caller_agent = recipient_agent
tool._event_handler = event_handler
tool._tool_call = tool_call

return tool.run()
return tool.run(), tool.ToolConfig.output_as_result
except Exception as e:
error_message = f"Error: {e}"
if "For further information visit" in error_message:
error_message = error_message.split("For further information visit")[0]
return error_message
return error_message, False

def _execute_async_tool_calls_outputs(self, tool_outputs):
async_tool_calls = []
Expand Down
16 changes: 15 additions & 1 deletion agency_swarm/tools/BaseTool.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,29 @@ class BaseTool(BaseModel, ABC):
_shared_state: ClassVar[SharedState] = None
_caller_agent: Any = None
_event_handler: Any = None
_tool_call: Any = None

def __init__(self, **kwargs):
if not self.__class__._shared_state:
self.__class__._shared_state = SharedState()
super().__init__(**kwargs)

# Ensure all ToolConfig variables are initialized
config_defaults = {
'strict': False,
'one_call_at_a_time': False,
'output_as_result': False
}

for key, value in config_defaults.items():
if not hasattr(self.ToolConfig, key):
setattr(self.ToolConfig, key, value)

class ToolConfig:
strict: bool = False
one_call_at_a_time: bool = False
# return the tool output as assistant message
output_as_result: bool = False

@classmethod
@property
Expand Down Expand Up @@ -76,5 +90,5 @@ def openai_schema(cls):
return schema

@abstractmethod
def run(self, **kwargs):
def run(self):
pass
2 changes: 1 addition & 1 deletion agency_swarm/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
from .oai.CodeInterpreter import CodeInterpreter
from .oai.FileSearch import FileSearch
from .oai.Retrieval import Retrieval
from .ToolFactory import ToolFactory
from .ToolFactory import ToolFactory
60 changes: 60 additions & 0 deletions agency_swarm/tools/send_message/SendMessage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from agency_swarm.threads.thread import Thread
from typing import ClassVar, Optional, List, Type
from pydantic import Field, field_validator, model_validator
from .SendMessageBase import SendMessageBase

class SendMessage(SendMessageBase):
"""Use this tool to facilitate direct, synchronous communication between specialized agents within your agency. When you send a message using this tool, you receive a response exclusively from the designated recipient agent. To continue the dialogue, invoke this tool again with the desired recipient agent and your follow-up message. Remember, communication here is synchronous; the recipient agent won't perform any tasks post-response. You are responsible for relaying the recipient agent's responses back to the user, as the user does not have direct access to these replies. Keep engaging with the tool for continuous interaction until the task is fully resolved. Do not send more than 1 message at a time."""
message: str = Field(
...,
description="Specify the task required for the recipient agent to complete. Focus on clarifying what the task entails, rather than providing exact instructions."
)
my_primary_instructions: str = Field(
...,
description=(
"Please repeat your primary instructions step-by-step, including both completed "
"and the following next steps that you need to perform. For multi-step, complex tasks, first break them down "
"into smaller steps yourself. Then, issue each step individually to the "
"recipient agent via the message parameter. Each identified step should be "
"sent in a separate message. Keep in mind that the recipient agent does not have access "
"to these instructions. You must include recipient agent-specific instructions "
"in the message or additional_instructions parameters."
)
)
message_files: Optional[List[str]] = Field(
default=None,
description="A list of file IDs to be sent as attachments to this message. Only use this if you have the file ID that starts with 'file-'.",
examples=["file-1234", "file-5678"]
)
additional_instructions: Optional[str] = Field(
default=None,
description="Additional context or instructions from the conversation needed by the recipient agent to complete the task."
)


@model_validator(mode='after')
def validate_files(self):
if hasattr(self, 'message') and "file-" in self.message or (self.additional_instructions and "file-" in self.additional_instructions):
if not self.message_files:
raise ValueError("You must include file IDs in message_files parameter.")
return self

@field_validator('additional_instructions', mode='before')
@classmethod
def validate_additional_instructions(cls, value):
if isinstance(value, list):
return "\n".join(value)
return value


def run(self):
thread: Thread = self._agents_and_threads[self._caller_agent.name][self.recipient.value]

message = thread.get_completion(message=self.message,
message_files=self.message_files,
event_handler=self._event_handler,
yield_messages=not self._event_handler,
additional_instructions=self.additional_instructions,
)

return message or ""
16 changes: 16 additions & 0 deletions agency_swarm/tools/send_message/SendMessageAsyncThreading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from typing import ClassVar, Type
from agency_swarm.threads.thread_async import ThreadAsync
from .SendMessage import SendMessage

class SendMessageAsyncThreading(SendMessage):
"""Use this tool for asynchronous communication with other agents within your agency. Initiate tasks by messaging, and check status and responses later with the 'GetResponse' tool. Relay responses to the user, who instructs on status checks. Continue until task completion."""
_thread_type: ClassVar[Type[ThreadAsync]] = ThreadAsync

def run(self):
thread: ThreadAsync = self._agents_and_threads[self._caller_agent.name][self.recipient.value]

message = thread.get_completion_async(message=self.message,
message_files=self.message_files,
additional_instructions=self.additional_instructions)

return message or ""
19 changes: 19 additions & 0 deletions agency_swarm/tools/send_message/SendMessageBase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from agency_swarm.threads.thread import Thread
from typing import ClassVar, Optional, List, Type
from pydantic import Field, field_validator, model_validator
from agency_swarm.tools import BaseTool
from abc import ABC

class SendMessageBase(BaseTool, ABC):
"""Use this tool to facilitate direct, synchronous communication between specialized agents within your agency. When you send a message using this tool, you receive a response exclusively from the designated recipient agent. To continue the dialogue, invoke this tool again with the desired recipient agent and your follow-up message. Remember, communication here is synchronous; the recipient agent won't perform any tasks post-response. You are responsible for relaying the recipient agent's responses back to the user, as the user does not have direct access to these replies. Keep engaging with the tool for continuous interaction until the task is fully resolved. Do not send more than 1 message at a time."""

recipient: str = Field(..., description="Recipient agent that you want to send the message to. This field will be overriden inside the agency class.")

_agents_and_threads: ClassVar = None
_thread_type: ClassVar[Type[Thread]] = Thread # thread type assigned by the agency class

@classmethod
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if not cls.__name__.startswith("SendMessage"):
raise TypeError(f"Class name '{cls.__name__}' must start with 'SendMessage'.")
Loading

0 comments on commit f966462

Please sign in to comment.