diff --git a/backend/src/agents/web_agent.py b/backend/src/agents/web_agent.py index b330baec..e895a68a 100644 --- a/backend/src/agents/web_agent.py +++ b/backend/src/agents/web_agent.py @@ -20,7 +20,7 @@ async def web_general_search_core(search_query, llm, model) -> str: try: - search_result = perform_search(search_query, num_results=15) + search_result = await perform_search(search_query, num_results=15) if search_result["status"] == "error": return "No relevant information found on the internet for the given query." urls = search_result["urls"] @@ -32,7 +32,8 @@ async def web_general_search_core(search_query, llm, model) -> str: summary = await perform_summarization(search_query, content, llm, model) if not summary: continue - if await is_valid_answer(summary, search_query): + is_valid = await is_valid_answer(summary, search_query) + if is_valid: response = { "content": summary, "ignore_validation": "false" @@ -109,9 +110,9 @@ async def is_valid_answer(answer, task) -> bool: return is_valid -def perform_search(search_query: str, num_results: int) -> Dict[str, Any]: +async def perform_search(search_query: str, num_results: int) -> Dict[str, Any]: try: - search_result_json = search_urls(search_query, num_results=num_results) + search_result_json = await search_urls(search_query, num_results=num_results) return json.loads(search_result_json) except Exception as e: logger.error(f"Error during web search: {e}") diff --git a/backend/src/utils/web_utils.py b/backend/src/utils/web_utils.py index 2edaed60..90082be7 100644 --- a/backend/src/utils/web_utils.py +++ b/backend/src/utils/web_utils.py @@ -13,7 +13,7 @@ engine = PromptEngine() -def search_urls(search_query, num_results=10) -> str: +async def search_urls(search_query, num_results=10) -> str: logger.info(f"Searching the web for: {search_query}") urls = [] try: diff --git a/backend/tests/agents/web_agent_test.py b/backend/tests/agents/web_agent_test.py index b50dc36e..c18d2070 100644 --- a/backend/tests/agents/web_agent_test.py +++ b/backend/tests/agents/web_agent_test.py @@ -1,76 +1,67 @@ -import unittest -from unittest.mock import AsyncMock, patch - import pytest +from unittest.mock import patch, AsyncMock +import json from src.agents.web_agent import web_general_search_core +@pytest.mark.asyncio +@patch("src.agents.web_agent.perform_search", new_callable=AsyncMock) +@patch("src.agents.web_agent.perform_scrape", new_callable=AsyncMock) +@patch("src.agents.web_agent.perform_summarization", new_callable=AsyncMock) +@patch("src.agents.web_agent.is_valid_answer", new_callable=AsyncMock) +async def test_web_general_search_core( + mock_is_valid_answer, + mock_perform_summarization, + mock_perform_scrape, + mock_perform_search, +): + llm = AsyncMock() + model = "mock_model" -class TestWebAgentCore(unittest.TestCase): - def setUp(self): - self.llm = AsyncMock() - self.model = "mock_model" - - @patch("src.agents.web_agent.perform_search") - @patch("src.agents.web_agent.perform_scrape") - @patch("src.agents.web_agent.perform_summarization") - @patch("src.agents.web_agent.is_valid_answer") - @pytest.mark.asyncio - async def test_web_general_search_core( - self, mock_is_valid_answer, mock_perform_summarization, mock_perform_scrape, mock_perform_search - ): - mock_perform_search.return_value = {"status": "success", "urls": ["http://example.com"]} - mock_perform_scrape.return_value = "Example scraped content." - mock_perform_summarization.return_value = "Example summary." - mock_is_valid_answer.return_value = True - - result = await web_general_search_core("example query", self.llm, self.model) - self.assertEqual(result, "Example summary.") - mock_perform_search.assert_called_once_with("example query", num_results=15) - mock_perform_scrape.assert_called_once_with("http://example.com") - mock_perform_summarization.assert_called_once_with( - "example query", "Example scraped content.", self.llm, self.model - ) - mock_is_valid_answer.assert_called_once_with("Example summary.", "example query") - - @patch("src.agents.web_agent.perform_search") - @patch("src.agents.web_agent.perform_scrape") - @patch("src.agents.web_agent.perform_summarization") - @patch("src.agents.web_agent.is_valid_answer") - @pytest.mark.asyncio - async def test_web_general_search_core_no_results( - self, mock_is_valid_answer, mock_perform_summarization, mock_perform_scrape, mock_perform_search - ): - mock_perform_search.return_value = {"status": "error", "urls": []} - - result = await web_general_search_core("example query", self.llm, self.model) - self.assertEqual(result, "No relevant information found on the internet for the given query.") - mock_perform_search.assert_called_once_with("example query", num_results=15) - mock_perform_scrape.assert_not_called() - mock_perform_summarization.assert_not_called() - mock_is_valid_answer.assert_not_called() - - @patch("src.agents.web_agent.perform_search") - @patch("src.agents.web_agent.perform_scrape") - @patch("src.agents.web_agent.perform_summarization") - @patch("src.agents.web_agent.is_valid_answer") - @pytest.mark.asyncio - async def test_web_general_search_core_invalid_summary( - self, mock_is_valid_answer, mock_perform_summarization, mock_perform_scrape, mock_perform_search - ): - mock_perform_search.return_value = {"status": "success", "urls": ["http://example.com"]} - mock_perform_scrape.return_value = "Example scraped content." - mock_perform_summarization.return_value = "Example invalid summary." - mock_is_valid_answer.return_value = False + mock_perform_search.return_value = {"status": "success", "urls": ["http://example.com"]} + mock_perform_scrape.return_value = "Example scraped content." + mock_perform_summarization.return_value = "Example summary." + mock_is_valid_answer.return_value = True + result = await web_general_search_core("example query", llm, model) + expected_response = { + "content": "Example summary.", + "ignore_validation": "false" + } + assert json.loads(result) == expected_response - result = await web_general_search_core("example query", self.llm, self.model) - self.assertEqual(result, "No relevant information found on the internet for the given query.") - mock_perform_search.assert_called_once_with("example query", num_results=15) - mock_perform_scrape.assert_called_once_with("http://example.com") - mock_perform_summarization.assert_called_once_with( - "example query", "Example scraped content.", self.llm, self.model - ) - mock_is_valid_answer.assert_called_once_with("Example invalid summary.", "example query") +@pytest.mark.asyncio +@patch("src.agents.web_agent.perform_search", new_callable=AsyncMock) +@patch("src.agents.web_agent.perform_scrape", new_callable=AsyncMock) +@patch("src.agents.web_agent.perform_summarization", new_callable=AsyncMock) +@patch("src.agents.web_agent.is_valid_answer", new_callable=AsyncMock) +async def test_web_general_search_core_no_results( + mock_is_valid_answer, + mock_perform_summarization, + mock_perform_scrape, + mock_perform_search, +): + llm = AsyncMock() + model = "mock_model" + mock_perform_search.return_value = {"status": "error", "urls": []} + result = await web_general_search_core("example query", llm, model) + assert result == "No relevant information found on the internet for the given query." -if __name__ == "__main__": - unittest.main() +@pytest.mark.asyncio +@patch("src.agents.web_agent.perform_search", new_callable=AsyncMock) +@patch("src.agents.web_agent.perform_scrape", new_callable=AsyncMock) +@patch("src.agents.web_agent.perform_summarization", new_callable=AsyncMock) +@patch("src.agents.web_agent.is_valid_answer", new_callable=AsyncMock) +async def test_web_general_search_core_invalid_summary( + mock_is_valid_answer, + mock_perform_summarization, + mock_perform_scrape, + mock_perform_search +): + llm = AsyncMock() + model = "mock_model" + mock_perform_search.return_value = {"status": "success", "urls": ["http://example.com"]} + mock_perform_scrape.return_value = "Example scraped content." + mock_perform_summarization.return_value = "Example invalid summary." + mock_is_valid_answer.return_value = False + result = await web_general_search_core("example query", llm, model) + assert result == "No relevant information found on the internet for the given query."