-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implemented basic memory mechanism + basic chat agent
- Loading branch information
1 parent
a1405ee
commit 35643a7
Showing
8 changed files
with
201 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from src.agent.agent import Agent |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,44 @@ | ||
from src.agent.llm import LLM | ||
from src.agent.plan import Plan | ||
from src.agent.memory import Memory, Message, Role | ||
from src.agent.prompts import PROMPTS | ||
from src.agent.tools import TOOLS | ||
|
||
|
||
class Agent: | ||
pass | ||
def __init__(self, model: str, tools_docs): | ||
self.llm = LLM(model=model) | ||
self.mem = Memory() | ||
self.system_prompt = 'You are an assistant.' # PROMPTS[ollama_model]['system']['plan'].format(tools=tools_docs) | ||
self.user_prompt = PROMPTS[model]['user']['plan'] | ||
|
||
def new_session(self, sid: int): | ||
self.mem.store_message(sid, Message(Role.SYS, self.system_prompt)) | ||
|
||
if __name__ == "__main__": | ||
model = 'gemma:2b' | ||
llm = LLM(model=model) | ||
tools_documentation = '\n'.join([tool.get_documentation() for tool in TOOLS]) | ||
def get_session(self, sid: int): | ||
return self.mem.get_session(sid) | ||
|
||
sys_prompt = PROMPTS[model]['system']['plan'].format(tools=tools_documentation) | ||
usr_prompt = '' | ||
def query(self, sid: int, user_in: str): | ||
self.mem.store_message( | ||
sid, | ||
Message(Role.USER, self.user_prompt.format(user_input=user_in)) | ||
) | ||
messages = self.mem.get_session(sid).messages_to_dict_list() | ||
|
||
while True: | ||
user_input = input("Enter: ") | ||
if user_input == "-1": | ||
break | ||
elif user_input == "exec": | ||
pass # execute plan | ||
else: | ||
pass # make query | ||
response = '' | ||
for chunk in self.llm.query(messages): | ||
yield chunk['message']['content'] | ||
response += chunk['message']['content'] | ||
|
||
self.mem.store_message( | ||
sid, | ||
Message(Role.ASSISTANT, response) | ||
) | ||
|
||
def save_session(self, sid: int): | ||
self.mem.save_session(sid) | ||
|
||
def delete_session(self, sid: int): | ||
self.mem.delete_session(sid) | ||
|
||
def rename_session(self, sid: int, session_name: str): | ||
self.mem.rename_session(sid, session_name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,3 @@ | ||
from pathlib import Path | ||
from src.agent.memory.base import Memory, Message, Role | ||
|
||
|
||
SESSIONS_PATH = Path(Path.home() / '.aiops' / 'sessions') | ||
if not SESSIONS_PATH.exists(): | ||
SESSIONS_PATH.mkdir(parents=True, exist_ok=True) |
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
import json | ||
from dataclasses import dataclass | ||
from typing import List | ||
from enum import StrEnum | ||
from pathlib import Path | ||
|
||
from src.agent.plan import Plan | ||
|
||
SESSIONS_PATH = Path(Path.home() / '.aiops' / 'sessions') | ||
if not SESSIONS_PATH.exists(): | ||
SESSIONS_PATH.mkdir(parents=True, exist_ok=True) | ||
|
||
|
||
class Role(StrEnum): | ||
SYS = 'system' | ||
USER = 'user' | ||
ASSISTANT = 'assistant' | ||
|
||
@staticmethod | ||
def from_str(item): | ||
if item == 'user': | ||
return Role.USER | ||
elif item == 'assistant': | ||
return Role.ASSISTANT | ||
elif item == 'system': | ||
return Role.SYS | ||
else: | ||
return None | ||
|
||
|
||
@dataclass | ||
class Message: | ||
role: Role | ||
content: str | ||
|
||
|
||
@dataclass | ||
class Session: | ||
name: str | ||
messages: List[Message] | ||
|
||
# plan | ||
|
||
def messages_to_dict_list(self): | ||
"""Converts the message list into a format compatible with Ollama""" | ||
return [{'role': str(msg.role), 'content': msg.content} for msg in self.messages] | ||
|
||
@staticmethod | ||
def from_json(path: str): | ||
"""Get a session from a JSON file""" | ||
with open(str(path), 'r', encoding='utf-8') as fp: | ||
data = json.load(fp) | ||
return data['id'], Session( | ||
name=data['name'], | ||
messages=[Message(Role.from_str(msg['role']), msg['content']) for msg in data['messages']] | ||
) | ||
|
||
|
||
class Memory: | ||
""" | ||
Contains the chat history for each session, it is bounded to the Agent class. | ||
""" | ||
|
||
def __init__(self): | ||
self.sessions = {} | ||
self.load_sessions() | ||
|
||
def store_message(self, sid: int, message: Message): | ||
"""Add a message to a session identified by session id or creates a new one""" | ||
if not isinstance(message, Message): | ||
raise ValueError(f'Not a message: {message}') | ||
if sid not in self.sessions: | ||
self.sessions[sid] = Session(name='New Session', messages=[]) | ||
|
||
self.sessions[sid].messages.append(message) | ||
|
||
def store_plan(self, sid: int, plan: Plan): | ||
pass | ||
|
||
def get_session(self, sid: int) -> Session: | ||
""" | ||
:return: a session identified by session id or None | ||
""" | ||
return self.sessions[sid] if sid in self.sessions else None | ||
|
||
def get_sessions(self) -> dict: | ||
"""Returns all loaded sessions as id: session""" | ||
return self.sessions | ||
|
||
def save_session(self, sid: int): | ||
"""Saves the current session state to a JSON file at SESSION_PATH""" | ||
if sid not in self.sessions: | ||
raise ValueError(f'Session {sid} does not exist') | ||
|
||
session = self.sessions[sid] | ||
with open(f'{SESSIONS_PATH}/{sid}__{session.name}.json', 'w+', encoding='utf-8') as fp: | ||
data = { | ||
'id': sid, | ||
'name': session.name, | ||
'messages': session.messages_to_dict_list() | ||
} | ||
json.dump(data, fp) | ||
|
||
def delete_session(self, sid: int): | ||
"""Deletes a session from SESSION_PATH""" | ||
if sid not in self.sessions: | ||
raise ValueError(f'Session {sid} does not exist') | ||
|
||
for path in SESSIONS_PATH.iterdir(): | ||
if path.is_file() and path.suffix == '.json' and path.name.startswith(f'{sid}__'): | ||
path.unlink() | ||
|
||
def rename_session(self, sid: int, session_name: str): | ||
"""Renames a session identified by session id or creates a new one""" | ||
if sid not in self.sessions: | ||
self.sessions[sid] = Session(name=session_name, messages=[]) | ||
else: | ||
self.sessions[sid].name = session_name | ||
|
||
def load_sessions(self): | ||
"""Loads the saved sessions at SESSION_PATH""" | ||
for path in SESSIONS_PATH.iterdir(): | ||
if path.is_file() and path.suffix == '.json': | ||
sid, session = Session.from_json(str(path)) | ||
self.sessions[sid] = session |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,39 @@ | ||
from src.agent import Agent | ||
from src.agent.tools import TOOLS | ||
|
||
|
||
def cli_test(): | ||
"""testing Agent""" | ||
ollama_model = 'gemma:2b' | ||
tools_documentation = '\n'.join([tool.get_documentation() for tool in TOOLS]) | ||
|
||
agent = Agent(model=ollama_model, tools_docs=tools_documentation) | ||
current_session = 0 | ||
while True: | ||
user_input = input("Enter: ") | ||
if user_input == "-1": | ||
break | ||
elif user_input == "exec": # execute plan | ||
pass | ||
|
||
elif user_input.split(" ")[0] == "new": # create session | ||
agent.new_session(int(user_input.split(" ")[1])) | ||
current_session = int(user_input.split(" ")[1]) | ||
|
||
elif user_input.split(" ")[0] == "save": # save session | ||
agent.save_session(int(user_input.split(" ")[1])) | ||
|
||
elif user_input.split(" ")[0] == "load": # load session | ||
current_session = int(user_input.split(" ")[1]) | ||
session_history = agent.get_session(current_session) | ||
for msg in session_history.messages_to_dict_list(): | ||
print(f'\n> {msg["role"]}: {msg["content"]}') | ||
|
||
else: # query | ||
for chunk in agent.query(current_session, user_input): | ||
print(chunk, end='') | ||
print() | ||
|
||
|
||
if __name__ == "__main__": | ||
pass |