Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add history persistence component #354

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import uuid

from sqlalchemy import TIMESTAMP, Column, ForeignKey, Integer, String, func
from sqlalchemy.orm import DeclarativeBase


class Base(DeclarativeBase):
ludwiktrammer marked this conversation as resolved.
Show resolved Hide resolved
"""
This class represents the base of the database.
"""


class Conversation(Base):
"""
Represents a conversation in the database.

Attributes:
id: The unique identifier for the conversation.
created_at: The timestamp when the conversation was created.

Table:
conversations: Stores conversation records.
"""

__tablename__ = "conversations"

id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
created_at = Column(TIMESTAMP, server_default=func.now())


class Message(Base):
"""
Represents a message in a conversation.

Attributes:
id: The unique identifier for the message.
conversation_id: The ID of the conversation to which the message belongs.
role: The role of the sender (e.g., 'user', 'assistant').
content: The content of the message.
created_at: The timestamp when the message was created.

Table:
messages: Stores message records.
"""

__tablename__ = "messages"

id = Column(Integer, primary_key=True, autoincrement=True)
conversation_id = Column(String, ForeignKey("conversations.id", ondelete="CASCADE"), nullable=False)
role = Column(String, nullable=False)
content = Column(String, nullable=False)
created_at = Column(TIMESTAMP, server_default=func.now())
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import hashlib
import json

import sqlalchemy

from ragbits.core.prompt import ChatFormat

from .models import Conversation, Message


class ConversationHistoryStore:
ludwiktrammer marked this conversation as resolved.
Show resolved Hide resolved
"""
A class to manage storing, retrieving, and updating conversation histories.

This class provides methods to create a new conversation, fetch an existing conversation,
and update a conversation with new messages. The conversations are stored in a SQLAlchemy-backed
database, and a unique conversation ID is generated based on the message contents.
"""

def __init__(self, sqlalchemy_engine: sqlalchemy.Engine):
"""
Initializes the ConversationHistoryStore with a SQLAlchemy engine.

Args:
sqlalchemy_engine: The SQLAlchemy engine used to interact with the database.
"""
self.sqlalchemy_engine = sqlalchemy_engine

@staticmethod
def generate_conversation_id(messages: ChatFormat) -> str:
ludwiktrammer marked this conversation as resolved.
Show resolved Hide resolved
"""
Generates a unique conversation ID based on the provided messages.

Args:
messages: A list of message objects in the chat format.

Returns:
A unique string representing the conversation ID.
"""
json_obj = json.dumps(messages, separators=(",", ":"))
return hashlib.sha256(json_obj.encode()).hexdigest()

def create_conversation(self, messages: ChatFormat) -> str:
"""
Creates a new conversation in the database or returns an existing one.

Args:
messages: A list of dictionaries, where each dictionary represents a message with 'role' and 'content' keys.

Returns:
The ID of the conversation.
"""
conversation_id = self.generate_conversation_id(messages)

with self.sqlalchemy_engine.connect() as connection:
if connection.execute(sqlalchemy.select(Conversation).filter_by(id=conversation_id)).fetchone():
return conversation_id

connection.execute(sqlalchemy.insert(Conversation).values(id=conversation_id))
connection.execute(
sqlalchemy.insert(Message).values(
[
{"conversation_id": conversation_id, "role": msg["role"], "content": msg["content"]}
for msg in messages
]
)
)
connection.commit()

return conversation_id

def fetch_conversation(self, conversation_id: str) -> ChatFormat:
"""
Fetches a conversation by its ID.

Args:
conversation_id: The ID of the conversation to fetch.

Returns:
A list of message dictionaries, each containing 'role' and 'content'.
"""
with self.sqlalchemy_engine.connect() as connection:
rows = connection.execute(
sqlalchemy.select(Message).filter_by(conversation_id=conversation_id).order_by(Message.created_at)
).fetchall()
return [{"role": row.role, "content": row.content} for row in rows] if rows else []

def update_conversation(self, conversation_id: str, new_messages: ChatFormat) -> str:
"""
Updates a conversation with new messages.

Args:
conversation_id: The ID of the conversation to update.
new_messages: A list of new message objects in the chat format.

Returns:
The ID of the updated conversation.
"""
with self.sqlalchemy_engine.connect() as connection:
connection.execute(
sqlalchemy.insert(Message).values(
[
{"conversation_id": conversation_id, "role": msg["role"], "content": msg["content"]}
for msg in new_messages
]
)
)
connection.commit()

return conversation_id
51 changes: 51 additions & 0 deletions packages/ragbits-conversations/tests/unit/history/test_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from unittest.mock import MagicMock, patch

import pytest
import sqlalchemy

from ragbits.conversations.history.store import ConversationHistoryStore


@pytest.fixture
def store():
engine_mock = MagicMock(spec=sqlalchemy.Engine)
return ConversationHistoryStore(engine_mock)


def test_generate_conversation_id(store: ConversationHistoryStore):
messages = [{"role": "user", "content": "Hello"}]
id = store.generate_conversation_id(messages)
assert isinstance(id, str)
assert len(id) == 64


def test_create_conversation(store: ConversationHistoryStore):
with patch.object(store.sqlalchemy_engine, "connect") as mock_connect:
mock_connection = mock_connect.return_value.__enter__.return_value
mock_connection.execute.return_value.fetchone.return_value = None

id = store.create_conversation([{"role": "user", "content": "Hello"}])
assert isinstance(id, str)
mock_connection.execute.assert_called()
mock_connection.commit.assert_called_once()


def test_fetch_conversation(store: ConversationHistoryStore):
with patch.object(store.sqlalchemy_engine, "connect") as mock_connect:
mock_connection = mock_connect.return_value.__enter__.return_value
mock_connection.execute.return_value.fetchall.return_value = [
MagicMock(role="user", content="Hi"),
MagicMock(role="model", content="Hello"),
]

messages = store.fetch_conversation("id")
assert messages == [{"role": "user", "content": "Hi"}, {"role": "model", "content": "Hello"}]


def test_update_conversation(store: ConversationHistoryStore):
with patch.object(store.sqlalchemy_engine, "connect") as mock_connect:
mock_connection = mock_connect.return_value.__enter__.return_value

store.update_conversation("id", [{"role": "user", "content": "How are you?"}])
mock_connection.execute.assert_called()
mock_connection.commit.assert_called_once()
Loading