Skip to content

Commit

Permalink
Almost purge Node2Vec config from algo code
Browse files Browse the repository at this point in the history
  • Loading branch information
jjaderberg committed Oct 31, 2023
1 parent 6941b24 commit 6b22038
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 74 deletions.
37 changes: 27 additions & 10 deletions algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2Vec.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,26 +61,43 @@ public static MemoryEstimation memoryEstimation(int walksPerNode, int walkLength
.build();
}

static Node2Vec create(Graph graph, Node2VecBaseConfig config, ProgressTracker progressTracker) {
static Node2Vec create(
Graph graph,
int concurrency,
WalkParameters walkParameters,
TrainParameters trainParameters,
ProgressTracker progressTracker
) {
return create(graph, concurrency, Optional.empty(), walkParameters, trainParameters, progressTracker);
}

static Node2Vec create(
Graph graph,
int concurrency,
Optional<Long> maybeRandomSeed,
WalkParameters walkParameters,
TrainParameters trainParameters,
ProgressTracker progressTracker
) {
return new Node2Vec(
graph,
config.concurrency(),
config.walkParameters(),
config.sourceNodes(),
config.randomSeed(),
progressTracker,
config.trainParameters()
concurrency,
List.of(),
maybeRandomSeed,
walkParameters,
trainParameters,
progressTracker
);
}

public Node2Vec(
Graph graph,
int concurrency,
WalkParameters walkParameters,
List<Long> sourceNodes,
Optional<Long> maybeRandomSeed,
ProgressTracker progressTracker,
TrainParameters trainParameters
WalkParameters walkParameters,
TrainParameters trainParameters,
ProgressTracker progressTracker
) {
super(progressTracker);
this.graph = graph;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,15 @@ public Node2Vec build(
ProgressTracker progressTracker
) {
validateConfig(configuration, graph);
return Node2Vec.create(graph, configuration, progressTracker);
return new Node2Vec(
graph,
configuration.concurrency(),
configuration.sourceNodes(),
configuration.randomSeed(),
configuration.walkParameters(),
configuration.trainParameters(),
progressTracker
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public class TrainParameters {
final int embeddingDimension;
final EmbeddingInitializer embeddingInitializer;

TrainParameters(
public TrainParameters(
double initialLearningRate,
double minLearningRate,
int iterations,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,14 @@ public class WalkParameters extends org.neo4j.gds.traversal.WalkParameters {
final double negativeSamplingExponent;
final double positiveSamplingFactor;

WalkParameters(int walksPerNode, int walkLength, double returnFactor, double inOutFactor, double positiveSamplingFactor, double negativeSamplingExponent) {
public WalkParameters(
int walksPerNode,
int walkLength,
double returnFactor,
double inOutFactor,
double positiveSamplingFactor,
double negativeSamplingExponent
) {
super(walksPerNode, walkLength, returnFactor, inOutFactor);
this.negativeSamplingExponent = negativeSamplingExponent;
this.positiveSamplingFactor = positiveSamplingFactor;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,22 +60,16 @@ void testModel() {
random
);

Node2VecStreamConfig defaults = ImmutableNode2VecStreamConfig.builder().build();
var trainParameters = new TrainParameters(0.05, 0.0001, 5, 10, 1, 10, EmbeddingInitializer.NORMALIZED);

int nodeCount = numberOfClusters * clusterSize;

var node2VecModel = new Node2VecModel(
nodeId -> nodeId,
nodeCount,
0.05,
defaults.minLearningRate(),
5,
10,
defaults.windowSize(),
1,
defaults.embeddingInitializer(),
trainParameters,
4,
defaults.randomSeed(),
Optional.empty(),
walks,
probabilitiesBuilder.build(),
ProgressTracker.NULL_TRACKER
Expand Down Expand Up @@ -165,20 +159,14 @@ void randomSeed(int iterations) {

CompressedRandomWalks walks = generateRandomWalks(probabilitiesBuilder, numberOfClusters, clusterSize, numberOfWalks, walkLength, random);

Node2VecStreamConfig defaults = ImmutableNode2VecStreamConfig.builder().build();
var trainParameters = new TrainParameters(0.05, 0.0001, iterations, 10, 1, 2, EmbeddingInitializer.NORMALIZED);

int nodeCount = numberOfClusters * clusterSize;

var node2VecModel = new Node2VecModel(
nodeId -> nodeId,
nodeCount,
0.05,
defaults.minLearningRate(),
iterations,
2,
defaults.windowSize(),
1,
defaults.embeddingInitializer(),
trainParameters,
4,
Optional.of(1337L),
walks,
Expand All @@ -189,13 +177,7 @@ void randomSeed(int iterations) {
var otherNode2VecModel = new Node2VecModel(
nodeId -> nodeId,
nodeCount,
0.05,
defaults.minLearningRate(),
iterations,
2,
defaults.windowSize(),
1,
defaults.embeddingInitializer(),
trainParameters,
4,
Optional.of(1337L),
walks,
Expand Down
102 changes: 65 additions & 37 deletions algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,20 @@ void embeddingsShouldHaveTheConfiguredDimension(String msg, Iterable<String> nod
.graph();

int embeddingDimension = 128;
var trainParameters = new TrainParameters(
0.025,
0.0001,
1,
10,
5,
embeddingDimension,
EmbeddingInitializer.NORMALIZED
);
HugeObjectArray<FloatVector> node2Vec = Node2Vec.create(
graph,
ImmutableNode2VecStreamConfig.builder().embeddingDimension(embeddingDimension).build(),
4,
new WalkParameters(10, 80, 1.0, 1.0, 0.001, 0.75),
trainParameters,
ProgressTracker.NULL_TRACKER
).compute().embeddings();

Expand Down Expand Up @@ -132,11 +143,25 @@ void shouldLogProgress(boolean relationshipWeights, int expectedProgresses) {
.embeddingDimension(embeddingDimension)
.build();
var progressTask = new Node2VecAlgorithmFactory<>().progressTask(graph, config);

var walkParameters = new WalkParameters(10, 80, 1.0, 1.0, 0.001, 0.75);
var trainParameters = new TrainParameters(
0.025,
0.0001,
1,
10,
5,
embeddingDimension,
EmbeddingInitializer.NORMALIZED
);
var log = Neo4jProxy.testLog();
var progressTracker = new TestProgressTracker(progressTask, log, 4, EmptyTaskRegistryFactory.INSTANCE);
Node2Vec.create(
graph,
config,
4,
Optional.empty(),
walkParameters,
trainParameters,
progressTracker
).compute();

Expand Down Expand Up @@ -170,10 +195,12 @@ void shouldLogProgress(boolean relationshipWeights, int expectedProgresses) {
@Test
void shouldEstimateMemory() {
var nodeCount = 1000;
var config = ImmutableNode2VecStreamConfig.builder().build();
var memoryEstimation = Node2Vec.memoryEstimation(config.walksPerNode(), config.walkLength(), config.embeddingDimension());
var walksPerNode = 10;
var walkLength = 80;
var embeddingDimension = 128;
var memoryEstimation = Node2Vec.memoryEstimation(walksPerNode, walkLength, embeddingDimension);

var numberOfRandomWalks = nodeCount * config.walksPerNode() * config.walkLength();
var numberOfRandomWalks = nodeCount * walksPerNode * walkLength;
var randomWalkMemoryUsageLowerBound = numberOfRandomWalks * Long.BYTES;

var estimate = memoryEstimation.estimate(GraphDimensions.of(nodeCount), 1);
Expand All @@ -193,12 +220,16 @@ void shouldEstimateMemory() {
void failOnNegativeWeights() {
var graph = GdlFactory.of("CREATE (a)-[:REL {weight: -1}]->(b)").build().getUnion();

var config = ImmutableNode2VecStreamConfig
.builder()
.relationshipWeightProperty("weight")
.build();
var walkParameters = new WalkParameters(10, 80, 1.0, 1.0, 0.001, 0.75);
var trainParameters = new TrainParameters(0.025, 0.0001, 1, 1, 1, 128, EmbeddingInitializer.NORMALIZED);

var node2Vec = Node2Vec.create(graph, config, ProgressTracker.NULL_TRACKER);
var node2Vec = Node2Vec.create(
graph,
4,
walkParameters,
trainParameters,
ProgressTracker.NULL_TRACKER
);

assertThatThrownBy(node2Vec::compute)
.isInstanceOf(RuntimeException.class)
Expand All @@ -214,30 +245,26 @@ void randomSeed(SoftAssertions softly) {
Graph graph = new StoreLoaderBuilder().databaseService(db).build().graph();

int embeddingDimension = 2;

var config = ImmutableNode2VecStreamConfig
.builder()
.embeddingDimension(embeddingDimension)
.iterations(1)
.negativeSamplingRate(1)
.windowSize(1)
.walksPerNode(1)
.walkLength(20)
.walkBufferSize(50)
.randomSeed(1337L)
.build();
var walkParameters = new WalkParameters(1, 20, 1.0, 1.0, 0.001, 0.75);
var trainParameters = new TrainParameters(0.025, 0.0001, 1, 1, 1, embeddingDimension, EmbeddingInitializer.NORMALIZED);

var embeddings = Node2Vec.create(
graph,
config,
4,
Optional.of(1337L),
walkParameters,
trainParameters,
ProgressTracker.NULL_TRACKER
).compute().embeddings();
).compute().embeddings();

var otherEmbeddings = Node2Vec.create(
graph,
config,
4,
Optional.of(1337L),
walkParameters,
trainParameters,
ProgressTracker.NULL_TRACKER
).compute().embeddings();
).compute().embeddings();

for (long node = 0; node < graph.nodeCount(); node++) {
softly.assertThat(otherEmbeddings.get(node)).isEqualTo(embeddings.get(node));
Expand Down Expand Up @@ -318,25 +345,26 @@ void shouldBeFairlyConsistentUnderOriginalIds(EmbeddingInitializer embeddingInit
var firstGraph = GraphFactory.create(firstIdMap, firstRelationships);
var secondGraph = GraphFactory.create(secondIdMap, secondRelationships);

var config = ImmutableNode2VecStreamConfig
.builder()
.embeddingInitializer(embeddingInitializer)
.embeddingDimension(embeddingDimension)
.randomSeed(1337L)
.concurrency(1)
.build();
var walkParameters = new WalkParameters(10, 80, 1.0, 1.0, 0.01, 0.75);
var trainParameters = new TrainParameters(0.025, 0.0001, 1, 10, 5, embeddingDimension, embeddingInitializer);

var firstEmbeddings = Node2Vec.create(
firstGraph,
config,
4,
Optional.of(1337L),
walkParameters,
trainParameters,
ProgressTracker.NULL_TRACKER
).compute().embeddings();
).compute().embeddings();

var secondEmbeddings = Node2Vec.create(
secondGraph,
config,
4,
Optional.of(1337L),
walkParameters,
trainParameters,
ProgressTracker.NULL_TRACKER
).compute().embeddings();
).compute().embeddings();

double cosineSum = 0;
for (long originalNodeId = 0; originalNodeId < nodeCount; originalNodeId++) {
Expand Down

0 comments on commit 6b22038

Please sign in to comment.