Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Falkordb vectorstore Integration #16047

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
from typing import Any, Dict, List, Optional, Tuple
import logging

from falkordb import FalkorDB

from llama_index.core.bridge.pydantic import PrivateAttr
from llama_index.core.schema import BaseNode, MetadataMode
from llama_index.core.vector_stores.types import (
BasePydanticVectorStore,
VectorStoreQuery,
VectorStoreQueryResult,
FilterOperator,
MetadataFilters,
MetadataFilter,
FilterCondition,
)
from llama_index.core.vector_stores.utils import (
metadata_dict_to_node,
node_to_metadata_dict,
)

_logger = logging.getLogger(__name__)

def clean_params(params: List[BaseNode]) -> List[Dict[str, Any]]:
clean_params = []
for record in params:
text = record.get_content(metadata_mode=MetadataMode.NONE)
embedding = record.get_embedding()
id = record.node_id
metadata = node_to_metadata_dict(record, remove_text=True, flat_metadata=False)
for k in ["document_id", "doc_id"]:
if k in metadata:
del metadata[k]
clean_params.append(
{"text": text, "embedding": embedding, "id": id, "metadata": metadata}
)
return clean_params

def _to_falkordb_operator(operator: FilterOperator) -> str:
operator_map = {
FilterOperator.EQ: "=",
FilterOperator.GT: ">",
FilterOperator.LT: "<",
FilterOperator.NE: "<>",
FilterOperator.GTE: ">=",
FilterOperator.LTE: "<=",
FilterOperator.IN: "IN",
FilterOperator.NIN: "NOT IN",
FilterOperator.CONTAINS: "CONTAINS"
}
return operator_map.get(operator, "=")

def construct_metadata_filter(filters: MetadataFilters):
cypher_snippets = []
params = {}
for index, filter in enumerate(filters.filters):
cypher_snippets.append(
f"n.`{filter.key}` {_to_falkordb_operator(filter.operator)} $param_{index}"
)
params[f"param_{index}"] = filter.value

condition = " OR " if filters.condition == FilterCondition.OR else " AND "
return condition.join(cypher_snippets), params

class FalkorDBVectorStore(BasePydanticVectorStore):
stores_text: bool = True
flat_metadata: bool = True

distance_strategy: str
index_name: str
node_label: str
embedding_node_property: str
text_node_property: str
embedding_dimension: int

_driver: FalkorDB = PrivateAttr()
_database: str = PrivateAttr()

def __init__(
self,
url: str,
database: str = "falkor",
index_name: str = "vector",
node_label: str = "Chunk",
embedding_node_property: str = "embedding",
text_node_property: str = "text",
distance_strategy: str = "cosine",
embedding_dimension: int = 1536,
**kwargs: Any,
) -> None:
super().__init__(
distance_strategy=distance_strategy,
index_name=index_name,
node_label=node_label,
embedding_node_property=embedding_node_property,
text_node_property=text_node_property,
embedding_dimension=embedding_dimension,
)

if distance_strategy not in ["cosine", "euclidean"]:
raise ValueError("distance_strategy must be either 'euclidean' or 'cosine'")

self._driver = FalkorDB.from_url(url).select_graph(database)
self._database = database

# Inline check_if_not_null function
for prop, value in zip(
["index_name", "node_label", "embedding_node_property", "text_node_property"],
[index_name, node_label, embedding_node_property, text_node_property]
):
if not value:
raise ValueError(f"Parameter `{prop}` must not be None or empty string")

if not self.retrieve_existing_index():
self.create_new_index()

@property
def client(self) -> FalkorDB:
return self._driver

def create_new_index(self) -> None:
index_query = (
f"CREATE VECTOR INDEX {self.index_name} "
f"FOR (n:`{self.node_label}`) "
f"ON (n.`{self.embedding_node_property}`) "
f"OPTIONS {{dimension: {self.embedding_dimension}, metric: '{self.distance_strategy}'}}"
)
self._driver.query(index_query)

def retrieve_existing_index(self) -> bool:
index_information = self._driver.query(
"CALL db.indexes() "
"YIELD label, properties, types, options, entitytype "
"WHERE types = ['VECTOR'] AND label = $index_name",
params={"index_name": self.index_name}
)
if index_information.result_set:
index = index_information.result_set[0]
self.node_label = index['entitytype']
self.embedding_node_property = index['properties'][0]
self.embedding_dimension = index['options']['dimension']
self.distance_strategy = index['options']['metric']
return True
return False

def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:
ids = [r.node_id for r in nodes]
import_query = (
"UNWIND $data AS row "
f"MERGE (c:`{self.node_label}` {{id: row.id}}) "
f"SET c.`{self.embedding_node_property}` = row.embedding, "
f"c.`{self.text_node_property}` = row.text, "
"c += row.metadata"
)

self._driver.query(import_query, params={"data": clean_params(nodes)})
return ids

def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
base_query = (
f"MATCH (n:`{self.node_label}`) "
f"WHERE n.`{self.embedding_node_property}` IS NOT NULL "
)

if query.filters:
filter_snippets, filter_params = construct_metadata_filter(query.filters)
base_query += f"AND {filter_snippets} "
else:
filter_params = {}

similarity_query = (
f"WITH n, vector.similarity.{self.distance_strategy}("
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

f"n.`{self.embedding_node_property}`, $embedding) AS score "
"ORDER BY score DESC LIMIT toInteger($k) "
)

