diff --git a/algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecTest.java b/algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecTest.java index 78df6d5465..69c1644956 100644 --- a/algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecTest.java +++ b/algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecTest.java @@ -23,7 +23,6 @@ import org.assertj.core.api.junit.jupiter.SoftAssertionsExtension; import org.assertj.core.data.Offset; import org.assertj.core.data.Percentage; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -32,14 +31,12 @@ import org.junit.jupiter.params.provider.CsvSource; import org.junit.jupiter.params.provider.EnumSource; import org.junit.jupiter.params.provider.MethodSource; -import org.neo4j.gds.BaseTest; import org.neo4j.gds.NodeLabel; import org.neo4j.gds.Orientation; -import org.neo4j.gds.PropertyMapping; import org.neo4j.gds.RelationshipType; -import org.neo4j.gds.StoreLoaderBuilder; import org.neo4j.gds.TestProgressTracker; import org.neo4j.gds.api.Graph; +import org.neo4j.gds.api.GraphStore; import org.neo4j.gds.collections.ha.HugeLongArray; import org.neo4j.gds.collections.ha.HugeObjectArray; import org.neo4j.gds.collections.hsa.HugeSparseLongArray; @@ -55,6 +52,9 @@ import org.neo4j.gds.core.utils.progress.EmptyTaskRegistryFactory; import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; import org.neo4j.gds.core.utils.shuffle.ShuffleUtil; +import org.neo4j.gds.extension.GdlExtension; +import org.neo4j.gds.extension.GdlGraph; +import org.neo4j.gds.extension.Inject; import org.neo4j.gds.gdl.GdlFactory; import org.neo4j.gds.ml.core.tensor.FloatVector; @@ -69,11 +69,13 @@ import static org.neo4j.gds.assertj.Extractors.removingThreadId; @ExtendWith(SoftAssertionsExtension.class) -class Node2VecTest extends BaseTest { +@GdlExtension +class Node2VecTest { private static final List NO_SOURCE_NODES = List.of(); private static final Optional NO_RANDOM_SEED = Optional.empty(); + @GdlGraph private static final String DB_CYPHER = "CREATE" + " (a:Node1)" + @@ -88,21 +90,19 @@ class Node2VecTest extends BaseTest { ", (b)-[:REL {prop: 1.0}]->(c)" + ", (c)-[:REL {prop: 1.0}]->(b)"; - @BeforeEach - void setUp() { - runQuery(DB_CYPHER); - } + @Inject + private Graph graph; + + @Inject + private GraphStore graphStore; @ParameterizedTest(name = "{0}") @MethodSource("graphs") - void embeddingsShouldHaveTheConfiguredDimension(String msg, Iterable nodeLabels) { - Graph graph = new StoreLoaderBuilder() - .databaseService(db) - .nodeLabels(nodeLabels) - .build() - .graph(); + void embeddingsShouldHaveTheConfiguredDimension(String msg, List nodeLabels) { + var currentGraph = graphStore.getGraph(nodeLabels); int embeddingDimension = 128; + var trainParameters = new TrainParameters( 0.025, 0.0001, @@ -112,8 +112,9 @@ void embeddingsShouldHaveTheConfiguredDimension(String msg, Iterable nod embeddingDimension, EmbeddingInitializer.NORMALIZED ); + HugeObjectArray node2Vec = new Node2Vec( - graph, + currentGraph, 4, NO_SOURCE_NODES, NO_RANDOM_SEED, @@ -123,7 +124,7 @@ void embeddingsShouldHaveTheConfiguredDimension(String msg, Iterable nod ProgressTracker.NULL_TRACKER ).compute().embeddings(); - graph.forEachNode(node -> { + currentGraph.forEachNode(node -> { assertEquals(embeddingDimension, node2Vec.get(node).data().length); return true; } @@ -136,19 +137,19 @@ void embeddingsShouldHaveTheConfiguredDimension(String msg, Iterable nod "false,3" }) void shouldLogProgress(boolean relationshipWeights, int expectedProgresses) { - var storeLoaderBuilder = new StoreLoaderBuilder() - .databaseService(db); + Graph currentGraph; if (relationshipWeights) { - storeLoaderBuilder.addRelationshipProperty(PropertyMapping.of("prop")); + currentGraph = graph; + } else { + currentGraph = graphStore.getGraph(RelationshipType.of("REL"), Optional.empty()); } - Graph graph = storeLoaderBuilder.build().graph(); int embeddingDimension = 128; Node2VecStreamConfig config = ImmutableNode2VecStreamConfig .builder() .embeddingDimension(embeddingDimension) .build(); - var progressTask = new Node2VecAlgorithmFactory<>().progressTask(graph, config); + var progressTask = new Node2VecAlgorithmFactory<>().progressTask(currentGraph, config); var walkParameters = new WalkParameters(10, 80, 1.0, 1.0, 0.001, 0.75); var trainParameters = new TrainParameters( @@ -163,7 +164,7 @@ void shouldLogProgress(boolean relationshipWeights, int expectedProgresses) { var log = Neo4jProxy.testLog(); var progressTracker = new TestProgressTracker(progressTask, log, 4, EmptyTaskRegistryFactory.INSTANCE); new Node2Vec( - graph, + currentGraph, 4, NO_SOURCE_NODES, NO_RANDOM_SEED, @@ -226,13 +227,13 @@ void shouldEstimateMemory() { @Test void failOnNegativeWeights() { - var graph = GdlFactory.of("CREATE (a)-[:REL {weight: -1}]->(b)").build().getUnion(); + var negativeGraph = GdlFactory.of("CREATE (a)-[:REL {weight: -1}]->(b)").build().getUnion(); var walkParameters = new WalkParameters(10, 80, 1.0, 1.0, 0.001, 0.75); var trainParameters = new TrainParameters(0.025, 0.0001, 1, 1, 1, 128, EmbeddingInitializer.NORMALIZED); var node2Vec = new Node2Vec( - graph, + negativeGraph, 4, NO_SOURCE_NODES, NO_RANDOM_SEED, @@ -253,7 +254,7 @@ void failOnNegativeWeights() { @Disabled("The order of the randomWalks + its usage in the training is not deterministic yet.") @Test void randomSeed(SoftAssertions softly) { - Graph graph = new StoreLoaderBuilder().databaseService(db).build().graph(); + int embeddingDimension = 2; var walkParameters = new WalkParameters(1, 20, 1.0, 1.0, 0.001, 0.75); @@ -288,8 +289,8 @@ void randomSeed(SoftAssertions softly) { static Stream graphs() { return Stream.of( - Arguments.of("All Labels", List.of()), - Arguments.of("Non Consecutive Original IDs", List.of("Node2", "Isolated")) + Arguments.of("All Labels", List.of(NodeLabel.of("Node1"), NodeLabel.of("Node2"), NodeLabel.of("Isolated"))), + Arguments.of("Non Consecutive Original IDs", List.of(NodeLabel.of("Node2"), NodeLabel.of("Isolated"))) ); }