Skip to content

Commit

Permalink
refactoring: updated usage of llm and memory module.
Browse files Browse the repository at this point in the history
  • Loading branch information
antoninoLorenzo committed Jan 22, 2025
1 parent 7b20b8e commit 3fa3751
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 28 deletions.
10 changes: 5 additions & 5 deletions src/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,21 +91,21 @@ def new_session(self, sid: int):

def get_session(self, sid: int):
"""Open existing conversation"""
return self.agent.memory.get_session(sid)
return self.agent.memory.get_conversation(sid)

def get_sessions(self):
"""Returns list of Session objects"""
return self.agent.memory.get_sessions()
return self.agent.memory.get_conversations()

def save_session(self, sid: int):
"""Saves the specified session to JSON"""
self.agent.memory.save_session(sid)
self.agent.memory.save_conversation(sid)

def delete_session(self, sid: int):
"""Deletes the specified session"""
self.agent.memory.delete_session(sid)
self.agent.memory.delete_conversation(sid)

def rename_session(self, sid: int, session_name: str):
"""Rename the specified session"""
self.agent.memory.rename_session(sid, session_name)
self.agent.memory.rename_conversation(sid, session_name)

48 changes: 26 additions & 22 deletions src/agent/architectures/default/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from tool_parse import ToolRegistry

from src.agent import AgentArchitecture
from src.core import Conversation
from src.core.llm import LLM
from src.core.memory import Message, Role
from src.utils import get_logger, LOGS_PATH
Expand Down Expand Up @@ -124,7 +125,7 @@ def query(
:returns: Generator with response text in chunks."""
# TODO: yield (chunk, context_length)
# create a new conversation if not exists
if not self.memory.get_session(session_id):
if not self.memory.get_conversation(session_id):
self.new_session(session_id)

# route query
Expand All @@ -143,7 +144,7 @@ def query(
tool_call_str: str | None = None
for tool_call_execution in self.__tool_call(
user_input,
self.memory.get_session(session_id).message_dict,
self.memory.get_conversation(session_id),
):
tool_call_state = tool_call_execution['state']
if tool_call_state == 'error':
Expand All @@ -163,9 +164,9 @@ def query(
assistant_index = 1

# Replace system prompt with the one built for specific assistant type
history = self.memory.get_session(session_id)
history = self.memory.get_conversation(session_id)
history.messages[0] = Message(role=Role.SYS, content=prompt)
history.add_message(
history.add(
Message(
role=Role.USER,
content=user_input_with_tool_call
Expand All @@ -174,7 +175,7 @@ def query(

# note: history.message_dict doesn't care about context length
response = ''
for chunk, ctx_length in self.llm.query(history.message_dict):
for chunk, ctx_length in self.llm.query(history):
if ctx_length:
self.token_logger.info(
f'Session: {session_id}; Tokens: {ctx_length}'
Expand All @@ -194,9 +195,9 @@ def query(

# remove tool call result from user input and add response to history
history.messages[-1].content = user_input
history.add_message(
history.add(
Message(
Role.ASSISTANT,
role=Role.ASSISTANT,
content=response
)
)
Expand All @@ -221,10 +222,13 @@ def __get_assistant_index(
:param user_input: The user's input query.
:return: An index to choose the proper prompt.
"""
route_messages = [
{'role': 'system', 'content': self.__prompts['router']},
{'role': 'user', 'content': user_input}
]
route_messages = Conversation(
name='get_assistant_index',
messages=[
{'role': 'system', 'content': self.__prompts['router']},
{'role': 'user', 'content': user_input}
]
)
assistant_index_buffer = ''
for chunk, _ in self.llm.query(route_messages):
if not chunk:
Expand All @@ -241,26 +245,26 @@ def __get_assistant_index(
def __tool_call(
self,
user_input: str,
message_history: list
conversation: Conversation
):
"""Query a LLM for a tool call and executes it.
:param user_input: The user's input query.
:param message_history: The conversation history.
:param conversation: The conversation history.
:returns: Result of the tool execution."""
# replace system prompt and generate tool call
message_history[0] = {
'role': 'system',
'content': self.__prompts['tool']
}
message_history.append({
'role': 'user',
'content': user_input
})
conversation.messages[0] = Message(
role='system',
content=self.__prompts['tool']
)
conversation.add(Message(
role='user',
content=user_input
))

tool_call_response = ''
for chunk, _ in self.llm.query(message_history):
for chunk, _ in self.llm.query(conversation):
tool_call_response += chunk

# extract tool call and run it
Expand Down
2 changes: 1 addition & 1 deletion src/routers/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ async def get_session(sid: int, agent: Agent = Depends(get_agent)):
return {
'sid': sid,
'name': session.name,
'messages': session.message_dict
'messages': session.model_dump()
}


Expand Down

0 comments on commit 3fa3751

Please sign in to comment.