return_query = (
f"RETURN n.`{self.text_node_property}` AS text, score, "
"n.id AS id, "
f"n {{.*, `{self.text_node_property}`: NULL, "
f"`{self.embedding_node_property}`: NULL, id: NULL}} AS metadata"
)

full_query = base_query + similarity_query + return_query

parameters = {
"k": query.similarity_top_k,
"embedding": query.query_embedding,
**filter_params,
}

results = self._driver.query(full_query, params=parameters)

nodes = []
similarities = []
ids = []
for record in results.result_set:
node = metadata_dict_to_node(record["metadata"])
node.set_content(str(record["text"]))
nodes.append(node)
similarities.append(record["score"])
ids.append(record["id"])

return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)

def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
self._driver.query(
f"MATCH (n:`{self.node_label}`) WHERE n.ref_doc_id = $id DETACH DELETE n",
params={"id": ref_doc_id},
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import pytest
from typing import List
from unittest.mock import MagicMock, patch

from llama_index.core.schema import NodeRelationship, RelatedNodeInfo, TextNode
from llama_index.vector_stores.falkordb import FalkorDBVectorStore
from llama_index.vector_stores.types import (
VectorStoreQuery,
MetadataFilters,
ExactMatchFilter,
)

# Mock FalkorDB client
class MockFalkorDBClient:
def __init__(self):
self.nodes = {}
self.query_results = []

def query(self, query, params=None):
if "CREATE VECTOR INDEX" in query:
return MagicMock()
elif "SHOW INDEXES" in query:
return MagicMock(result_set=[])
elif "MATCH (n:Chunk)" in query:
return MagicMock(result_set=self.query_results)
elif "MERGE (c:Chunk" in query:
for node in params["data"]:
self.nodes[node["id"]] = node
return MagicMock()
elif "MATCH (n:Chunk) WHERE n.id" in query:
node_id = params["id"]
if node_id in self.nodes:
del self.nodes[node_id]
return MagicMock()
return MagicMock()

def set_query_results(self, results):
self.query_results = results

@pytest.fixture
def mock_falkordb():
with patch("falkordb.FalkorDB") as mock:
client = MockFalkorDBClient()
mock.from_url.return_value.select_graph.return_value = client
yield client

@pytest.fixture
def falkordb_store(mock_falkordb):
return FalkorDBVectorStore(
url="bolt://localhost:7687",
database="testdb",
index_name="test_index",
node_label="Chunk",
embedding_node_property="embedding",
text_node_property="text",
)

def test_falkordb_add(falkordb_store):
nodes = [
TextNode(
text="Hello world",
id_="1",
embedding=[1.0, 0.0, 0.0],
metadata={"key": "value"},
),
TextNode(
text="Hello world 2",
id_="2",
embedding=[0.0, 1.0, 0.0],
metadata={"key2": "value2"},
),
]
ids = falkordb_store.add(nodes)
assert ids == ["1", "2"]
assert len(falkordb_store.client.nodes) == 2

def test_falkordb_delete(falkordb_store):
node = TextNode(
text="Hello world",
id_="test_node",
embedding=[1.0, 0.0, 0.0],
)
falkordb_store.add([node])
assert "test_node" in falkordb_store.client.nodes

falkordb_store.delete("test_node")
assert "test_node" not in falkordb_store.client.nodes

def test_falkordb_query(falkordb_store, mock_falkordb):
mock_falkordb.set_query_results([
{"text": "Hello world", "score": 0.9, "id": "1", "metadata": {"key": "value"}},
{"text": "Hello world 2", "score": 0.7, "id": "2", "metadata": {"key2": "value2"}},
])

query = VectorStoreQuery(
query_embedding=[1.0, 0.0, 0.0],
similarity_top_k=2,
)
results = falkordb_store.query(query)

assert len(results.nodes) == 2
assert results.nodes[0].text == "Hello world"
assert results.nodes[1].text == "Hello world 2"
assert results.similarities == [0.9, 0.7]

def test_falkordb_query_with_filters(falkordb_store, mock_falkordb):
mock_falkordb.set_query_results([
{"text": "Hello world", "score": 0.9, "id": "1", "metadata": {"key": "value"}},
])

query = VectorStoreQuery(
query_embedding=[1.0, 0.0, 0.0],
similarity_top_k=2,
filters=MetadataFilters(filters=[ExactMatchFilter(key="key", value="value")]),
)
results = falkordb_store.query(query)

assert len(results.nodes) == 1
assert results.nodes[0].text == "Hello world"
assert results.similarities == [0.9]

def test_falkordb_update(falkordb_store):
node = TextNode(
text="Original text",
id_="update_node",
embedding=[1.0, 0.0, 0.0],
)
falkordb_store.add([node])

updated_node = TextNode(
text="Updated text",
id_="update_node",
embedding=[0.0, 1.0, 0.0],
)
falkordb_store.update(updated_node)

assert falkordb_store.client.nodes["update_node"]["text"] == "Updated text"
assert falkordb_store.client.nodes["update_node"]["embedding"] == [0.0, 1.0, 0.0]

def test_falkordb_get(falkordb_store):
node = TextNode(
text="Get test",
id_="get_node",
embedding=[1.0, 1.0, 1.0],
)
falkordb_store.add([node])

retrieved_node = falkordb_store.get("get_node")
assert retrieved_node is not None
assert retrieved_node.text == "Get test"
assert retrieved_node.embedding == [1.0, 1.0, 1.0]

def test_falkordb_nonexistent_get(falkordb_store):
retrieved_node = falkordb_store.get("nonexistent_node")
assert retrieved_node is None

if __name__ == "__main__":
pytest.main()