Skip to content

Commit

Permalink
refactoring: Memory class was fixed for future updates.
Browse files Browse the repository at this point in the history
  • Loading branch information
antoninoLorenzo committed Jan 25, 2025
1 parent 610da92 commit b0ed72e
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 108 deletions.
17 changes: 9 additions & 8 deletions src/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def query(
raise NotImplementedError()

@abstractmethod
def new_session(self, session_id: int):
def new_session(self, session_id: int, name: str):
raise NotImplementedError()


Expand Down Expand Up @@ -85,27 +85,28 @@ def query(self, session_id: int, user_input: str):
# consumption (arch_name, sid, context_length)
yield from self.agent.query(session_id, user_input)

def new_session(self, sid: int):
def new_session(self, sid: int, name: str):
"""Initializes a new conversation"""
self.agent.new_session(sid)
self.agent.new_session(sid, name)

def get_session(self, sid: int):
"""Open existing conversation"""
return self.agent.memory.get_conversation(sid)
return self.agent.memory[sid]

def get_sessions(self):
"""Returns list of Session objects"""
return self.agent.memory.get_conversations()
return self.agent.memory.conversations

def save_session(self, sid: int):
"""Saves the specified session to JSON"""
self.agent.memory.save_conversation(sid)
self.agent.memory.save(sid)

def delete_session(self, sid: int):
"""Deletes the specified session"""
self.agent.memory.delete_conversation(sid)
self.agent.memory.delete(sid)

def rename_session(self, sid: int, session_name: str):
"""Rename the specified session"""
self.agent.memory.rename_conversation(sid, session_name)
if sid in self.agent.memory:
self.agent.memory[sid].name = session_name

61 changes: 32 additions & 29 deletions src/agent/architectures/default/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ def state(self, c: str):


class DefaultArchitecture(AgentArchitecture):
"""
TODO: fix architecture since its kinda broken
- search adds two user messages ???
The overall code sucks. My bad.
"""
model: str
architecture_name = 'default_architecture'

Expand Down Expand Up @@ -125,7 +131,7 @@ def query(
:returns: Generator with response text in chunks."""
# TODO: yield (chunk, context_length)
# create a new conversation if not exists
if not self.memory.get_conversation(session_id):
if not self.memory[session_id]:
self.new_session(session_id)

# route query
Expand All @@ -144,7 +150,7 @@ def query(
tool_call_str: str | None = None
for tool_call_execution in self.__tool_call(
user_input,
self.memory.get_conversation(session_id),
self.memory[session_id],
):
tool_call_state = tool_call_execution['state']
if tool_call_state == 'error':
Expand All @@ -164,23 +170,21 @@ def query(
assistant_index = 1

# Replace system prompt with the one built for specific assistant type
history = self.memory.get_conversation(session_id)
history.messages[0] = Message(role=Role.SYS, content=prompt)
history.add(
Message(
role=Role.USER,
content=user_input_with_tool_call
)
conversation = self.memory[session_id]
conversation.messages[0] = Message(role=Role.SYS, content=prompt)
conversation += Message(
role=Role.USER,
content=user_input_with_tool_call
)

# note: history.message_dict doesn't care about context length
# note: conversation.message_dict doesn't care about context length
response = ''
# yes, I called ass_tokens the assistant tokens
response_tokens = 0
for chunk, usr_tokens, ass_tokens in self.llm.query(history):
for chunk, usr_tokens, ass_tokens in self.llm.query(conversation):
if usr_tokens:
# set last message (usr) token usage
history.messages[-1].set_tokens(usr_tokens)
conversation.messages[-1].set_tokens(usr_tokens)
response_tokens = ass_tokens
break
if assistant_index == 1:
Expand All @@ -195,25 +199,24 @@ def query(
yield c
# add thinking yield

# remove tool call result from user input and add response to history
history.messages[-1].content = user_input
history.add(
Message(
role=Role.ASSISTANT,
content=response,
)
# remove tool call result from user input and add response to conversation
conversation.messages[-1].content = user_input
conversation += Message(
role=Role.ASSISTANT,
content=response,
)
history.messages[-1].set_tokens(response_tokens)

def new_session(self, session_id: int):
conversation.messages[-1].set_tokens(response_tokens)
logger.debug(f'CONVERSATION: {conversation}')

def new_session(self, session_id: int, name: str):
"""Create a new conversation if not exists"""
# logger.debug('Creating new session')
self.memory.store_message(
session_id,
Message(
role=Role.SYS,
content=self.__prompts['general']
)
if session_id not in self.memory:
self.memory[session_id] = Conversation(name=name)
self.memory[session_id] += Message(
role=Role.SYS,
content=self.__prompts['general']
)

def __get_assistant_index(
Expand Down Expand Up @@ -261,10 +264,10 @@ def __tool_call(
role='system',
content=self.__prompts['tool']
)
conversation.add(Message(
conversation += Message(
role='user',
content=user_input
))
)

tool_call_response = ''
for chunk, _, _ in self.llm.query(conversation):
Expand Down
128 changes: 62 additions & 66 deletions src/core/memory/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Dict

from pydantic import validate_call
from src.core.memory.schema import Message, Conversation
from src.core.memory.schema import Conversation
from src.utils import get_logger


Expand All @@ -16,92 +16,88 @@


class Memory:
"""
Contains the chat history for each session.
"""Manages in-memory conversations and provides persistence.
Conversations are accessible via Python mapping protocol (`memory[key]`).
Conversations can be saved to persistent JSON files and will be
loaded on Memory initialization.
# TODO: Implement conversation management for max context length
"""

def __init__(self):
self.sessions: Dict[int: Conversation] = {}
self.load_conversations()

@validate_call
def store_message(self, sid: int, message: Message):
"""Add a message to a session identified by session id.
Creates a new session if the specified do not exist."""
if sid not in self.sessions:
self.sessions[sid] = Conversation(name='New Session', messages=[])
self.sessions[sid].add(message)

def get_conversation(self, sid: int) -> Conversation:
"""
:return: a session identified by session id or None
"""
return self.sessions[sid] if sid in self.sessions else None
self.__conversation_map: Dict[int: Conversation] = {}
self.__load_conversations()

@validate_call
def replace_system_prompt(self, sid: int, message: Message):
session = self.sessions[sid]
session.messages[0] = message
def __setitem__(self, key, value: Conversation):
self.__conversation_map[key] = value

def get_conversations(self) -> dict:
"""Returns all loaded sessions as id: session"""
return self.sessions
def __getitem__(self, key) -> Conversation:
return self.__conversation_map[key] \
if key in self.__conversation_map \
else None

def save_conversation(self, sid: int):
"""Saves the current session state to a JSON file at SESSION_PATH"""
if sid not in self.sessions:
logger.error(f'\tError in {self.__name__}: session not exists.')
raise ValueError(f'Session {sid} does not exist')
def __contains__(self, key):
return key in self.__conversation_map

session: Conversation = self.sessions[sid]
self.delete_conversation(sid)
@property
def conversations(self) -> Dict[int, Conversation]:
return self.__conversation_map

path = f'{SESSIONS_PATH}/{sid}__{session.name}.json'
with open(path, 'w+', encoding='utf-8') as fp:
def save(self, conversation_id: int):
"""Saves a conversation to a persistent JSON file in SESSION_PATH.
In the case `conversation_id` is not in memory, the error is logged.
"""
if conversation_id not in self:
logger.error(f'[save]: {conversation_id} not available.')
return

conversation = self.__conversation_map[conversation_id]
conversation_path = (
SESSIONS_PATH
/ f'{conversation_id}__{conversation.name}.json'
)
with open(str(conversation_path), 'w+', encoding='utf-8') as fp:
try:
data = {
'id': sid,
'name': session.name,
'messages': session.model_dump(),
'id': conversation_id,
'name': conversation.name,
'messages': conversation.model_dump(),
}
json.dump(data, fp, indent='\t')
except (
UnicodeDecodeError,
json.JSONDecodeError,
OverflowError
) as save_error:
except Exception as save_error:
logger.error(
f'Failed saving session {sid}. {save_error}'
f'[save]: failed saving conversation {conversation_id}. {save_error}'
)

def delete_conversation(self, sid: int):
"""Deletes a session from SESSION_PATH"""
# TODO: should also delete session from sessions dictionary
if sid not in self.sessions:
raise ValueError(f'Session {sid} does not exist')

# delete file from ~/.aiops/sessions
for path in SESSIONS_PATH.iterdir():
if path.is_file() and path.suffix == '.json' and \
path.name.startswith(f'{sid}__'):
path.unlink()

# delete session from memory
self.sessions.pop(sid, None)

def rename_conversation(self, sid: int, session_name: str):
"""Renames a session identified by session id or creates a new one"""
if sid not in self.sessions:
self.sessions[sid] = Conversation(name=session_name, messages=[])
def delete(self, conversation_id: int):
"""Removes Conversation object from internal map and from persistent
files. In case `conversation_id` is not in memory, the error is logged.
"""
if conversation_id not in self:
logger.error(f'[delete]: {conversation_id} not available.')
return

conversation_path = (
SESSIONS_PATH
/ f'{conversation_id}__{self[conversation_id].name}.json'
)
if conversation_path.exists():
conversation_path.unlink()
self.__conversation_map.pop(conversation_id, None)
else:
self.sessions[sid].name = session_name
logger.error(f'[delete]: {conversation_path} not found.')

def load_conversations(self):
"""Loads the saved sessions at SESSION_PATH"""
def __load_conversations(self):
"""
Populates the internal map from persistent JSON files at SESSION_PATH.
"""
for path in SESSIONS_PATH.iterdir():
if path.is_file() and path.suffix == '.json':
sid, session = Conversation.from_json(str(path))
if sid == -1:
logger.error(f"\tFailed loading session {path}")
logger.info(f"\tLoaded session {path}")
self.sessions[sid] = session
self.__conversation_map[sid] = session
9 changes: 4 additions & 5 deletions src/core/memory/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ class Conversation(BaseModel):
_tokens: int = 0

@validate_call
def add(self, message: Message):
"""Append a message"""
self.messages.append(message)
def __iadd__(self, other):
self.messages.append(other)
return self

def model_dump(self, **kwargs):
# return only a list of messages when converting to list[dict]
Expand All @@ -62,8 +62,7 @@ def __len__(self):

@staticmethod
def from_json(path: str):
"""
Get a session from a JSON file.
"""Get a session from a JSON file.
Reason for not using model_validate_json: saved JSON file contains
ID, instead Conversation doesn't have an ID field.
"""
Expand Down

0 comments on commit b0ed72e

Please sign in to comment.