Skip to content

Commit

Permalink
Moved async_mode into ToolConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
VRSEN committed Nov 15, 2024
1 parent 2646b8a commit bac5cf8
Show file tree
Hide file tree
Showing 9 changed files with 118 additions and 58 deletions.
13 changes: 10 additions & 3 deletions agency_swarm/agency/agency.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,19 @@ def __init__(self,
self.max_completion_tokens = max_completion_tokens
self.truncation_strategy = truncation_strategy

# set thread type based send_message_tool_class async mode
if hasattr(send_message_tool_class.ToolConfig, "async_mode") and send_message_tool_class.ToolConfig.async_mode:
self._thread_type = ThreadAsync
else:
self._thread_type = Thread

if self.async_mode == "threading":
from agency_swarm.tools.send_message import SendMessageAsyncThreading
print("Warning: 'threading' mode is deprecated. Please use send_message_tool_class = SendMessageAsyncThreading to use async communication.")
self.send_message_tool_class = SendMessageAsyncThreading
elif self.async_mode == "tools_threading":
Thread.async_mode = self.async_mode
Thread.async_mode = "tools_threading"
print("Warning: 'tools_threading' mode is deprecated. Use tool.ToolConfig.async_mode = 'threading' instead.")
elif self.async_mode is None:
pass
else:
Expand Down Expand Up @@ -887,7 +894,7 @@ def _init_threads(self):
continue
for other_agent, items in threads.items():
# create thread class
self.agents_and_threads[agent_name][other_agent] = self.send_message_tool_class._thread_type(
self.agents_and_threads[agent_name][other_agent] = self._thread_type(
self._get_agent_by_name(items["agent"]),
self._get_agent_by_name(
items["recipient_agent"]))
Expand Down Expand Up @@ -1034,7 +1041,7 @@ def _create_special_tools(self):
continue
agent = self._get_agent_by_name(agent_name)
agent.add_tool(self._create_send_message_tool(agent, recipient_agents))
if self.send_message_tool_class._thread_type == ThreadAsync:
if self._thread_type == ThreadAsync:
agent.add_tool(self._create_get_response_tool(agent, recipient_agents))

def _create_send_message_tool(self, agent: Agent, recipient_agents: List[Agent]):
Expand Down
44 changes: 38 additions & 6 deletions agency_swarm/threads/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(self, agent: Union[Agent, User], recipient_agent: Agent):
def init_thread(self):
if self.id:
return
print("creating thread")

self._thread = self.client.beta.threads.create()
self.id = self._thread.id
if self.recipient_agent.examples:
Expand All @@ -66,7 +66,7 @@ def init_thread(self):
thread_id=self.id,
**example,
)
print("thread created", self.id)

def get_completion_stream(self,
message: Union[str, List[dict], None],
event_handler: type(AgencyEventHandler),
Expand Down Expand Up @@ -149,8 +149,7 @@ def get_completion(self,
if self._run.status == "requires_action":
tool_calls = self._run.required_action.submit_tool_outputs.tool_calls
tool_outputs_and_names = [] # list of tuples (name, tool_output)
sync_tool_calls = [tool_call for tool_call in tool_calls if tool_call.function.name.startswith("SendMessage")]
async_tool_calls = [tool_call for tool_call in tool_calls if not tool_call.function.name.startswith("SendMessage")]
sync_tool_calls, async_tool_calls = self._get_sync_async_tool_calls(tool_calls, recipient_agent)

def handle_output(tool_call, output):
if inspect.isgenerator(output):
Expand Down Expand Up @@ -207,7 +206,7 @@ def handle_output(tool_call, output):
tool_names = [name for name, _ in tool_outputs_and_names]

# await coroutines
tool_outputs = self._execute_async_tool_calls_outputs(tool_outputs)
tool_outputs = self._await_coroutines(tool_outputs)

# convert all tool outputs to strings
for tool_output in tool_outputs:
Expand Down Expand Up @@ -508,7 +507,7 @@ def execute_tool(self, tool_call, recipient_agent=None, event_handler=None, tool
error_message = error_message.split("For further information visit")[0]
return error_message, False

def _execute_async_tool_calls_outputs(self, tool_outputs):
def _await_coroutines(self, tool_outputs):
async_tool_calls = []
for tool_output in tool_outputs:
if inspect.iscoroutine(tool_output["output"]):
Expand All @@ -530,6 +529,39 @@ def _execute_async_tool_calls_outputs(self, tool_outputs):
tool_output["output"] = str(result)

return tool_outputs

def _get_sync_async_tool_calls(self, tool_calls, recipient_agent):
async_tool_calls = []
sync_tool_calls = []
for tool_call in tool_calls:
if tool_call.function.name.startswith("SendMessage"):
sync_tool_calls.append(tool_call)
continue

tool = next((func for func in recipient_agent.functions if func.__name__ == tool_call.function.name), None)

if (hasattr(tool.ToolConfig, "async_mode") and tool.ToolConfig.async_mode) or self.async_mode == "tools_threading":
async_tool_calls.append(tool_call)
else:
sync_tool_calls.append(tool_call)

return sync_tool_calls, async_tool_calls

def get_messages(self, limit=None):
all_messages = []
after = None
while True:
response = self.client.beta.threads.messages.list(thread_id=self.id, limit=100, after=after)
messages = response.data
if not messages:
break
all_messages.extend(messages)
after = messages[-1].id # Set the 'after' cursor to the ID of the last message

if limit and len(all_messages) >= limit:
break

return all_messages



Expand Down
6 changes: 4 additions & 2 deletions agency_swarm/tools/BaseTool.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, ClassVar
from typing import Any, ClassVar, Literal, Union

from docstring_parser import parse

Expand All @@ -22,7 +22,8 @@ def __init__(self, **kwargs):
config_defaults = {
'strict': False,
'one_call_at_a_time': False,
'output_as_result': False
'output_as_result': False,
'async_mode': None
}

for key, value in config_defaults.items():
Expand All @@ -34,6 +35,7 @@ class ToolConfig:
one_call_at_a_time: bool = False
# return the tool output as assistant message
output_as_result: bool = False
async_mode: Union[Literal["threading"], None] = None

@classmethod
@property
Expand Down
12 changes: 6 additions & 6 deletions agency_swarm/tools/send_message/SendMessage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@
from .SendMessageBase import SendMessageBase

class SendMessage(SendMessageBase):
"""Use this tool to facilitate direct, synchronous communication between specialized agents within your agency. When you send a message using this tool, you receive a response exclusively from the designated recipient agent. To continue the dialogue, invoke this tool again with the desired recipient agent and your follow-up message. Remember, communication here is synchronous; the recipient agent won't perform any tasks post-response. You are responsible for relaying the recipient agent's responses back to the user, as the user does not have direct access to these replies. Keep engaging with the tool for continuous interaction until the task is fully resolved. Do not send more than 1 message at a time."""
message: str = Field(
...,
description="Specify the task required for the recipient agent to complete. Focus on clarifying what the task entails, rather than providing exact instructions."
)
"""Use this tool to facilitate direct, synchronous communication between specialized agents within your agency. When you send a message using this tool, you receive a response exclusively from the designated recipient agent. To continue the dialogue, invoke this tool again with the desired recipient agent and your follow-up message. Remember, communication here is synchronous; the recipient agent won't perform any tasks post-response. You are responsible for relaying the recipient agent's responses back to the user, as the user does not have direct access to these replies. Keep engaging with the tool for continuous interaction until the task is fully resolved. Do not send more than 1 message to the same recipient agent at the same time."""
my_primary_instructions: str = Field(
...,
description=(
Expand All @@ -21,6 +17,10 @@ class SendMessage(SendMessageBase):
"in the message or additional_instructions parameters."
)
)
message: str = Field(
...,
description="Specify the task required for the recipient agent to complete. Focus on clarifying what the task entails, rather than providing exact instructions. Make sure to inlcude all the relevant information needed to complete the task."
)
message_files: Optional[List[str]] = Field(
default=None,
description="A list of file IDs to be sent as attachments to this message. Only use this if you have the file ID that starts with 'file-'.",
Expand Down Expand Up @@ -48,7 +48,7 @@ def validate_additional_instructions(cls, value):


def run(self):
thread: Thread = self._agents_and_threads[self._caller_agent.name][self.recipient.value]
thread = self._get_thread()

message = thread.get_completion(message=self.message,
message_files=self.message_files,
Expand Down
7 changes: 4 additions & 3 deletions agency_swarm/tools/send_message/SendMessageAsyncThreading.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

class SendMessageAsyncThreading(SendMessage):
"""Use this tool for asynchronous communication with other agents within your agency. Initiate tasks by messaging, and check status and responses later with the 'GetResponse' tool. Relay responses to the user, who instructs on status checks. Continue until task completion."""
_thread_type: ClassVar[Type[ThreadAsync]] = ThreadAsync

class ToolConfig:
async_mode = "threading"

def run(self):
thread: ThreadAsync = self._agents_and_threads[self._caller_agent.name][self.recipient.value]
thread = self._get_thread()

message = thread.get_completion_async(message=self.message,
message_files=self.message_files,
Expand Down
19 changes: 14 additions & 5 deletions agency_swarm/tools/send_message/SendMessageBase.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,28 @@
from agency_swarm.agents.agent import Agent
from agency_swarm.threads.thread import Thread
from typing import ClassVar, Optional, List, Type
from pydantic import Field, field_validator, model_validator
from typing import ClassVar
from pydantic import Field
from agency_swarm.threads.thread_async import ThreadAsync
from agency_swarm.tools import BaseTool
from abc import ABC

class SendMessageBase(BaseTool, ABC):
"""Use this tool to facilitate direct, synchronous communication between specialized agents within your agency. When you send a message using this tool, you receive a response exclusively from the designated recipient agent. To continue the dialogue, invoke this tool again with the desired recipient agent and your follow-up message. Remember, communication here is synchronous; the recipient agent won't perform any tasks post-response. You are responsible for relaying the recipient agent's responses back to the user, as the user does not have direct access to these replies. Keep engaging with the tool for continuous interaction until the task is fully resolved. Do not send more than 1 message at a time."""

recipient: str = Field(..., description="Recipient agent that you want to send the message to. This field will be overriden inside the agency class.")

_agents_and_threads: ClassVar = None
_thread_type: ClassVar[Type[Thread]] = Thread # thread type assigned by the agency class

@classmethod
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if not cls.__name__.startswith("SendMessage"):
raise TypeError(f"Class name '{cls.__name__}' must start with 'SendMessage'.")
raise TypeError(f"Class name '{cls.__name__}' must start with 'SendMessage'.")

def _get_thread(self) -> Thread | ThreadAsync:
return self._agents_and_threads[self._caller_agent.name][self.recipient.value]

def _get_main_thread(self) -> Thread | ThreadAsync:
return self._agents_and_threads["main_thread"]

def _get_recipient_agent(self) -> Agent:
return self._agents_and_threads[self._caller_agent.name][self.recipient.value].recipient_agent
4 changes: 2 additions & 2 deletions agency_swarm/tools/send_message/SendMessageSwarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ class ToolConfig:

def run(self):
# get main thread
thread: Thread = self._agents_and_threads["main_thread"]
thread = self._get_main_thread()

# get recipient agent from thread
recipient_agent = self._agents_and_threads[self._caller_agent.name][self.recipient.value].recipient_agent
recipient_agent = self._get_recipient_agent()

# submit tool output
try:
Expand Down
Loading

0 comments on commit bac5cf8

Please sign in to comment.