diff --git a/algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2Vec.java b/algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2Vec.java index 9b731d7ca8..37d5fb1833 100644 --- a/algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2Vec.java +++ b/algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2Vec.java @@ -61,26 +61,43 @@ public static MemoryEstimation memoryEstimation(int walksPerNode, int walkLength .build(); } - static Node2Vec create(Graph graph, Node2VecBaseConfig config, ProgressTracker progressTracker) { + static Node2Vec create( + Graph graph, + int concurrency, + WalkParameters walkParameters, + TrainParameters trainParameters, + ProgressTracker progressTracker + ) { + return create(graph, concurrency, Optional.empty(), walkParameters, trainParameters, progressTracker); + } + + static Node2Vec create( + Graph graph, + int concurrency, + Optional maybeRandomSeed, + WalkParameters walkParameters, + TrainParameters trainParameters, + ProgressTracker progressTracker + ) { return new Node2Vec( graph, - config.concurrency(), - config.walkParameters(), - config.sourceNodes(), - config.randomSeed(), - progressTracker, - config.trainParameters() + concurrency, + List.of(), + maybeRandomSeed, + walkParameters, + trainParameters, + progressTracker ); } public Node2Vec( Graph graph, int concurrency, - WalkParameters walkParameters, List sourceNodes, Optional maybeRandomSeed, - ProgressTracker progressTracker, - TrainParameters trainParameters + WalkParameters walkParameters, + TrainParameters trainParameters, + ProgressTracker progressTracker ) { super(progressTracker); this.graph = graph; diff --git a/algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2VecAlgorithmFactory.java b/algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2VecAlgorithmFactory.java index dd90eb051f..66adfe0adb 100644 --- a/algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2VecAlgorithmFactory.java +++ b/algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2VecAlgorithmFactory.java @@ -47,7 +47,15 @@ public Node2Vec build( ProgressTracker progressTracker ) { validateConfig(configuration, graph); - return Node2Vec.create(graph, configuration, progressTracker); + return new Node2Vec( + graph, + configuration.concurrency(), + configuration.sourceNodes(), + configuration.randomSeed(), + configuration.walkParameters(), + configuration.trainParameters(), + progressTracker + ); } @Override diff --git a/algo/src/main/java/org/neo4j/gds/embeddings/node2vec/TrainParameters.java b/algo/src/main/java/org/neo4j/gds/embeddings/node2vec/TrainParameters.java index 09b699441f..8c7b0b2e9e 100644 --- a/algo/src/main/java/org/neo4j/gds/embeddings/node2vec/TrainParameters.java +++ b/algo/src/main/java/org/neo4j/gds/embeddings/node2vec/TrainParameters.java @@ -28,7 +28,7 @@ public class TrainParameters { final int embeddingDimension; final EmbeddingInitializer embeddingInitializer; - TrainParameters( + public TrainParameters( double initialLearningRate, double minLearningRate, int iterations, diff --git a/algo/src/main/java/org/neo4j/gds/embeddings/node2vec/WalkParameters.java b/algo/src/main/java/org/neo4j/gds/embeddings/node2vec/WalkParameters.java index 0893e4f144..ae76c0b7dd 100644 --- a/algo/src/main/java/org/neo4j/gds/embeddings/node2vec/WalkParameters.java +++ b/algo/src/main/java/org/neo4j/gds/embeddings/node2vec/WalkParameters.java @@ -23,7 +23,14 @@ public class WalkParameters extends org.neo4j.gds.traversal.WalkParameters { final double negativeSamplingExponent; final double positiveSamplingFactor; - WalkParameters(int walksPerNode, int walkLength, double returnFactor, double inOutFactor, double positiveSamplingFactor, double negativeSamplingExponent) { + public WalkParameters( + int walksPerNode, + int walkLength, + double returnFactor, + double inOutFactor, + double positiveSamplingFactor, + double negativeSamplingExponent + ) { super(walksPerNode, walkLength, returnFactor, inOutFactor); this.negativeSamplingExponent = negativeSamplingExponent; this.positiveSamplingFactor = positiveSamplingFactor; diff --git a/algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecModelTest.java b/algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecModelTest.java index 78d04131ca..be78ded2c1 100644 --- a/algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecModelTest.java +++ b/algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecModelTest.java @@ -60,22 +60,16 @@ void testModel() { random ); - Node2VecStreamConfig defaults = ImmutableNode2VecStreamConfig.builder().build(); + var trainParameters = new TrainParameters(0.05, 0.0001, 5, 10, 1, 10, EmbeddingInitializer.NORMALIZED); int nodeCount = numberOfClusters * clusterSize; var node2VecModel = new Node2VecModel( nodeId -> nodeId, nodeCount, - 0.05, - defaults.minLearningRate(), - 5, - 10, - defaults.windowSize(), - 1, - defaults.embeddingInitializer(), + trainParameters, 4, - defaults.randomSeed(), + Optional.empty(), walks, probabilitiesBuilder.build(), ProgressTracker.NULL_TRACKER @@ -165,20 +159,14 @@ void randomSeed(int iterations) { CompressedRandomWalks walks = generateRandomWalks(probabilitiesBuilder, numberOfClusters, clusterSize, numberOfWalks, walkLength, random); - Node2VecStreamConfig defaults = ImmutableNode2VecStreamConfig.builder().build(); + var trainParameters = new TrainParameters(0.05, 0.0001, iterations, 10, 1, 2, EmbeddingInitializer.NORMALIZED); int nodeCount = numberOfClusters * clusterSize; var node2VecModel = new Node2VecModel( nodeId -> nodeId, nodeCount, - 0.05, - defaults.minLearningRate(), - iterations, - 2, - defaults.windowSize(), - 1, - defaults.embeddingInitializer(), + trainParameters, 4, Optional.of(1337L), walks, @@ -189,13 +177,7 @@ void randomSeed(int iterations) { var otherNode2VecModel = new Node2VecModel( nodeId -> nodeId, nodeCount, - 0.05, - defaults.minLearningRate(), - iterations, - 2, - defaults.windowSize(), - 1, - defaults.embeddingInitializer(), + trainParameters, 4, Optional.of(1337L), walks, diff --git a/algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecTest.java b/algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecTest.java index 230b0fe8fd..7e9716e5b0 100644 --- a/algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecTest.java +++ b/algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecTest.java @@ -100,9 +100,20 @@ void embeddingsShouldHaveTheConfiguredDimension(String msg, Iterable nod .graph(); int embeddingDimension = 128; + var trainParameters = new TrainParameters( + 0.025, + 0.0001, + 1, + 10, + 5, + embeddingDimension, + EmbeddingInitializer.NORMALIZED + ); HugeObjectArray node2Vec = Node2Vec.create( graph, - ImmutableNode2VecStreamConfig.builder().embeddingDimension(embeddingDimension).build(), + 4, + new WalkParameters(10, 80, 1.0, 1.0, 0.001, 0.75), + trainParameters, ProgressTracker.NULL_TRACKER ).compute().embeddings(); @@ -132,11 +143,25 @@ void shouldLogProgress(boolean relationshipWeights, int expectedProgresses) { .embeddingDimension(embeddingDimension) .build(); var progressTask = new Node2VecAlgorithmFactory<>().progressTask(graph, config); + + var walkParameters = new WalkParameters(10, 80, 1.0, 1.0, 0.001, 0.75); + var trainParameters = new TrainParameters( + 0.025, + 0.0001, + 1, + 10, + 5, + embeddingDimension, + EmbeddingInitializer.NORMALIZED + ); var log = Neo4jProxy.testLog(); var progressTracker = new TestProgressTracker(progressTask, log, 4, EmptyTaskRegistryFactory.INSTANCE); Node2Vec.create( graph, - config, + 4, + Optional.empty(), + walkParameters, + trainParameters, progressTracker ).compute(); @@ -170,10 +195,12 @@ void shouldLogProgress(boolean relationshipWeights, int expectedProgresses) { @Test void shouldEstimateMemory() { var nodeCount = 1000; - var config = ImmutableNode2VecStreamConfig.builder().build(); - var memoryEstimation = Node2Vec.memoryEstimation(config.walksPerNode(), config.walkLength(), config.embeddingDimension()); + var walksPerNode = 10; + var walkLength = 80; + var embeddingDimension = 128; + var memoryEstimation = Node2Vec.memoryEstimation(walksPerNode, walkLength, embeddingDimension); - var numberOfRandomWalks = nodeCount * config.walksPerNode() * config.walkLength(); + var numberOfRandomWalks = nodeCount * walksPerNode * walkLength; var randomWalkMemoryUsageLowerBound = numberOfRandomWalks * Long.BYTES; var estimate = memoryEstimation.estimate(GraphDimensions.of(nodeCount), 1); @@ -193,12 +220,16 @@ void shouldEstimateMemory() { void failOnNegativeWeights() { var graph = GdlFactory.of("CREATE (a)-[:REL {weight: -1}]->(b)").build().getUnion(); - var config = ImmutableNode2VecStreamConfig - .builder() - .relationshipWeightProperty("weight") - .build(); + var walkParameters = new WalkParameters(10, 80, 1.0, 1.0, 0.001, 0.75); + var trainParameters = new TrainParameters(0.025, 0.0001, 1, 1, 1, 128, EmbeddingInitializer.NORMALIZED); - var node2Vec = Node2Vec.create(graph, config, ProgressTracker.NULL_TRACKER); + var node2Vec = Node2Vec.create( + graph, + 4, + walkParameters, + trainParameters, + ProgressTracker.NULL_TRACKER + ); assertThatThrownBy(node2Vec::compute) .isInstanceOf(RuntimeException.class) @@ -214,30 +245,26 @@ void randomSeed(SoftAssertions softly) { Graph graph = new StoreLoaderBuilder().databaseService(db).build().graph(); int embeddingDimension = 2; - - var config = ImmutableNode2VecStreamConfig - .builder() - .embeddingDimension(embeddingDimension) - .iterations(1) - .negativeSamplingRate(1) - .windowSize(1) - .walksPerNode(1) - .walkLength(20) - .walkBufferSize(50) - .randomSeed(1337L) - .build(); + var walkParameters = new WalkParameters(1, 20, 1.0, 1.0, 0.001, 0.75); + var trainParameters = new TrainParameters(0.025, 0.0001, 1, 1, 1, embeddingDimension, EmbeddingInitializer.NORMALIZED); var embeddings = Node2Vec.create( graph, - config, + 4, + Optional.of(1337L), + walkParameters, + trainParameters, ProgressTracker.NULL_TRACKER - ).compute().embeddings(); + ).compute().embeddings(); var otherEmbeddings = Node2Vec.create( graph, - config, + 4, + Optional.of(1337L), + walkParameters, + trainParameters, ProgressTracker.NULL_TRACKER - ).compute().embeddings(); + ).compute().embeddings(); for (long node = 0; node < graph.nodeCount(); node++) { softly.assertThat(otherEmbeddings.get(node)).isEqualTo(embeddings.get(node)); @@ -318,25 +345,26 @@ void shouldBeFairlyConsistentUnderOriginalIds(EmbeddingInitializer embeddingInit var firstGraph = GraphFactory.create(firstIdMap, firstRelationships); var secondGraph = GraphFactory.create(secondIdMap, secondRelationships); - var config = ImmutableNode2VecStreamConfig - .builder() - .embeddingInitializer(embeddingInitializer) - .embeddingDimension(embeddingDimension) - .randomSeed(1337L) - .concurrency(1) - .build(); + var walkParameters = new WalkParameters(10, 80, 1.0, 1.0, 0.01, 0.75); + var trainParameters = new TrainParameters(0.025, 0.0001, 1, 10, 5, embeddingDimension, embeddingInitializer); var firstEmbeddings = Node2Vec.create( firstGraph, - config, + 4, + Optional.of(1337L), + walkParameters, + trainParameters, ProgressTracker.NULL_TRACKER - ).compute().embeddings(); + ).compute().embeddings(); var secondEmbeddings = Node2Vec.create( secondGraph, - config, + 4, + Optional.of(1337L), + walkParameters, + trainParameters, ProgressTracker.NULL_TRACKER - ).compute().embeddings(); + ).compute().embeddings(); double cosineSum = 0; for (long originalNodeId = 0; originalNodeId < nodeCount; originalNodeId++) {