Skip to content

Commit

Permalink
Use the algorithm runner in the algorithm facade
Browse files Browse the repository at this point in the history
Co-authored-by: Ioannis Panagiotas <ioannis.panagiotas@neotechnology.com>
  • Loading branch information
vnickolov and IoannisPanagiotas committed Nov 15, 2023
1 parent 84e5a11 commit d82176a
Show file tree
Hide file tree
Showing 12 changed files with 155 additions and 378 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,8 @@
*/
package org.neo4j.gds.algorithms.community;

import org.neo4j.gds.Algorithm;
import org.neo4j.gds.GraphAlgorithmFactory;
import org.neo4j.gds.PreconditionsProvider;
import org.neo4j.gds.algorithms.AlgorithmComputationResult;
import org.neo4j.gds.algorithms.AlgorithmMemoryEstimation;
import org.neo4j.gds.algorithms.AlgorithmMemoryValidationService;
import org.neo4j.gds.api.DatabaseId;
import org.neo4j.gds.api.GraphName;
import org.neo4j.gds.api.User;
import org.neo4j.gds.approxmaxkcut.ApproxMaxKCutAlgorithmFactory;
import org.neo4j.gds.approxmaxkcut.ApproxMaxKCutResult;
Expand All @@ -35,12 +29,7 @@
import org.neo4j.gds.conductance.ConductanceAlgorithmFactory;
import org.neo4j.gds.conductance.ConductanceBaseConfig;
import org.neo4j.gds.conductance.ConductanceResult;
import org.neo4j.gds.config.AlgoBaseConfig;
import org.neo4j.gds.core.GraphDimensions;
import org.neo4j.gds.core.loading.GraphStoreCatalogService;
import org.neo4j.gds.core.utils.paged.dss.DisjointSetStruct;
import org.neo4j.gds.core.utils.progress.TaskRegistryFactory;
import org.neo4j.gds.core.utils.warnings.UserLogRegistryFactory;
import org.neo4j.gds.k1coloring.K1ColoringAlgorithmFactory;
import org.neo4j.gds.k1coloring.K1ColoringBaseConfig;
import org.neo4j.gds.k1coloring.K1ColoringResult;
Expand All @@ -56,7 +45,6 @@
import org.neo4j.gds.leiden.LeidenAlgorithmFactory;
import org.neo4j.gds.leiden.LeidenBaseConfig;
import org.neo4j.gds.leiden.LeidenResult;
import org.neo4j.gds.logging.Log;
import org.neo4j.gds.louvain.LouvainAlgorithmFactory;
import org.neo4j.gds.louvain.LouvainBaseConfig;
import org.neo4j.gds.louvain.LouvainResult;
Expand All @@ -80,24 +68,11 @@
import java.util.Optional;

