diff --git a/agency_swarm/agents/agent.py b/agency_swarm/agents/agent.py index e582562e..59141112 100644 --- a/agency_swarm/agents/agent.py +++ b/agency_swarm/agents/agent.py @@ -377,29 +377,19 @@ def get_id_from_file(f_path): def add_tool(self, tool): if not isinstance(tool, type): raise Exception("Tool must not be initialized.") - if issubclass(tool, FileSearch): - # check that tools name is not already in tools - for t in self.tools: - if issubclass(t, FileSearch): - return - self.tools.append(tool) - elif issubclass(tool, CodeInterpreter): - for t in self.tools: - if issubclass(t, CodeInterpreter): - return - self.tools.append(tool) - elif issubclass(tool, Retrieval): - for t in self.tools: - if issubclass(t, Retrieval): - return - self.tools.append(tool) - elif issubclass(tool, BaseTool): + + subclasses = [FileSearch, CodeInterpreter, Retrieval] + for subclass in subclasses: + if issubclass(tool, subclass): + if not any(issubclass(t, subclass) for t in self.tools): + self.tools.append(tool) + return + + if issubclass(tool, BaseTool): if tool.__name__ == "ExampleTool": print("Skipping importing ExampleTool...") return - for t in self.tools: - if t.__name__ == tool.__name__: - self.tools.remove(t) + self.tools = [t for t in self.tools if t.__name__ != tool.__name__] self.tools.append(tool) else: raise Exception("Invalid tool type.")