From 2646b8adb2164f94d3a7ea03b562647074308421 Mon Sep 17 00:00:00 2001 From: Arsenii Shatokhin Date: Fri, 15 Nov 2024 09:49:59 +0400 Subject: [PATCH] Refactor thread and optimize thread initializetion --- agency_swarm/agency/agency.py | 4 + agency_swarm/threads/thread.py | 124 ++++++++++++++------------- agency_swarm/threads/thread_async.py | 5 +- 3 files changed, 72 insertions(+), 61 deletions(-) diff --git a/agency_swarm/agency/agency.py b/agency_swarm/agency/agency.py index 50840c40..3770d4aa 100644 --- a/agency_swarm/agency/agency.py +++ b/agency_swarm/agency/agency.py @@ -881,17 +881,21 @@ def _init_threads(self): # Save main_thread into agents_and_threads self.agents_and_threads["main_thread"] = self.main_thread + # initialize threads for agent_name, threads in self.agents_and_threads.items(): if agent_name == "main_thread": continue for other_agent, items in threads.items(): + # create thread class self.agents_and_threads[agent_name][other_agent] = self.send_message_tool_class._thread_type( self._get_agent_by_name(items["agent"]), self._get_agent_by_name( items["recipient_agent"])) + # load thread id if available if agent_name in loaded_thread_ids and other_agent in loaded_thread_ids[agent_name]: self.agents_and_threads[agent_name][other_agent].id = loaded_thread_ids[agent_name][other_agent] + # init threads if threre are threads callbacks so the ids are saved for later use elif self.threads_callbacks: self.agents_and_threads[agent_name][other_agent].init_thread() diff --git a/agency_swarm/threads/thread.py b/agency_swarm/threads/thread.py index 8a9e49bb..6fb9c9dd 100644 --- a/agency_swarm/threads/thread.py +++ b/agency_swarm/threads/thread.py @@ -28,6 +28,16 @@ class Thread: @property def thread_url(self): return f'https://platform.openai.com/playground/assistants?assistant={self.recipient_agent.id}&mode=assistant&thread={self.id}' + + @property + def thread(self): + self.init_thread() + + if not self._thread: + print("retrieving thread", self.id) + self._thread = self.client.beta.threads.retrieve(self.id) + + return self._thread def __init__(self, agent: Union[Agent, User], recipient_agent: Agent): self.agent = agent @@ -36,28 +46,27 @@ def __init__(self, agent: Union[Agent, User], recipient_agent: Agent): self.client = get_openai_client() self.id = None - self.thread = None - self.run = None - self.stream = None + self._thread = None + self._run = None + self._stream = None - self.num_run_retries = 0 + self._num_run_retries = 0 self.terminal_states = ["cancelled", "completed", "failed", "expired", "incomplete"] def init_thread(self): if self.id: - self.thread = self.client.beta.threads.retrieve(self.id) - else: - self.thread = self.client.beta.threads.create() - self.id = self.thread.id - - if self.recipient_agent.examples: - for example in self.recipient_agent.examples: - self.client.beta.threads.messages.create( - thread_id=self.id, - **example, - ) - + return + print("creating thread") + self._thread = self.client.beta.threads.create() + self.id = self._thread.id + if self.recipient_agent.examples: + for example in self.recipient_agent.examples: + self.client.beta.threads.messages.create( + thread_id=self.id, + **example, + ) + print("thread created", self.id) def get_completion_stream(self, message: Union[str, List[dict], None], event_handler: type(AgencyEventHandler), @@ -89,6 +98,8 @@ def get_completion(self, yield_messages: bool = False, response_format: Optional[dict] = None ): + self.init_thread() + if not recipient_agent: recipient_agent = self.recipient_agent @@ -107,9 +118,6 @@ def get_completion(self, attachments.append({"file_id": file_id, "tools": recipient_tools or [{"type": "file_search"}]}) - if not self.thread: - self.init_thread() - if event_handler: event_handler.set_agent(self.agent) event_handler.set_recipient_agent(recipient_agent) @@ -138,8 +146,8 @@ def get_completion(self, self._run_until_done() # function execution - if self.run.status == "requires_action": - tool_calls = self.run.required_action.submit_tool_outputs.tool_calls + if self._run.status == "requires_action": + tool_calls = self._run.required_action.submit_tool_outputs.tool_calls tool_outputs_and_names = [] # list of tuples (name, tool_output) sync_tool_calls = [tool_call for tool_call in tool_calls if tool_call.function.name.startswith("SendMessage")] async_tool_calls = [tool_call for tool_call in tool_calls if not tool_call.function.name.startswith("SendMessage")] @@ -224,11 +232,11 @@ def handle_output(tool_call, output): self._create_run(recipient_agent, additional_instructions, event_handler, 'required', temperature=0) self._run_until_done() - if self.run.status != "requires_action": - raise Exception("Run Failed. Error: ", self.run.last_error or self.run.incomplete_details) + if self._run.status != "requires_action": + raise Exception("Run Failed. Error: ", self._run.last_error or self._run.incomplete_details) # change tool call ids - tool_calls = self.run.required_action.submit_tool_outputs.tool_calls + tool_calls = self._run.required_action.submit_tool_outputs.tool_calls if len(tool_calls) != len(tool_outputs): tool_outputs = [] @@ -246,10 +254,10 @@ def handle_output(tool_call, output): else: raise e # error - elif self.run.status == "failed": + elif self._run.status == "failed": full_message += self._get_last_message_text() common_errors = ["something went wrong", "the server had an error processing your request", "rate limit reached"] - error_message = self.run.last_error.message.lower() + error_message = self._run.last_error.message.lower() if error_attempts < 3 and any(error in error_message for error in common_errors): if error_attempts < 2: @@ -261,9 +269,9 @@ def handle_output(tool_call, output): tool_choice, response_format=response_format) error_attempts += 1 else: - raise Exception("OpenAI Run Failed. Error: ", self.run.last_error.message) - elif self.run.status == "incomplete": - raise Exception("OpenAI Run Incomplete. Details: ", self.run.incomplete_details) + raise Exception("OpenAI Run Failed. Error: ", self._run.last_error.message) + elif self._run.status == "incomplete": + raise Exception("OpenAI Run Incomplete. Details: ", self._run.incomplete_details) # return assistant message else: message_obj = self._get_last_assistant_message() @@ -318,7 +326,7 @@ def _create_run(self, recipient_agent, additional_instructions, event_handler, t try: if event_handler: with self.client.beta.threads.runs.stream( - thread_id=self.thread.id, + thread_id=self.id, event_handler=event_handler(), assistant_id=recipient_agent.id, additional_instructions=additional_instructions, @@ -331,10 +339,10 @@ def _create_run(self, recipient_agent, additional_instructions, event_handler, t response_format=response_format ) as stream: stream.until_done() - self.run = stream.get_final_run() + self._run = stream.get_final_run() else: - self.run = self.client.beta.threads.runs.create( - thread_id=self.thread.id, + self._run = self.client.beta.threads.runs.create( + thread_id=self.id, assistant_id=recipient_agent.id, additional_instructions=additional_instructions, tool_choice=tool_choice, @@ -345,68 +353,68 @@ def _create_run(self, recipient_agent, additional_instructions, event_handler, t parallel_tool_calls=recipient_agent.parallel_tool_calls, response_format=response_format ) - self.run = self.client.beta.threads.runs.poll( - thread_id=self.thread.id, - run_id=self.run.id, + self._run = self.client.beta.threads.runs.poll( + thread_id=self.id, + run_id=self._run.id, # poll_interval_ms=500, ) except APIError as e: match = re.search(r"Thread (\w+) already has an active run (\w+)", e.message) if match: self._cancel_run(thread_id=match.groups()[0], run_id=match.groups()[1], check_status=False) - elif "The server had an error processing your request" in e.message and self.num_run_retries < 3: + elif "The server had an error processing your request" in e.message and self._num_run_retries < 3: time.sleep(1) self._create_run(recipient_agent, additional_instructions, event_handler, tool_choice, response_format=response_format) - self.num_run_retries += 1 + self._num_run_retries += 1 else: raise e def _run_until_done(self): - while self.run.status in ['queued', 'in_progress', "cancelling"]: + while self._run.status in ['queued', 'in_progress', "cancelling"]: time.sleep(0.5) - self.run = self.client.beta.threads.runs.retrieve( - thread_id=self.thread.id, - run_id=self.run.id + self._run = self.client.beta.threads.runs.retrieve( + thread_id=self.id, + run_id=self._run.id ) def _submit_tool_outputs(self, tool_outputs, event_handler=None, poll=True): if not poll: - self.run = self.client.beta.threads.runs.submit_tool_outputs( - thread_id=self.thread.id, - run_id=self.run.id, + self._run = self.client.beta.threads.runs.submit_tool_outputs( + thread_id=self.id, + run_id=self._run.id, tool_outputs=tool_outputs ) else: if not event_handler: - self.run = self.client.beta.threads.runs.submit_tool_outputs_and_poll( - thread_id=self.thread.id, - run_id=self.run.id, + self._run = self.client.beta.threads.runs.submit_tool_outputs_and_poll( + thread_id=self.id, + run_id=self._run.id, tool_outputs=tool_outputs ) else: with self.client.beta.threads.runs.submit_tool_outputs_stream( - thread_id=self.thread.id, - run_id=self.run.id, + thread_id=self.id, + run_id=self._run.id, tool_outputs=tool_outputs, event_handler=event_handler() ) as stream: stream.until_done() - self.run = stream.get_final_run() + self._run = stream.get_final_run() def _cancel_run(self, thread_id=None, run_id=None, check_status=True): - if check_status and self.run.status in self.terminal_states and not run_id: + if check_status and self._run.status in self.terminal_states and not run_id: return try: - self.run = self.client.beta.threads.runs.cancel( - thread_id=self.thread.id, - run_id=self.run.id + self._run = self.client.beta.threads.runs.cancel( + thread_id=self.id, + run_id=self._run.id ) except BadRequestError as e: if "Cannot cancel run with status" in e.message: - self.run = self.client.beta.threads.runs.poll( - thread_id=thread_id or self.thread.id, - run_id=run_id or self.run.id, + self._run = self.client.beta.threads.runs.poll( + thread_id=thread_id or self.id, + run_id=run_id or self._run.id, poll_interval_ms=500, ) else: diff --git a/agency_swarm/threads/thread_async.py b/agency_swarm/threads/thread_async.py index a6ae8f72..6f1ebd70 100644 --- a/agency_swarm/threads/thread_async.py +++ b/agency_swarm/threads/thread_async.py @@ -90,11 +90,10 @@ def check_status(self, run=None): return f"""{self.recipient_agent.name}'s Response: '{messages.data[0].content[0].text.value}'""" def get_last_run(self): - if not self.thread: - self.init_thread() + self.init_thread() runs = self.client.beta.threads.runs.list( - thread_id=self.thread.id, + thread_id=self.id, order="desc", )