Skip to content

Commit

Permalink
[NOID] Fixes #4091: Update RAG docs with vector db examples (#4116) (#…
Browse files Browse the repository at this point in the history
…4272)

* [NOID] Fixes #4091: Update RAG docs with vector db examples (#4116)

* [NOID] test fixes
  • Loading branch information
vga91 authored Dec 6, 2024
1 parent dc72f76 commit be96c18
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,22 @@ CALL apoc.vectordb.chroma.query($host, '<collection_id>',
----


[NOTE]
====
To optimize performances, we can choose what to `YIELD` with the apoc.vectordb.chroma.query and the `apoc.vectordb.chroma.get` procedures.
For example, by executing a `CALL apoc.vectordb.chroma.query(...) YIELD metadata, score, id`, the RestAPI request will have an {"include": ["metadatas", "documents", "distances"]},
so that we do not return the other values that we do not need.
====

It is possible to execute vector db procedures together with the xref::ml/rag.adoc[apoc.ml.rag] as follow:

[source,cypher]
----
CALL apoc.vectordb.chroma.getAndUpdate($host, $collection, [<id1>, <id2>], $conf) YIELD node, metadata, id, vector
WITH collect(node) as paths
CALL apoc.ml.rag(paths, $attributes, $question, $confPrompt) YIELD value
RETURN value
----

which returns a string that answers the `$question` by leveraging the embeddings of the db vector.

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

== Qdrant

Here is a list of all available Qdrant procedures,
Expand Down Expand Up @@ -218,7 +217,15 @@ For example, by executing a `CALL apoc.vectordb.qdrant.query(...) YIELD metadata
so that we do not return the other values that we do not need.
====

It is possible to execute vector db procedures together with the xref::ml/rag.adoc[apoc.ml.rag] as follow:

[source,cypher]
----
CALL apoc.vectordb.qdrant.getAndUpdate($host, $collection, [<id1>, <id2>], $conf) YIELD node, metadata, id, vector
WITH collect(node) as paths
CALL apoc.ml.rag(paths, $attributes, $question, $confPrompt) YIELD value
RETURN value
----

which returns a string that answers the `$question` by leveraging the embeddings of the db vector.

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

== Weaviate

Here is a list of all available Weaviate procedures,
Expand Down Expand Up @@ -235,7 +234,15 @@ For example, by executing a `CALL apoc.vectordb.weaviate.query(...) YIELD metada
so that we do not return the other values that we do not need.
====

It is possible to execute vector db procedures together with the xref::ml/rag.adoc[apoc.ml.rag] as follow:

[source,cypher]
----
CALL apoc.vectordb.weaviate.getAndUpdate($host, $collection, [<id1>, <id2>], $conf) YIELD score, node, metadata, id, vector
WITH collect(node) as paths
CALL apoc.ml.rag(paths, $attributes, $question, $confPrompt) YIELD value
RETURN value
----

which returns a string that answers the `$question` by leveraging the embeddings of the db vector.

Expand Down
30 changes: 28 additions & 2 deletions full-it/src/test/java/apoc/full/it/vectordb/ChromaDbTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package apoc.full.it.vectordb;

import static apoc.ml.Prompt.API_KEY_CONF;
import static apoc.util.MapUtil.map;
import static apoc.util.TestUtil.testCall;
import static apoc.util.TestUtil.testResult;
Expand All @@ -10,6 +11,7 @@
import static apoc.vectordb.VectorDbTestUtil.assertNodesCreated;
import static apoc.vectordb.VectorDbTestUtil.assertRelsCreated;
import static apoc.vectordb.VectorDbTestUtil.dropAndDeleteAll;
import static apoc.vectordb.VectorDbTestUtil.ragSetup;
import static apoc.vectordb.VectorDbUtil.ERROR_READONLY_MAPPING;
import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY;
import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY;
Expand All @@ -21,9 +23,11 @@
import static org.neo4j.configuration.GraphDatabaseSettings.DEFAULT_DATABASE_NAME;
import static org.neo4j.configuration.GraphDatabaseSettings.SYSTEM_DATABASE_NAME;

import apoc.ml.Prompt;
import apoc.util.TestUtil;
import apoc.vectordb.ChromaDb;
import apoc.vectordb.VectorDb;
import apoc.vectordb.VectorDbTestUtil;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
Expand Down Expand Up @@ -60,9 +64,9 @@ public static void setUp() throws Exception {
sysDb = databaseManagementService.database(SYSTEM_DATABASE_NAME);

CHROMA_CONTAINER.start();
HOST = CHROMA_CONTAINER.getEndpoint();

TestUtil.registerProcedure(db, ChromaDb.class, VectorDb.class);
HOST = CHROMA_CONTAINER.getEndpoint();
TestUtil.registerProcedure(db, ChromaDb.class, VectorDb.class, Prompt.class);

testCall(
db,
Expand Down Expand Up @@ -452,4 +456,26 @@ public void queryVectorsWithSystemDbStorage() {

assertNodesCreated(db);
}

@Test
public void queryVectorsWithRag() {
String openAIKey = ragSetup(db);

Map<String, Object> conf = map(
ALL_RESULTS_KEY, true, MAPPING_KEY, map(NODE_LABEL, "Rag", ENTITY_KEY, "readID", METADATA_KEY, "foo"));

testResult(
db,
"CALL apoc.vectordb.chroma.getAndUpdate($host, $collection, ['1', '2'], $conf) YIELD node, metadata, id, vector\n"
+ "WITH collect(node) as paths\n"
+ "CALL apoc.ml.rag(paths, $attributes, \"Which city has foo equals to one?\", $confPrompt) YIELD value\n"
+ "RETURN value",
map(
"host", HOST,
"conf", conf,
"collection", COLL_ID.get(),
"confPrompt", map(API_KEY_CONF, openAIKey),
"attributes", List.of("city", "foo")),
VectorDbTestUtil::assertRagWithVectors);
}
}
47 changes: 45 additions & 2 deletions full-it/src/test/java/apoc/full/it/vectordb/QdrantTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package apoc.full.it.vectordb;

import static apoc.ml.Prompt.API_KEY_CONF;
import static apoc.ml.RestAPIConfig.HEADERS_KEY;
import static apoc.util.MapUtil.map;
import static apoc.util.TestUtil.testCall;
Expand All @@ -22,17 +23,21 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.neo4j.configuration.GraphDatabaseSettings.*;

import apoc.ml.Prompt;
import apoc.util.TestUtil;
import apoc.util.Util;
import apoc.vectordb.Qdrant;
import apoc.vectordb.VectorDb;
import apoc.vectordb.VectorDbTestUtil;
import java.util.List;
import java.util.Map;
import org.assertj.core.api.Assertions;
import org.junit.AfterClass;
import org.junit.Assume;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.ClassRule;
Expand Down Expand Up @@ -72,9 +77,9 @@ public static void setUp() throws Exception {
sysDb = databaseManagementService.database(SYSTEM_DATABASE_NAME);

QDRANT_CONTAINER.start();
HOST = QDRANT_CONTAINER.getHost() + ":" + QDRANT_CONTAINER.getMappedPort(6333);

TestUtil.registerProcedure(db, Qdrant.class, VectorDb.class);
HOST = QDRANT_CONTAINER.getHost() + ":" + QDRANT_CONTAINER.getMappedPort(6333);
TestUtil.registerProcedure(db, Qdrant.class, VectorDb.class, Prompt.class);

testCall(
db,
Expand Down Expand Up @@ -203,6 +208,44 @@ public void queryVectors() {
});
}

@Test
public void queryVectorsWithRag() {
String openAIKey = System.getenv("OPENAI_KEY");
;
Assume.assumeNotNull("No OPENAI_KEY environment configured", openAIKey);

db.executeTransactionally("CREATE (:Rag {readID: 'one'}), (:Rag {readID: 'two'})");

Map<String, Object> conf = map(
ALL_RESULTS_KEY,
true,
HEADERS_KEY,
READONLY_AUTHORIZATION,
MAPPING_KEY,
map(NODE_LABEL, "Rag", ENTITY_KEY, "readID", METADATA_KEY, "foo"));

testResult(
db,
"CALL apoc.vectordb.qdrant.getAndUpdate($host, 'test_collection', [1, 2], $conf) YIELD node, metadata, id, vector\n"
+ "WITH collect(node) as paths\n"
+ "CALL apoc.ml.rag(paths, $attributes, \"Which city has foo equals to one?\", $confPrompt) YIELD value\n"
+ "RETURN value",
map(
"host",
HOST,
"conf",
conf,
"confPrompt",
map(API_KEY_CONF, openAIKey),
"attributes",
List.of("city", "foo")),
r -> {
Map<String, Object> row = r.next();
Object value = row.get("value");
assertTrue("The actual value is: " + value, value.toString().contains("Berlin"));
});
}

@Test
public void queryVectorsWithoutVectorResult() {
testResult(
Expand Down
43 changes: 42 additions & 1 deletion full-it/src/test/java/apoc/full/it/vectordb/WeaviateTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package apoc.full.it.vectordb;

import static apoc.ml.Prompt.API_KEY_CONF;
import static apoc.ml.RestAPIConfig.HEADERS_KEY;
import static apoc.util.TestUtil.testCall;
import static apoc.util.TestUtil.testCallEmpty;
Expand All @@ -8,6 +9,16 @@
import static apoc.vectordb.VectorDbHandler.Type.WEAVIATE;
import static apoc.vectordb.VectorDbTestUtil.*;
import static apoc.vectordb.VectorDbTestUtil.EntityType.*;
import static apoc.vectordb.VectorDbTestUtil.EntityType.FALSE;
import static apoc.vectordb.VectorDbTestUtil.EntityType.NODE;
import static apoc.vectordb.VectorDbTestUtil.EntityType.REL;
import static apoc.vectordb.VectorDbTestUtil.assertBerlinResult;
import static apoc.vectordb.VectorDbTestUtil.assertLondonResult;
import static apoc.vectordb.VectorDbTestUtil.assertNodesCreated;
import static apoc.vectordb.VectorDbTestUtil.assertRelsCreated;
import static apoc.vectordb.VectorDbTestUtil.dropAndDeleteAll;
import static apoc.vectordb.VectorDbTestUtil.getAuthHeader;
import static apoc.vectordb.VectorDbTestUtil.ragSetup;
import static apoc.vectordb.VectorDbUtil.ERROR_READONLY_MAPPING;
import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY;
import static apoc.vectordb.VectorEmbeddingConfig.FIELDS_KEY;
Expand All @@ -23,6 +34,7 @@
import static org.neo4j.configuration.GraphDatabaseSettings.DEFAULT_DATABASE_NAME;
import static org.neo4j.configuration.GraphDatabaseSettings.SYSTEM_DATABASE_NAME;

import apoc.ml.Prompt;
import apoc.util.MapUtil;
import apoc.util.TestUtil;
import apoc.vectordb.VectorDb;
Expand Down Expand Up @@ -84,7 +96,7 @@ public static void setUp() throws Exception {
WEAVIATE_CONTAINER.start();
HOST = WEAVIATE_CONTAINER.getHttpHostAddress();

TestUtil.registerProcedure(db, Weaviate.class, VectorDb.class);
TestUtil.registerProcedure(db, Weaviate.class, VectorDb.class, Prompt.class);

testCall(
db,
Expand Down Expand Up @@ -593,4 +605,33 @@ private static void assertQueryVectorsWithSystemDbStorage(String keyConfig, Stri
});
assertNodesCreated(db);
}

@Test
public void queryVectorsWithRag() {
String openAIKey = ragSetup(db);

Map<String, Object> conf = MapUtil.map(
FIELDS_KEY,
FIELDS,
ALL_RESULTS_KEY,
true,
HEADERS_KEY,
READONLY_AUTHORIZATION,
MAPPING_KEY,
MapUtil.map(EMBEDDING_KEY, "vect", NODE_LABEL, "Rag", ENTITY_KEY, "readID", METADATA_KEY, "foo"));

testResult(
db,
"CALL apoc.vectordb.weaviate.getAndUpdate($host, 'TestCollection', [$id1], $conf) YIELD score, node, metadata, id, vector\n"
+ "WITH collect(node) as paths\n"
+ "CALL apoc.ml.rag(paths, $attributes, \"Which city has foo equals to one?\", $confPrompt) YIELD value\n"
+ "RETURN value",
MapUtil.map(
"host", HOST,
"id1", ID_1,
"conf", conf,
"confPrompt", MapUtil.map(API_KEY_CONF, openAIKey),
"attributes", List.of("city", "foo")),
VectorDbTestUtil::assertRagWithVectors);
}
}
33 changes: 33 additions & 0 deletions full/src/test/java/apoc/vectordb/VectorDbTestUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
import static apoc.util.Util.map;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;

import apoc.util.MapUtil;
import java.util.Map;
import org.junit.Assume;
import org.neo4j.graphdb.Entity;
import org.neo4j.graphdb.GraphDatabaseService;
import org.neo4j.graphdb.ResourceIterator;
Expand Down Expand Up @@ -89,4 +92,34 @@ private static void assertBerlinProperties(Map props) {
public static Map<String, String> getAuthHeader(String key) {
return map("Authorization", "Bearer " + key);
}

public static void assertReadOnlyProcWithMappingResults(Result r, String node) {
Map<String, Object> row = r.next();
Map<String, Object> props = ((Entity) row.get(node)).getAllProperties();
assertEquals(MapUtil.map("readID", "one"), props);
assertNotNull(row.get("vector"));
assertNotNull(row.get("id"));

row = r.next();
props = ((Entity) row.get(node)).getAllProperties();
assertEquals(MapUtil.map("readID", "two"), props);
assertNotNull(row.get("vector"));
assertNotNull(row.get("id"));

assertFalse(r.hasNext());
}

public static void assertRagWithVectors(Result r) {
Map<String, Object> row = r.next();
Object value = row.get("value");
assertTrue("The actual value is: " + value, value.toString().contains("Berlin"));
}

public static String ragSetup(GraphDatabaseService db) {
String openAIKey = System.getenv("OPENAI_KEY");
;
Assume.assumeNotNull("No OPENAI_KEY environment configured", openAIKey);
db.executeTransactionally("CREATE (:Rag {readID: 'one'}), (:Rag {readID: 'two'})");
return openAIKey;
}
}

0 comments on commit be96c18

Please sign in to comment.