Skip to content

Commit

Permalink
Refactor thread and optimize thread initializetion
Browse files Browse the repository at this point in the history
  • Loading branch information
VRSEN committed Nov 15, 2024
1 parent f966462 commit 2646b8a
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 61 deletions.
4 changes: 4 additions & 0 deletions agency_swarm/agency/agency.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,17 +881,21 @@ def _init_threads(self):
# Save main_thread into agents_and_threads
self.agents_and_threads["main_thread"] = self.main_thread

# initialize threads
for agent_name, threads in self.agents_and_threads.items():
if agent_name == "main_thread":
continue
for other_agent, items in threads.items():
# create thread class
self.agents_and_threads[agent_name][other_agent] = self.send_message_tool_class._thread_type(
self._get_agent_by_name(items["agent"]),
self._get_agent_by_name(
items["recipient_agent"]))

# load thread id if available
if agent_name in loaded_thread_ids and other_agent in loaded_thread_ids[agent_name]:
self.agents_and_threads[agent_name][other_agent].id = loaded_thread_ids[agent_name][other_agent]
# init threads if threre are threads callbacks so the ids are saved for later use
elif self.threads_callbacks:
self.agents_and_threads[agent_name][other_agent].init_thread()

Expand Down
124 changes: 66 additions & 58 deletions agency_swarm/threads/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,16 @@ class Thread:
@property
def thread_url(self):
return f'https://platform.openai.com/playground/assistants?assistant={self.recipient_agent.id}&mode=assistant&thread={self.id}'

@property
def thread(self):
self.init_thread()

if not self._thread:
print("retrieving thread", self.id)
self._thread = self.client.beta.threads.retrieve(self.id)

return self._thread

def __init__(self, agent: Union[Agent, User], recipient_agent: Agent):
self.agent = agent
Expand All @@ -36,28 +46,27 @@ def __init__(self, agent: Union[Agent, User], recipient_agent: Agent):
self.client = get_openai_client()

self.id = None
self.thread = None
self.run = None
self.stream = None
self._thread = None
self._run = None
self._stream = None

self.num_run_retries = 0
self._num_run_retries = 0

self.terminal_states = ["cancelled", "completed", "failed", "expired", "incomplete"]

def init_thread(self):
if self.id:
self.thread = self.client.beta.threads.retrieve(self.id)
else:
self.thread = self.client.beta.threads.create()
self.id = self.thread.id

if self.recipient_agent.examples:
for example in self.recipient_agent.examples:
self.client.beta.threads.messages.create(
thread_id=self.id,
**example,
)

