From 8c320435445c0d17def735e3ac408c76569b9f45 Mon Sep 17 00:00:00 2001 From: Arsenii Shatokhin Date: Thu, 29 Feb 2024 08:01:00 +0400 Subject: [PATCH] Added tool factory to api reference --- agency_swarm/tools/ToolFactory.py | 72 +++++++++++++++++++++++-------- docs/api.md | 4 +- 2 files changed, 56 insertions(+), 20 deletions(-) diff --git a/agency_swarm/tools/ToolFactory.py b/agency_swarm/tools/ToolFactory.py index a26c2fab..595478f7 100644 --- a/agency_swarm/tools/ToolFactory.py +++ b/agency_swarm/tools/ToolFactory.py @@ -16,11 +16,15 @@ class ToolFactory: @staticmethod - def from_langchain_tools(tools: List): + def from_langchain_tools(tools: List) -> List[Type[BaseTool]]: """ Converts a list of langchain tools into a list of BaseTools. - :param tools: A list of langchain tools. - :return: A list of BaseTools. + + Parameters: + tools: The langchain tools to convert. + + Returns: + A list of BaseTools. """ converted_tools = [] for tool in tools: @@ -29,11 +33,15 @@ def from_langchain_tools(tools: List): return converted_tools @staticmethod - def from_langchain_tool(tool): + def from_langchain_tool(tool) -> Type[BaseTool]: """ Converts a langchain tool into a BaseTool. - :param tool: A langchain tool. - :return: A BaseTool. + + Parameters: + tool: The langchain tool to convert. + + Returns: + A BaseTool. """ try: from langchain.tools import format_tool_to_openai_function @@ -61,12 +69,16 @@ def callback(self): @staticmethod - def from_openai_schema(schema: Dict[str, Any], callback: Any): + def from_openai_schema(schema: Dict[str, Any], callback: Any) -> Type[BaseTool]: """ Converts an OpenAI schema into a BaseTool. Nested propoerties without refs are not supported yet. - :param schema: - :param callback: - :return: + + Parameters: + schema: The OpenAI schema to convert. + callback: The function to run when the tool is called. + + Returns: + A BaseTool. """ def resolve_ref(ref: str, defs: Dict[str, Any]) -> Any: # Extract the key from the reference @@ -167,7 +179,19 @@ def create_fields(schema: Dict[str, Any], type_mapping: Dict[str, Type[Any]], re return tool @staticmethod - def from_openapi_schema(schema: Union[str, dict], headers: Dict[str, str] = None, params: Dict[str, Any] = None): + def from_openapi_schema(schema: Union[str, dict], headers: Dict[str, str] = None, params: Dict[str, Any] = None) \ + -> List[Type[BaseTool]]: + """ + Converts an OpenAPI schema into a list of BaseTools. + + Parameters: + schema: The OpenAPI schema to convert. + headers: The headers to use for requests. + params: The parameters to use for requests. + + Returns: + A list of BaseTools. + """ if isinstance(schema, dict): openapi_spec = schema openapi_spec = jsonref.JsonRef.replace_refs(openapi_spec) @@ -260,8 +284,16 @@ def callback(self): return tools @staticmethod - def from_file(file_path: str): - """Dynamically imports a class from a Python file, ensuring BaseTool itself is not imported.""" + def from_file(file_path: str) -> Type[BaseTool]: + """Dynamically imports a BaseTool from a Python file. The file must be named the same as the class. + + Parameters: + file_path: The file path to the Python file containing the BaseTool class. + + Returns: + The BaseTool class from the given file path. + + """ # Extract class name from file path (assuming class name matches file name without .py extension) class_name = os.path.basename(file_path) if class_name.endswith('.py'): @@ -283,16 +315,18 @@ def from_file(file_path: str): @staticmethod def get_openapi_schema(tools: List[Type[BaseTool]], url: str, title="Agent Tools", - description="A collection of tools."): + description="A collection of tools.") -> str: """ Generates an OpenAPI schema from a list of BaseTools. - :param tools: BaseTools to generate the schema from. - :param url: The base URL for the schema. - :param title: The title of the schema. - :param description: The description of the schema. + Parameters: + tools: BaseTools to generate the schema from. + url: The base URL for the schema. + title: The title of the schema. + description: The description of the schema. - :return: A JSON string representing the OpenAPI schema with all the tools combined as separate endpoints. + Returns: + A JSON string representing the OpenAPI schema with all the tools combined as separate endpoints. """ schema = { "openapi": "3.1.0", diff --git a/docs/api.md b/docs/api.md index 9523a1d9..bb38f313 100644 --- a/docs/api.md +++ b/docs/api.md @@ -2,4 +2,6 @@ ::: agency_swarm.agents.agent -::: agency_swarm.agency.agency \ No newline at end of file +::: agency_swarm.agency + +::: agency_swarm.tools.ToolFactory \ No newline at end of file