diff --git a/packages/ragbits-conversations/src/ragbits/conversations/history/stores/__init__.py b/packages/ragbits-conversations/src/ragbits/conversations/history/stores/__init__.py new file mode 100644 index 000000000..8786b79c7 --- /dev/null +++ b/packages/ragbits-conversations/src/ragbits/conversations/history/stores/__init__.py @@ -0,0 +1,4 @@ +from .base import HistoryStore +from .sql import SQLHistoryStore + +__all__ = ["HistoryStore", "SQLHistoryStore"] diff --git a/packages/ragbits-conversations/src/ragbits/conversations/history/stores/base.py b/packages/ragbits-conversations/src/ragbits/conversations/history/stores/base.py new file mode 100644 index 000000000..c93c8e274 --- /dev/null +++ b/packages/ragbits-conversations/src/ragbits/conversations/history/stores/base.py @@ -0,0 +1,56 @@ +from abc import ABC, abstractmethod +from typing import ClassVar, TypeVar + +from ragbits.conversations.history import stores +from ragbits.core.options import Options +from ragbits.core.prompt import ChatFormat +from ragbits.core.utils.config_handling import ConfigurableComponent + +HistoryStoreOptionsT = TypeVar("HistoryStoreOptionsT", bound=Options) + + +class HistoryStore(ConfigurableComponent[HistoryStoreOptionsT], ABC): + """ + Abstract base class for conversation history stores. + """ + + options_cls: type[HistoryStoreOptionsT] + default_module: ClassVar = stores + configuration_key: ClassVar = "history_store" + + @abstractmethod + async def create_conversation(self, messages: ChatFormat) -> str: + """ + Creates a new conversation and stores the given messages. + + Args: + messages: A list of message objects representing the conversation history. + + Returns: + A unique identifier for the created conversation. + """ + + @abstractmethod + async def fetch_conversation(self, conversation_id: str) -> ChatFormat: + """ + Retrieves a conversation by its unique identifier. + + Args: + conversation_id: The unique ID of the conversation to fetch. + + Returns: + A list of message objects representing the retrieved conversation history. + """ + + @abstractmethod + async def update_conversation(self, conversation_id: str, new_messages: ChatFormat) -> str: + """ + Updates an existing conversation with new messages. + + Args: + conversation_id: The unique ID of the conversation to update. + new_messages: A list of new message objects to append to the conversation. + + Returns: + The ID of the updated conversation. + """ diff --git a/packages/ragbits-conversations/src/ragbits/conversations/history/stores/sql.py b/packages/ragbits-conversations/src/ragbits/conversations/history/stores/sql.py new file mode 100644 index 000000000..30205a3e2 --- /dev/null +++ b/packages/ragbits-conversations/src/ragbits/conversations/history/stores/sql.py @@ -0,0 +1,195 @@ +import uuid +from typing import TypeVar + +import sqlalchemy +from sqlalchemy import TIMESTAMP, Column, ForeignKey, Integer, String, func +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine +from sqlalchemy.orm import DeclarativeBase +from typing_extensions import Self + +from ragbits.conversations.history.stores.base import HistoryStore +from ragbits.core.options import Options +from ragbits.core.prompt import ChatFormat +from ragbits.core.utils.config_handling import ObjectContructionConfig + + +class _Base(DeclarativeBase): + @classmethod + def set_table_name(cls, name: str) -> None: + cls.__tablename__ = name + + +class Conversation(_Base): + """ + Represents a conversation in the database. + + Attributes: + id: The unique identifier for the conversation. + created_at: The timestamp when the conversation was created. + + Table: + conversations: Stores conversation records. + """ + + __tablename__ = "ragbits_conversations" + id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4())) + created_at = Column(TIMESTAMP, server_default=func.now()) + + +class Message(_Base): + """ + Represents a message in a conversation. + + Attributes: + id: The unique identifier for the message. + conversation_id: The ID of the conversation to which the message belongs. + role: The role of the sender (e.g., 'user', 'assistant'). + content: The content of the message. + created_at: The timestamp when the message was created. + + Table: + messages: Stores message records. + """ + + __tablename__ = "ragbits_messages" + id = Column(Integer, primary_key=True, autoincrement=True) + conversation_id = Column(String, ForeignKey("ragbits_conversations.id", ondelete="CASCADE"), nullable=False) + role = Column(String, nullable=False) + content = Column(String, nullable=False) + created_at = Column(TIMESTAMP, server_default=func.now()) + + +class SQLHistoryStoreOptions(Options): + """ + Stores table names for the database models in SQLHistoryStore. + """ + + conversations_table: str = "conversations" + messages_table: str = "messages" + + +SQLHistoryStoreOptionsT = TypeVar("SQLHistoryStoreOptionsT", bound=SQLHistoryStoreOptions) + + +class SQLHistoryStore(HistoryStore[SQLHistoryStoreOptions]): + """ + A class to manage storing, retrieving, and updating conversation histories. + + This class provides methods to create a new conversation, fetch an existing conversation, + and update a conversation with new messages. The conversations are stored in a SQLAlchemy-backed + database, and a unique conversation ID is generated based on the message contents. + """ + + options_cls = SQLHistoryStoreOptions + + def __init__(self, sqlalchemy_engine: AsyncEngine, default_options: SQLHistoryStoreOptionsT | None = None) -> None: + """ + Initializes the ConversationHistoryStore with a SQLAlchemy engine. + + Args: + sqlalchemy_engine: The SQLAlchemy engine used to interact with the database. + default_options: An optional SQLHistoryStoreOptions specifying table names. + """ + super().__init__(default_options=default_options) + self.sqlalchemy_engine = sqlalchemy_engine + + Conversation.set_table_name(self.default_options.conversations_table) + Message.set_table_name(self.default_options.messages_table) + + async def init_db(self) -> None: + """ + Initializes the database tables by creating them in the database. + Conditional by default, will not attempt to recreate tables already + present in the target database. + """ + async with self.sqlalchemy_engine.begin() as conn: + await conn.run_sync(_Base.metadata.create_all) + + async def create_conversation(self, messages: ChatFormat) -> str: + """ + Creates a new conversation in the database with an auto-generated ID. + + Args: + messages: A list of dictionaries, where each dictionary represents a message + with 'role' and 'content' keys. + + Returns: + The database-generated ID of the conversation. + + Raises: + ValueError: If the conversation could not be generated. + """ + async with AsyncSession(self.sqlalchemy_engine) as session: + async with session.begin(): + result = await session.execute(sqlalchemy.insert(Conversation).returning(Conversation.id)) + conversation_id = result.scalar() + + if not conversation_id: + raise ValueError("Failed to generate conversation.") + + await session.execute( + sqlalchemy.insert(Message).values( + [ + {"conversation_id": conversation_id, "role": msg["role"], "content": msg["content"]} + for msg in messages + ] + ) + ) + await session.commit() + return conversation_id + + async def fetch_conversation(self, conversation_id: str) -> ChatFormat: + """ + Fetches a conversation by its ID. + + Args: + conversation_id: The ID of the conversation to fetch. + + Returns: + A list of message dictionaries, each containing 'role' and 'content'. + """ + async with AsyncSession(self.sqlalchemy_engine) as session: + result = await session.execute( + sqlalchemy.select(Message).filter_by(conversation_id=conversation_id).order_by(Message.created_at) + ) + rows = result.scalars().all() + return [{"role": row.role, "content": row.content} for row in rows] if rows else [] + + async def update_conversation(self, conversation_id: str, new_messages: ChatFormat) -> str: + """ + Updates a conversation with new messages. + + Args: + conversation_id: The ID of the conversation to update. + new_messages: A list of new message objects in the chat format. + + Returns: + The ID of the updated conversation. + """ + async with AsyncSession(self.sqlalchemy_engine) as session: + async with session.begin(): + await session.execute( + sqlalchemy.insert(Message).values( + [ + {"conversation_id": conversation_id, "role": msg["role"], "content": msg["content"]} + for msg in new_messages + ] + ) + ) + await session.commit() + return conversation_id + + @classmethod + def from_config(cls, config: dict) -> Self: + """ + Initializes the class with the provided configuration. + + Args: + config: A dictionary containing configuration details for the class. + + Returns: + An instance of the class initialized with the provided configuration. + """ + engine_options = ObjectContructionConfig.model_validate(config["sqlalchemy_engine"]) + config["sqlalchemy_engine"] = create_async_engine(engine_options.config["url"]) + return cls(**config) diff --git a/packages/ragbits-conversations/tests/unit/history/test_sql_store.py b/packages/ragbits-conversations/tests/unit/history/test_sql_store.py new file mode 100644 index 000000000..482830b83 --- /dev/null +++ b/packages/ragbits-conversations/tests/unit/history/test_sql_store.py @@ -0,0 +1,77 @@ +import pytest +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine + +from ragbits.conversations.history.stores.sql import ( + ChatFormat, + Conversation, + Message, + SQLHistoryStore, +) + +MESSAGES: ChatFormat = [ + {"role": "user", "content": "Hi"}, + {"role": "model", "content": "Hello"}, +] + + +@pytest.fixture +async def async_engine(): + engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=True) + async with engine.begin() as conn: + await conn.run_sync(Conversation.metadata.create_all) + await conn.run_sync(Message.metadata.create_all) + yield engine + await engine.dispose() + + +@pytest.fixture +async def async_session(async_engine: AsyncEngine) -> async_sessionmaker[AsyncSession]: + async_session = async_sessionmaker(async_engine, expire_on_commit=False) + return async_session + + +@pytest.fixture +async def history_store(async_engine: AsyncEngine) -> SQLHistoryStore: + store = SQLHistoryStore(async_engine) + await store.init_db() + return store + + +@pytest.mark.asyncio +async def test_create_conversation(history_store: SQLHistoryStore): + conversation_id = await history_store.create_conversation(MESSAGES) + assert conversation_id is not None + assert isinstance(conversation_id, str) + + +@pytest.mark.asyncio +async def test_fetch_conversation(history_store: SQLHistoryStore): + MESSAGES: ChatFormat = [ + {"role": "user", "content": "Hi"}, + {"role": "model", "content": "Hello"}, + ] + conversation_id = await history_store.create_conversation(MESSAGES) + fetched_messages = await history_store.fetch_conversation(conversation_id) + assert fetched_messages == MESSAGES + + +@pytest.mark.asyncio +async def test_update_conversation(history_store: SQLHistoryStore): + conversation_id = await history_store.create_conversation(MESSAGES) + new_messages: ChatFormat = [ + {"role": "user", "content": "How are you?"}, + ] + updated_conversation_id = await history_store.update_conversation(conversation_id, new_messages) + assert updated_conversation_id == conversation_id + fetched_MESSAGES = await history_store.fetch_conversation(conversation_id) + assert len(fetched_MESSAGES) == 3 + assert fetched_MESSAGES[2]["role"] == "user" + assert fetched_MESSAGES[2]["content"] == "How are you?" + + +@pytest.mark.asyncio +async def test_from_config(): + config = {"sqlalchemy_engine": {"type": "AsyncEngine", "config": {"url": "sqlite+aiosqlite:///:memory:"}}} + store = SQLHistoryStore.from_config(config) + assert store is not None + assert isinstance(store, SQLHistoryStore) diff --git a/pyproject.toml b/pyproject.toml index e17ebb614..25aea8b5f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ [tool.uv] dev-dependencies = [ + "aiosqlite>=0.21.0", "pre-commit~=3.8.0", "pytest~=8.3.3", "pytest-cov~=5.0.0", diff --git a/uv.lock b/uv.lock index 805d5dc10..285a4eb1f 100644 --- a/uv.lock +++ b/uv.lock @@ -126,6 +126,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/76/ac/a7305707cb852b7e16ff80eaf5692309bde30e2b1100a1fcacdc8f731d97/aiosignal-1.3.1-py3-none-any.whl", hash = "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17", size = 7617 }, ] +[[package]] +name = "aiosqlite" +version = "0.21.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/13/7d/8bca2bf9a247c2c5dfeec1d7a5f40db6518f88d314b8bca9da29670d2671/aiosqlite-0.21.0.tar.gz", hash = "sha256:131bb8056daa3bc875608c631c678cda73922a2d4ba8aec373b19f18c17e7aa3", size = 13454 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f5/10/6c25ed6de94c49f88a91fa5018cb4c0f3625f31d5be9f771ebe5cc7cd506/aiosqlite-0.21.0-py3-none-any.whl", hash = "sha256:2549cf4057f95f53dcba16f2b64e8e2791d7e1adedb13197dd8ed77bb226d7d0", size = 15792 }, +] + [[package]] name = "alembic" version = "1.14.0" @@ -621,7 +633,7 @@ name = "click" version = "8.1.7" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 } wheels = [ @@ -2234,7 +2246,7 @@ version = "1.6.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, { name = "ghp-import" }, { name = "jinja2" }, { name = "markdown" }, @@ -2585,6 +2597,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/bc/f7/7ec7fddc92e50714ea3745631f79bd9c96424cb2702632521028e57d3a36/multiprocess-0.70.16-py310-none-any.whl", hash = "sha256:c4a9944c67bd49f823687463660a2d6daae94c289adff97e0f9d696ba6371d02", size = 134824 }, { url = "https://files.pythonhosted.org/packages/50/15/b56e50e8debaf439f44befec5b2af11db85f6e0f344c3113ae0be0593a91/multiprocess-0.70.16-py311-none-any.whl", hash = "sha256:af4cabb0dac72abfb1e794fa7855c325fd2b55a10a44628a3c1ad3311c04127a", size = 143519 }, { url = "https://files.pythonhosted.org/packages/0a/7d/a988f258104dcd2ccf1ed40fdc97e26c4ac351eeaf81d76e266c52d84e2f/multiprocess-0.70.16-py312-none-any.whl", hash = "sha256:fc0544c531920dde3b00c29863377f87e1632601092ea2daca74e4beb40faa2e", size = 146741 }, + { url = "https://files.pythonhosted.org/packages/ea/89/38df130f2c799090c978b366cfdf5b96d08de5b29a4a293df7f7429fa50b/multiprocess-0.70.16-py38-none-any.whl", hash = "sha256:a71d82033454891091a226dfc319d0cfa8019a4e888ef9ca910372a446de4435", size = 132628 }, + { url = "https://files.pythonhosted.org/packages/da/d9/f7f9379981e39b8c2511c9e0326d212accacb82f12fbfdc1aa2ce2a7b2b6/multiprocess-0.70.16-py39-none-any.whl", hash = "sha256:a0bafd3ae1b732eac64be2e72038231c1ba97724b60b09400d68f229fcc2fbf3", size = 133351 }, ] [[package]] @@ -2858,7 +2872,6 @@ name = "nvidia-nvjitlink-cu12" version = "12.6.77" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/11/8c/386018fdffdce2ff8d43fedf192ef7d14cab7501cbf78a106dd2e9f1fc1f/nvidia_nvjitlink_cu12-12.6.77-py3-none-manylinux2014_aarch64.whl", hash = "sha256:3bf10d85bb1801e9c894c6e197e44dd137d2a0a9e43f8450e9ad13f2df0dd52d", size = 19270432 }, { url = "https://files.pythonhosted.org/packages/fe/e4/486de766851d58699bcfeb3ba6a3beb4d89c3809f75b9d423b9508a8760f/nvidia_nvjitlink_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9ae346d16203ae4ea513be416495167a0101d33d2d14935aa9c1829a3fb45142", size = 19745114 }, ] @@ -3353,7 +3366,7 @@ name = "portalocker" version = "2.10.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "pywin32", marker = "platform_system == 'Windows'" }, + { name = "pywin32", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/ed/d3/c6c64067759e87af98cc668c1cc75171347d0f1577fab7ca3749134e3cd4/portalocker-2.10.1.tar.gz", hash = "sha256:ef1bf844e878ab08aee7e40184156e1151f228f103aa5c6bd0724cc330960f8f", size = 40891 } wheels = [ @@ -3424,8 +3437,6 @@ version = "6.0.0" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/18/c7/8c6872f7372eb6a6b2e4708b88419fb46b857f7a2e1892966b851cc79fc9/psutil-6.0.0.tar.gz", hash = "sha256:8faae4f310b6d969fa26ca0545338b21f73c6b15db7c4a8d934a5482faa818f2", size = 508067 } wheels = [ - { url = "https://files.pythonhosted.org/packages/c5/66/78c9c3020f573c58101dc43a44f6855d01bbbd747e24da2f0c4491200ea3/psutil-6.0.0-cp27-none-win32.whl", hash = "sha256:02b69001f44cc73c1c5279d02b30a817e339ceb258ad75997325e0e6169d8b35", size = 249766 }, - { url = "https://files.pythonhosted.org/packages/e1/3f/2403aa9558bea4d3854b0e5e567bc3dd8e9fbc1fc4453c0aa9aafeb75467/psutil-6.0.0-cp27-none-win_amd64.whl", hash = "sha256:21f1fb635deccd510f69f485b87433460a603919b45e2a324ad65b0cc74f8fb1", size = 253024 }, { url = "https://files.pythonhosted.org/packages/0b/37/f8da2fbd29690b3557cca414c1949f92162981920699cd62095a984983bf/psutil-6.0.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:c588a7e9b1173b6e866756dde596fd4cad94f9399daf99ad8c3258b3cb2b47a0", size = 250961 }, { url = "https://files.pythonhosted.org/packages/35/56/72f86175e81c656a01c4401cd3b1c923f891b31fbcebe98985894176d7c9/psutil-6.0.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ed2440ada7ef7d0d608f20ad89a04ec47d2d3ab7190896cd62ca5fc4fe08bf0", size = 287478 }, { url = "https://files.pythonhosted.org/packages/19/74/f59e7e0d392bc1070e9a70e2f9190d652487ac115bb16e2eff6b22ad1d24/psutil-6.0.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd9a97c8e94059b0ef54a7d4baf13b405011176c3b6ff257c247cae0d560ecd", size = 290455 }, @@ -4172,6 +4183,7 @@ dependencies = [ [package.dev-dependencies] dev = [ + { name = "aiosqlite" }, { name = "griffe" }, { name = "griffe-typingdoc" }, { name = "mkdocs" }, @@ -4203,6 +4215,7 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ + { name = "aiosqlite", specifier = ">=0.21.0" }, { name = "griffe", specifier = ">=1.3.2" }, { name = "griffe-typingdoc", specifier = ">=0.2.7" }, { name = "mkdocs", specifier = ">=1.6.1" }, @@ -5163,19 +5176,19 @@ dependencies = [ { name = "fsspec" }, { name = "jinja2" }, { name = "networkx" }, - { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "sympy" }, - { name = "triton", marker = "python_full_version < '3.12' and platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "triton", marker = "python_full_version < '3.12' and platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "typing-extensions" }, ] wheels = [ @@ -5201,7 +5214,7 @@ name = "tqdm" version = "4.66.5" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/58/83/6ba9844a41128c62e810fddddd72473201f3eacde02046066142a2d96cc5/tqdm-4.66.5.tar.gz", hash = "sha256:e1020aef2e5096702d8a025ac7d16b1577279c9d63f8375b63083e9a5f0fcbad", size = 169504 } wheels = [