Skip to content

Commit

Permalink
Bug fixes and improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
bonk1t committed Dec 20, 2024
1 parent 8612c3b commit 64f2b54
Show file tree
Hide file tree
Showing 12 changed files with 60 additions and 40 deletions.
2 changes: 1 addition & 1 deletion agency_swarm/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from openai.lib._parsing._completions import type_to_response_format_param
from openai.types.beta.assistant import ToolResources

from agency_swarm.constants import DEFAULT_MODEL
from agency_swarm.tools import (
BaseTool,
CodeInterpreter,
Expand All @@ -17,7 +18,6 @@
ToolFactory,
)
from agency_swarm.tools.oai.FileSearch import FileSearchConfig
from agency_swarm.util.constants import DEFAULT_MODEL
from agency_swarm.util.oai import get_openai_client
from agency_swarm.util.openapi import validate_openapi_spec
from agency_swarm.util.shared_state import SharedState
Expand Down
File renamed without changes.
27 changes: 14 additions & 13 deletions agency_swarm/threads/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,20 +194,21 @@ def get_completion(
)

# chat model start callback
chat_messages = (
[[HumanMessage(content=message)]] if isinstance(message, str) else []
)
if self.callback_handler and chat_messages:
self.callback_handler.on_chat_model_start(
serialized={"name": self._run.model},
messages=chat_messages,
run_id=self._run.id,
parent_run_id=chain_run_id,
metadata={
"agent_name": self.agent.name,
"recipient_agent_name": recipient_agent.name,
},
if self.callback_handler:
chat_messages = (
[[HumanMessage(content=message)]] if isinstance(message, str) else []
)
if chat_messages:
self.callback_handler.on_chat_model_start(
serialized={"name": self._run.model},
messages=chat_messages,
run_id=self._run.id,
parent_run_id=chain_run_id,
metadata={
"agent_name": self.agent.name,
"recipient_agent_name": recipient_agent.name,
},
)

try:
error_attempts = 0
Expand Down
5 changes: 3 additions & 2 deletions agency_swarm/util/tracking/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
_lock = threading.Lock()


SUPPORTED_TRACKERS = Literal["langfuse", "local"]
SUPPORTED_TRACKERS = ["langfuse", "local"]
SUPPORTED_TRACKERS_TYPE = Literal["langfuse", "local"]


def get_callback_handler():
Expand All @@ -20,7 +21,7 @@ def set_callback_handler(handler: Callable):
_callback_handler = handler()


def init_tracking(tracker_name: SUPPORTED_TRACKERS, **kwargs):
def init_tracking(tracker_name: SUPPORTED_TRACKERS_TYPE, **kwargs):
if tracker_name not in SUPPORTED_TRACKERS:
raise ValueError(f"Invalid tracker name: {tracker_name}")

Expand Down
29 changes: 24 additions & 5 deletions agency_swarm/util/tracking/langchain_types.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,33 @@
from typing import TYPE_CHECKING, TypeVar
from typing import TYPE_CHECKING, Any, Dict

from pydantic import BaseModel

if TYPE_CHECKING:
from langchain.schema import AgentAction as LangchainAgentAction
from langchain.schema import AgentFinish as LangchainAgentFinish
from langchain.schema import HumanMessage as LangchainHumanMessage

# Define TypeVars that can be either our placeholder or langchain types
AgentAction = TypeVar("AgentAction", bound="LangchainAgentAction")
AgentFinish = TypeVar("AgentFinish", bound="LangchainAgentFinish")
HumanMessage = TypeVar("HumanMessage", bound="LangchainHumanMessage")

# Create base classes that match langchain's structure
class BaseAgentAction(BaseModel):
tool: str
tool_input: Dict[str, Any] | str
log: str


class BaseAgentFinish(BaseModel):
return_values: Dict[str, Any]
log: str


class BaseHumanMessage(BaseModel):
content: str


# Initialize with our base implementations first
AgentAction = BaseAgentAction
AgentFinish = BaseAgentFinish
HumanMessage = BaseHumanMessage


def use_langchain_types():
Expand Down
4 changes: 1 addition & 3 deletions agency_swarm/util/tracking/local_callback_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@

from langchain.schema import AgentAction, AgentFinish, BaseMessage, Document, LLMResult

from agency_swarm.util.tracking.callbacks import CallbackHandler


class LocalCallbackHandler(CallbackHandler):
class LocalCallbackHandler:
def __init__(self, db_path: str = "usage.db"):
self.conn = sqlite3.connect(db_path, check_same_thread=False)
self.lock = threading.Lock()
Expand Down
2 changes: 1 addition & 1 deletion agency_swarm/util/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from openai import OpenAI
from pydantic import BaseModel, Field

from agency_swarm.util.constants import DEFAULT_MODEL_MINI
from agency_swarm.constants import DEFAULT_MODEL_MINI
from agency_swarm.util.oai import get_openai_client


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ dependencies = [
"rich==13.7.1",
"termcolor==2.4.0",
]
requires-python = ">=3.10,<3.13"
requires-python = ">=3.10"
urls = { homepage = "https://github.com/VRSEN/agency-swarm" }

[project.scripts]
Expand Down
1 change: 1 addition & 0 deletions requirements_test.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
langchain==0.3.13
langchain-community==0.3.13
pytest
11 changes: 5 additions & 6 deletions tests/demos/demo_observability.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from dotenv import load_dotenv

load_dotenv()
from agency_swarm import Agency, Agent
from agency_swarm.util import init_tracking

from agency_swarm import Agency, Agent # noqa
from agency_swarm.util import init_tracking # noqa
load_dotenv()


def main():
# Set the tracker type
TRACKER = "local"
# To use Langfuse, uncomment the next line
# TRACKER = "langfuse"
# TRACKER = "local"
TRACKER = "langfuse"

# Initialize tracking based on the selected tracker
init_tracking(TRACKER)
Expand Down
10 changes: 4 additions & 6 deletions tests/demos/streaming_demo.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import sys
import time
import unittest

from agency_swarm import Agent, BaseTool
from agency_swarm.agency.agency import Agency

sys.path.insert(0, "../agency-swarm")
from agency_swarm.constants import DEFAULT_MODEL_MINI


class StreamingTest(unittest.TestCase):
Expand All @@ -19,12 +17,12 @@ def run(self):
self.ceo = Agent(
name="ceo",
instructions="You are a CEO of an agency made for testing purposes.",
model="gpt-4o-mini",
model=DEFAULT_MODEL_MINI,
)
self.test_agent1 = Agent(
name="test_agent1", tools=[TestTool], model="gpt-4o-mini"
name="test_agent1", tools=[TestTool], model=DEFAULT_MODEL_MINI
)
self.test_agent2 = Agent(name="test_agent2", model="gpt-4o-mini")
self.test_agent2 = Agent(name="test_agent2", model=DEFAULT_MODEL_MINI)

self.agency = Agency(
[
Expand Down
7 changes: 5 additions & 2 deletions tests/test_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@ def test_send_message_double_recipient_error(self):
ceo = Agent(
name="CEO",
description="Responsible for client communication, task planning and management.",
instructions="You are an agent for testing. Route request AT THE SAME TIME as instructed. If there is an error in a single request, please say 'error'. If there are errors in both requests, please say 'fatal'. do not output anything else.",
instructions="""You are an agent for testing. When asked to route requests AT THE SAME TIME:
1. If you detect multiple simultaneous routing requests, respond with 'error'
2. If you detect errors in all routing attempts, respond with 'fatal'
3. Do not output anything else besides these exact words.""",
)
test_agent = Agent(
name="Test Agent1",
Expand All @@ -79,7 +82,7 @@ def test_send_message_double_recipient_error(self):
)
agency = Agency([ceo, [ceo, test_agent]], temperature=0)
response = agency.get_completion(
"Please route me to customer support TWICE at the same time. I am testing something."
"Route me to customer support TWICE simultaneously (at the exact same time). This is a test of concurrent routing."
)
self.assertTrue("error" in response.lower(), agency.main_thread.thread_url)
self.assertTrue("fatal" not in response.lower(), agency.main_thread.thread_url)
Expand Down

0 comments on commit 64f2b54

Please sign in to comment.