Skip to content

Commit

Permalink
Node2Vec also gets walkBufferSize
Browse files Browse the repository at this point in the history
  • Loading branch information
jjaderberg committed Oct 31, 2023
1 parent 6b22038 commit 9cb8785
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ public class Node2Vec extends Algorithm<Node2VecModel.Result> {
private final List<Long> sourceNodes;
private final Optional<Long> maybeRandomSeed;
private final TrainParameters trainParameters;
private final int walkBufferSize;


public static MemoryEstimation memoryEstimation(int walksPerNode, int walkLength, int embeddingDimension) {
Expand Down Expand Up @@ -84,6 +85,7 @@ static Node2Vec create(
concurrency,
List.of(),
maybeRandomSeed,
1000,
walkParameters,
trainParameters,
progressTracker
Expand All @@ -95,6 +97,7 @@ public Node2Vec(
int concurrency,
List<Long> sourceNodes,
Optional<Long> maybeRandomSeed,
int walkBufferSize,
WalkParameters walkParameters,
TrainParameters trainParameters,
ProgressTracker progressTracker
Expand All @@ -103,6 +106,7 @@ public Node2Vec(
this.graph = graph;
this.concurrency = concurrency;
this.walkParameters = walkParameters;
this.walkBufferSize = walkBufferSize;
this.sourceNodes = sourceNodes;
this.maybeRandomSeed = maybeRandomSeed;
this.trainParameters = trainParameters;
Expand Down Expand Up @@ -140,6 +144,7 @@ public Node2VecModel.Result compute() {
concurrency,
sourceNodes,
walkParameters,
walkBufferSize,
DefaultPool.INSTANCE,
progressTracker,
terminationFlag
Expand Down Expand Up @@ -185,6 +190,7 @@ private List<Node2VecRandomWalkTask> walkTasks(
int concurrency,
List<Long> sourceNodes,
WalkParameters walkParameters,
int walkBufferSize,
ExecutorService executorService,
ProgressTracker progressTracker,
TerminationFlag terminationFlag
Expand All @@ -211,6 +217,7 @@ private List<Node2VecRandomWalkTask> walkTasks(
index,
compressedRandomWalks,
randomWalkPropabilitiesBuilder,
walkBufferSize,
randomSeed,
walkParameters.walkLength,
walkParameters.returnFactor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ public Node2Vec build(
configuration.concurrency(),
configuration.sourceNodes(),
configuration.randomSeed(),
configuration.walkBufferSize(),
configuration.walkParameters(),
configuration.trainParameters(),
progressTracker
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ final class Node2VecRandomWalkTask implements Runnable {
private final CompressedRandomWalks compressedRandomWalks;
private final RandomWalkProbabilities.Builder randomWalkProbabilitiesBuilder;
private final RandomWalkSampler sampler;
private final int walkBufferSize;
private int walks;
private int maxWalkLength;
private long maxIndex;
Expand All @@ -52,6 +53,7 @@ final class Node2VecRandomWalkTask implements Runnable {
AtomicLong walkIndex,
CompressedRandomWalks compressedRandomWalks,
RandomWalkProbabilities.Builder randomWalkProbabilitiesBuilder,
int walkBufferSize,
long randomSeed,
int walkLength,
double returnFactor,
Expand All @@ -65,6 +67,7 @@ final class Node2VecRandomWalkTask implements Runnable {
this.walkIndex = walkIndex;
this.compressedRandomWalks = compressedRandomWalks;
this.randomWalkProbabilitiesBuilder = randomWalkProbabilitiesBuilder;
this.walkBufferSize = walkBufferSize;

this.sampler = RandomWalkSampler.create(
graph,
Expand All @@ -85,7 +88,7 @@ private boolean consumePath(long[] path) {
randomWalkProbabilitiesBuilder.registerWalk(path);
compressedRandomWalks.add(index, path);
maxWalkLength = Math.max(path.length, maxWalkLength);
if (walks++ == 1000) { //this is just to get the same
if (walks++ == walkBufferSize) {
walks = 0;
return this.terminationFlag.running();
}
Expand Down

0 comments on commit 9cb8785

Please sign in to comment.