Skip to content

Commit

Permalink
Integrate AgentOps; Minor fixes in thread.py
Browse files Browse the repository at this point in the history
  • Loading branch information
bonk1t committed Dec 24, 2024
1 parent f53429f commit a0e926a
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 149 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ celerybeat.pid

# Environments
.env
.venv
.venv/
env/
venv/
ENV/
Expand Down
133 changes: 57 additions & 76 deletions agency_swarm/threads/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import re
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Type, Union
from typing import Any, Type, Union
from uuid import UUID

from openai import APIError, BadRequestError
Expand Down Expand Up @@ -89,10 +89,10 @@ def init_thread(self):

def get_completion_stream(
self,
message: Union[str, List[dict], None],
message: Union[str, list[dict], None],
event_handler: Type[AgencyEventHandler] | None = None,
message_files: List[str] = None,
attachments: List[Attachment] | None = None,
message_files: list[str] = None,
attachments: list[Attachment] | None = None,
recipient_agent: Agent = None,
additional_instructions: str = None,
tool_choice: AssistantToolChoice = None,
Expand All @@ -112,9 +112,9 @@ def get_completion_stream(

def get_completion(
self,
message: Union[str, List[dict], None],
message_files: List[str] = None,
attachments: List[dict] | None = None,
message: Union[str, list[dict], None],
message_files: list[str] = None,
attachments: list[dict] | None = None,
recipient_agent: Union[Agent, None] = None,
additional_instructions: str = None,
event_handler: Type[AgencyEventHandler] | None = None,
Expand Down Expand Up @@ -178,6 +178,7 @@ def get_completion(
)

chain_run_id = self._run.id
final_output = None # Track the final output

# Chain start
if self.callback_handler:
Expand All @@ -198,29 +199,24 @@ def get_completion(

# chat model start callback
if self.callback_handler:
chat_messages = []
if isinstance(message, str):
chat_messages = [[HumanMessage(content=message)]]

kwargs = {
"invocation_params": {
"_type": "openai",
"model": self._run.model,
"temperature": self._run.temperature,
},
"name": recipient_agent.name if recipient_agent else "Unknown",
}
agent_name = recipient_agent.name if recipient_agent else "Unknown"

self.callback_handler.on_chat_model_start(
serialized={"name": kwargs["name"], "id": [self._run.id]},
messages=chat_messages,
serialized={"name": agent_name, "id": [self._run.id]},
messages=[[HumanMessage(content=message)]],
run_id=self._run.id,
parent_run_id=chain_run_id,
metadata={
"agent_name": self.agent.name,
"recipient_agent_name": recipient_agent.name,
},
**kwargs,
invocation_params={
"_type": "openai",
"model": self._run.model,
"temperature": self._run.temperature,
},
name=agent_name,
)

try:
Expand All @@ -230,13 +226,12 @@ def get_completion(
while True:
self._run_until_done()

# function execution
if self._run.status == "requires_action":
self._called_recepients = []
tool_calls = (
self._run.required_action.submit_tool_outputs.tool_calls
)
tool_outputs_and_names = [] # list of tuples (name, tool_output)
tool_outputs_and_names: list[tuple[str, Any]] = []

self._track_tool_calls(tool_calls, chain_run_id)

Expand Down Expand Up @@ -309,23 +304,8 @@ def handle_output(tool_call, output):
output = yield from handle_output(tool_call, output)
if output_as_result:
self._cancel_run()
# chain end
if self.callback_handler:
self.callback_handler.on_chain_end(
outputs={"response": output},
run_id=chain_run_id,
parent_run_id=parent_run_id,
)
finish = AgentFinish(
return_values={"response": output},
log=output,
)
self.callback_handler.on_agent_finish(
finish=finish,
run_id=self._run.id,
parent_run_id=chain_run_id,
)
return output
final_output = output
break
else:
sync_tool_calls += async_tool_calls

Expand Down Expand Up @@ -354,22 +334,11 @@ def handle_output(tool_call, output):
output = yield from handle_output(tool_call, output)
if output_as_result:
self._cancel_run()
# chain end
if self.callback_handler:
self.callback_handler.on_chain_end(
outputs={"response": output},
run_id=chain_run_id,
parent_run_id=parent_run_id,
)
finish = AgentFinish(
return_values={"response": output}, log=output
)
self.callback_handler.on_agent_finish(
finish=finish,
run_id=self._run.id,
parent_run_id=chain_run_id,
)
return output
final_output = output
break

if final_output is not None:
break

tool_outputs = [t for _, t in tool_outputs_and_names]
tool_names = [n for n, _ in tool_outputs_and_names]
Expand Down Expand Up @@ -547,23 +516,35 @@ def handle_output(tool_call, output):
)
continue

# chain end
if self.callback_handler:
self.callback_handler.on_chain_end(
outputs={"response": last_message},
run_id=chain_run_id,
parent_run_id=parent_run_id,
)
# agent finish
finish = AgentFinish(
return_values={"response": last_message}, log=last_message
)
self.callback_handler.on_agent_finish(
finish=finish,
run_id=self._run.id,
parent_run_id=chain_run_id,
)
return last_message
if final_output is None:
final_output = last_message
break

# Ensure final_output is a string
if inspect.isgenerator(final_output):
final_output = "".join(list(final_output))
elif not isinstance(final_output, str):
final_output = str(final_output)

# Only fire callbacks once at the end
if self.callback_handler:
self.callback_handler.on_chain_end(
outputs={"response": final_output},
run_id=chain_run_id,
parent_run_id=parent_run_id,
)
finish = AgentFinish(
return_values={"response": final_output},
log=final_output,
)
self.callback_handler.on_agent_finish(
finish=finish,
run_id=self._run.id,
parent_run_id=chain_run_id,
)

return final_output

except Exception as e:
# chain error
if self.callback_handler:
Expand Down Expand Up @@ -720,7 +701,7 @@ def _get_last_assistant_message(self):
raise Exception("No assistant message found in the thread")

def create_message(
self, message: str, role: str = "user", attachments: List[dict] = None
self, message: str, role: str = "user", attachments: list[dict] = None
):
try:
return self.client.beta.threads.messages.create(
Expand Down Expand Up @@ -921,7 +902,7 @@ def get_messages(self, limit=None):

return all_messages

def _track_tool_calls(self, tool_calls: List[ToolCall], chain_run_id: str) -> None:
def _track_tool_calls(self, tool_calls: list[ToolCall], parent_run_id: str) -> None:
"""Send agent_action before each tool call"""
if not self.callback_handler:
return
Expand All @@ -936,5 +917,5 @@ def _track_tool_calls(self, tool_calls: List[ToolCall], chain_run_id: str) -> No
self.callback_handler.on_agent_action(
action=action,
run_id=self._run.id,
parent_run_id=chain_run_id,
parent_run_id=parent_run_id,
)
21 changes: 16 additions & 5 deletions agency_swarm/util/tracking/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import threading
from typing import Callable, Literal

Expand All @@ -6,9 +7,11 @@
_callback_handler = None
_lock = threading.Lock()

logger = logging.getLogger(__name__)

SUPPORTED_TRACKERS = ["langfuse", "local"]
SUPPORTED_TRACKERS_TYPE = Literal["langfuse", "local"]

SUPPORTED_TRACKERS = ["agentops", "langfuse", "local"]
SUPPORTED_TRACKERS_TYPE = Literal["agentops", "langfuse", "local"]


def get_callback_handler():
Expand All @@ -27,16 +30,24 @@ def init_tracking(tracker_name: SUPPORTED_TRACKERS_TYPE, **kwargs):
if tracker_name not in SUPPORTED_TRACKERS:
raise ValueError(f"Invalid tracker name: {tracker_name}")

logger.info(f"Initializing tracking with {tracker_name}...")

use_langchain_types()

if tracker_name == "local":
from .local_callback_handler import LocalCallbackHandler

set_callback_handler(lambda: LocalCallbackHandler(**kwargs))
handler_class = LocalCallbackHandler
elif tracker_name == "agentops":
from agentops import LangchainCallbackHandler

handler_class = LangchainCallbackHandler
elif tracker_name == "langfuse":
from langfuse.callback import CallbackHandler as LangfuseCallbackHandler
from langfuse.callback import CallbackHandler

handler_class = CallbackHandler

set_callback_handler(lambda: LangfuseCallbackHandler(**kwargs))
set_callback_handler(lambda: handler_class(**kwargs))


__all__ = [
Expand Down
5 changes: 0 additions & 5 deletions agency_swarm/util/tracking/langchain_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,6 @@ def use_langchain_types() -> None:
from langchain.schema import AgentFinish as LangchainAgentFinish
from langchain.schema import HumanMessage as LangchainHumanMessage

# Call model_rebuild on these imported classes to resolve forward references
LangchainAgentAction.model_rebuild()
LangchainAgentFinish.model_rebuild()
LangchainHumanMessage.model_rebuild()

AgentAction.set_implementation(LangchainAgentAction)
AgentFinish.set_implementation(LangchainAgentFinish)
HumanMessage.set_implementation(LangchainHumanMessage)
2 changes: 1 addition & 1 deletion agency_swarm/util/tracking/local_callback_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ def on_retriever_end(
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
docs_json = json.dumps([doc.dict() for doc in documents])
docs_json = json.dumps([doc.model_dump() for doc in documents])
self._update_event(
run_id, set_end_time=True, parent_run_id=parent_run_id, documents=docs_json
)
Expand Down
Loading

0 comments on commit a0e926a

Please sign in to comment.