Skip to content

Commit

Permalink
Implemented basic memory mechanism + basic chat agent
Browse files Browse the repository at this point in the history
  • Loading branch information
antoninoLorenzo committed Jun 17, 2024
1 parent a1405ee commit 35643a7
Show file tree
Hide file tree
Showing 8 changed files with 201 additions and 20 deletions.
1 change: 1 addition & 0 deletions src/agent/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from src.agent.agent import Agent
50 changes: 34 additions & 16 deletions src/agent/agent.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,44 @@
from src.agent.llm import LLM
from src.agent.plan import Plan
from src.agent.memory import Memory, Message, Role
from src.agent.prompts import PROMPTS
from src.agent.tools import TOOLS


class Agent:
pass
def __init__(self, model: str, tools_docs):
self.llm = LLM(model=model)
self.mem = Memory()
self.system_prompt = 'You are an assistant.' # PROMPTS[ollama_model]['system']['plan'].format(tools=tools_docs)
self.user_prompt = PROMPTS[model]['user']['plan']

def new_session(self, sid: int):
self.mem.store_message(sid, Message(Role.SYS, self.system_prompt))

if __name__ == "__main__":
model = 'gemma:2b'
llm = LLM(model=model)
tools_documentation = '\n'.join([tool.get_documentation() for tool in TOOLS])
def get_session(self, sid: int):
return self.mem.get_session(sid)

sys_prompt = PROMPTS[model]['system']['plan'].format(tools=tools_documentation)
usr_prompt = ''
def query(self, sid: int, user_in: str):
self.mem.store_message(
sid,
Message(Role.USER, self.user_prompt.format(user_input=user_in))
)
messages = self.mem.get_session(sid).messages_to_dict_list()

while True:
user_input = input("Enter: ")
if user_input == "-1":
break
elif user_input == "exec":
pass # execute plan
else:
pass # make query
response = ''
for chunk in self.llm.query(messages):
yield chunk['message']['content']
response += chunk['message']['content']

self.mem.store_message(
sid,
Message(Role.ASSISTANT, response)
)

def save_session(self, sid: int):
self.mem.save_session(sid)

def delete_session(self, sid: int):
self.mem.delete_session(sid)

def rename_session(self, sid: int, session_name: str):
self.mem.rename_session(sid, session_name)
6 changes: 2 additions & 4 deletions src/agent/memory/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from pathlib import Path
from src.agent.memory.base import Memory, Message, Role


SESSIONS_PATH = Path(Path.home() / '.aiops' / 'sessions')
if not SESSIONS_PATH.exists():
SESSIONS_PATH.mkdir(parents=True, exist_ok=True)
Binary file not shown.
Binary file added src/agent/memory/__pycache__/base.cpython-311.pyc
Binary file not shown.
125 changes: 125 additions & 0 deletions src/agent/memory/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import json
from dataclasses import dataclass
from typing import List
from enum import StrEnum
from pathlib import Path

from src.agent.plan import Plan

SESSIONS_PATH = Path(Path.home() / '.aiops' / 'sessions')
if not SESSIONS_PATH.exists():
SESSIONS_PATH.mkdir(parents=True, exist_ok=True)


class Role(StrEnum):
SYS = 'system'
USER = 'user'
ASSISTANT = 'assistant'

@staticmethod
def from_str(item):
if item == 'user':
return Role.USER
elif item == 'assistant':
return Role.ASSISTANT
elif item == 'system':
return Role.SYS
else:
return None


@dataclass
class Message:
role: Role
content: str


@dataclass
class Session:
name: str
messages: List[Message]

# plan

def messages_to_dict_list(self):
"""Converts the message list into a format compatible with Ollama"""
return [{'role': str(msg.role), 'content': msg.content} for msg in self.messages]

@staticmethod
def from_json(path: str):
"""Get a session from a JSON file"""
with open(str(path), 'r', encoding='utf-8') as fp:
data = json.load(fp)
return data['id'], Session(
name=data['name'],
messages=[Message(Role.from_str(msg['role']), msg['content']) for msg in data['messages']]
)


class Memory:
"""
Contains the chat history for each session, it is bounded to the Agent class.
"""

def __init__(self):
self.sessions = {}
self.load_sessions()

def store_message(self, sid: int, message: Message):
"""Add a message to a session identified by session id or creates a new one"""
if not isinstance(message, Message):
raise ValueError(f'Not a message: {message}')
if sid not in self.sessions:
self.sessions[sid] = Session(name='New Session', messages=[])

self.sessions[sid].messages.append(message)

def store_plan(self, sid: int, plan: Plan):
pass

def get_session(self, sid: int) -> Session:
"""
:return: a session identified by session id or None
"""
return self.sessions[sid] if sid in self.sessions else None

def get_sessions(self) -> dict:
"""Returns all loaded sessions as id: session"""
return self.sessions

def save_session(self, sid: int):
"""Saves the current session state to a JSON file at SESSION_PATH"""
if sid not in self.sessions:
raise ValueError(f'Session {sid} does not exist')

session = self.sessions[sid]
with open(f'{SESSIONS_PATH}/{sid}__{session.name}.json', 'w+', encoding='utf-8') as fp:
data = {
'id': sid,
'name': session.name,
'messages': session.messages_to_dict_list()
}
json.dump(data, fp)

def delete_session(self, sid: int):
"""Deletes a session from SESSION_PATH"""
if sid not in self.sessions:
raise ValueError(f'Session {sid} does not exist')

for path in SESSIONS_PATH.iterdir():
if path.is_file() and path.suffix == '.json' and path.name.startswith(f'{sid}__'):
path.unlink()

def rename_session(self, sid: int, session_name: str):
"""Renames a session identified by session id or creates a new one"""
if sid not in self.sessions:
self.sessions[sid] = Session(name=session_name, messages=[])
else:
self.sessions[sid].name = session_name

def load_sessions(self):
"""Loads the saved sessions at SESSION_PATH"""
for path in SESSIONS_PATH.iterdir():
if path.is_file() and path.suffix == '.json':
sid, session = Session.from_json(str(path))
self.sessions[sid] = session
3 changes: 3 additions & 0 deletions src/agent/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
Utilize only the provided TOOLS and follow their usage examples strictly, the available TOOLS are as follows:
{tools}
"""),
},
'user': {
'plan': '{user_input}'
}
}
}
36 changes: 36 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,39 @@
from src.agent import Agent
from src.agent.tools import TOOLS


def cli_test():
"""testing Agent"""
ollama_model = 'gemma:2b'
tools_documentation = '\n'.join([tool.get_documentation() for tool in TOOLS])

agent = Agent(model=ollama_model, tools_docs=tools_documentation)
current_session = 0
while True:
user_input = input("Enter: ")
if user_input == "-1":
break
elif user_input == "exec": # execute plan
pass

elif user_input.split(" ")[0] == "new": # create session
agent.new_session(int(user_input.split(" ")[1]))
current_session = int(user_input.split(" ")[1])

elif user_input.split(" ")[0] == "save": # save session
agent.save_session(int(user_input.split(" ")[1]))

elif user_input.split(" ")[0] == "load": # load session
current_session = int(user_input.split(" ")[1])
session_history = agent.get_session(current_session)
for msg in session_history.messages_to_dict_list():
print(f'\n> {msg["role"]}: {msg["content"]}')

else: # query
for chunk in agent.query(current_session, user_input):
print(chunk, end='')
print()


if __name__ == "__main__":
pass

0 comments on commit 35643a7

Please sign in to comment.