Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/create file agent #94

Merged
merged 5 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ NEO4J_URI=bolt://localhost:7687
NEO4J_HTTP_PORT=7474
NEO4J_BOLT_PORT=7687

# files location
FILES_DIRECTORY=files

# backend LLM properties
MISTRAL_KEY=my-api-key

Expand Down Expand Up @@ -42,6 +45,7 @@ MATHS_AGENT_LLM="openai"
WEB_AGENT_LLM="openai"
CHART_GENERATOR_LLM="openai"
ROUTER_LLM="openai"
FILE_AGENT_LLM="openai"

# llm model
ANSWER_AGENT_MODEL="gpt-4o mini"
Expand All @@ -52,3 +56,4 @@ MATHS_AGENT_MODEL="gpt-4o mini"
WEB_AGENT_MODEL="gpt-4o mini"
CHART_GENERATOR_MODEL="gpt-4o mini"
ROUTER_MODEL="gpt-4o mini"
FILE_AGENT_MODEL="gpt-4o mini"
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ celerybeat.pid
# Environments
.env
.venv
files
env/
venv/
ENV/
Expand Down
2 changes: 2 additions & 0 deletions backend/src/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .validator_agent import ValidatorAgent
from .answer_agent import AnswerAgent
from .chart_generator_agent import ChartGeneratorAgent
from .file_agent import FileAgent

config = Config()

Expand All @@ -32,6 +33,7 @@ def get_available_agents() -> List[Agent]:
return [DatastoreAgent(config.datastore_agent_llm, config.datastore_agent_model),
WebAgent(config.web_agent_llm, config.web_agent_model),
ChartGeneratorAgent(config.chart_generator_llm, config.chart_generator_model),
FileAgent(config.file_agent_llm, config.file_agent_model),
]


Expand Down
92 changes: 92 additions & 0 deletions backend/src/agents/file_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import logging
from .agent_types import Parameter
from .agent import Agent, agent
from .tool import tool
import json
import os
from src.utils.config import Config

logger = logging.getLogger(__name__)
config = Config()

FILES_DIRECTORY = f"/app/{config.files_directory}"

# Constants for response status
IGNORE_VALIDATION = "true"
STATUS_SUCCESS = "success"
STATUS_ERROR = "error"

# Utility function for error responses
def create_response(content: str, status: str = STATUS_SUCCESS) -> str:
return json.dumps({
"content": content,
"ignore_validation": IGNORE_VALIDATION,
"status": status
}, indent=4)

async def read_file_core(file_path: str) -> str:
full_path = os.path.normpath(os.path.join(FILES_DIRECTORY, file_path))
try:
with open(full_path, 'r') as file:
content = file.read()
return create_response(content)
except FileNotFoundError:
error_message = f"File {file_path} not found."
logger.error(error_message)
return create_response(error_message, STATUS_ERROR)
except Exception as e:
logger.error(f"Error reading file {full_path}: {e}")
return create_response(f"Error reading file: {file_path}", STATUS_ERROR)


async def write_file_core(file_path: str, content: str) -> str:
full_path = os.path.normpath(os.path.join(FILES_DIRECTORY, file_path))
try:
with open(full_path, 'w') as file:
file.write(content)
logger.info(f"Content written to file {full_path} successfully.")
return create_response(f"Content written to file {file_path}.")
except Exception as e:
logger.error(f"Error writing to file {full_path}: {e}")
return create_response(f"Error writing to file: {file_path}", STATUS_ERROR)


@tool(
name="read_file",
description="Read the content of a text file.",
parameters={
"file_path": Parameter(
type="string",
description="The path to the file to be read."
),
},
)
async def read_file(file_path: str, llm, model) -> str:
return await read_file_core(file_path)


@tool(
name="write_file",
description="Write content to a text file.",
parameters={
"file_path": Parameter(
type="string",
description="The path to the file where the content will be written."
),
"content": Parameter(
type="string",
description="The content to write to the file."
),
},
)
async def write_file(file_path: str, content: str, llm, model) -> str:
return await write_file_core(file_path, content)