return
print("creating thread")
self._thread = self.client.beta.threads.create()
self.id = self._thread.id
if self.recipient_agent.examples:
for example in self.recipient_agent.examples:
self.client.beta.threads.messages.create(
thread_id=self.id,
**example,
)
print("thread created", self.id)
def get_completion_stream(self,
message: Union[str, List[dict], None],
event_handler: type(AgencyEventHandler),
Expand Down Expand Up @@ -89,6 +98,8 @@ def get_completion(self,
yield_messages: bool = False,
response_format: Optional[dict] = None
):
self.init_thread()

if not recipient_agent:
recipient_agent = self.recipient_agent

Expand All @@ -107,9 +118,6 @@ def get_completion(self,
attachments.append({"file_id": file_id,
"tools": recipient_tools or [{"type": "file_search"}]})

if not self.thread:
self.init_thread()

if event_handler:
event_handler.set_agent(self.agent)
event_handler.set_recipient_agent(recipient_agent)
Expand Down Expand Up @@ -138,8 +146,8 @@ def get_completion(self,
self._run_until_done()

# function execution
if self.run.status == "requires_action":
tool_calls = self.run.required_action.submit_tool_outputs.tool_calls
if self._run.status == "requires_action":
tool_calls = self._run.required_action.submit_tool_outputs.tool_calls
tool_outputs_and_names = [] # list of tuples (name, tool_output)
sync_tool_calls = [tool_call for tool_call in tool_calls if tool_call.function.name.startswith("SendMessage")]
async_tool_calls = [tool_call for tool_call in tool_calls if not tool_call.function.name.startswith("SendMessage")]
Expand Down Expand Up @@ -224,11 +232,11 @@ def handle_output(tool_call, output):
self._create_run(recipient_agent, additional_instructions, event_handler, 'required', temperature=0)
self._run_until_done()

if self.run.status != "requires_action":
raise Exception("Run Failed. Error: ", self.run.last_error or self.run.incomplete_details)
if self._run.status != "requires_action":
raise Exception("Run Failed. Error: ", self._run.last_error or self._run.incomplete_details)

# change tool call ids
tool_calls = self.run.required_action.submit_tool_outputs.tool_calls
tool_calls = self._run.required_action.submit_tool_outputs.tool_calls

if len(tool_calls) != len(tool_outputs):
tool_outputs = []
Expand All @@ -246,10 +254,10 @@ def handle_output(tool_call, output):
else:
raise e
# error
elif self.run.status == "failed":
elif self._run.status == "failed":
full_message += self._get_last_message_text()
common_errors = ["something went wrong", "the server had an error processing your request", "rate limit reached"]
error_message = self.run.last_error.message.lower()
error_message = self._run.last_error.message.lower()

if error_attempts < 3 and any(error in error_message for error in common_errors):
if error_attempts < 2:
Expand All @@ -261,9 +269,9 @@ def handle_output(tool_call, output):
tool_choice, response_format=response_format)
error_attempts += 1
else:
raise Exception("OpenAI Run Failed. Error: ", self.run.last_error.message)
elif self.run.status == "incomplete":
raise Exception("OpenAI Run Incomplete. Details: ", self.run.incomplete_details)
raise Exception("OpenAI Run Failed. Error: ", self._run.last_error.message)
elif self._run.status == "incomplete":
raise Exception("OpenAI Run Incomplete. Details: ", self._run.incomplete_details)
# return assistant message
else:
message_obj = self._get_last_assistant_message()
Expand Down Expand Up @@ -318,7 +326,7 @@ def _create_run(self, recipient_agent, additional_instructions, event_handler, t
try:
if event_handler:
with self.client.beta.threads.runs.stream(
thread_id=self.thread.id,
thread_id=self.id,
event_handler=event_handler(),
assistant_id=recipient_agent.id,
additional_instructions=additional_instructions,
Expand All @@ -331,10 +339,10 @@ def _create_run(self, recipient_agent, additional_instructions, event_handler, t
response_format=response_format
) as stream:
stream.until_done()
self.run = stream.get_final_run()
self._run = stream.get_final_run()
else:
self.run = self.client.beta.threads.runs.create(
thread_id=self.thread.id,
self._run = self.client.beta.threads.runs.create(
thread_id=self.id,
assistant_id=recipient_agent.id,
additional_instructions=additional_instructions,
tool_choice=tool_choice,
Expand All @@ -345,68 +353,68 @@ def _create_run(self, recipient_agent, additional_instructions, event_handler, t
parallel_tool_calls=recipient_agent.parallel_tool_calls,
response_format=response_format
)
self.run = self.client.beta.threads.runs.poll(
thread_id=self.thread.id,
run_id=self.run.id,
self._run = self.client.beta.threads.runs.poll(
thread_id=self.id,
run_id=self._run.id,
# poll_interval_ms=500,
)
except APIError as e:
match = re.search(r"Thread (\w+) already has an active run (\w+)", e.message)
if match:
self._cancel_run(thread_id=match.groups()[0], run_id=match.groups()[1], check_status=False)
elif "The server had an error processing your request" in e.message and self.num_run_retries < 3:
elif "The server had an error processing your request" in e.message and self._num_run_retries < 3:
time.sleep(1)
self._create_run(recipient_agent, additional_instructions, event_handler, tool_choice, response_format=response_format)
self.num_run_retries += 1
self._num_run_retries += 1
else:
raise e

def _run_until_done(self):
while self.run.status in ['queued', 'in_progress', "cancelling"]:
while self._run.status in ['queued', 'in_progress', "cancelling"]:
time.sleep(0.5)
self.run = self.client.beta.threads.runs.retrieve(
thread_id=self.thread.id,
run_id=self.run.id
self._run = self.client.beta.threads.runs.retrieve(
thread_id=self.id,
run_id=self._run.id
)

def _submit_tool_outputs(self, tool_outputs, event_handler=None, poll=True):
if not poll:
self.run = self.client.beta.threads.runs.submit_tool_outputs(
thread_id=self.thread.id,
run_id=self.run.id,
self._run = self.client.beta.threads.runs.submit_tool_outputs(
thread_id=self.id,
run_id=self._run.id,
tool_outputs=tool_outputs
)
else:
if not event_handler:
self.run = self.client.beta.threads.runs.submit_tool_outputs_and_poll(
thread_id=self.thread.id,
run_id=self.run.id,
self._run = self.client.beta.threads.runs.submit_tool_outputs_and_poll(
thread_id=self.id,
run_id=self._run.id,
tool_outputs=tool_outputs
)
else:
with self.client.beta.threads.runs.submit_tool_outputs_stream(
thread_id=self.thread.id,
run_id=self.run.id,
thread_id=self.id,
run_id=self._run.id,
tool_outputs=tool_outputs,
event_handler=event_handler()
) as stream:
stream.until_done()
self.run = stream.get_final_run()
self._run = stream.get_final_run()

def _cancel_run(self, thread_id=None, run_id=None, check_status=True):
if check_status and self.run.status in self.terminal_states and not run_id:
if check_status and self._run.status in self.terminal_states and not run_id:
return

try:
self.run = self.client.beta.threads.runs.cancel(
thread_id=self.thread.id,
run_id=self.run.id
self._run = self.client.beta.threads.runs.cancel(
thread_id=self.id,
run_id=self._run.id
)
except BadRequestError as e:
if "Cannot cancel run with status" in e.message:
self.run = self.client.beta.threads.runs.poll(
thread_id=thread_id or self.thread.id,
run_id=run_id or self.run.id,
self._run = self.client.beta.threads.runs.poll(
thread_id=thread_id or self.id,
run_id=run_id or self._run.id,
poll_interval_ms=500,
)
else:
Expand Down
5 changes: 2 additions & 3 deletions agency_swarm/threads/thread_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,10 @@ def check_status(self, run=None):
return f"""{self.recipient_agent.name}'s Response: '{messages.data[0].content[0].text.value}'"""

def get_last_run(self):
if not self.thread:
self.init_thread()
self.init_thread()

runs = self.client.beta.threads.runs.list(
thread_id=self.thread.id,
thread_id=self.id,
order="desc",
)

Expand Down

0 comments on commit 2646b8a

Please sign in to comment.