From 2b3ec3aec2fca17367c3fbb6ffb78991ec4d278f Mon Sep 17 00:00:00 2001 From: ioannispan Date: Tue, 6 Aug 2024 13:27:55 +0200 Subject: [PATCH] Change the streaming mechanism --- .../WeightedAllShortestPaths.java | 36 +++++++++++-------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/algo/src/main/java/org/neo4j/gds/allshortestpaths/WeightedAllShortestPaths.java b/algo/src/main/java/org/neo4j/gds/allshortestpaths/WeightedAllShortestPaths.java index 5d4bff2885..2a2d25cf25 100644 --- a/algo/src/main/java/org/neo4j/gds/allshortestpaths/WeightedAllShortestPaths.java +++ b/algo/src/main/java/org/neo4j/gds/allshortestpaths/WeightedAllShortestPaths.java @@ -61,7 +61,7 @@ public class WeightedAllShortestPaths extends MSBFSASPAlgorithm { private final ExecutorService executorService; private final Graph graph; private final AtomicInteger counter; // nodeId counter (init with nodeCount, counts down for each node) - + private final AtomicInteger runningTaskCounter = new AtomicInteger(0); private volatile boolean outputStreamOpen; public WeightedAllShortestPaths(Graph graph, ExecutorService executorService, Concurrency concurrency, TerminationFlag terminationFlag) { @@ -93,6 +93,7 @@ public Stream compute() { for (int i = 0; i < concurrency.value(); i++) { executorService.submit(new ShortestPathTask()); + runningTaskCounter.incrementAndGet(); } return AllShortestPathsStream.stream(resultQueue, () -> { @@ -125,20 +126,23 @@ public void run() { int startNode; while (outputStreamOpen && terminationFlag.running() && (startNode = counter.getAndIncrement()) < nodeCount) { compute(startNode); - for (int i = 0; i < nodeCount; i++) { - var result = AllShortestPathsStreamResult.result( - graph.toOriginalNodeId(startNode), - graph.toOriginalNodeId(i), - distance[i] - ); - try { - resultQueue.put(result); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new RuntimeException(e); - } - } - progressTracker.logProgress(); + } + if (runningTaskCounter.decrementAndGet() == 0 && outputStreamOpen) { + resultQueue.add(AllShortestPathsStreamResult.DONE); + } + } + + private void streamResult(int source, int target, double distance){ + var result = AllShortestPathsStreamResult.result( + graph.toOriginalNodeId(source), + graph.toOriginalNodeId(target), + distance + ); + try { + resultQueue.put(result); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); } } @@ -149,6 +153,7 @@ void compute(int startNode) { while (outputStreamOpen && !queue.isEmpty()) { final int node = queue.pop(); final double sourceDistance = distance[node]; + streamResult(startNode,node,sourceDistance); threadLocalGraph.forEachRelationship( node, Double.NaN, @@ -162,6 +167,7 @@ void compute(int startNode) { return true; })); } + progressTracker.logProgress(); } } }