From 64f2b54df0191339d6cdf8ca705db4cf12a8d34f Mon Sep 17 00:00:00 2001 From: Nick Bobrowski <39348559+bonk1t@users.noreply.github.com> Date: Fri, 20 Dec 2024 01:11:14 +0000 Subject: [PATCH] Bug fixes and improvements --- agency_swarm/agents/agent.py | 2 +- agency_swarm/{util => }/constants.py | 0 agency_swarm/threads/thread.py | 27 ++++++++--------- agency_swarm/util/tracking/__init__.py | 5 ++-- agency_swarm/util/tracking/langchain_types.py | 29 +++++++++++++++---- .../util/tracking/local_callback_handler.py | 4 +-- agency_swarm/util/validators.py | 2 +- pyproject.toml | 2 +- requirements_test.txt | 1 + tests/demos/demo_observability.py | 11 ++++--- tests/demos/streaming_demo.py | 10 +++---- tests/test_communication.py | 7 +++-- 12 files changed, 60 insertions(+), 40 deletions(-) rename agency_swarm/{util => }/constants.py (100%) diff --git a/agency_swarm/agents/agent.py b/agency_swarm/agents/agent.py index befd5c81..b1f73df3 100644 --- a/agency_swarm/agents/agent.py +++ b/agency_swarm/agents/agent.py @@ -9,6 +9,7 @@ from openai.lib._parsing._completions import type_to_response_format_param from openai.types.beta.assistant import ToolResources +from agency_swarm.constants import DEFAULT_MODEL from agency_swarm.tools import ( BaseTool, CodeInterpreter, @@ -17,7 +18,6 @@ ToolFactory, ) from agency_swarm.tools.oai.FileSearch import FileSearchConfig -from agency_swarm.util.constants import DEFAULT_MODEL from agency_swarm.util.oai import get_openai_client from agency_swarm.util.openapi import validate_openapi_spec from agency_swarm.util.shared_state import SharedState diff --git a/agency_swarm/util/constants.py b/agency_swarm/constants.py similarity index 100% rename from agency_swarm/util/constants.py rename to agency_swarm/constants.py diff --git a/agency_swarm/threads/thread.py b/agency_swarm/threads/thread.py index 6b6eb567..0ca27891 100644 --- a/agency_swarm/threads/thread.py +++ b/agency_swarm/threads/thread.py @@ -194,20 +194,21 @@ def get_completion( ) # chat model start callback - chat_messages = ( - [[HumanMessage(content=message)]] if isinstance(message, str) else [] - ) - if self.callback_handler and chat_messages: - self.callback_handler.on_chat_model_start( - serialized={"name": self._run.model}, - messages=chat_messages, - run_id=self._run.id, - parent_run_id=chain_run_id, - metadata={ - "agent_name": self.agent.name, - "recipient_agent_name": recipient_agent.name, - }, + if self.callback_handler: + chat_messages = ( + [[HumanMessage(content=message)]] if isinstance(message, str) else [] ) + if chat_messages: + self.callback_handler.on_chat_model_start( + serialized={"name": self._run.model}, + messages=chat_messages, + run_id=self._run.id, + parent_run_id=chain_run_id, + metadata={ + "agent_name": self.agent.name, + "recipient_agent_name": recipient_agent.name, + }, + ) try: error_attempts = 0 diff --git a/agency_swarm/util/tracking/__init__.py b/agency_swarm/util/tracking/__init__.py index 434749b2..0a995dbb 100644 --- a/agency_swarm/util/tracking/__init__.py +++ b/agency_swarm/util/tracking/__init__.py @@ -5,7 +5,8 @@ _lock = threading.Lock() -SUPPORTED_TRACKERS = Literal["langfuse", "local"] +SUPPORTED_TRACKERS = ["langfuse", "local"] +SUPPORTED_TRACKERS_TYPE = Literal["langfuse", "local"] def get_callback_handler(): @@ -20,7 +21,7 @@ def set_callback_handler(handler: Callable): _callback_handler = handler() -def init_tracking(tracker_name: SUPPORTED_TRACKERS, **kwargs): +def init_tracking(tracker_name: SUPPORTED_TRACKERS_TYPE, **kwargs): if tracker_name not in SUPPORTED_TRACKERS: raise ValueError(f"Invalid tracker name: {tracker_name}") diff --git a/agency_swarm/util/tracking/langchain_types.py b/agency_swarm/util/tracking/langchain_types.py index 22834080..0a8f4bfd 100644 --- a/agency_swarm/util/tracking/langchain_types.py +++ b/agency_swarm/util/tracking/langchain_types.py @@ -1,14 +1,33 @@ -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING, Any, Dict + +from pydantic import BaseModel if TYPE_CHECKING: from langchain.schema import AgentAction as LangchainAgentAction from langchain.schema import AgentFinish as LangchainAgentFinish from langchain.schema import HumanMessage as LangchainHumanMessage -# Define TypeVars that can be either our placeholder or langchain types -AgentAction = TypeVar("AgentAction", bound="LangchainAgentAction") -AgentFinish = TypeVar("AgentFinish", bound="LangchainAgentFinish") -HumanMessage = TypeVar("HumanMessage", bound="LangchainHumanMessage") + +# Create base classes that match langchain's structure +class BaseAgentAction(BaseModel): + tool: str + tool_input: Dict[str, Any] | str + log: str + + +class BaseAgentFinish(BaseModel): + return_values: Dict[str, Any] + log: str + + +class BaseHumanMessage(BaseModel): + content: str + + +# Initialize with our base implementations first +AgentAction = BaseAgentAction +AgentFinish = BaseAgentFinish +HumanMessage = BaseHumanMessage def use_langchain_types(): diff --git a/agency_swarm/util/tracking/local_callback_handler.py b/agency_swarm/util/tracking/local_callback_handler.py index c0422a76..8fcef20c 100644 --- a/agency_swarm/util/tracking/local_callback_handler.py +++ b/agency_swarm/util/tracking/local_callback_handler.py @@ -6,10 +6,8 @@ from langchain.schema import AgentAction, AgentFinish, BaseMessage, Document, LLMResult -from agency_swarm.util.tracking.callbacks import CallbackHandler - -class LocalCallbackHandler(CallbackHandler): +class LocalCallbackHandler: def __init__(self, db_path: str = "usage.db"): self.conn = sqlite3.connect(db_path, check_same_thread=False) self.lock = threading.Lock() diff --git a/agency_swarm/util/validators.py b/agency_swarm/util/validators.py index 3ac2f4bc..0b710501 100644 --- a/agency_swarm/util/validators.py +++ b/agency_swarm/util/validators.py @@ -3,7 +3,7 @@ from openai import OpenAI from pydantic import BaseModel, Field -from agency_swarm.util.constants import DEFAULT_MODEL_MINI +from agency_swarm.constants import DEFAULT_MODEL_MINI from agency_swarm.util.oai import get_openai_client diff --git a/pyproject.toml b/pyproject.toml index 5ed92e80..ecea1ac9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ dependencies = [ "rich==13.7.1", "termcolor==2.4.0", ] -requires-python = ">=3.10,<3.13" +requires-python = ">=3.10" urls = { homepage = "https://github.com/VRSEN/agency-swarm" } [project.scripts] diff --git a/requirements_test.txt b/requirements_test.txt index fe29eb20..55cb356d 100644 --- a/requirements_test.txt +++ b/requirements_test.txt @@ -1,2 +1,3 @@ langchain==0.3.13 +langchain-community==0.3.13 pytest diff --git a/tests/demos/demo_observability.py b/tests/demos/demo_observability.py index 41a454dc..fc455118 100644 --- a/tests/demos/demo_observability.py +++ b/tests/demos/demo_observability.py @@ -1,16 +1,15 @@ from dotenv import load_dotenv -load_dotenv() +from agency_swarm import Agency, Agent +from agency_swarm.util import init_tracking -from agency_swarm import Agency, Agent # noqa -from agency_swarm.util import init_tracking # noqa +load_dotenv() def main(): # Set the tracker type - TRACKER = "local" - # To use Langfuse, uncomment the next line - # TRACKER = "langfuse" + # TRACKER = "local" + TRACKER = "langfuse" # Initialize tracking based on the selected tracker init_tracking(TRACKER) diff --git a/tests/demos/streaming_demo.py b/tests/demos/streaming_demo.py index 66b0fe53..8cf7ed95 100644 --- a/tests/demos/streaming_demo.py +++ b/tests/demos/streaming_demo.py @@ -1,11 +1,9 @@ -import sys import time import unittest from agency_swarm import Agent, BaseTool from agency_swarm.agency.agency import Agency - -sys.path.insert(0, "../agency-swarm") +from agency_swarm.constants import DEFAULT_MODEL_MINI class StreamingTest(unittest.TestCase): @@ -19,12 +17,12 @@ def run(self): self.ceo = Agent( name="ceo", instructions="You are a CEO of an agency made for testing purposes.", - model="gpt-4o-mini", + model=DEFAULT_MODEL_MINI, ) self.test_agent1 = Agent( - name="test_agent1", tools=[TestTool], model="gpt-4o-mini" + name="test_agent1", tools=[TestTool], model=DEFAULT_MODEL_MINI ) - self.test_agent2 = Agent(name="test_agent2", model="gpt-4o-mini") + self.test_agent2 = Agent(name="test_agent2", model=DEFAULT_MODEL_MINI) self.agency = Agency( [ diff --git a/tests/test_communication.py b/tests/test_communication.py index be62b02b..7dd757b6 100644 --- a/tests/test_communication.py +++ b/tests/test_communication.py @@ -70,7 +70,10 @@ def test_send_message_double_recipient_error(self): ceo = Agent( name="CEO", description="Responsible for client communication, task planning and management.", - instructions="You are an agent for testing. Route request AT THE SAME TIME as instructed. If there is an error in a single request, please say 'error'. If there are errors in both requests, please say 'fatal'. do not output anything else.", + instructions="""You are an agent for testing. When asked to route requests AT THE SAME TIME: + 1. If you detect multiple simultaneous routing requests, respond with 'error' + 2. If you detect errors in all routing attempts, respond with 'fatal' + 3. Do not output anything else besides these exact words.""", ) test_agent = Agent( name="Test Agent1", @@ -79,7 +82,7 @@ def test_send_message_double_recipient_error(self): ) agency = Agency([ceo, [ceo, test_agent]], temperature=0) response = agency.get_completion( - "Please route me to customer support TWICE at the same time. I am testing something." + "Route me to customer support TWICE simultaneously (at the exact same time). This is a test of concurrent routing." ) self.assertTrue("error" in response.lower(), agency.main_thread.thread_url) self.assertTrue("fatal" not in response.lower(), agency.main_thread.thread_url)