@agent(
name="FileAgent",
description="This agent is responsible for reading from and writing to files.",
tools=[read_file, write_file],
)
class FileAgent(Agent):
pass
9 changes: 6 additions & 3 deletions backend/src/agents/web_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@ async def web_general_search_core(search_query, llm, model) -> str:
content = await perform_scrape(url)
if not content:
continue
summary = await perform_summarization(search_query, content, llm, model)
if not summary:
summarisation = await perform_summarization(search_query, content, llm, model)
if not summarisation:
continue
is_valid = await is_valid_answer(summary, search_query)
is_valid = await is_valid_answer(summarisation, search_query)
parsed_json = json.loads(summarisation)
summary = parsed_json.get('summary', '')
if is_valid:
response = {
"content": summary,
Expand Down Expand Up @@ -137,6 +139,7 @@ async def perform_summarization(search_query: str, content: str, llm: Any, model
summarise_result = json.loads(summarise_result_json)
if summarise_result["status"] == "error":
return ""
logger.info(f"Content summarized successfully: {summarise_result['response']}")
return summarise_result["response"]
except Exception as e:
logger.error(f"Error summarizing content: {e}")
Expand Down
4 changes: 4 additions & 0 deletions backend/src/prompts/templates/intent.j2
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,7 @@ Response:
Q: Show me a chart of different subscription prices with Netflix?
Response:
{"query": "Show me a chart of different subscription prices with Netflix?", "user_intent": "retrieve and visualize subscription data", "questions": [{"query": "What are the different subscription prices with Netflix?", "question_intent": "retrieve subscription pricing information", "operation": "literal search", "question_category": "data driven", "parameters": [{"type": "company", "value": "Netflix"}], "aggregation": "none", "sort_order": "none", "timeframe": "none"}, {"query": "Show me the results in a chart", "question_intent": "display subscription pricing information in a chart", "operation": "data visualization", "question_category": "data presentation", "parameters": [], "aggregation": "none", "sort_order": "none", "timeframe": "none"}]}

Q: Read the file called file_to_read.txt and write its content to a file called output.txt.
Response:
{"query": "Read the file called {{ file_name }} and write its content to a file called {{ output_file_name }}.", "user_intent": "read and write file content", "questions": [{"query": "Read the file called {{ file_name }} using fileagent.", "question_intent": "read file content", "operation": "literal search", "question_category": "data driven", "parameters": [{"type": "file", "value": "{{ file_name }}"}], "aggregation": "none", "sort_order": "none", "timeframe": "none"}, {"query": "Write the content to a file called {{ output_file_name }} using fileagent.", "question_intent": "write file content", "operation": "literal search", "question_category": "data driven", "parameters": [{"type": "file", "value": "{{ output_file_name }}"}], "aggregation": "none", "sort_order": "none", "timeframe": "none"}]}
10 changes: 1 addition & 9 deletions backend/src/prompts/templates/summariser.j2
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,15 @@ You will be passed a user query and the content scraped from the web. You need t

Ensure the summary is clear, well-structured, and directly addresses the user's query.


User's question is:
{{ question }}

Below is the content scraped from the web:
{{ content }}
{{ content | replace("\n\n", "\n") }} # Adding this will introduce breaks between paragraphs

Reply only in json with the following format:

{
"summary": "The summary of the content that answers the user's query",
"reasoning": "A sentence on why you chose that summary"
}

e.g.
Task: What is the capital of England
{
"summary": "The capital of England is London.",
"reasoning": "London is widely known as the capital of England, a fact mentioned in various authoritative sources and geographical references."
}
10 changes: 7 additions & 3 deletions backend/src/supervisors/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,29 @@ async def solve_all(intent_json) -> None:

for question in questions:
try:
(agent_name, answer) = await solve_task(question, get_scratchpad())
(agent_name, answer, status) = await solve_task(question, get_scratchpad())
update_scratchpad(agent_name, question, answer)
if status == "error":
raise Exception(answer)
except Exception as error:
update_scratchpad(error=error)


async def solve_task(task, scratchpad, attempt=0) -> Tuple[str, str]:
async def solve_task(task, scratchpad, attempt=0) -> Tuple[str, str, str]:
if attempt == 5:
raise Exception(unsolvable_response)

agent = await get_agent_for_task(task, scratchpad)
logger.info(f"Agent selected: {agent}")
if agent is None:
raise Exception(no_agent_response)
answer = await agent.invoke(task)
parsed_json = json.loads(answer)
status = parsed_json.get('status', 'success')
ignore_validation = parsed_json.get('ignore_validation', '')
answer_content = parsed_json.get('content', '')
if(ignore_validation == 'true') or await is_valid_answer(answer_content, task):
return (agent.name, answer_content)
return (agent.name, answer_content, status)
return await solve_task(task, scratchpad, attempt + 1)


Expand Down
7 changes: 7 additions & 0 deletions backend/src/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

default_frontend_url = "http://localhost:8650"
default_neo4j_uri = "bolt://localhost:7687"
default_files_directory = "files"


class Config(object):
Expand All @@ -25,6 +26,7 @@ def __init__(self):
self.maths_agent_llm = None
self.web_agent_llm = None
self.chart_generator_llm = None
self.file_agent_llm = None
self.router_llm = None
self.validator_agent_model = None
self.intent_agent_model = None
Expand All @@ -33,6 +35,8 @@ def __init__(self):
self.chart_generator_model = None
self.web_agent_model = None
self.router_model = None
self.files_directory = default_files_directory
self.file_agent_model = None
self.load_env()

def load_env(self):
Expand All @@ -49,6 +53,7 @@ def load_env(self):
self.neo4j_uri = os.getenv("NEO4J_URI", default_neo4j_uri)
self.neo4j_user = os.getenv("NEO4J_USERNAME")
self.neo4j_password = os.getenv("NEO4J_PASSWORD")
self.files_directory = os.getenv("FILES_DIRECTORY", default_files_directory)
self.azure_storage_connection_string = os.getenv("AZURE_STORAGE_CONNECTION_STRING")
self.azure_storage_container_name = os.getenv("AZURE_STORAGE_CONTAINER_NAME")
self.azure_initial_data_filename = os.getenv("AZURE_INITIAL_DATA_FILENAME")
Expand All @@ -57,6 +62,7 @@ def load_env(self):
self.validator_agent_llm = os.getenv("VALIDATOR_AGENT_LLM")
self.datastore_agent_llm = os.getenv("DATASTORE_AGENT_LLM")
self.chart_generator_llm = os.getenv("CHART_GENERATOR_LLM")
self.file_agent_llm = os.getenv("FILE_AGENT_LLM")
self.web_agent_llm = os.getenv("WEB_AGENT_LLM")
self.maths_agent_llm = os.getenv("MATHS_AGENT_LLM")
self.router_llm = os.getenv("ROUTER_LLM")
Expand All @@ -68,6 +74,7 @@ def load_env(self):
self.chart_generator_model = os.getenv("CHART_GENERATOR_MODEL")
self.maths_agent_model = os.getenv("MATHS_AGENT_MODEL")
self.router_model = os.getenv("ROUTER_MODEL")
self.file_agent_model = os.getenv("FILE_AGENT_MODEL")
except FileNotFoundError:
raise FileNotFoundError("Please provide a .env file. See the Getting Started guide on the README.md")
except Exception:
Expand Down
53 changes: 53 additions & 0 deletions backend/tests/agents/file_agent_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pytest
from unittest.mock import patch, mock_open
import json
import os
from src.agents.file_agent import read_file_core, write_file_core, create_response

# Mocking config for the test
@pytest.fixture(autouse=True)
def mock_config(monkeypatch):
monkeypatch.setattr('src.agents.file_agent.config.files_directory', 'files')

@pytest.mark.asyncio
@patch("builtins.open", new_callable=mock_open, read_data="Example file content.")
async def test_read_file_core_success(mock_file):
file_path = "example.txt"
result = await read_file_core(file_path)
expected_response = create_response("Example file content.")
assert json.loads(result) == json.loads(expected_response)
expected_full_path = os.path.normpath("/app/files/example.txt")
mock_file.assert_called_once_with(expected_full_path, 'r')

@pytest.mark.asyncio
@patch("builtins.open", side_effect=FileNotFoundError)
async def test_read_file_core_file_not_found(mock_file):
file_path = "missing_file.txt"
result = await read_file_core(file_path)
expected_response = create_response(f"File {file_path} not found.", "error")
assert json.loads(result) == json.loads(expected_response)
expected_full_path = os.path.normpath("/app/files/missing_file.txt")
mock_file.assert_called_once_with(expected_full_path, 'r')

@pytest.mark.asyncio
@patch("builtins.open", new_callable=mock_open)
async def test_write_file_core_success(mock_file):
file_path = "example_write.txt"
content = "This is test content to write."
result = await write_file_core(file_path, content)
expected_response = create_response(f"Content written to file {file_path}.")
assert json.loads(result) == json.loads(expected_response)
expected_full_path = os.path.normpath("/app/files/example_write.txt")
mock_file.assert_called_once_with(expected_full_path, 'w')
mock_file().write.assert_called_once_with(content)

@pytest.mark.asyncio
@patch("builtins.open", side_effect=Exception("Unexpected error"))
async def test_write_file_core_error(mock_file):
file_path = "error_file.txt"
content = "Content with error."
result = await write_file_core(file_path, content)
expected_response = create_response(f"Error writing to file: {file_path}", "error")
assert json.loads(result) == json.loads(expected_response)
expected_full_path = os.path.normpath("/app/files/error_file.txt")
mock_file.assert_called_once_with(expected_full_path, 'w')
4 changes: 2 additions & 2 deletions backend/tests/agents/web_agent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ async def test_web_general_search_core(

mock_perform_search.return_value = {"status": "success", "urls": ["http://example.com"]}
mock_perform_scrape.return_value = "Example scraped content."
mock_perform_summarization.return_value = "Example summary."
mock_perform_summarization.return_value = json.dumps({"summary": "Example summary."})
mock_is_valid_answer.return_value = True
result = await web_general_search_core("example query", llm, model)
expected_response = {
Expand Down Expand Up @@ -61,7 +61,7 @@ async def test_web_general_search_core_invalid_summary(
model = "mock_model"
mock_perform_search.return_value = {"status": "success", "urls": ["http://example.com"]}
mock_perform_scrape.return_value = "Example scraped content."
mock_perform_summarization.return_value = "Example invalid summary."
mock_perform_summarization.return_value = json.dumps({"summary": "Example invalid summary."})
mock_is_valid_answer.return_value = False
result = await web_general_search_core("example query", llm, model)
assert result == "No relevant information found on the internet for the given query."
4 changes: 2 additions & 2 deletions backend/tests/supervisors/supervisor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ async def test_solve_task_first_attempt_solves(mocker):
mock_answer_json = json.loads(mock_answer)

# Ensure that the result is returned directly without validation
assert answer == (agent.name, mock_answer_json.get('content', ''))
assert answer == (agent.name, mock_answer_json.get('content', ''), "success")


@pytest.mark.asyncio
Expand All @@ -83,7 +83,7 @@ async def test_solve_task_ignore_validation(mocker):
mock_answer_json = json.loads(mock_answer)

# Ensure that the result is returned directly without validation
assert answer == (agent.name, mock_answer_json.get('content', ''))
assert answer == (agent.name, mock_answer_json.get('content', ''), "success")
mock_is_valid_answer.assert_not_called() # Validation should not be called

@pytest.mark.asyncio
Expand Down
Loading
Loading