diff --git a/src/neuroagent/agent_routine.py b/src/neuroagent/agent_routine.py index 6334a98..c123201 100644 --- a/src/neuroagent/agent_routine.py +++ b/src/neuroagent/agent_routine.py @@ -56,7 +56,7 @@ async def get_chat_completion( "stream": stream, } if stream: - create_params["stream_options"] = ({"include_usage": True},) + create_params["stream_options"] = {"include_usage": True} if tools: create_params["parallel_tool_calls"] = agent.parallel_tool_calls @@ -224,6 +224,7 @@ async def arun( model_override=model_override, stream=False, ) + message = completion.choices[0].message # type: ignore message.sender = active_agent.name @@ -269,7 +270,7 @@ async def arun( # If the tool call response contains HIL validation, do not update anything and return if partial_response.hil_messages: return Response( - messages=history[init_len:], + messages=[], agent=active_agent, context_variables=context_variables, hil_messages=partial_response.hil_messages, @@ -424,7 +425,7 @@ async def astream( # If the tool call response contains HIL validation, do not update anything and return if partial_response.hil_messages: yield Response( - messages=history[init_len:], + messages=[], agent=active_agent, context_variables=context_variables, hil_messages=partial_response.hil_messages, diff --git a/tests/app/routers/test_tools.py b/tests/app/routers/test_tools.py index b5ed083..7826b97 100644 --- a/tests/app/routers/test_tools.py +++ b/tests/app/routers/test_tools.py @@ -161,3 +161,168 @@ async def test_get_tool_output( tool_output = app_client.get(f"/tools/output/{thread_id}/{tool_call_id}") assert tool_output.json() == [json.dumps({"assistant": agent_2.name})] + + +@pytest.mark.asyncio +async def test_get_required_validation( + patch_required_env, + app_client, + httpx_mock, + db_connection, + mock_openai_client, + agent_handoff_tool, +): + routine = AgentsRoutine(client=mock_openai_client) + + mock_openai_client.set_sequential_responses( + [ + create_mock_response( + message={"role": "assistant", "content": ""}, + function_calls=[{"name": "agent_handoff_tool", "args": {}}], + ), + create_mock_response( + {"role": "assistant", "content": "sample response content"} + ), + ] + ) + agent_handoff_tool.hil = True + agent_1 = Agent(name="Test agent 1", tools=[agent_handoff_tool]) + agent_2 = Agent(name="Test agent 2", tools=[]) + + app.dependency_overrides[get_agents_routine] = lambda: routine + app.dependency_overrides[get_starting_agent] = lambda: agent_1 + app.dependency_overrides[get_context_variables] = lambda: {"to_agent": agent_2} + test_settings = Settings( + db={"prefix": db_connection}, + ) + app.dependency_overrides[get_settings] = lambda: test_settings + httpx_mock.add_response( + url=f"{test_settings.virtual_lab.get_project_url}/test_vlab/projects/test_project" + ) + + with app_client as app_client: + wrong_response = app_client.get("/tools/validation/test/") + assert wrong_response.status_code == 404 + assert wrong_response.json() == {"detail": {"detail": "Thread not found."}} + + # Create a thread + create_output = app_client.post( + "/threads/?virtual_lab_id=test_vlab&project_id=test_project" + ).json() + thread_id = create_output["thread_id"] + validation_list = app_client.get(f"/tools/validation/{thread_id}/") + assert validation_list.json() == [] + # Fill the thread + app_client.post( + f"/qa/chat/{thread_id}", + json={"query": "This is my query"}, + params={"thread_id": thread_id}, + headers={"x-virtual-lab-id": "test_vlab", "x-project-id": "test_project"}, + ) + + validation_list = app_client.get(f"/tools/validation/{thread_id}/") + assert validation_list.status_code == 200 + assert validation_list.json() == [ + { + "message": "Please validate the following inputs before proceeding.", + "name": "agent_handoff_tool", + "inputs": {}, + "tool_call_id": "mock_tc_id", + } + ] + + # Validate the tool call + app_client.patch( + f"/tools/validation/{thread_id}/{validation_list.json()[0]['tool_call_id']}", + json={"is_validated": True}, + ) + # Validation list should now be empty + validation_list = app_client.get(f"/tools/validation/{thread_id}/") + assert validation_list.json() == [] + + +async def test_validate_input( + patch_required_env, + app_client, + httpx_mock, + db_connection, + mock_openai_client, + agent_handoff_tool, +): + routine = AgentsRoutine(client=mock_openai_client) + + mock_openai_client.set_sequential_responses( + [ + create_mock_response( + message={"role": "assistant", "content": ""}, + function_calls=[{"name": "agent_handoff_tool", "args": {}}], + ), + create_mock_response( + {"role": "assistant", "content": "sample response content"} + ), + ] + ) + agent_handoff_tool.hil = True + agent_1 = Agent(name="Test agent 1", tools=[agent_handoff_tool]) + agent_2 = Agent(name="Test agent 2", tools=[]) + + app.dependency_overrides[get_agents_routine] = lambda: routine + app.dependency_overrides[get_starting_agent] = lambda: agent_1 + app.dependency_overrides[get_context_variables] = lambda: {"to_agent": agent_2} + test_settings = Settings( + db={"prefix": db_connection}, + ) + app.dependency_overrides[get_settings] = lambda: test_settings + httpx_mock.add_response( + url=f"{test_settings.virtual_lab.get_project_url}/test_vlab/projects/test_project" + ) + + with app_client as app_client: + wrong_response = app_client.get("/tools/validation/test/123") + assert wrong_response.status_code == 404 + assert wrong_response.json() == {"detail": {"detail": "Thread not found."}} + + # Create a thread + create_output = app_client.post( + "/threads/?virtual_lab_id=test_vlab&project_id=test_project" + ).json() + thread_id = create_output["thread_id"] + + # Fill the thread + response = app_client.post( + f"/qa/chat/{thread_id}", + json={"query": "This is my query"}, + params={"thread_id": thread_id}, + headers={"x-virtual-lab-id": "test_vlab", "x-project-id": "test_project"}, + ) + + assert response.status_code == 200 + assert response.json() == [ + { + "message": "Please validate the following inputs before proceeding.", + "name": "agent_handoff_tool", + "inputs": {}, + "tool_call_id": "mock_tc_id", + } + ] + + # Validate the tool call + to_validate_list = response.json() + validated = app_client.patch( + f"/tools/validation/{thread_id}/{to_validate_list[0]['tool_call_id']}", + json={"is_validated": True}, + ) + assert validated.status_code == 200 + assert validated.json() == { + "tool_call_id": to_validate_list[0]["tool_call_id"], + "name": to_validate_list[0]["name"], + "arguments": to_validate_list[0]["tool_call_id"], + } + + # Check that is has been validated and cannot be validated anymore + validated = app_client.patch( + f"/tools/validation/{thread_id}/{to_validate_list[0]['tool_call_id']}", + json={"is_validated": True}, + ) + assert validated.status_code == 403 + assert validated.content == b"The tool call has already been validated." diff --git a/tests/test_agent_routine.py b/tests/test_agent_routine.py index 1886531..fe036b6 100644 --- a/tests/test_agent_routine.py +++ b/tests/test_agent_routine.py @@ -13,7 +13,7 @@ from neuroagent.agent_routine import AgentsRoutine from neuroagent.app.database.sql_schemas import Entity, Messages, ToolCalls -from neuroagent.new_types import Agent, Response, Result +from neuroagent.new_types import Agent, HILResponse, Response, Result from tests.mock_client import create_mock_response @@ -39,7 +39,6 @@ async def test_get_chat_completion_simple_message(self, mock_openai_client): "tools": None, "tool_choice": None, "stream": False, - "stream_options": {"include_usage": True}, } ) @@ -75,7 +74,6 @@ def agent_instruction(context_variables): "tools": None, "tool_choice": None, "stream": False, - "stream_options": {"include_usage": True}, } ) @@ -124,7 +122,6 @@ async def test_get_chat_completion_tools( "tool_choice": None, "stream": False, "parallel_tool_calls": True, - "stream_options": {"include_usage": True}, } ) @@ -445,8 +442,11 @@ async def test_handle_tool_call_handoff( ) @pytest.mark.asyncio - async def test_arun(self, mock_openai_client, get_weather_tool, agent_handoff_tool): - agent_1 = Agent(name="Test Agent", tools=[agent_handoff_tool]) + @pytest.mark.parametrize("hil", [True, False]) + async def test_arun( + self, hil, mock_openai_client, get_weather_tool, agent_handoff_tool + ): + agent_1 = Agent(name="Test Agent", tools=[agent_handoff_tool, get_weather_tool]) agent_2 = Agent(name="Test Agent", tools=[get_weather_tool]) messages = [ Messages( @@ -466,16 +466,14 @@ async def test_arun(self, mock_openai_client, get_weather_tool, agent_handoff_to ] context_variables = {"to_agent": agent_2, "planet": "Mars"} # set mock to return a response that triggers function call + get_weather_tool.hil = hil mock_openai_client.set_sequential_responses( [ - create_mock_response( - message={"role": "assistant", "content": ""}, - function_calls=[{"name": "agent_handoff_tool", "args": {}}], - ), create_mock_response( message={"role": "assistant", "content": ""}, function_calls=[ - {"name": "get_weather", "args": {"location": "Montreux"}} + {"name": "agent_handoff_tool", "args": {}}, + {"name": "get_weather", "args": {"location": "Montreux"}}, ], ), create_mock_response( @@ -484,25 +482,36 @@ async def test_arun(self, mock_openai_client, get_weather_tool, agent_handoff_to ] ) - # set up client and run + # Set up client and run client = AgentsRoutine(client=mock_openai_client) response = await client.arun( agent=agent_1, messages=messages, context_variables=context_variables ) + if hil: + assert response.messages == [] + assert response.hil_messages == [ + HILResponse( + message="Please validate the following inputs before proceeding.", + name="get_weather", + inputs={"location": "Montreux"}, + tool_call_id="mock_tc_id", + ) + ] - assert response.messages[2]["role"] == "tool" - assert response.messages[2]["content"] == json.dumps( - {"assistant": agent_1.name} - ) - assert response.messages[-2]["role"] == "tool" - assert ( - response.messages[-2]["content"] - == "It's sunny today in Montreux from planet Mars." - ) - assert response.messages[-1]["role"] == "assistant" - assert response.messages[-1]["content"] == "sample response content" - assert response.agent == agent_2 - assert response.context_variables == context_variables + else: + assert response.messages[2]["role"] == "tool" + assert response.messages[2]["content"] == json.dumps( + {"assistant": agent_1.name} + ) + assert response.messages[-2]["role"] == "tool" + assert ( + response.messages[-2]["content"] + == "It's sunny today in Montreux from planet Mars." + ) + assert response.messages[-1]["role"] == "assistant" + assert response.messages[-1]["content"] == "sample response content" + assert response.agent == agent_2 + assert response.context_variables == context_variables @pytest.mark.asyncio async def test_astream(