public class CommunityAlgorithmsFacade {
private final GraphStoreCatalogService graphStoreCatalogService;
private final TaskRegistryFactory taskRegistryFactory;
private final UserLogRegistryFactory userLogRegistryFactory;
private final AlgorithmMemoryValidationService memoryUsageValidator;
private final Log log;

public CommunityAlgorithmsFacade(
GraphStoreCatalogService graphStoreCatalogService,
TaskRegistryFactory taskRegistryFactory,
UserLogRegistryFactory userLogRegistryFactory,
AlgorithmMemoryValidationService memoryUsageValidator,
Log log
) {
this.graphStoreCatalogService = graphStoreCatalogService;
this.taskRegistryFactory = taskRegistryFactory;
this.userLogRegistryFactory = userLogRegistryFactory;
this.memoryUsageValidator = memoryUsageValidator;
this.log = log;
private final BasicAlgorithmRunner algorithmRunner;

public CommunityAlgorithmsFacade(BasicAlgorithmRunner algorithmRunner) {
this.algorithmRunner = algorithmRunner;
}

AlgorithmComputationResult<DisjointSetStruct> wcc(
Expand All @@ -106,7 +81,7 @@ AlgorithmComputationResult<DisjointSetStruct> wcc(
User user,
DatabaseId databaseId
) {
return run(
return algorithmRunner.run(
graphName,
config,
config.relationshipWeightProperty(),
Expand All @@ -122,7 +97,7 @@ AlgorithmComputationResult<TriangleCountResult> triangleCount(
User user,
DatabaseId databaseId
) {
return run(
return algorithmRunner.run(
graphName,
config,
Optional.empty(),
Expand All @@ -138,7 +113,7 @@ AlgorithmComputationResult<KCoreDecompositionResult> kCore(
User user,
DatabaseId databaseId
) {
return run(
return algorithmRunner.run(
graphName,
config,
Optional.empty(),
Expand All @@ -154,7 +129,7 @@ AlgorithmComputationResult<LouvainResult> louvain(
User user,
DatabaseId databaseId
) {
return run(
return algorithmRunner.run(
graphName,
config,
config.relationshipWeightProperty(),
Expand All @@ -170,7 +145,7 @@ AlgorithmComputationResult<LeidenResult> leiden(
User user,
DatabaseId databaseId
) {
return run(
return algorithmRunner.run(
graphName,
config,
config.relationshipWeightProperty(),
Expand All @@ -186,7 +161,7 @@ AlgorithmComputationResult<LabelPropagationResult> labelPropagation(
User user,
DatabaseId databaseId
) {
return run(
return algorithmRunner.run(
graphName,
configuration,
configuration.relationshipWeightProperty(),
Expand All @@ -202,7 +177,7 @@ AlgorithmComputationResult<HugeLongArray> scc(
User user,
DatabaseId databaseId
) {
return run(
return algorithmRunner.run(
graphName,
config,
Optional.empty(),
Expand All @@ -218,7 +193,7 @@ AlgorithmComputationResult<ModularityResult> modularity(
User user,
DatabaseId databaseId
) {
return run(
return algorithmRunner.run(
graphName,
config,
config.relationshipWeightProperty(),
Expand All @@ -234,7 +209,7 @@ AlgorithmComputationResult<KmeansResult> kmeans(
User user,
DatabaseId databaseId
) {
return run(
return algorithmRunner.run(
graphName,
config,
Optional.empty(),
Expand All @@ -250,7 +225,7 @@ public AlgorithmComputationResult<LocalClusteringCoefficientResult> localCluster
User user,
DatabaseId databaseId
) {
return run(
return algorithmRunner.run(
graphName,
config,
Optional.empty(),
Expand All @@ -266,7 +241,7 @@ AlgorithmComputationResult<K1ColoringResult> k1Coloring(
User user,
DatabaseId databaseId
) {
return run(
return algorithmRunner.run(
graphName,
config,
Optional.empty(),
Expand All @@ -282,7 +257,7 @@ AlgorithmComputationResult<ConductanceResult> conductance(
User user,
DatabaseId databaseId
) {
return run(
return algorithmRunner.run(
graphName,
config,
config.relationshipWeightProperty(),
Expand All @@ -298,7 +273,7 @@ AlgorithmComputationResult<ApproxMaxKCutResult> approxMaxKCut(
User user,
DatabaseId databaseId
) {
return run(
return algorithmRunner.run(
graphName,
config,
config.relationshipWeightProperty(),
Expand All @@ -315,7 +290,7 @@ public AlgorithmComputationResult<ModularityOptimizationResult> modularityOptimi
User user,
DatabaseId databaseId
) {
return run(
return algorithmRunner.run(
graphName,
config,
config.relationshipWeightProperty(),
Expand All @@ -326,65 +301,4 @@ public AlgorithmComputationResult<ModularityOptimizationResult> modularityOptimi
}


private <A extends Algorithm<R>, R, C extends AlgoBaseConfig> AlgorithmComputationResult<R> run(
String graphName,
C config,
Optional<String> relationshipProperty,
GraphAlgorithmFactory<A, C> algorithmFactory,
User user,
DatabaseId databaseId
) {
// TODO: Is this the best place to check for preconditions???
PreconditionsProvider.preconditions().check();

// Go get the graph and graph store from the catalog
var graphWithGraphStore = graphStoreCatalogService.getGraphWithGraphStore(
GraphName.parse(graphName),
config,
relationshipProperty,
user,
databaseId
);

var graph = graphWithGraphStore.getLeft();
var graphStore = graphWithGraphStore.getRight();

// No algorithm execution when the graph is empty
if (graph.isEmpty()) {
return AlgorithmComputationResult.withoutAlgorithmResult(graph, graphStore);
}

// create the algorithm
var algorithmEstimator = new AlgorithmMemoryEstimation<>(
GraphDimensions.of(
graph.nodeCount(),
graph.relationshipCount()
),
algorithmFactory
);

memoryUsageValidator.validateAlgorithmCanRunWithTheAvailableMemory(
config,
algorithmEstimator::memoryEstimation,
graphStoreCatalogService.graphStoreCount()
);
var algorithm = algorithmFactory.build(
graph,
config,
(org.neo4j.logging.Log) log.getNeo4jLog(),
taskRegistryFactory,
userLogRegistryFactory
);

// run the algorithm
try {
var algorithmResult = algorithm.compute();

return AlgorithmComputationResult.of(algorithmResult, graph, graphStore, algorithm.getTerminationFlag());
} catch (Exception e) {
log.warn("Computation failed", e);
algorithm.getProgressTracker().endSubTaskWithFailure();
throw e;
}
}
}
Loading

0 comments on commit d82176a

Please sign in to comment.