Skip to content

Commit

Permalink
Azure CosmosDB NoSql Storage Integrations (run-llama#16138)
Browse files Browse the repository at this point in the history
  • Loading branch information
aayush3011 authored and jzhao62 committed Oct 4, 2024
1 parent bb462b9 commit 341b4c5
Show file tree
Hide file tree
Showing 40 changed files with 1,036 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
poetry_requirements(
name="poetry",
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
GIT_ROOT ?= $(shell git rev-parse --show-toplevel)

help: ## Show all Makefile targets.
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}'

format: ## Run code autoformatters (black).
pre-commit install
git ls-files | xargs pre-commit run black --files

lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy
pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files

test: ## Run tests via pytest.
pytest tests

watch-docs: ## Build and watch documentation.
sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# LlamaIndex Chat_Store Integration: Azure CosmosDB NoSQL Chat Store
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python_sources()
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from llama_index.storage.chat_store.azurecosmosnosql.base import (
AzureCosmosNoSqlChatStore,
)

__all__ = ["AzureCosmosNoSqlChatStore"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
import logging
from typing import Any, Dict, List, Optional

from azure.cosmos import CosmosClient, DatabaseProxy, ContainerProxy
from llama_index.core.llms import ChatMessage
from llama_index.core.storage.chat_store import BaseChatStore

DEFAULT_CHAT_DATABASE = "ChatMessagesDB"
DEFAULT_CHAT_CONTAINER = "ChatMessagesContainer"


logger = logging.getLogger(__name__)


# Convert a ChatMessage to a JSON object
def _message_to_dict(message: ChatMessage) -> dict:
return message.dict()


# Convert a list of ChatMessages to a list of JSON objects
def _messages_to_dict(messages: List[ChatMessage]) -> List[dict]:
return [_message_to_dict(message) for message in messages]


# Convert a JSON object to a ChatMessage
def _dict_to_message(d: dict) -> ChatMessage:
return ChatMessage.model_validate(d)


class AzureCosmosNoSqlChatStore(BaseChatStore):
"""Creates an Azure Cosmos DB NoSql Chat Store."""

_cosmos_client = CosmosClient
_database = DatabaseProxy
_container = ContainerProxy

def __init__(
self,
cosmos_client: CosmosClient,
chat_db_name: str = DEFAULT_CHAT_DATABASE,
chat_container_name: str = DEFAULT_CHAT_CONTAINER,
cosmos_container_properties: Dict[str, Any] = None,
cosmos_database_properties: Dict[str, Any] = None,
**kwargs,
):
super().__init__(
cosmos_client=cosmos_client,
chat_db_name=chat_db_name,
chat_container_name=chat_container_name,
cosmos_container_properties=cosmos_container_properties,
cosmos_database_properties=cosmos_database_properties,
)

self._cosmos_client = cosmos_client

# Create the database if it already doesn't exist
self._database = self._cosmos_client.create_database_if_not_exists(
id=chat_db_name,
offer_throughput=cosmos_database_properties.get("offer_throughput"),
session_token=cosmos_database_properties.get("session_token"),
initial_headers=cosmos_database_properties.get("initial_headers"),
etag=cosmos_database_properties.get("etag"),
match_condition=cosmos_database_properties.get("match_condition"),
)

# Create the collection if it already doesn't exist
self._container = self._database.create_container_if_not_exists(
id=chat_container_name,
partition_key=cosmos_container_properties["partition_key"],
indexing_policy=cosmos_container_properties.get("indexing_policy"),
default_ttl=cosmos_container_properties.get("default_ttl"),
offer_throughput=cosmos_container_properties.get("offer_throughput"),
unique_key_policy=cosmos_container_properties.get("unique_key_policy"),
conflict_resolution_policy=cosmos_container_properties.get(
"conflict_resolution_policy"
),
analytical_storage_ttl=cosmos_container_properties.get(
"analytical_storage_ttl"
),
computed_properties=cosmos_container_properties.get("computed_properties"),
etag=cosmos_container_properties.get("etag"),
match_condition=cosmos_container_properties.get("match_condition"),
session_token=cosmos_container_properties.get("session_token"),
initial_headers=cosmos_container_properties.get("initial_headers"),
)

@classmethod
def from_connection_string(
cls,
connection_string: str,
chat_db_name: str = DEFAULT_CHAT_DATABASE,
chat_container_name: str = DEFAULT_CHAT_CONTAINER,
cosmos_container_properties: Dict[str, Any] = None,
cosmos_database_properties: Dict[str, Any] = None,
):
"""Creates an instance of Azure Cosmos DB NoSql Chat Store using a connection string."""
cosmos_client = CosmosClient.from_connection_string(connection_string)

return cls(
cosmos_client,
chat_db_name,
chat_container_name,
cosmos_container_properties,
cosmos_database_properties,
)

@classmethod
def from_account_and_key(
cls,
endpoint: str,
key: str,
chat_db_name: str = DEFAULT_CHAT_DATABASE,
chat_container_name: str = DEFAULT_CHAT_CONTAINER,
cosmos_container_properties: Dict[str, Any] = None,
cosmos_database_properties: Dict[str, Any] = None,
) -> "AzureCosmosNoSqlChatStore":
"""Initializes AzureCosmosNoSqlChatStore from an endpoint url and key."""
cosmos_client = CosmosClient(endpoint, key)
return cls(
cosmos_client,
chat_db_name,
chat_container_name,
cosmos_container_properties,
cosmos_database_properties,
)

@classmethod
def from_aad_token(
cls,
endpoint: str,
chat_db_name: str = DEFAULT_CHAT_DATABASE,
chat_container_name: str = DEFAULT_CHAT_CONTAINER,
cosmos_container_properties: Dict[str, Any] = None,
cosmos_database_properties: Dict[str, Any] = None,
) -> "AzureCosmosNoSqlChatStore":
"""Creates an AzureChatStore using an Azure Active Directory token."""
from azure.identity import DefaultAzureCredential

credential = DefaultAzureCredential()
return cls._from_clients(
endpoint,
credential,
chat_db_name,
chat_container_name,
cosmos_container_properties,
cosmos_database_properties,
)

def set_messages(self, key: str, messages: List[ChatMessage]) -> None:
"""Set messages for a key."""
if not self._container:
raise ValueError("Container not initialized")
self._container.upsert_item(
body={
"id": key,
"messages": _messages_to_dict(messages),
}
)

def get_messages(self, key: str) -> List[ChatMessage]:
"""Get messages for a key."""
response = self._container.read_item(key)
if response is not None:
message_history = response["messages"]
else:
message_history = []
return [_dict_to_message(message) for message in message_history]

def add_message(self, key: str, message: ChatMessage) -> None:
"""Add a message for a key."""
current_messages = _messages_to_dict(self.get_messages(key))
current_messages.append(_message_to_dict(message))

self._container.create_item(
body={
"id": key,
"messages": current_messages,
}
)

def delete_messages(self, key: str) -> Optional[List[ChatMessage]]:
"""Delete messages for a key."""
messages_to_delete = self.get_messages(key)
self._container.delete_item(key)
return messages_to_delete

def delete_message(self, key: str, idx: int) -> Optional[ChatMessage]:
"""Delete specific message for a key."""
current_messages = self.get_messages(key)
try:
message_to_delete = current_messages[idx]
del current_messages[idx]
self.set_messages(key, current_messages)
return message_to_delete
except IndexError:
logger.error(
IndexError(f"No message exists at index, {idx}, for key {key}")
)
return None

def delete_last_message(self, key: str) -> Optional[ChatMessage]:
"""Delete last message for a key."""
return self.delete_message(key, -1)

def get_keys(self) -> List[str]:
"""Get all keys."""
items = self._container.read_all_items()
return [item["id"] for item in items]

@classmethod
def class_name(cls) -> str:
"""Get class name."""
return "AzureCosmosNoSqlChatStore"

@classmethod
def _from_clients(
cls,
endpoint: str,
credential: Any,
chat_db_name: str = DEFAULT_CHAT_DATABASE,
chat_container_name: str = DEFAULT_CHAT_CONTAINER,
cosmos_container_properties: Dict[str, Any] = None,
cosmos_database_properties: Dict[str, Any] = None,
) -> "AzureCosmosNoSqlChatStore":
"""Create cosmos db service clients."""
cosmos_client = CosmosClient(url=endpoint, credential=credential)
return cls(
cosmos_client,
chat_db_name,
chat_container_name,
cosmos_container_properties,
cosmos_database_properties,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
[build-system]
build-backend = "poetry.core.masonry.api"
requires = ["poetry-core"]

[tool.codespell]
check-filenames = true
check-hidden = true
skip = "*.csv,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb"

[tool.llamahub]
contains_example = false
import_path = "llama_index.storage.chat_store.azurecosmosnosql"

[tool.llamahub.class_authors]
AzureCosmosNoSqlChatStore = "Aayush"

[tool.mypy]
disallow_untyped_defs = true
exclude = ["_static", "build", "examples", "notebooks", "venv"]
ignore_missing_imports = true
python_version = "3.8"

[tool.poetry]
authors = ["Aayush Kataria <aayushkataria3011@gmail.com>"]
description = "llama-index storage-chat-store azure cosmosdb nosql integration"
exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-storage-chat-store-azurecosmosnosql"
readme = "README.md"
version = "1.0.0"

[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
azure-cosmos = "^4.7.0"
azure-identity = "^1.7.1"
llama-index-core = "^0.11.0"

[tool.poetry.group.dev.dependencies]
ipython = "8.10.0"
jupyter = "^1.0.0"
mypy = "0.991"
pre-commit = "3.2.0"
pylint = "2.15.10"
pytest = "7.2.1"
pytest-mock = "3.11.1"
ruff = "0.0.292"
tree-sitter-languages = "^1.8.0"
types-Deprecated = ">=0.1.0"
types-PyYAML = "^6.0.12.12"
types-protobuf = "^4.24.0.4"
types-redis = "4.5.5.0"
types-requests = "2.28.11.8"
types-setuptools = "67.1.0.0"

[tool.poetry.group.dev.dependencies.black]
extras = ["jupyter"]
version = "<=23.9.1,>=23.7.0"

[tool.poetry.group.dev.dependencies.codespell]
extras = ["toml"]
version = ">=v2.2.6"

[[tool.poetry.packages]]
include = "llama_index/"
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python_tests()
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from llama_index.core.storage.chat_store.base import BaseChatStore
from llama_index.storage.chat_store.azurecosmosnosql import AzureCosmosNoSqlChatStore


def test_class():
names_of_base_classes = [b.__name__ for b in AzureCosmosNoSqlChatStore.__mro__]
assert BaseChatStore.__name__ in names_of_base_classes
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
poetry_requirements(
name="poetry",
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
GIT_ROOT ?= $(shell git rev-parse --show-toplevel)

help: ## Show all Makefile targets.
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}'

format: ## Run code autoformatters (black).
pre-commit install
git ls-files | xargs pre-commit run black --files

lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy
pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files

test: ## Run tests via pytest.
pytest tests

watch-docs: ## Build and watch documentation.
sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# LlamaIndex Index_Store Integration: Azure CosmosDB NoSQL Index Store
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python_sources()
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from llama_index.storage.docstore.azurecosmosnosql.base import (
AzureCosmosNoSqlDocumentStore,
)

__all__ = ["AzureCosmosNoSqlDocumentStore"]
Loading

0 comments on commit 341b4c5

Please sign in to comment.