Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
WonderPG committed Jan 10, 2025
1 parent 4d9549f commit f3f179a
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 28 deletions.
7 changes: 4 additions & 3 deletions src/neuroagent/agent_routine.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ async def get_chat_completion(
"stream": stream,
}
if stream:
create_params["stream_options"] = ({"include_usage": True},)
create_params["stream_options"] = {"include_usage": True}

if tools:
create_params["parallel_tool_calls"] = agent.parallel_tool_calls
Expand Down Expand Up @@ -224,6 +224,7 @@ async def arun(
model_override=model_override,
stream=False,
)

message = completion.choices[0].message # type: ignore
message.sender = active_agent.name

Expand Down Expand Up @@ -269,7 +270,7 @@ async def arun(
# If the tool call response contains HIL validation, do not update anything and return
if partial_response.hil_messages:
return Response(
messages=history[init_len:],
messages=[],
agent=active_agent,
context_variables=context_variables,
hil_messages=partial_response.hil_messages,
Expand Down Expand Up @@ -424,7 +425,7 @@ async def astream(
# If the tool call response contains HIL validation, do not update anything and return
if partial_response.hil_messages:
yield Response(
messages=history[init_len:],
messages=[],
agent=active_agent,
context_variables=context_variables,
hil_messages=partial_response.hil_messages,
Expand Down
165 changes: 165 additions & 0 deletions tests/app/routers/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,168 @@ async def test_get_tool_output(
tool_output = app_client.get(f"/tools/output/{thread_id}/{tool_call_id}")

assert tool_output.json() == [json.dumps({"assistant": agent_2.name})]


@pytest.mark.asyncio
async def test_get_required_validation(
patch_required_env,
app_client,
httpx_mock,
db_connection,
mock_openai_client,
agent_handoff_tool,
):
routine = AgentsRoutine(client=mock_openai_client)

mock_openai_client.set_sequential_responses(
[
create_mock_response(
message={"role": "assistant", "content": ""},
function_calls=[{"name": "agent_handoff_tool", "args": {}}],
),
create_mock_response(
{"role": "assistant", "content": "sample response content"}
),
]
)
agent_handoff_tool.hil = True
agent_1 = Agent(name="Test agent 1", tools=[agent_handoff_tool])
agent_2 = Agent(name="Test agent 2", tools=[])

app.dependency_overrides[get_agents_routine] = lambda: routine
app.dependency_overrides[get_starting_agent] = lambda: agent_1
app.dependency_overrides[get_context_variables] = lambda: {"to_agent": agent_2}
test_settings = Settings(
db={"prefix": db_connection},
)
app.dependency_overrides[get_settings] = lambda: test_settings
httpx_mock.add_response(
url=f"{test_settings.virtual_lab.get_project_url}/test_vlab/projects/test_project"
)

with app_client as app_client:
wrong_response = app_client.get("/tools/validation/test/")
assert wrong_response.status_code == 404
assert wrong_response.json() == {"detail": {"detail": "Thread not found."}}

# Create a thread
create_output = app_client.post(
"/threads/?virtual_lab_id=test_vlab&project_id=test_project"
).json()
thread_id = create_output["thread_id"]
validation_list = app_client.get(f"/tools/validation/{thread_id}/")
assert validation_list.json() == []
# Fill the thread
app_client.post(
f"/qa/chat/{thread_id}",
json={"query": "This is my query"},
params={"thread_id": thread_id},
headers={"x-virtual-lab-id": "test_vlab", "x-project-id": "test_project"},
)

validation_list = app_client.get(f"/tools/validation/{thread_id}/")
assert validation_list.status_code == 200
assert validation_list.json() == [
{
"message": "Please validate the following inputs before proceeding.",
"name": "agent_handoff_tool",
"inputs": {},
"tool_call_id": "mock_tc_id",
}
]

# Validate the tool call
app_client.patch(
f"/tools/validation/{thread_id}/{validation_list.json()[0]['tool_call_id']}",
json={"is_validated": True},
)
# Validation list should now be empty
validation_list = app_client.get(f"/tools/validation/{thread_id}/")
assert validation_list.json() == []


async def test_validate_input(
patch_required_env,
app_client,
httpx_mock,
db_connection,
mock_openai_client,
agent_handoff_tool,
):
routine = AgentsRoutine(client=mock_openai_client)

mock_openai_client.set_sequential_responses(
[
create_mock_response(
message={"role": "assistant", "content": ""},
function_calls=[{"name": "agent_handoff_tool", "args": {}}],
),
create_mock_response(
{"role": "assistant", "content": "sample response content"}
),
]
)
agent_handoff_tool.hil = True
agent_1 = Agent(name="Test agent 1", tools=[agent_handoff_tool])
agent_2 = Agent(name="Test agent 2", tools=[])

app.dependency_overrides[get_agents_routine] = lambda: routine
app.dependency_overrides[get_starting_agent] = lambda: agent_1
app.dependency_overrides[get_context_variables] = lambda: {"to_agent": agent_2}
test_settings = Settings(
db={"prefix": db_connection},
)
app.dependency_overrides[get_settings] = lambda: test_settings
httpx_mock.add_response(
url=f"{test_settings.virtual_lab.get_project_url}/test_vlab/projects/test_project"
)

with app_client as app_client:
wrong_response = app_client.get("/tools/validation/test/123")
assert wrong_response.status_code == 404
assert wrong_response.json() == {"detail": {"detail": "Thread not found."}}

# Create a thread
create_output = app_client.post(
"/threads/?virtual_lab_id=test_vlab&project_id=test_project"
).json()
thread_id = create_output["thread_id"]

# Fill the thread
response = app_client.post(
f"/qa/chat/{thread_id}",
json={"query": "This is my query"},
params={"thread_id": thread_id},
headers={"x-virtual-lab-id": "test_vlab", "x-project-id": "test_project"},
)

assert response.status_code == 200
assert response.json() == [
{
"message": "Please validate the following inputs before proceeding.",
"name": "agent_handoff_tool",
"inputs": {},
"tool_call_id": "mock_tc_id",
}
]

# Validate the tool call
to_validate_list = response.json()
validated = app_client.patch(
f"/tools/validation/{thread_id}/{to_validate_list[0]['tool_call_id']}",
json={"is_validated": True},
)
assert validated.status_code == 200
assert validated.json() == {
"tool_call_id": to_validate_list[0]["tool_call_id"],
"name": to_validate_list[0]["name"],
"arguments": to_validate_list[0]["tool_call_id"],
}

# Check that is has been validated and cannot be validated anymore
validated = app_client.patch(
f"/tools/validation/{thread_id}/{to_validate_list[0]['tool_call_id']}",
json={"is_validated": True},
)
assert validated.status_code == 403
assert validated.content == b"The tool call has already been validated."
59 changes: 34 additions & 25 deletions tests/test_agent_routine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from neuroagent.agent_routine import AgentsRoutine
from neuroagent.app.database.sql_schemas import Entity, Messages, ToolCalls
from neuroagent.new_types import Agent, Response, Result
from neuroagent.new_types import Agent, HILResponse, Response, Result
from tests.mock_client import create_mock_response


Expand All @@ -39,7 +39,6 @@ async def test_get_chat_completion_simple_message(self, mock_openai_client):
"tools": None,
"tool_choice": None,
"stream": False,
"stream_options": {"include_usage": True},
}
)

Expand Down Expand Up @@ -75,7 +74,6 @@ def agent_instruction(context_variables):
"tools": None,
"tool_choice": None,
"stream": False,
"stream_options": {"include_usage": True},
}
)

Expand Down Expand Up @@ -124,7 +122,6 @@ async def test_get_chat_completion_tools(
"tool_choice": None,
"stream": False,
"parallel_tool_calls": True,
"stream_options": {"include_usage": True},
}
)

