From 71272df9e8351c575b62395fc89808d80ea1187c Mon Sep 17 00:00:00 2001 From: Arsenii Shatokhin Date: Tue, 23 Jan 2024 09:41:42 +0400 Subject: [PATCH] Added printing openapi schemas for agents --- agency_swarm/agency/agency.py | 267 ++++++++++++++------------- agency_swarm/agents/agent.py | 66 +++++++ agency_swarm/threads/thread_async.py | 5 +- 3 files changed, 207 insertions(+), 131 deletions(-) diff --git a/agency_swarm/agency/agency.py b/agency_swarm/agency/agency.py index 5d631337..b5ea5a58 100644 --- a/agency_swarm/agency/agency.py +++ b/agency_swarm/agency/agency.py @@ -73,79 +73,6 @@ def __init__(self, agency_chart: List, shared_instructions: str = "", shared_fil self.user = User() self.main_thread = Thread(self.user, self.ceo) - def _init_agents(self): - """ - Initializes all agents in the agency with unique IDs, shared instructions, and OpenAI models. - - This method iterates through each agent in the agency, assigns a unique ID, adds shared instructions, and initializes the OpenAI models for each agent. - - There are no input parameters. - - There are no output parameters as this method is used for internal initialization purposes within the Agency class. - """ - if self.settings_callbacks: - loaded_settings = self.settings_callbacks["load"]() - with open(self.agents[0].get_settings_path(), 'w') as f: - json.dump(loaded_settings, f, indent=4) - - for agent in self.agents: - if "temp_id" in agent.id: - agent.id = None - agent.add_shared_instructions(self.shared_instructions) - - if self.shared_files: - if isinstance(agent.files_folder, str): - agent.files_folder = [agent.files_folder] - agent.files_folder += self.shared_files - elif isinstance(agent.files_folder, list): - agent.files_folder += self.shared_files - - agent.init_oai() - - if self.settings_callbacks: - with open(self.agents[0].get_settings_path(), 'r') as f: - settings = f.read() - settings = json.loads(settings) - self.settings_callbacks["save"](settings) - - def _init_threads(self): - """ - Initializes threads for communication between agents within the agency. - - This method creates Thread objects for each pair of interacting agents as defined in the agents_and_threads attribute of the Agency. Each thread facilitates communication and task execution between an agent and its designated recipient agent. - - No input parameters. - - Output Parameters: - This method does not return any value but updates the agents_and_threads attribute with initialized Thread objects. - """ - # load thread ids - loaded_thread_ids = {} - if self.threads_callbacks: - loaded_thread_ids = self.threads_callbacks["load"]() - - for agent_name, threads in self.agents_and_threads.items(): - for other_agent, items in threads.items(): - self.agents_and_threads[agent_name][other_agent] = self.ThreadType( - self.get_agent_by_name(items["agent"]), - self.get_agent_by_name( - items["recipient_agent"])) - - if agent_name in loaded_thread_ids and other_agent in loaded_thread_ids[agent_name]: - self.agents_and_threads[agent_name][other_agent].id = loaded_thread_ids[agent_name][other_agent] - elif self.threads_callbacks: - self.agents_and_threads[agent_name][other_agent].init_thread() - - # save thread ids - if self.threads_callbacks: - loaded_thread_ids = {} - for agent_name, threads in self.agents_and_threads.items(): - loaded_thread_ids[agent_name] = {} - for other_agent, thread in threads.items(): - loaded_thread_ids[agent_name][other_agent] = thread.id - - self.threads_callbacks["save"](loaded_thread_ids) - def get_completion(self, message: str, message_files=None, yield_messages=True): """ Retrieves the completion for a given message from the main thread. @@ -244,6 +171,92 @@ def run_demo(self): except StopIteration as e: pass + def get_openapi_schema(self, url: str): + """Returns the OpenAPI schema for the agency from the CEO agent, that you can use to integrate with custom gpts. + + Parameters: + url (str): Your server url where the api will be hosted. + """ + + return self.ceo.get_openapi_schema(url) + + + def plot_agency_chart(self): + pass + + def _init_agents(self): + """ + Initializes all agents in the agency with unique IDs, shared instructions, and OpenAI models. + + This method iterates through each agent in the agency, assigns a unique ID, adds shared instructions, and initializes the OpenAI models for each agent. + + There are no input parameters. + + There are no output parameters as this method is used for internal initialization purposes within the Agency class. + """ + if self.settings_callbacks: + loaded_settings = self.settings_callbacks["load"]() + with open(self.agents[0].get_settings_path(), 'w') as f: + json.dump(loaded_settings, f, indent=4) + + for agent in self.agents: + if "temp_id" in agent.id: + agent.id = None + agent.add_shared_instructions(self.shared_instructions) + + if self.shared_files: + if isinstance(agent.files_folder, str): + agent.files_folder = [agent.files_folder] + agent.files_folder += self.shared_files + elif isinstance(agent.files_folder, list): + agent.files_folder += self.shared_files + + agent.init_oai() + + if self.settings_callbacks: + with open(self.agents[0].get_settings_path(), 'r') as f: + settings = f.read() + settings = json.loads(settings) + self.settings_callbacks["save"](settings) + + def _init_threads(self): + """ + Initializes threads for communication between agents within the agency. + + This method creates Thread objects for each pair of interacting agents as defined in the agents_and_threads attribute of the Agency. Each thread facilitates communication and task execution between an agent and its designated recipient agent. + + No input parameters. + + Output Parameters: + This method does not return any value but updates the agents_and_threads attribute with initialized Thread objects. + """ + # load thread ids + loaded_thread_ids = {} + if self.threads_callbacks: + loaded_thread_ids = self.threads_callbacks["load"]() + + for agent_name, threads in self.agents_and_threads.items(): + for other_agent, items in threads.items(): + self.agents_and_threads[agent_name][other_agent] = self.ThreadType( + self.get_agent_by_name(items["agent"]), + self.get_agent_by_name( + items["recipient_agent"])) + + if agent_name in loaded_thread_ids and other_agent in loaded_thread_ids[agent_name]: + self.agents_and_threads[agent_name][other_agent].id = loaded_thread_ids[agent_name][other_agent] + elif self.threads_callbacks: + self.agents_and_threads[agent_name][other_agent].init_thread() + + # save thread ids + if self.threads_callbacks: + loaded_thread_ids = {} + for agent_name, threads in self.agents_and_threads.items(): + loaded_thread_ids[agent_name] = {} + for other_agent, thread in threads.items(): + loaded_thread_ids[agent_name][other_agent] = thread.id + + self.threads_callbacks["save"](loaded_thread_ids) + def _parse_agency_chart(self, agency_chart): """ Parses the provided agency chart to initialize and organize agents within the agency. @@ -311,57 +324,6 @@ def _add_agent(self, agent): else: return self.get_agent_ids().index(agent.id) - def get_agent_by_name(self, agent_name): - """ - Retrieves an agent from the agency based on the agent's name. - - Parameters: - agent_name (str): The name of the agent to be retrieved. - - Returns: - Agent: The agent object with the specified name. - - Raises: - Exception: If no agent with the given name is found in the agency. - """ - for agent in self.agents: - if agent.name == agent_name: - return agent - raise Exception(f"Agent {agent_name} not found.") - - def get_agents_by_names(self, agent_names): - """ - Retrieves a list of agent objects based on their names. - - Parameters: - agent_names: A list of strings representing the names of the agents to be retrieved. - - Returns: - A list of Agent objects corresponding to the given names. - """ - return [self.get_agent_by_name(agent_name) for agent_name in agent_names] - - def get_agent_ids(self): - """ - Retrieves the IDs of all agents currently in the agency. - - Returns: - List[str]: A list containing the unique IDs of all agents. - """ - return [agent.id for agent in self.agents] - - def get_agent_names(self): - """ - Retrieves the names of all agents in the agency. - - Parameters: - None - - Returns: - List[str]: A list of names of all agents currently part of the agency. - """ - return [agent.name for agent in self.agents] - def _read_instructions(self, path): """ Reads shared instructions from a specified file and stores them in the agency. @@ -375,9 +337,6 @@ def _read_instructions(self, path): with open(path, 'r') as f: self.shared_instructions = f.read() - def plot_agency_chart(self): - pass - def _create_special_tools(self): """ Creates and assigns 'SendMessage' tools to each agent based on the agency's structure. @@ -483,7 +442,8 @@ def _create_get_response_tool(self, agent: Agent, recipient_agents: List[Agent]) class GetResponse(BaseTool): """This tool allows you to check the status of a task or get a response from a specified recipient agent, if the task has been completed. You must always use 'SendMessage' tool first.""" - recipient: recipients = Field(..., description=f"Recipient agent that you want to check the status of. Valid recipients are: {recipient_names}") + recipient: recipients = Field(..., + description=f"Recipient agent that you want to check the status of. Valid recipients are: {recipient_names}") caller_agent_name: str = Field(default=agent.name, description="The agent calling this tool. Defaults to your name. Do not change it.") @@ -508,6 +468,57 @@ def run(self): return GetResponse + def get_agent_by_name(self, agent_name): + """ + Retrieves an agent from the agency based on the agent's name. + + Parameters: + agent_name (str): The name of the agent to be retrieved. + + Returns: + Agent: The agent object with the specified name. + + Raises: + Exception: If no agent with the given name is found in the agency. + """ + for agent in self.agents: + if agent.name == agent_name: + return agent + raise Exception(f"Agent {agent_name} not found.") + + def get_agents_by_names(self, agent_names): + """ + Retrieves a list of agent objects based on their names. + + Parameters: + agent_names: A list of strings representing the names of the agents to be retrieved. + + Returns: + A list of Agent objects corresponding to the given names. + """ + return [self.get_agent_by_name(agent_name) for agent_name in agent_names] + + def get_agent_ids(self): + """ + Retrieves the IDs of all agents currently in the agency. + + Returns: + List[str]: A list containing the unique IDs of all agents. + """ + return [agent.id for agent in self.agents] + + def get_agent_names(self): + """ + Retrieves the names of all agents in the agency. + + Parameters: + None + + Returns: + List[str]: A list of names of all agents currently part of the agency. + """ + return [agent.name for agent in self.agents] + def get_recipient_names(self): """ Retrieves the names of all agents in the agency. diff --git a/agency_swarm/agents/agent.py b/agency_swarm/agents/agent.py index 894f2626..1f6d39f5 100644 --- a/agency_swarm/agents/agent.py +++ b/agency_swarm/agents/agent.py @@ -316,6 +316,72 @@ def _parse_schemas(self): else: raise Exception("Schemas folder path must be a string or list of strings.") + def get_openapi_schema(self, url): + """Get openapi schema that contains all tools from the agent as different api paths. Make sure to call this after agency has been initialized.""" + if self.assistant is None: + raise Exception("Assistant is not initialized. Please initialize the agency first, before using this method") + + schema = { + "openapi": "3.1.0", + "info": { + "title": self.name, + "description": self.description if self.description else "", + "version": "v1.0.0" + }, + "servers": [ + { + "url": url, + } + ], + "paths": {}, + "components": { + "schemas": {}, + "securitySchemes": { + "apiKey": { + "type": "apiKey" + } + } + }, + } + + for tool in self.tools: + if issubclass(tool, BaseTool): + openai_schema = tool.openai_schema + defs = {} + if '$defs' in openai_schema['parameters']: + defs = openai_schema['parameters']['$defs'] + del openai_schema['parameters']['$defs'] + + schema['paths']["/" + tool.__name__] = { + "post": { + "description": openai_schema['description'], + "operationId": tool.__name__, + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/" + tool.__name__ + } + } + }, + "required": True, + }, + "deprecated": False, + "security": [ + { + "apiKey": [] + } + ] + } + } + + if defs: + schema['components']['schemas'][tool.__name__] = openai_schema['parameters'], + schema['components']['schemas'].update(defs) + + return json.dumps(schema, indent=2).replace("#/$defs/", "#/components/schemas/") + # --- Settings Methods --- def _check_parameters(self, assistant_settings): diff --git a/agency_swarm/threads/thread_async.py b/agency_swarm/threads/thread_async.py index ba046bcf..e9a95e47 100644 --- a/agency_swarm/threads/thread_async.py +++ b/agency_swarm/threads/thread_async.py @@ -1,10 +1,9 @@ -from agency_swarm.threads import Thread import threading from typing import Literal + from agency_swarm.agents import Agent -from agency_swarm.messages import MessageOutput +from agency_swarm.threads import Thread from agency_swarm.user import User -from agency_swarm.util.oai import get_openai_client class ThreadAsync(Thread):