From 4d9549f1336012a29b693824b1244a4dda6f6460 Mon Sep 17 00:00:00 2001 From: Nicolas Frank Date: Fri, 10 Jan 2025 14:57:18 +0100 Subject: [PATCH] Add get validation endpoint --- src/neuroagent/agent_routine.py | 4 ++- src/neuroagent/app/routers/tools.py | 44 +++++++++++++++++++++++++++-- src/neuroagent/new_types.py | 1 + 3 files changed, 46 insertions(+), 3 deletions(-) diff --git a/src/neuroagent/agent_routine.py b/src/neuroagent/agent_routine.py index b5f43b4..6334a98 100644 --- a/src/neuroagent/agent_routine.py +++ b/src/neuroagent/agent_routine.py @@ -53,9 +53,10 @@ async def get_chat_completion( "messages": messages, "tools": tools or None, "tool_choice": agent.tool_choice, - "stream_options": {"include_usage": True}, "stream": stream, } + if stream: + create_params["stream_options"] = ({"include_usage": True},) if tools: create_params["parallel_tool_calls"] = agent.parallel_tool_calls @@ -156,6 +157,7 @@ async def handle_tool_call( if tool_call.validated is None: return HILResponse( message="Please validate the following inputs before proceeding.", + name=tool_call.name, inputs=input_schema.model_dump(), tool_call_id=tool_call.tool_call_id, ), None diff --git a/src/neuroagent/app/routers/tools.py b/src/neuroagent/app/routers/tools.py index f43fa76..b0ddf39 100644 --- a/src/neuroagent/app/routers/tools.py +++ b/src/neuroagent/app/routers/tools.py @@ -13,7 +13,7 @@ from neuroagent.app.database.schemas import ToolCallSchema from neuroagent.app.database.sql_schemas import Entity, Messages, Threads, ToolCalls from neuroagent.app.dependencies import get_session, get_starting_agent -from neuroagent.new_types import Agent, HILValidation +from neuroagent.new_types import Agent, HILResponse, HILValidation logger = logging.getLogger(__name__) @@ -113,7 +113,47 @@ async def get_tool_returns( return tool_output -@router.patch("/validate/{thread_id}/{tool_call_id}") +@router.get("/validation/{thread_id}/") +async def get_required_validation( + _: Annotated[Threads, Depends(get_thread)], + thread_id: str, + session: Annotated[AsyncSession, Depends(get_session)], + starting_agent: Annotated[Agent, Depends(get_starting_agent)], +) -> list[HILResponse]: + """List tool calls currently requiring validation in a thread.""" + message_query = await session.execute( + select(Messages) + .where(Messages.thread_id == thread_id) + .order_by(desc(Messages.order)) + .limit(1) + ) + message = message_query.scalar_one_or_none() + if not message or message.entity != Entity.AI_TOOL: + return [] + + else: + tool_calls = await message.awaitable_attrs.tool_calls + need_validation = [] + for tool_call in tool_calls: + tool = next( + tool for tool in starting_agent.tools if tool.name == tool_call.name + ) + if tool.hil and tool_call.validated is None: + input_schema = tool.__annotations__["input_schema"]( + **json.loads(tool_call.arguments) + ) + need_validation.append( + HILResponse( + message="Please validate the following inputs before proceeding.", + name=tool_call.name, + inputs=input_schema.model_dump(), + tool_call_id=tool_call.tool_call_id, + ) + ) + return need_validation + + +@router.patch("/validation/{thread_id}/{tool_call_id}") async def validate_input( user_request: HILValidation, _: Annotated[Threads, Depends(get_thread)], diff --git a/src/neuroagent/new_types.py b/src/neuroagent/new_types.py index f5ec3c4..2a22a46 100644 --- a/src/neuroagent/new_types.py +++ b/src/neuroagent/new_types.py @@ -23,6 +23,7 @@ class HILResponse(BaseModel): """Response for tools that require HIL validation.""" message: str + name: str inputs: dict[str, Any] tool_call_id: str