Expand Down Expand Up @@ -445,8 +442,11 @@ async def test_handle_tool_call_handoff(
)

@pytest.mark.asyncio
async def test_arun(self, mock_openai_client, get_weather_tool, agent_handoff_tool):
agent_1 = Agent(name="Test Agent", tools=[agent_handoff_tool])
@pytest.mark.parametrize("hil", [True, False])
async def test_arun(
self, hil, mock_openai_client, get_weather_tool, agent_handoff_tool
):
agent_1 = Agent(name="Test Agent", tools=[agent_handoff_tool, get_weather_tool])
agent_2 = Agent(name="Test Agent", tools=[get_weather_tool])
messages = [
Messages(
Expand All @@ -466,16 +466,14 @@ async def test_arun(self, mock_openai_client, get_weather_tool, agent_handoff_to
]
context_variables = {"to_agent": agent_2, "planet": "Mars"}
# set mock to return a response that triggers function call
get_weather_tool.hil = hil
mock_openai_client.set_sequential_responses(
[
create_mock_response(
message={"role": "assistant", "content": ""},
function_calls=[{"name": "agent_handoff_tool", "args": {}}],
),
create_mock_response(
message={"role": "assistant", "content": ""},
function_calls=[
{"name": "get_weather", "args": {"location": "Montreux"}}
{"name": "agent_handoff_tool", "args": {}},
{"name": "get_weather", "args": {"location": "Montreux"}},
],
),
create_mock_response(
Expand All @@ -484,25 +482,36 @@ async def test_arun(self, mock_openai_client, get_weather_tool, agent_handoff_to
]
)

# set up client and run
# Set up client and run
client = AgentsRoutine(client=mock_openai_client)
response = await client.arun(
agent=agent_1, messages=messages, context_variables=context_variables
)
if hil:
assert response.messages == []
assert response.hil_messages == [
HILResponse(
message="Please validate the following inputs before proceeding.",
name="get_weather",
inputs={"location": "Montreux"},
tool_call_id="mock_tc_id",
)
]

assert response.messages[2]["role"] == "tool"
assert response.messages[2]["content"] == json.dumps(
{"assistant": agent_1.name}
)
assert response.messages[-2]["role"] == "tool"
assert (
response.messages[-2]["content"]
== "It's sunny today in Montreux from planet Mars."
)
assert response.messages[-1]["role"] == "assistant"
assert response.messages[-1]["content"] == "sample response content"
assert response.agent == agent_2
assert response.context_variables == context_variables
else:
assert response.messages[2]["role"] == "tool"
assert response.messages[2]["content"] == json.dumps(
{"assistant": agent_1.name}
)
assert response.messages[-2]["role"] == "tool"
assert (
response.messages[-2]["content"]
== "It's sunny today in Montreux from planet Mars."
)
assert response.messages[-1]["role"] == "assistant"
assert response.messages[-1]["content"] == "sample response content"
assert response.agent == agent_2
assert response.context_variables == context_variables

@pytest.mark.asyncio
async def test_astream(
Expand Down

0 comments on commit f3f179a

Please sign in to comment.