From b3efe530851bb28fa6ef85cc8f3b3233c97bd071 Mon Sep 17 00:00:00 2001 From: Veselin Nikolov Date: Mon, 13 Nov 2023 11:34:58 +0000 Subject: [PATCH 1/8] Create Algorithm Metrics API module Co-authored-by: Ioannis Panagiotas --- algorithm-metrics-api/build.gradle | 16 +++++ .../metrics/AlgorithmMetricRegistrar.java | 26 +++++++++ .../metrics/AlgorithmMetricsService.java | 38 ++++++++++++ .../PassthroughAlgorithmMetricRegistrar.java | 36 ++++++++++++ .../metrics/AlgorithmMetricsServiceTest.java | 58 +++++++++++++++++++ settings.gradle | 3 + 6 files changed, 177 insertions(+) create mode 100644 algorithm-metrics-api/build.gradle create mode 100644 algorithm-metrics-api/src/main/java/org/neo4j/gds/algorithms/metrics/AlgorithmMetricRegistrar.java create mode 100644 algorithm-metrics-api/src/main/java/org/neo4j/gds/algorithms/metrics/AlgorithmMetricsService.java create mode 100644 algorithm-metrics-api/src/main/java/org/neo4j/gds/algorithms/metrics/PassthroughAlgorithmMetricRegistrar.java create mode 100644 algorithm-metrics-api/src/test/java/org/neo4j/gds/algorithms/metrics/AlgorithmMetricsServiceTest.java diff --git a/algorithm-metrics-api/build.gradle b/algorithm-metrics-api/build.gradle new file mode 100644 index 0000000000..840167babe --- /dev/null +++ b/algorithm-metrics-api/build.gradle @@ -0,0 +1,16 @@ +apply plugin: 'java-library' + +description = 'Neo4j Graph Data Science :: Algorithm Metrics API' + +group = 'org.neo4j.gds' + +dependencies { + + testImplementation( + platform(dep.junit5bom), + dep.junit5jupiter, + ) + + testImplementation group: 'org.assertj', name: 'assertj-core', version: ver.'assertj' + testImplementation group: 'org.mockito', name: 'mockito-junit-jupiter', version: ver.'mockito-junit-jupiter' +} diff --git a/algorithm-metrics-api/src/main/java/org/neo4j/gds/algorithms/metrics/AlgorithmMetricRegistrar.java b/algorithm-metrics-api/src/main/java/org/neo4j/gds/algorithms/metrics/AlgorithmMetricRegistrar.java new file mode 100644 index 0000000000..9f98c0ed99 --- /dev/null +++ b/algorithm-metrics-api/src/main/java/org/neo4j/gds/algorithms/metrics/AlgorithmMetricRegistrar.java @@ -0,0 +1,26 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.algorithms.metrics; + +public interface AlgorithmMetricRegistrar { + + void started(String algorithm); + void failed(String algorithm); +} diff --git a/algorithm-metrics-api/src/main/java/org/neo4j/gds/algorithms/metrics/AlgorithmMetricsService.java b/algorithm-metrics-api/src/main/java/org/neo4j/gds/algorithms/metrics/AlgorithmMetricsService.java new file mode 100644 index 0000000000..2dc6c24c0a --- /dev/null +++ b/algorithm-metrics-api/src/main/java/org/neo4j/gds/algorithms/metrics/AlgorithmMetricsService.java @@ -0,0 +1,38 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.algorithms.metrics; + +public class AlgorithmMetricsService { + + private final AlgorithmMetricRegistrar metricRegistrar; + + public AlgorithmMetricsService(AlgorithmMetricRegistrar metricRegistrar) { + this.metricRegistrar = metricRegistrar; + } + + public void started(String algorithm) { + metricRegistrar.started(algorithm); + } + + public void failed(String algorithm) { + metricRegistrar.failed(algorithm); + } + +} diff --git a/algorithm-metrics-api/src/main/java/org/neo4j/gds/algorithms/metrics/PassthroughAlgorithmMetricRegistrar.java b/algorithm-metrics-api/src/main/java/org/neo4j/gds/algorithms/metrics/PassthroughAlgorithmMetricRegistrar.java new file mode 100644 index 0000000000..bb4fbad2d4 --- /dev/null +++ b/algorithm-metrics-api/src/main/java/org/neo4j/gds/algorithms/metrics/PassthroughAlgorithmMetricRegistrar.java @@ -0,0 +1,36 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.algorithms.metrics; + +/** + * No-op metrics registrar; to be used when metrics are not enabled in Neo4j. + */ +public class PassthroughAlgorithmMetricRegistrar implements AlgorithmMetricRegistrar { + + @Override + public void started(String algorithm) { + + } + + @Override + public void failed(String algorithm) { + + } +} diff --git a/algorithm-metrics-api/src/test/java/org/neo4j/gds/algorithms/metrics/AlgorithmMetricsServiceTest.java b/algorithm-metrics-api/src/test/java/org/neo4j/gds/algorithms/metrics/AlgorithmMetricsServiceTest.java new file mode 100644 index 0000000000..a869605f09 --- /dev/null +++ b/algorithm-metrics-api/src/test/java/org/neo4j/gds/algorithms/metrics/AlgorithmMetricsServiceTest.java @@ -0,0 +1,58 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.algorithms.metrics; + +import org.junit.jupiter.api.Test; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +class AlgorithmMetricsServiceTest { + + @Test + void shouldRegisterStarted() { + // given + var registrarMock = mock(AlgorithmMetricRegistrar.class); + var metricsService = new AlgorithmMetricsService(registrarMock); + + // when + metricsService.started("foo"); + + // then + verify(registrarMock, times(1)).started("foo"); + verifyNoMoreInteractions(registrarMock); + } + + @Test + void shouldRegisterFailed() { + // given + var registrarMock = mock(AlgorithmMetricRegistrar.class); + var metricsService = new AlgorithmMetricsService(registrarMock); + + // when + metricsService.failed("foo"); + + // then + verify(registrarMock, times(1)).failed("foo"); + verifyNoMoreInteractions(registrarMock); + } +} diff --git a/settings.gradle b/settings.gradle index a13a280bb0..2e5cbbf37c 100644 --- a/settings.gradle +++ b/settings.gradle @@ -19,6 +19,9 @@ project(':algo-common').projectDir = file('algo-common') include('algo-test') project(':algo-test').projectDir = file('algo-test') +include('algorithm-metrics-api') +project(':algorithm-metrics-api').projectDir = file('algorithm-metrics-api') + include('alpha-proc') project(':alpha-proc').projectDir = file('alpha/alpha-proc') From 84e5a11bd31d4767583d95fb8f2659a6b6d5adc2 Mon Sep 17 00:00:00 2001 From: Veselin Nikolov Date: Mon, 13 Nov 2023 13:23:14 +0000 Subject: [PATCH 2/8] Extract algorithm run logic Co-authored-by: Ioannis Panagiotas --- algo/build.gradle | 1 + .../community/BasicAlgorithmRunner.java | 131 ++++++++++++++++ .../community/BasicAlgorithmRunnerTest.java | 142 ++++++++++++++++++ 3 files changed, 274 insertions(+) create mode 100644 algo/src/main/java/org/neo4j/gds/algorithms/community/BasicAlgorithmRunner.java create mode 100644 algo/src/test/java/org/neo4j/gds/algorithms/community/BasicAlgorithmRunnerTest.java diff --git a/algo/build.gradle b/algo/build.gradle index 4386cbd71b..a1e07dbb4b 100644 --- a/algo/build.gradle +++ b/algo/build.gradle @@ -27,6 +27,7 @@ dependencies { compileOnly group: 'org.neo4j', name: 'neo4j-graph-algo', version: ver.'neo4j' implementation project(':algo-common') + implementation project(':algorithm-metrics-api') implementation project(':annotations') implementation project(':collections-memory-estimation') implementation project(':config-api') diff --git a/algo/src/main/java/org/neo4j/gds/algorithms/community/BasicAlgorithmRunner.java b/algo/src/main/java/org/neo4j/gds/algorithms/community/BasicAlgorithmRunner.java new file mode 100644 index 0000000000..ff14781ff8 --- /dev/null +++ b/algo/src/main/java/org/neo4j/gds/algorithms/community/BasicAlgorithmRunner.java @@ -0,0 +1,131 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.algorithms.community; + +import org.neo4j.gds.Algorithm; +import org.neo4j.gds.GraphAlgorithmFactory; +import org.neo4j.gds.PreconditionsProvider; +import org.neo4j.gds.algorithms.AlgorithmComputationResult; +import org.neo4j.gds.algorithms.AlgorithmMemoryEstimation; +import org.neo4j.gds.algorithms.AlgorithmMemoryValidationService; +import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService; +import org.neo4j.gds.api.DatabaseId; +import org.neo4j.gds.api.GraphName; +import org.neo4j.gds.api.User; +import org.neo4j.gds.config.AlgoBaseConfig; +import org.neo4j.gds.core.GraphDimensions; +import org.neo4j.gds.core.loading.GraphStoreCatalogService; +import org.neo4j.gds.core.utils.progress.TaskRegistryFactory; +import org.neo4j.gds.core.utils.warnings.UserLogRegistryFactory; +import org.neo4j.gds.logging.Log; + +import java.util.Optional; + +public class BasicAlgorithmRunner { + private final GraphStoreCatalogService graphStoreCatalogService; + private final TaskRegistryFactory taskRegistryFactory; + private final UserLogRegistryFactory userLogRegistryFactory; + private final AlgorithmMemoryValidationService memoryUsageValidator; + + private final AlgorithmMetricsService algorithmMetricsService; + + private final Log log; + + public BasicAlgorithmRunner( + GraphStoreCatalogService graphStoreCatalogService, + TaskRegistryFactory taskRegistryFactory, + UserLogRegistryFactory userLogRegistryFactory, + AlgorithmMemoryValidationService memoryUsageValidator, + AlgorithmMetricsService algorithmMetricsService, + Log log + ) { + this.graphStoreCatalogService = graphStoreCatalogService; + this.taskRegistryFactory = taskRegistryFactory; + this.userLogRegistryFactory = userLogRegistryFactory; + this.memoryUsageValidator = memoryUsageValidator; + this.algorithmMetricsService = algorithmMetricsService; + this.log = log; + } + + public , R, C extends AlgoBaseConfig> AlgorithmComputationResult run( + String graphName, + C config, + Optional relationshipProperty, + GraphAlgorithmFactory algorithmFactory, + User user, + DatabaseId databaseId + ) { + // TODO: Is this the best place to check for preconditions??? + PreconditionsProvider.preconditions().check(); + + // Go get the graph and graph store from the catalog + var graphWithGraphStore = graphStoreCatalogService.getGraphWithGraphStore( + GraphName.parse(graphName), + config, + relationshipProperty, + user, + databaseId + ); + + var graph = graphWithGraphStore.getLeft(); + var graphStore = graphWithGraphStore.getRight(); + + // No algorithm execution when the graph is empty + if (graph.isEmpty()) { + return AlgorithmComputationResult.withoutAlgorithmResult(graph, graphStore); + } + + // create the algorithm + var algorithmEstimator = new AlgorithmMemoryEstimation<>( + GraphDimensions.of( + graph.nodeCount(), + graph.relationshipCount() + ), + algorithmFactory + ); + + memoryUsageValidator.validateAlgorithmCanRunWithTheAvailableMemory( + config, + algorithmEstimator::memoryEstimation, + graphStoreCatalogService.graphStoreCount() + ); + var algorithm = algorithmFactory.build( + graph, + config, + (org.neo4j.logging.Log) log.getNeo4jLog(), + taskRegistryFactory, + userLogRegistryFactory + ); + + // run the algorithm + try { + algorithmMetricsService.started(algorithmFactory.taskName()); + var algorithmResult = algorithm.compute(); + + return AlgorithmComputationResult.of(algorithmResult, graph, graphStore, algorithm.getTerminationFlag()); + } catch (Exception e) { + log.warn("Computation failed", e); + algorithm.getProgressTracker().endSubTaskWithFailure(); + algorithmMetricsService.failed(algorithmFactory.taskName()); + throw e; + } + } + +} diff --git a/algo/src/test/java/org/neo4j/gds/algorithms/community/BasicAlgorithmRunnerTest.java b/algo/src/test/java/org/neo4j/gds/algorithms/community/BasicAlgorithmRunnerTest.java new file mode 100644 index 0000000000..9edf4c83fe --- /dev/null +++ b/algo/src/test/java/org/neo4j/gds/algorithms/community/BasicAlgorithmRunnerTest.java @@ -0,0 +1,142 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.algorithms.community; + +import org.apache.commons.lang3.tuple.Pair; +import org.junit.jupiter.api.Test; +import org.neo4j.gds.Algorithm; +import org.neo4j.gds.GraphAlgorithmFactory; +import org.neo4j.gds.algorithms.AlgorithmMemoryValidationService; +import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService; +import org.neo4j.gds.api.DatabaseId; +import org.neo4j.gds.api.Graph; +import org.neo4j.gds.api.GraphStore; +import org.neo4j.gds.api.User; +import org.neo4j.gds.compat.Neo4jProxy; +import org.neo4j.gds.config.AlgoBaseConfig; +import org.neo4j.gds.core.loading.GraphStoreCatalogService; +import org.neo4j.gds.core.utils.progress.TaskRegistryFactory; +import org.neo4j.gds.core.utils.warnings.EmptyUserLogRegistryFactory; +import org.neo4j.gds.logging.Log; + +import java.util.Optional; + +import static org.assertj.core.api.Assertions.assertThatException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.RETURNS_DEEP_STUBS; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +class BasicAlgorithmRunnerTest { + + @Test + void shouldRegisterAlgorithmMetricCountForSuccess() { + var graphMock = mock(Graph.class); + when(graphMock.isEmpty()).thenReturn(false); + + var graphStoreCatalogServiceMock = mock(GraphStoreCatalogService.class); + when(graphStoreCatalogServiceMock.getGraphWithGraphStore(any(), any(), any(), any(), any())) + .thenReturn(Pair.of(graphMock, mock(GraphStore.class))); + + var algorithmMetricsServiceMock = mock(AlgorithmMetricsService.class); + + var logMock = mock(Log.class); + when(logMock.getNeo4jLog()).thenReturn(Neo4jProxy.testLog()); + + var runner = new BasicAlgorithmRunner( + graphStoreCatalogServiceMock, + TaskRegistryFactory.empty(), + EmptyUserLogRegistryFactory.INSTANCE, + mock(AlgorithmMemoryValidationService.class), + algorithmMetricsServiceMock, + logMock + ); + + var algorithmMock = mock(Algorithm.class, RETURNS_DEEP_STUBS); + when(algorithmMock.compute()).thenReturn("WooHoo"); + var algorithmFactoryMock = mock(GraphAlgorithmFactory.class); + when(algorithmFactoryMock.taskName()).thenReturn("TestingMetrics"); + when(algorithmFactoryMock.build(any(), any(), any(), any(), any())).thenReturn(algorithmMock); + + runner.run( + "foo", + mock(AlgoBaseConfig.class), + Optional.empty(), + algorithmFactoryMock, + mock(User.class), + DatabaseId.EMPTY + ); + + verify(algorithmMetricsServiceMock, times(1)).started("TestingMetrics"); + verify(algorithmMetricsServiceMock, times(0)).failed("TestingMetrics"); + verifyNoMoreInteractions(algorithmMetricsServiceMock); + } + + + @Test + void shouldRegisterAlgorithmMetricCountForFailure() { + var graphMock = mock(Graph.class); + when(graphMock.isEmpty()).thenReturn(false); + + var graphStoreCatalogServiceMock = mock(GraphStoreCatalogService.class); + when(graphStoreCatalogServiceMock.getGraphWithGraphStore(any(), any(), any(), any(), any())) + .thenReturn(Pair.of(graphMock, mock(GraphStore.class))); + + var algorithmMetricsServiceMock = mock(AlgorithmMetricsService.class); + + var logMock = mock(Log.class); + when(logMock.getNeo4jLog()).thenReturn(Neo4jProxy.testLog()); + + var runner = new BasicAlgorithmRunner( + graphStoreCatalogServiceMock, + TaskRegistryFactory.empty(), + EmptyUserLogRegistryFactory.INSTANCE, + mock(AlgorithmMemoryValidationService.class), + algorithmMetricsServiceMock, + logMock + ); + + var algorithmMock = mock(Algorithm.class, RETURNS_DEEP_STUBS); + when(algorithmMock.compute()).thenThrow(new RuntimeException("Ooops")); + + var algorithmFactoryMock = mock(GraphAlgorithmFactory.class); + when(algorithmFactoryMock.taskName()).thenReturn("TestingMetrics"); + when(algorithmFactoryMock.build(any(), any(), any(), any(), any())).thenReturn(algorithmMock); + + assertThatException().isThrownBy( + () -> runner.run( + "foo", + mock(AlgoBaseConfig.class), + Optional.empty(), + algorithmFactoryMock, + mock(User.class), + DatabaseId.EMPTY + ) + ).withMessage("Ooops"); + + verify(algorithmMetricsServiceMock, times(1)).started("TestingMetrics"); + verify(algorithmMetricsServiceMock, times(1)).failed("TestingMetrics"); + verifyNoMoreInteractions(algorithmMetricsServiceMock); + } + +} From d82176a182df5f363607fadc4c35ff9a3d7da953 Mon Sep 17 00:00:00 2001 From: Veselin Nikolov Date: Mon, 13 Nov 2023 15:05:17 +0000 Subject: [PATCH 3/8] Use the algorithm runner in the algorithm facade Co-authored-by: Ioannis Panagiotas --- .../community/CommunityAlgorithmsFacade.java | 122 +++----------- ...ityAlgorithmsStreamBusinessFacadeTest.java | 149 ------------------ proc/community/build.gradle | 1 + .../k1coloring/K1ColoringStreamProcTest.java | 16 +- .../LabelPropagationMutateProcTest.java | 93 ++++------- .../ModularityOptimizationMutateProcTest.java | 17 +- .../org/neo4j/gds/wcc/WccMutateProcTest.java | 88 +++++------ .../org/neo4j/gds/wcc/WccStatsProcTest.java | 17 +- .../org/neo4j/gds/wcc/WccWriteProcTest.java | 16 +- .../OpenGraphDataScienceExtension.java | 3 +- procedures/integration/build.gradle | 1 + .../CommunityProcedureProvider.java | 10 +- 12 files changed, 155 insertions(+), 378 deletions(-) delete mode 100644 algo/src/test/java/org/neo4j/gds/algorithms/community/CommunityAlgorithmsStreamBusinessFacadeTest.java diff --git a/algo/src/main/java/org/neo4j/gds/algorithms/community/CommunityAlgorithmsFacade.java b/algo/src/main/java/org/neo4j/gds/algorithms/community/CommunityAlgorithmsFacade.java index 3bd9e6a1b8..5d1d913f82 100644 --- a/algo/src/main/java/org/neo4j/gds/algorithms/community/CommunityAlgorithmsFacade.java +++ b/algo/src/main/java/org/neo4j/gds/algorithms/community/CommunityAlgorithmsFacade.java @@ -19,14 +19,8 @@ */ package org.neo4j.gds.algorithms.community; -import org.neo4j.gds.Algorithm; -import org.neo4j.gds.GraphAlgorithmFactory; -import org.neo4j.gds.PreconditionsProvider; import org.neo4j.gds.algorithms.AlgorithmComputationResult; -import org.neo4j.gds.algorithms.AlgorithmMemoryEstimation; -import org.neo4j.gds.algorithms.AlgorithmMemoryValidationService; import org.neo4j.gds.api.DatabaseId; -import org.neo4j.gds.api.GraphName; import org.neo4j.gds.api.User; import org.neo4j.gds.approxmaxkcut.ApproxMaxKCutAlgorithmFactory; import org.neo4j.gds.approxmaxkcut.ApproxMaxKCutResult; @@ -35,12 +29,7 @@ import org.neo4j.gds.conductance.ConductanceAlgorithmFactory; import org.neo4j.gds.conductance.ConductanceBaseConfig; import org.neo4j.gds.conductance.ConductanceResult; -import org.neo4j.gds.config.AlgoBaseConfig; -import org.neo4j.gds.core.GraphDimensions; -import org.neo4j.gds.core.loading.GraphStoreCatalogService; import org.neo4j.gds.core.utils.paged.dss.DisjointSetStruct; -import org.neo4j.gds.core.utils.progress.TaskRegistryFactory; -import org.neo4j.gds.core.utils.warnings.UserLogRegistryFactory; import org.neo4j.gds.k1coloring.K1ColoringAlgorithmFactory; import org.neo4j.gds.k1coloring.K1ColoringBaseConfig; import org.neo4j.gds.k1coloring.K1ColoringResult; @@ -56,7 +45,6 @@ import org.neo4j.gds.leiden.LeidenAlgorithmFactory; import org.neo4j.gds.leiden.LeidenBaseConfig; import org.neo4j.gds.leiden.LeidenResult; -import org.neo4j.gds.logging.Log; import org.neo4j.gds.louvain.LouvainAlgorithmFactory; import org.neo4j.gds.louvain.LouvainBaseConfig; import org.neo4j.gds.louvain.LouvainResult; @@ -80,24 +68,11 @@ import java.util.Optional; public class CommunityAlgorithmsFacade { - private final GraphStoreCatalogService graphStoreCatalogService; - private final TaskRegistryFactory taskRegistryFactory; - private final UserLogRegistryFactory userLogRegistryFactory; - private final AlgorithmMemoryValidationService memoryUsageValidator; - private final Log log; - public CommunityAlgorithmsFacade( - GraphStoreCatalogService graphStoreCatalogService, - TaskRegistryFactory taskRegistryFactory, - UserLogRegistryFactory userLogRegistryFactory, - AlgorithmMemoryValidationService memoryUsageValidator, - Log log - ) { - this.graphStoreCatalogService = graphStoreCatalogService; - this.taskRegistryFactory = taskRegistryFactory; - this.userLogRegistryFactory = userLogRegistryFactory; - this.memoryUsageValidator = memoryUsageValidator; - this.log = log; + private final BasicAlgorithmRunner algorithmRunner; + + public CommunityAlgorithmsFacade(BasicAlgorithmRunner algorithmRunner) { + this.algorithmRunner = algorithmRunner; } AlgorithmComputationResult wcc( @@ -106,7 +81,7 @@ AlgorithmComputationResult wcc( User user, DatabaseId databaseId ) { - return run( + return algorithmRunner.run( graphName, config, config.relationshipWeightProperty(), @@ -122,7 +97,7 @@ AlgorithmComputationResult triangleCount( User user, DatabaseId databaseId ) { - return run( + return algorithmRunner.run( graphName, config, Optional.empty(), @@ -138,7 +113,7 @@ AlgorithmComputationResult kCore( User user, DatabaseId databaseId ) { - return run( + return algorithmRunner.run( graphName, config, Optional.empty(), @@ -154,7 +129,7 @@ AlgorithmComputationResult louvain( User user, DatabaseId databaseId ) { - return run( + return algorithmRunner.run( graphName, config, config.relationshipWeightProperty(), @@ -170,7 +145,7 @@ AlgorithmComputationResult leiden( User user, DatabaseId databaseId ) { - return run( + return algorithmRunner.run( graphName, config, config.relationshipWeightProperty(), @@ -186,7 +161,7 @@ AlgorithmComputationResult labelPropagation( User user, DatabaseId databaseId ) { - return run( + return algorithmRunner.run( graphName, configuration, configuration.relationshipWeightProperty(), @@ -202,7 +177,7 @@ AlgorithmComputationResult scc( User user, DatabaseId databaseId ) { - return run( + return algorithmRunner.run( graphName, config, Optional.empty(), @@ -218,7 +193,7 @@ AlgorithmComputationResult modularity( User user, DatabaseId databaseId ) { - return run( + return algorithmRunner.run( graphName, config, config.relationshipWeightProperty(), @@ -234,7 +209,7 @@ AlgorithmComputationResult kmeans( User user, DatabaseId databaseId ) { - return run( + return algorithmRunner.run( graphName, config, Optional.empty(), @@ -250,7 +225,7 @@ public AlgorithmComputationResult localCluster User user, DatabaseId databaseId ) { - return run( + return algorithmRunner.run( graphName, config, Optional.empty(), @@ -266,7 +241,7 @@ AlgorithmComputationResult k1Coloring( User user, DatabaseId databaseId ) { - return run( + return algorithmRunner.run( graphName, config, Optional.empty(), @@ -282,7 +257,7 @@ AlgorithmComputationResult conductance( User user, DatabaseId databaseId ) { - return run( + return algorithmRunner.run( graphName, config, config.relationshipWeightProperty(), @@ -298,7 +273,7 @@ AlgorithmComputationResult approxMaxKCut( User user, DatabaseId databaseId ) { - return run( + return algorithmRunner.run( graphName, config, config.relationshipWeightProperty(), @@ -315,7 +290,7 @@ public AlgorithmComputationResult modularityOptimi User user, DatabaseId databaseId ) { - return run( + return algorithmRunner.run( graphName, config, config.relationshipWeightProperty(), @@ -326,65 +301,4 @@ public AlgorithmComputationResult modularityOptimi } - private , R, C extends AlgoBaseConfig> AlgorithmComputationResult run( - String graphName, - C config, - Optional relationshipProperty, - GraphAlgorithmFactory algorithmFactory, - User user, - DatabaseId databaseId - ) { - // TODO: Is this the best place to check for preconditions??? - PreconditionsProvider.preconditions().check(); - - // Go get the graph and graph store from the catalog - var graphWithGraphStore = graphStoreCatalogService.getGraphWithGraphStore( - GraphName.parse(graphName), - config, - relationshipProperty, - user, - databaseId - ); - - var graph = graphWithGraphStore.getLeft(); - var graphStore = graphWithGraphStore.getRight(); - - // No algorithm execution when the graph is empty - if (graph.isEmpty()) { - return AlgorithmComputationResult.withoutAlgorithmResult(graph, graphStore); - } - - // create the algorithm - var algorithmEstimator = new AlgorithmMemoryEstimation<>( - GraphDimensions.of( - graph.nodeCount(), - graph.relationshipCount() - ), - algorithmFactory - ); - - memoryUsageValidator.validateAlgorithmCanRunWithTheAvailableMemory( - config, - algorithmEstimator::memoryEstimation, - graphStoreCatalogService.graphStoreCount() - ); - var algorithm = algorithmFactory.build( - graph, - config, - (org.neo4j.logging.Log) log.getNeo4jLog(), - taskRegistryFactory, - userLogRegistryFactory - ); - - // run the algorithm - try { - var algorithmResult = algorithm.compute(); - - return AlgorithmComputationResult.of(algorithmResult, graph, graphStore, algorithm.getTerminationFlag()); - } catch (Exception e) { - log.warn("Computation failed", e); - algorithm.getProgressTracker().endSubTaskWithFailure(); - throw e; - } - } } diff --git a/algo/src/test/java/org/neo4j/gds/algorithms/community/CommunityAlgorithmsStreamBusinessFacadeTest.java b/algo/src/test/java/org/neo4j/gds/algorithms/community/CommunityAlgorithmsStreamBusinessFacadeTest.java deleted file mode 100644 index 68eca36f0e..0000000000 --- a/algo/src/test/java/org/neo4j/gds/algorithms/community/CommunityAlgorithmsStreamBusinessFacadeTest.java +++ /dev/null @@ -1,149 +0,0 @@ -/* - * Copyright (c) "Neo4j" - * Neo4j Sweden AB [http://neo4j.com] - * - * This file is part of Neo4j. - * - * Neo4j is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ -package org.neo4j.gds.algorithms.community; - -import org.apache.commons.lang3.tuple.Pair; -import org.junit.jupiter.api.Nested; -import org.junit.jupiter.api.Test; -import org.neo4j.gds.algorithms.AlgorithmMemoryValidationService; -import org.neo4j.gds.api.Graph; -import org.neo4j.gds.api.GraphStore; -import org.neo4j.gds.compat.Neo4jProxy; -import org.neo4j.gds.core.loading.GraphStoreCatalogService; -import org.neo4j.gds.core.utils.progress.TaskRegistryFactory; -import org.neo4j.gds.core.utils.warnings.EmptyUserLogRegistryFactory; -import org.neo4j.gds.core.utils.warnings.UserLogRegistryFactory; -import org.neo4j.gds.extension.GdlExtension; -import org.neo4j.gds.extension.GdlGraph; -import org.neo4j.gds.extension.Inject; -import org.neo4j.gds.extension.TestGraph; -import org.neo4j.gds.logging.Log; -import org.neo4j.gds.wcc.WccBaseConfig; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -class CommunityAlgorithmsStreamBusinessFacadeTest { - - @Nested - @GdlExtension - class WccTest { - @GdlGraph - private static final String TEST_GRAPH = - "CREATE" + - " (a:Node)" + - ", (b:Node)" + - ", (c:Node)" + - ", (d:Node)" + - ", (e:Node)" + - ", (f:Node)" + - ", (g:Node)" + - ", (h:Node)" + - ", (i:Node)" + - // {J} - ", (j:Node)" + - // {A, B, C, D} - ", (a)-[:TYPE]->(b)" + - ", (b)-[:TYPE]->(c)" + - ", (c)-[:TYPE]->(d)" + - ", (d)-[:TYPE]->(a)" + - // {E, F, G} - ", (e)-[:TYPE]->(f)" + - ", (f)-[:TYPE]->(g)" + - ", (g)-[:TYPE]->(e)" + - // {H, I} - ", (i)-[:TYPE]->(h)" + - ", (h)-[:TYPE]->(i)"; - - @Inject - TestGraph graph; - - @Inject - GraphStore graphStore; - - @Test - void wcc() { - // given - var graphStoreCatalogServiceMock = mock(GraphStoreCatalogService.class); - doReturn(Pair.of(graph, graphStore)) - .when(graphStoreCatalogServiceMock) - .getGraphWithGraphStore(any(), any(), any(), any(), any()); - - var config = mock(WccBaseConfig.class); - when(config.concurrency()).thenReturn(4); - var logMock = mock(Log.class); - when(logMock.getNeo4jLog()).thenReturn(Neo4jProxy.testLog()); - - var algorithmsBusinessFacade = new CommunityAlgorithmsStreamBusinessFacade( - new CommunityAlgorithmsFacade( - graphStoreCatalogServiceMock, - TaskRegistryFactory.empty(), - EmptyUserLogRegistryFactory.INSTANCE, - mock(AlgorithmMemoryValidationService.class), - logMock - ) - ); - - // when - var wccComputationResult = algorithmsBusinessFacade.wcc( - "meh", - config, - null, - null - ); - - //then - assertThat(wccComputationResult.result()) - .isNotEmpty() - .get() - .satisfies(disjointSetStruct -> { - assertThat(disjointSetStruct.size()).isEqualTo(10); - }); - assertThat(wccComputationResult.graph()).isSameAs(graph); - } - - @Test - void wccOnEmptyGraph() { - // given - var graphStoreCatalogServiceMock = mock(GraphStoreCatalogService.class); - var graphMock = mock(Graph.class); - when(graphMock.isEmpty()).thenReturn(true); - doReturn(Pair.of(graphMock, mock(GraphStore.class))) - .when(graphStoreCatalogServiceMock) - .getGraphWithGraphStore(any(), any(), any(), any(), any()); - var algorithmsBusinessFacade = new CommunityAlgorithmsStreamBusinessFacade( - new CommunityAlgorithmsFacade(graphStoreCatalogServiceMock, - mock(TaskRegistryFactory.class), - mock(UserLogRegistryFactory.class), - null, null - ) - ); - - // when - var wccComputationResult = algorithmsBusinessFacade.wcc("meh", mock(WccBaseConfig.class), null, null); - - //then - assertThat(wccComputationResult.result()).isEmpty(); - } - } -} diff --git a/proc/community/build.gradle b/proc/community/build.gradle index 374bdc657e..93f93b5ee3 100644 --- a/proc/community/build.gradle +++ b/proc/community/build.gradle @@ -41,6 +41,7 @@ dependencies { testAnnotationProcessor project(':annotations') testAnnotationProcessor project(':config-generator') + testImplementation project(':algorithm-metrics-api') testImplementation project(':logging') testImplementation project(':memory-usage') testImplementation project(':native-projection') diff --git a/proc/community/src/test/java/org/neo4j/gds/k1coloring/K1ColoringStreamProcTest.java b/proc/community/src/test/java/org/neo4j/gds/k1coloring/K1ColoringStreamProcTest.java index 7a0a7fef49..5f3564e88d 100644 --- a/proc/community/src/test/java/org/neo4j/gds/k1coloring/K1ColoringStreamProcTest.java +++ b/proc/community/src/test/java/org/neo4j/gds/k1coloring/K1ColoringStreamProcTest.java @@ -33,8 +33,11 @@ import org.neo4j.gds.Orientation; import org.neo4j.gds.TestProcedureRunner; import org.neo4j.gds.algorithms.AlgorithmMemoryValidationService; +import org.neo4j.gds.algorithms.community.BasicAlgorithmRunner; import org.neo4j.gds.algorithms.community.CommunityAlgorithmsFacade; import org.neo4j.gds.algorithms.community.CommunityAlgorithmsStreamBusinessFacade; +import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService; +import org.neo4j.gds.algorithms.metrics.PassthroughAlgorithmMetricRegistrar; import org.neo4j.gds.api.AlgorithmMetaDataSetter; import org.neo4j.gds.api.DatabaseId; import org.neo4j.gds.api.ProcedureReturnColumns; @@ -195,10 +198,15 @@ void shouldRegisterTaskWithCorrectJobId() { proc.taskRegistryFactory = taskRegistryFactory; var algorithmsStreamBusinessFacade = new CommunityAlgorithmsStreamBusinessFacade( - new CommunityAlgorithmsFacade(graphStoreCatalogService, - taskRegistryFactory, - EmptyUserLogRegistryFactory.INSTANCE, - memoryUsageValidator, logMock + new CommunityAlgorithmsFacade( + new BasicAlgorithmRunner( + graphStoreCatalogService, + taskRegistryFactory, + EmptyUserLogRegistryFactory.INSTANCE, + memoryUsageValidator, + new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar()), + logMock + ) )); proc.facade = new GraphDataScience( null, diff --git a/proc/community/src/test/java/org/neo4j/gds/labelpropagation/LabelPropagationMutateProcTest.java b/proc/community/src/test/java/org/neo4j/gds/labelpropagation/LabelPropagationMutateProcTest.java index 7a6cc4fa43..e18e1dbfac 100644 --- a/proc/community/src/test/java/org/neo4j/gds/labelpropagation/LabelPropagationMutateProcTest.java +++ b/proc/community/src/test/java/org/neo4j/gds/labelpropagation/LabelPropagationMutateProcTest.java @@ -43,9 +43,12 @@ import org.neo4j.gds.TestProcedureRunner; import org.neo4j.gds.TestSupport; import org.neo4j.gds.algorithms.AlgorithmMemoryValidationService; +import org.neo4j.gds.algorithms.community.BasicAlgorithmRunner; import org.neo4j.gds.algorithms.community.CommunityAlgorithmsFacade; import org.neo4j.gds.algorithms.community.CommunityAlgorithmsMutateBusinessFacade; import org.neo4j.gds.algorithms.community.MutateNodePropertyService; +import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService; +import org.neo4j.gds.algorithms.metrics.PassthroughAlgorithmMetricRegistrar; import org.neo4j.gds.api.DatabaseId; import org.neo4j.gds.api.DefaultValue; import org.neo4j.gds.api.Graph; @@ -306,21 +309,7 @@ void testWriteBackGraphMutationOnFilteredGraph() { storeLoaderBuilder.putRelationshipProjectionsWithIdentifier(relationshipType.name(), projection)); GraphLoader loader = storeLoaderBuilder.build(); GraphStoreCatalog.set(loader.projectConfig(), loader.graphStore()); - var logMock = mock(org.neo4j.gds.logging.Log.class); - when(logMock.getNeo4jLog()).thenReturn(Neo4jProxy.testLog()); - - final GraphStoreCatalogService graphStoreCatalogService = new GraphStoreCatalogService(); - final AlgorithmMemoryValidationService memoryUsageValidator = new AlgorithmMemoryValidationService( - logMock, - false - ); - var algorithmsMutateBusinessFacade = new CommunityAlgorithmsMutateBusinessFacade( - new CommunityAlgorithmsFacade(graphStoreCatalogService, - TaskRegistryFactory.empty(), - EmptyUserLogRegistryFactory.INSTANCE, - memoryUsageValidator, logMock), - new MutateNodePropertyService(logMock) - ); + var algorithmsMutateBusinessFacade = communityAlgorithmsMutateBusinessFacade(); TestProcedureRunner.applyOnProcedure(db, LabelPropagationMutateProc.class, procedure -> { procedure.facade = new GraphDataScience( @@ -378,6 +367,32 @@ void testWriteBackGraphMutationOnFilteredGraph() { ); } + @NotNull + private static CommunityAlgorithmsMutateBusinessFacade communityAlgorithmsMutateBusinessFacade() { + var logMock = mock(org.neo4j.gds.logging.Log.class); + when(logMock.getNeo4jLog()).thenReturn(Neo4jProxy.testLog()); + + final GraphStoreCatalogService graphStoreCatalogService = new GraphStoreCatalogService(); + final AlgorithmMemoryValidationService memoryUsageValidator = new AlgorithmMemoryValidationService( + logMock, + false + ); + var algorithmsMutateBusinessFacade = new CommunityAlgorithmsMutateBusinessFacade( + new CommunityAlgorithmsFacade( + new BasicAlgorithmRunner( + graphStoreCatalogService, + TaskRegistryFactory.empty(), + EmptyUserLogRegistryFactory.INSTANCE, + memoryUsageValidator, + new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar()), + logMock + ) + ), + new MutateNodePropertyService(logMock) + ); + return algorithmsMutateBusinessFacade; + } + @Test void testGraphMutation() { var graphStore = runMutation(ensureGraphExists(), Map.of("mutateProperty", MUTATE_PROPERTY)); @@ -424,21 +439,7 @@ void testGraphMutationOnFilteredGraph() { @Test void testMutateFailsOnExistingToken() { String graphName = ensureGraphExists(); - var logMock = mock(org.neo4j.gds.logging.Log.class); - when(logMock.getNeo4jLog()).thenReturn(Neo4jProxy.testLog()); - - final GraphStoreCatalogService graphStoreCatalogService = new GraphStoreCatalogService(); - final AlgorithmMemoryValidationService memoryUsageValidator = new AlgorithmMemoryValidationService( - logMock, - false - ); - var algorithmsMutateBusinessFacade = new CommunityAlgorithmsMutateBusinessFacade( - new CommunityAlgorithmsFacade(graphStoreCatalogService, - TaskRegistryFactory.empty(), - EmptyUserLogRegistryFactory.INSTANCE, - memoryUsageValidator, logMock), - new MutateNodePropertyService(logMock) - ); + var algorithmsMutateBusinessFacade = communityAlgorithmsMutateBusinessFacade(); TestProcedureRunner.applyOnProcedure(db, LabelPropagationMutateProc.class, procedure -> { procedure.facade = new GraphDataScience( @@ -497,21 +498,7 @@ void testExceptionLogging() { @Test void testRunOnEmptyGraph() { - var logMock = mock(org.neo4j.gds.logging.Log.class); - when(logMock.getNeo4jLog()).thenReturn(Neo4jProxy.testLog()); - - final GraphStoreCatalogService graphStoreCatalogService = new GraphStoreCatalogService(); - final AlgorithmMemoryValidationService memoryUsageValidator = new AlgorithmMemoryValidationService( - logMock, - false - ); - var algorithmsMutateBusinessFacade = new CommunityAlgorithmsMutateBusinessFacade( - new CommunityAlgorithmsFacade(graphStoreCatalogService, - TaskRegistryFactory.empty(), - EmptyUserLogRegistryFactory.INSTANCE, - memoryUsageValidator, logMock), - new MutateNodePropertyService(logMock) - ); + var algorithmsMutateBusinessFacade = communityAlgorithmsMutateBusinessFacade(); TestProcedureRunner.applyOnProcedure(db, LabelPropagationMutateProc.class, (procedure) -> { procedure.facade = new GraphDataScience( @@ -573,21 +560,7 @@ private String ensureGraphExists() { @NotNull private GraphStore runMutation(String graphName, Map config) { - var logMock = mock(org.neo4j.gds.logging.Log.class); - when(logMock.getNeo4jLog()).thenReturn(Neo4jProxy.testLog()); - - final GraphStoreCatalogService graphStoreCatalogService = new GraphStoreCatalogService(); - final AlgorithmMemoryValidationService memoryUsageValidator = new AlgorithmMemoryValidationService( - logMock, - false - ); - var algorithmsMutateBusinessFacade = new CommunityAlgorithmsMutateBusinessFacade( - new CommunityAlgorithmsFacade(graphStoreCatalogService, - TaskRegistryFactory.empty(), - EmptyUserLogRegistryFactory.INSTANCE, - memoryUsageValidator, logMock), - new MutateNodePropertyService(logMock) - ); + var algorithmsMutateBusinessFacade = communityAlgorithmsMutateBusinessFacade(); TestProcedureRunner.applyOnProcedure(db, LabelPropagationMutateProc.class, procedure -> { procedure.facade = new GraphDataScience( diff --git a/proc/community/src/test/java/org/neo4j/gds/modularityoptimization/ModularityOptimizationMutateProcTest.java b/proc/community/src/test/java/org/neo4j/gds/modularityoptimization/ModularityOptimizationMutateProcTest.java index c09bb52d4b..03220edcaa 100644 --- a/proc/community/src/test/java/org/neo4j/gds/modularityoptimization/ModularityOptimizationMutateProcTest.java +++ b/proc/community/src/test/java/org/neo4j/gds/modularityoptimization/ModularityOptimizationMutateProcTest.java @@ -43,6 +43,7 @@ import org.neo4j.gds.TestProcedureRunner; import org.neo4j.gds.TestSupport; import org.neo4j.gds.algorithms.AlgorithmMemoryValidationService; +import org.neo4j.gds.algorithms.community.BasicAlgorithmRunner; import org.neo4j.gds.algorithms.community.CommunityAlgorithmsEstimateBusinessFacade; import org.neo4j.gds.algorithms.community.CommunityAlgorithmsFacade; import org.neo4j.gds.algorithms.community.CommunityAlgorithmsMutateBusinessFacade; @@ -50,6 +51,8 @@ import org.neo4j.gds.algorithms.community.CommunityAlgorithmsStreamBusinessFacade; import org.neo4j.gds.algorithms.community.CommunityAlgorithmsWriteBusinessFacade; import org.neo4j.gds.algorithms.community.MutateNodePropertyService; +import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService; +import org.neo4j.gds.algorithms.metrics.PassthroughAlgorithmMetricRegistrar; import org.neo4j.gds.api.DatabaseId; import org.neo4j.gds.api.Graph; import org.neo4j.gds.api.GraphStore; @@ -565,10 +568,16 @@ private GraphDataScience createFacade() { ); var algorithmsMutateBusinessFacade = new CommunityAlgorithmsMutateBusinessFacade( - new CommunityAlgorithmsFacade(graphStoreCatalogService, - TaskRegistryFactory.empty(), - EmptyUserLogRegistryFactory.INSTANCE, - memoryUsageValidator, logMock), + new CommunityAlgorithmsFacade( + new BasicAlgorithmRunner( + graphStoreCatalogService, + TaskRegistryFactory.empty(), + EmptyUserLogRegistryFactory.INSTANCE, + memoryUsageValidator, + new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar()), + logMock + ) + ), new MutateNodePropertyService(logMock) ); diff --git a/proc/community/src/test/java/org/neo4j/gds/wcc/WccMutateProcTest.java b/proc/community/src/test/java/org/neo4j/gds/wcc/WccMutateProcTest.java index a1bf496243..9d3337e22b 100644 --- a/proc/community/src/test/java/org/neo4j/gds/wcc/WccMutateProcTest.java +++ b/proc/community/src/test/java/org/neo4j/gds/wcc/WccMutateProcTest.java @@ -40,9 +40,12 @@ import org.neo4j.gds.TestProcedureRunner; import org.neo4j.gds.TestSupport; import org.neo4j.gds.algorithms.AlgorithmMemoryValidationService; +import org.neo4j.gds.algorithms.community.BasicAlgorithmRunner; import org.neo4j.gds.algorithms.community.CommunityAlgorithmsFacade; import org.neo4j.gds.algorithms.community.CommunityAlgorithmsMutateBusinessFacade; import org.neo4j.gds.algorithms.community.MutateNodePropertyService; +import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService; +import org.neo4j.gds.algorithms.metrics.PassthroughAlgorithmMetricRegistrar; import org.neo4j.gds.api.DatabaseId; import org.neo4j.gds.api.DefaultValue; import org.neo4j.gds.api.Graph; @@ -336,21 +339,7 @@ void testWriteBackGraphMutationOnFilteredGraph() { GraphLoader loader = storeLoaderBuilder.build(); GraphStoreCatalog.set(loader.projectConfig(), loader.graphStore()); - var logMock = mock(org.neo4j.gds.logging.Log.class); - when(logMock.getNeo4jLog()).thenReturn(Neo4jProxy.testLog()); - - final GraphStoreCatalogService graphStoreCatalogService = new GraphStoreCatalogService(); - final AlgorithmMemoryValidationService memoryUsageValidator = new AlgorithmMemoryValidationService( - logMock, - false - ); - var algorithmsMutateBusinessFacade = new CommunityAlgorithmsMutateBusinessFacade( - new CommunityAlgorithmsFacade(graphStoreCatalogService, - TaskRegistryFactory.empty(), - EmptyUserLogRegistryFactory.INSTANCE, - memoryUsageValidator, logMock), - new MutateNodePropertyService(logMock) - ); + var algorithmsMutateBusinessFacade = communityAlgorithmsMutateBusinessFacade(); applyOnProcedure(procedure -> { procedure.facade = new GraphDataScience( @@ -408,6 +397,32 @@ void testWriteBackGraphMutationOnFilteredGraph() { ); } + @NotNull + private static CommunityAlgorithmsMutateBusinessFacade communityAlgorithmsMutateBusinessFacade() { + var logMock = mock(org.neo4j.gds.logging.Log.class); + when(logMock.getNeo4jLog()).thenReturn(Neo4jProxy.testLog()); + + final GraphStoreCatalogService graphStoreCatalogService = new GraphStoreCatalogService(); + final AlgorithmMemoryValidationService memoryUsageValidator = new AlgorithmMemoryValidationService( + logMock, + false + ); + var algorithmsMutateBusinessFacade = new CommunityAlgorithmsMutateBusinessFacade( + new CommunityAlgorithmsFacade( + new BasicAlgorithmRunner( + graphStoreCatalogService, + TaskRegistryFactory.empty(), + EmptyUserLogRegistryFactory.INSTANCE, + memoryUsageValidator, + new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar()), + logMock + ) + ), + new MutateNodePropertyService(logMock) + ); + return algorithmsMutateBusinessFacade; + } + @Test void testGraphMutation() { GraphStore graphStore = runMutation(ensureGraphExists(), Map.of()); @@ -460,20 +475,7 @@ void testMutateFailsOnExistingToken() { String graphName = ensureGraphExists(); applyOnProcedure(procedure -> { - var logMock = mock(org.neo4j.gds.logging.Log.class); - when(logMock.getNeo4jLog()).thenReturn(Neo4jProxy.testLog()); - final GraphStoreCatalogService graphStoreCatalogService = new GraphStoreCatalogService(); - final AlgorithmMemoryValidationService memoryUsageValidator = new AlgorithmMemoryValidationService( - logMock, - false - ); - var algorithmsBusinessFacade = new CommunityAlgorithmsMutateBusinessFacade( - new CommunityAlgorithmsFacade(graphStoreCatalogService, - TaskRegistryFactory.empty(), - EmptyUserLogRegistryFactory.INSTANCE, - memoryUsageValidator, logMock), - new MutateNodePropertyService(logMock) - ); + var algorithmsBusinessFacade = communityAlgorithmsMutateBusinessFacade(); procedure.facade = new GraphDataScience( null, @@ -532,20 +534,8 @@ void testExceptionLogging() { @Test void testRunOnEmptyGraph() { + var algorithmsBusinessFacade = communityAlgorithmsMutateBusinessFacade(); applyOnProcedure((proc) -> { - var logMock = mock(org.neo4j.gds.logging.Log.class); - final GraphStoreCatalogService graphStoreCatalogService = new GraphStoreCatalogService(); - final AlgorithmMemoryValidationService memoryUsageValidator = new AlgorithmMemoryValidationService( - logMock, - false - ); - var algorithmsBusinessFacade = new CommunityAlgorithmsMutateBusinessFacade( - new CommunityAlgorithmsFacade(graphStoreCatalogService, - TaskRegistryFactory.empty(), - EmptyUserLogRegistryFactory.INSTANCE, - memoryUsageValidator, null), - new MutateNodePropertyService(logMock) - ); proc.facade = new GraphDataScience( null, null, @@ -613,10 +603,16 @@ private GraphStore runMutation(String graphName, Map additionalC false ); var algorithmsBusinessFacade = new CommunityAlgorithmsMutateBusinessFacade( - new CommunityAlgorithmsFacade(graphStoreCatalogService, - TaskRegistryFactory.empty(), - EmptyUserLogRegistryFactory.INSTANCE, - memoryUsageValidator, logMock), + new CommunityAlgorithmsFacade( + new BasicAlgorithmRunner( + graphStoreCatalogService, + TaskRegistryFactory.empty(), + EmptyUserLogRegistryFactory.INSTANCE, + memoryUsageValidator, + new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar()), + logMock + ) + ), new MutateNodePropertyService(logMock) ); diff --git a/proc/community/src/test/java/org/neo4j/gds/wcc/WccStatsProcTest.java b/proc/community/src/test/java/org/neo4j/gds/wcc/WccStatsProcTest.java index 7954ead414..7a3f63be49 100644 --- a/proc/community/src/test/java/org/neo4j/gds/wcc/WccStatsProcTest.java +++ b/proc/community/src/test/java/org/neo4j/gds/wcc/WccStatsProcTest.java @@ -36,8 +36,11 @@ import org.neo4j.gds.TestProcedureRunner; import org.neo4j.gds.TestSupport; import org.neo4j.gds.algorithms.AlgorithmMemoryValidationService; +import org.neo4j.gds.algorithms.community.BasicAlgorithmRunner; import org.neo4j.gds.algorithms.community.CommunityAlgorithmsFacade; import org.neo4j.gds.algorithms.community.CommunityAlgorithmsStatsBusinessFacade; +import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService; +import org.neo4j.gds.algorithms.metrics.PassthroughAlgorithmMetricRegistrar; import org.neo4j.gds.api.DatabaseId; import org.neo4j.gds.api.ImmutableGraphLoaderContext; import org.neo4j.gds.api.ProcedureReturnColumns; @@ -270,10 +273,16 @@ private GraphDataScience createFacade() { ); var statsBusinessFacade = new CommunityAlgorithmsStatsBusinessFacade( - new CommunityAlgorithmsFacade(graphStoreCatalogService, - TaskRegistryFactory.empty(), - EmptyUserLogRegistryFactory.INSTANCE, - memoryUsageValidator, logMock) + new CommunityAlgorithmsFacade( + new BasicAlgorithmRunner( + graphStoreCatalogService, + TaskRegistryFactory.empty(), + EmptyUserLogRegistryFactory.INSTANCE, + memoryUsageValidator, + new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar()), + logMock + ) + ) ); return new GraphDataScience( diff --git a/proc/community/src/test/java/org/neo4j/gds/wcc/WccWriteProcTest.java b/proc/community/src/test/java/org/neo4j/gds/wcc/WccWriteProcTest.java index 0d2c4d4bd1..b9c2d3523f 100644 --- a/proc/community/src/test/java/org/neo4j/gds/wcc/WccWriteProcTest.java +++ b/proc/community/src/test/java/org/neo4j/gds/wcc/WccWriteProcTest.java @@ -39,9 +39,12 @@ import org.neo4j.gds.TestProcedureRunner; import org.neo4j.gds.TestSupport; import org.neo4j.gds.algorithms.AlgorithmMemoryValidationService; +import org.neo4j.gds.algorithms.community.BasicAlgorithmRunner; import org.neo4j.gds.algorithms.community.CommunityAlgorithmsFacade; import org.neo4j.gds.algorithms.community.CommunityAlgorithmsWriteBusinessFacade; import org.neo4j.gds.algorithms.community.WriteNodePropertyService; +import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService; +import org.neo4j.gds.algorithms.metrics.PassthroughAlgorithmMetricRegistrar; import org.neo4j.gds.api.DatabaseId; import org.neo4j.gds.api.ImmutableGraphLoaderContext; import org.neo4j.gds.api.ProcedureReturnColumns; @@ -471,10 +474,15 @@ void testRunOnEmptyGraph() { var taskRegistry = EmptyTaskRegistryFactory.INSTANCE; var algorithmsBusinessFacade = new CommunityAlgorithmsWriteBusinessFacade( - new CommunityAlgorithmsFacade(graphStoreCatalogService, - taskRegistry, - EmptyUserLogRegistryFactory.INSTANCE, - memoryUsageValidator, logMock + new CommunityAlgorithmsFacade( + new BasicAlgorithmRunner( + graphStoreCatalogService, + taskRegistry, + EmptyUserLogRegistryFactory.INSTANCE, + memoryUsageValidator, + new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar()), + logMock + ) ), new WriteNodePropertyService( wccWriteProc.executionContext().nodePropertyExporterBuilder(), diff --git a/procedures/extension/src/main/java/org/neo4j/gds/extension/OpenGraphDataScienceExtension.java b/procedures/extension/src/main/java/org/neo4j/gds/extension/OpenGraphDataScienceExtension.java index befb41adc6..a6e8cf9508 100644 --- a/procedures/extension/src/main/java/org/neo4j/gds/extension/OpenGraphDataScienceExtension.java +++ b/procedures/extension/src/main/java/org/neo4j/gds/extension/OpenGraphDataScienceExtension.java @@ -54,7 +54,8 @@ public Lifecycle newInstance(ExtensionContext extensionContext, Dependencies dep var log = new LogAccessor().getLog(dependencies.logService(), getClass()); var extensionBuilder = ExtensionBuilder.create( - log, dependencies.config(), + log, + dependencies.config(), dependencies.globalProcedures() ); diff --git a/procedures/integration/build.gradle b/procedures/integration/build.gradle index fa589f7e20..532bc21ed0 100644 --- a/procedures/integration/build.gradle +++ b/procedures/integration/build.gradle @@ -13,6 +13,7 @@ dependencies { testImplementation(group: 'org.neo4j', name: it, version: ver.'neo4j') } + implementation project(':algorithm-metrics-api') implementation project(':config-api') implementation project(':core') implementation project(':core-write') diff --git a/procedures/integration/src/main/java/org/neo4j/gds/procedures/integration/CommunityProcedureProvider.java b/procedures/integration/src/main/java/org/neo4j/gds/procedures/integration/CommunityProcedureProvider.java index a41be23797..1c1bc9f4fd 100644 --- a/procedures/integration/src/main/java/org/neo4j/gds/procedures/integration/CommunityProcedureProvider.java +++ b/procedures/integration/src/main/java/org/neo4j/gds/procedures/integration/CommunityProcedureProvider.java @@ -21,6 +21,7 @@ import org.neo4j.gds.ProcedureCallContextReturnColumns; import org.neo4j.gds.algorithms.AlgorithmMemoryValidationService; +import org.neo4j.gds.algorithms.community.BasicAlgorithmRunner; import org.neo4j.gds.algorithms.community.CommunityAlgorithmsEstimateBusinessFacade; import org.neo4j.gds.algorithms.community.CommunityAlgorithmsFacade; import org.neo4j.gds.algorithms.community.CommunityAlgorithmsMutateBusinessFacade; @@ -29,6 +30,8 @@ import org.neo4j.gds.algorithms.community.CommunityAlgorithmsWriteBusinessFacade; import org.neo4j.gds.algorithms.community.MutateNodePropertyService; import org.neo4j.gds.algorithms.community.WriteNodePropertyService; +import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService; +import org.neo4j.gds.algorithms.metrics.PassthroughAlgorithmMetricRegistrar; import org.neo4j.gds.api.DatabaseId; import org.neo4j.gds.api.GraphLoaderContext; import org.neo4j.gds.api.ImmutableGraphLoaderContext; @@ -121,15 +124,18 @@ public CommunityProcedureFacade createCommunityProcedureFacade(Context context) var exportBuildersProvider = exporterBuildersProviderService.identifyExportBuildersProvider(graphDatabaseService); - // algorithm facade - var communityAlgorithmsFacade = new CommunityAlgorithmsFacade( + var algorithmRunner = new BasicAlgorithmRunner( graphStoreCatalogService, taskRegistryFactory, userLogRegistryFactory, algorithmMemoryValidationService, + new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar()), log ); + // algorithm facade + var communityAlgorithmsFacade = new CommunityAlgorithmsFacade(algorithmRunner); + // moar services var fictitiousGraphStoreEstimationService = new FictitiousGraphStoreEstimationService(); var graphLoaderContext = buildGraphLoaderContext( From 45ac77ee20d5e2bf1d9e4c59feb1a5680274b220 Mon Sep 17 00:00:00 2001 From: Veselin Nikolov Date: Tue, 14 Nov 2023 08:08:55 +0000 Subject: [PATCH 4/8] Add timer capability to the Prometheus metrics Co-authored-by: Ioannis Panagiotas --- .../community/BasicAlgorithmRunner.java | 7 ++-- .../community/BasicAlgorithmRunnerTest.java | 28 +++++++++++---- .../algorithms/metrics/AlgorithmMetric.java | 36 +++++++++++++++++++ .../metrics/AlgorithmMetricRegistrar.java | 4 +-- .../metrics/AlgorithmMetricsService.java | 9 ++--- .../metrics/PassthroughAlgorithmMetric.java | 36 +++++++++++++++++++ .../PassthroughAlgorithmMetricRegistrar.java | 9 ++--- .../metrics/AlgorithmMetricsServiceTest.java | 19 ++-------- 8 files changed, 107 insertions(+), 41 deletions(-) create mode 100644 algorithm-metrics-api/src/main/java/org/neo4j/gds/algorithms/metrics/AlgorithmMetric.java create mode 100644 algorithm-metrics-api/src/main/java/org/neo4j/gds/algorithms/metrics/PassthroughAlgorithmMetric.java diff --git a/algo/src/main/java/org/neo4j/gds/algorithms/community/BasicAlgorithmRunner.java b/algo/src/main/java/org/neo4j/gds/algorithms/community/BasicAlgorithmRunner.java index ff14781ff8..85f1ea5878 100644 --- a/algo/src/main/java/org/neo4j/gds/algorithms/community/BasicAlgorithmRunner.java +++ b/algo/src/main/java/org/neo4j/gds/algorithms/community/BasicAlgorithmRunner.java @@ -115,15 +115,16 @@ public , R, C extends AlgoBaseConfig> AlgorithmComputatio ); // run the algorithm - try { - algorithmMetricsService.started(algorithmFactory.taskName()); + var algorithmMetric = algorithmMetricsService.create(algorithmFactory.taskName()); + try(algorithmMetric) { + algorithmMetric.start(); var algorithmResult = algorithm.compute(); return AlgorithmComputationResult.of(algorithmResult, graph, graphStore, algorithm.getTerminationFlag()); } catch (Exception e) { log.warn("Computation failed", e); algorithm.getProgressTracker().endSubTaskWithFailure(); - algorithmMetricsService.failed(algorithmFactory.taskName()); + algorithmMetric.failed(); throw e; } } diff --git a/algo/src/test/java/org/neo4j/gds/algorithms/community/BasicAlgorithmRunnerTest.java b/algo/src/test/java/org/neo4j/gds/algorithms/community/BasicAlgorithmRunnerTest.java index 9edf4c83fe..1c4206391f 100644 --- a/algo/src/test/java/org/neo4j/gds/algorithms/community/BasicAlgorithmRunnerTest.java +++ b/algo/src/test/java/org/neo4j/gds/algorithms/community/BasicAlgorithmRunnerTest.java @@ -24,6 +24,7 @@ import org.neo4j.gds.Algorithm; import org.neo4j.gds.GraphAlgorithmFactory; import org.neo4j.gds.algorithms.AlgorithmMemoryValidationService; +import org.neo4j.gds.algorithms.metrics.AlgorithmMetric; import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService; import org.neo4j.gds.api.DatabaseId; import org.neo4j.gds.api.Graph; @@ -40,6 +41,7 @@ import static org.assertj.core.api.Assertions.assertThatException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.RETURNS_DEEP_STUBS; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -58,7 +60,9 @@ void shouldRegisterAlgorithmMetricCountForSuccess() { when(graphStoreCatalogServiceMock.getGraphWithGraphStore(any(), any(), any(), any(), any())) .thenReturn(Pair.of(graphMock, mock(GraphStore.class))); + var algorithmMetricMock = mock(AlgorithmMetric.class); var algorithmMetricsServiceMock = mock(AlgorithmMetricsService.class); + when(algorithmMetricsServiceMock.create(anyString())).thenReturn(algorithmMetricMock); var logMock = mock(Log.class); when(logMock.getNeo4jLog()).thenReturn(Neo4jProxy.testLog()); @@ -87,9 +91,14 @@ void shouldRegisterAlgorithmMetricCountForSuccess() { DatabaseId.EMPTY ); - verify(algorithmMetricsServiceMock, times(1)).started("TestingMetrics"); - verify(algorithmMetricsServiceMock, times(0)).failed("TestingMetrics"); - verifyNoMoreInteractions(algorithmMetricsServiceMock); + verify(algorithmMetricsServiceMock, times(1)).create("TestingMetrics"); + verify(algorithmMetricMock, times(1)).start(); + verify(algorithmMetricMock, times(1)).close(); + verify(algorithmMetricMock, times(0)).failed(); + verifyNoMoreInteractions( + algorithmMetricsServiceMock, + algorithmMetricMock + ); } @@ -102,7 +111,9 @@ void shouldRegisterAlgorithmMetricCountForFailure() { when(graphStoreCatalogServiceMock.getGraphWithGraphStore(any(), any(), any(), any(), any())) .thenReturn(Pair.of(graphMock, mock(GraphStore.class))); + var algorithmMetricMock = mock(AlgorithmMetric.class); var algorithmMetricsServiceMock = mock(AlgorithmMetricsService.class); + when(algorithmMetricsServiceMock.create(anyString())).thenReturn(algorithmMetricMock); var logMock = mock(Log.class); when(logMock.getNeo4jLog()).thenReturn(Neo4jProxy.testLog()); @@ -134,9 +145,14 @@ void shouldRegisterAlgorithmMetricCountForFailure() { ) ).withMessage("Ooops"); - verify(algorithmMetricsServiceMock, times(1)).started("TestingMetrics"); - verify(algorithmMetricsServiceMock, times(1)).failed("TestingMetrics"); - verifyNoMoreInteractions(algorithmMetricsServiceMock); + verify(algorithmMetricsServiceMock, times(1)).create("TestingMetrics"); + verify(algorithmMetricMock, times(1)).start(); + verify(algorithmMetricMock, times(1)).close(); + verify(algorithmMetricMock, times(1)).failed(); + verifyNoMoreInteractions( + algorithmMetricsServiceMock, + algorithmMetricMock + ); } } diff --git a/algorithm-metrics-api/src/main/java/org/neo4j/gds/algorithms/metrics/AlgorithmMetric.java b/algorithm-metrics-api/src/main/java/org/neo4j/gds/algorithms/metrics/AlgorithmMetric.java new file mode 100644 index 0000000000..abe76bba13 --- /dev/null +++ b/algorithm-metrics-api/src/main/java/org/neo4j/gds/algorithms/metrics/AlgorithmMetric.java @@ -0,0 +1,36 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.algorithms.metrics; + +public abstract class AlgorithmMetric implements AutoCloseable { + + protected final String algorithm; + + protected AlgorithmMetric(String algorithm) { + this.algorithm = algorithm; + } + + public abstract void start(); + + public abstract void failed(); + + @Override + public abstract void close(); +} diff --git a/algorithm-metrics-api/src/main/java/org/neo4j/gds/algorithms/metrics/AlgorithmMetricRegistrar.java b/algorithm-metrics-api/src/main/java/org/neo4j/gds/algorithms/metrics/AlgorithmMetricRegistrar.java index 9f98c0ed99..0f793c4655 100644 --- a/algorithm-metrics-api/src/main/java/org/neo4j/gds/algorithms/metrics/AlgorithmMetricRegistrar.java +++ b/algorithm-metrics-api/src/main/java/org/neo4j/gds/algorithms/metrics/AlgorithmMetricRegistrar.java @@ -21,6 +21,6 @@ public interface AlgorithmMetricRegistrar { - void started(String algorithm); - void failed(String algorithm); + AlgorithmMetric create(String algorithm); + } diff --git a/algorithm-metrics-api/src/main/java/org/neo4j/gds/algorithms/metrics/AlgorithmMetricsService.java b/algorithm-metrics-api/src/main/java/org/neo4j/gds/algorithms/metrics/AlgorithmMetricsService.java index 2dc6c24c0a..fc6350364c 100644 --- a/algorithm-metrics-api/src/main/java/org/neo4j/gds/algorithms/metrics/AlgorithmMetricsService.java +++ b/algorithm-metrics-api/src/main/java/org/neo4j/gds/algorithms/metrics/AlgorithmMetricsService.java @@ -27,12 +27,7 @@ public AlgorithmMetricsService(AlgorithmMetricRegistrar metricRegistrar) { this.metricRegistrar = metricRegistrar; } - public void started(String algorithm) { - metricRegistrar.started(algorithm); + public AlgorithmMetric create(String algorithm) { + return metricRegistrar.create(algorithm); } - - public void failed(String algorithm) { - metricRegistrar.failed(algorithm); - } - } diff --git a/algorithm-metrics-api/src/main/java/org/neo4j/gds/algorithms/metrics/PassthroughAlgorithmMetric.java b/algorithm-metrics-api/src/main/java/org/neo4j/gds/algorithms/metrics/PassthroughAlgorithmMetric.java new file mode 100644 index 0000000000..597d809a08 --- /dev/null +++ b/algorithm-metrics-api/src/main/java/org/neo4j/gds/algorithms/metrics/PassthroughAlgorithmMetric.java @@ -0,0 +1,36 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.algorithms.metrics; + +public final class PassthroughAlgorithmMetric extends AlgorithmMetric { + + PassthroughAlgorithmMetric(String algorithm) { + super(algorithm); + } + + @Override + public void start() {} + + @Override + public void failed() {} + + @Override + public void close() {} +} diff --git a/algorithm-metrics-api/src/main/java/org/neo4j/gds/algorithms/metrics/PassthroughAlgorithmMetricRegistrar.java b/algorithm-metrics-api/src/main/java/org/neo4j/gds/algorithms/metrics/PassthroughAlgorithmMetricRegistrar.java index bb4fbad2d4..4cbdf232d0 100644 --- a/algorithm-metrics-api/src/main/java/org/neo4j/gds/algorithms/metrics/PassthroughAlgorithmMetricRegistrar.java +++ b/algorithm-metrics-api/src/main/java/org/neo4j/gds/algorithms/metrics/PassthroughAlgorithmMetricRegistrar.java @@ -25,12 +25,7 @@ public class PassthroughAlgorithmMetricRegistrar implements AlgorithmMetricRegistrar { @Override - public void started(String algorithm) { - - } - - @Override - public void failed(String algorithm) { - + public AlgorithmMetric create(String algorithm) { + return new PassthroughAlgorithmMetric(algorithm); } } diff --git a/algorithm-metrics-api/src/test/java/org/neo4j/gds/algorithms/metrics/AlgorithmMetricsServiceTest.java b/algorithm-metrics-api/src/test/java/org/neo4j/gds/algorithms/metrics/AlgorithmMetricsServiceTest.java index a869605f09..47bab58d5e 100644 --- a/algorithm-metrics-api/src/test/java/org/neo4j/gds/algorithms/metrics/AlgorithmMetricsServiceTest.java +++ b/algorithm-metrics-api/src/test/java/org/neo4j/gds/algorithms/metrics/AlgorithmMetricsServiceTest.java @@ -29,30 +29,17 @@ class AlgorithmMetricsServiceTest { @Test - void shouldRegisterStarted() { + void shouldCreateAlgorithmMetric() { // given var registrarMock = mock(AlgorithmMetricRegistrar.class); var metricsService = new AlgorithmMetricsService(registrarMock); // when - metricsService.started("foo"); + metricsService.create("foo"); // then - verify(registrarMock, times(1)).started("foo"); + verify(registrarMock, times(1)).create("foo"); verifyNoMoreInteractions(registrarMock); } - @Test - void shouldRegisterFailed() { - // given - var registrarMock = mock(AlgorithmMetricRegistrar.class); - var metricsService = new AlgorithmMetricsService(registrarMock); - - // when - metricsService.failed("foo"); - - // then - verify(registrarMock, times(1)).failed("foo"); - verifyNoMoreInteractions(registrarMock); - } } From 11b99b1196ab233d594e01154bbaadce4398973d Mon Sep 17 00:00:00 2001 From: Veselin Nikolov Date: Tue, 14 Nov 2023 09:51:35 +0000 Subject: [PATCH 5/8] Hook AlgorithmMetricsService to ProcedureExecutor Co-authored-by: Ioannis Panagiotas --- executor/build.gradle | 1 + .../neo4j/gds/executor/ExecutionContext.java | 9 +++++++++ .../neo4j/gds/executor/ProcedureExecutor.java | 19 ++++++++++++++++--- .../gds/executor/ProcedureExecutorTest.java | 3 +++ proc/common/build.gradle | 1 + .../src/main/java/org/neo4j/gds/BaseProc.java | 5 +++++ procedures/extension/build.gradle | 1 + .../OpenGraphDataScienceExtension.java | 14 +++++++++++++- .../integration/ExtensionBuilder.java | 5 ++++- 9 files changed, 53 insertions(+), 5 deletions(-) diff --git a/executor/build.gradle b/executor/build.gradle index 4f9c3a06fb..aaf31e574f 100644 --- a/executor/build.gradle +++ b/executor/build.gradle @@ -10,6 +10,7 @@ dependencies { annotationProcessor group: 'org.immutables', name: 'builder', version: ver.'immutables' annotationProcessor group: 'org.immutables', name: 'value', version: ver.'immutables' + implementation project(':algorithm-metrics-api') implementation project(':annotations') implementation project(':algo') implementation project(':algo-common') diff --git a/executor/src/main/java/org/neo4j/gds/executor/ExecutionContext.java b/executor/src/main/java/org/neo4j/gds/executor/ExecutionContext.java index ee333204b6..d3759c4f04 100644 --- a/executor/src/main/java/org/neo4j/gds/executor/ExecutionContext.java +++ b/executor/src/main/java/org/neo4j/gds/executor/ExecutionContext.java @@ -21,6 +21,8 @@ import org.jetbrains.annotations.Nullable; import org.neo4j.common.DependencyResolver; +import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService; +import org.neo4j.gds.algorithms.metrics.PassthroughAlgorithmMetricRegistrar; import org.neo4j.gds.annotation.ValueClass; import org.neo4j.gds.api.AlgorithmMetaDataSetter; import org.neo4j.gds.api.CloseableResourceRegistry; @@ -70,6 +72,8 @@ public interface ExecutionContext { boolean isGdsAdmin(); + AlgorithmMetricsService algorithmMetricsService(); + @Nullable RelationshipStreamExporterBuilder relationshipStreamExporterBuilder(); @@ -173,6 +177,11 @@ public boolean isGdsAdmin() { return false; } + @Override + public AlgorithmMetricsService algorithmMetricsService() { + return new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar()); + } + @Override public @Nullable RelationshipStreamExporterBuilder relationshipStreamExporterBuilder() { return null; diff --git a/executor/src/main/java/org/neo4j/gds/executor/ProcedureExecutor.java b/executor/src/main/java/org/neo4j/gds/executor/ProcedureExecutor.java index 93b0d78462..7611e96771 100644 --- a/executor/src/main/java/org/neo4j/gds/executor/ProcedureExecutor.java +++ b/executor/src/main/java/org/neo4j/gds/executor/ProcedureExecutor.java @@ -23,6 +23,7 @@ import org.neo4j.gds.AlgorithmFactory; import org.neo4j.gds.GraphAlgorithmFactory; import org.neo4j.gds.GraphStoreAlgorithmFactory; +import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService; import org.neo4j.gds.api.Graph; import org.neo4j.gds.api.GraphStore; import org.neo4j.gds.config.AlgoBaseConfig; @@ -110,7 +111,8 @@ public RESULT compute( algo.getProgressTracker().setEstimatedResourceFootprint(memoryEstimationInBytes, config.concurrency()); - ALGO_RESULT result = executeAlgorithm(builder, algo); + + ALGO_RESULT result = executeAlgorithm(builder, algo, executionContext.algorithmMetricsService()); var computationResult = builder .graph(graph) @@ -125,15 +127,26 @@ public RESULT compute( private ALGO_RESULT executeAlgorithm( ImmutableComputationResult.Builder builder, - ALGO algo + ALGO algo, + AlgorithmMetricsService algorithmMetricsService ) { return runWithExceptionLogging( "Computation failed", () -> { - try (ProgressTimer ignored = ProgressTimer.start(builder::computeMillis)) { + var algorithmMetric = algorithmMetricsService.create( + // we don't want to use `spec.name()` because it's different for the different procedure modes; + // we want to capture the algorithm name as defined by the algorithm factory `taskName()` + algoSpec.algorithmFactory(executionContext).taskName() + ); + try ( + ProgressTimer ignored = ProgressTimer.start(builder::computeMillis); + algorithmMetric; + ) { + algorithmMetric.start(); return algo.compute(); } catch (Throwable e) { algo.getProgressTracker().endSubTaskWithFailure(); + algorithmMetric.failed(); throw e; } finally { if (algoSpec.releaseProgressTask()) { diff --git a/executor/src/test/java/org/neo4j/gds/executor/ProcedureExecutorTest.java b/executor/src/test/java/org/neo4j/gds/executor/ProcedureExecutorTest.java index 147f5beead..5c92e79796 100644 --- a/executor/src/test/java/org/neo4j/gds/executor/ProcedureExecutorTest.java +++ b/executor/src/test/java/org/neo4j/gds/executor/ProcedureExecutorTest.java @@ -23,6 +23,8 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.neo4j.gds.ProcedureCallContextReturnColumns; +import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService; +import org.neo4j.gds.algorithms.metrics.PassthroughAlgorithmMetricRegistrar; import org.neo4j.gds.api.AlgorithmMetaDataSetter; import org.neo4j.gds.api.CloseableResourceRegistry; import org.neo4j.gds.api.GraphStore; @@ -136,6 +138,7 @@ private ExecutionContext executionContext(TaskStore taskStore) { .algorithmMetaDataSetter(AlgorithmMetaDataSetter.EMPTY) .nodeLookup(NodeLookup.EMPTY) .userLogRegistryFactory(EmptyUserLogRegistryFactory.INSTANCE) + .algorithmMetricsService(new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar())) .build(); } diff --git a/proc/common/build.gradle b/proc/common/build.gradle index 6a95c597a9..5c5c5cbef2 100644 --- a/proc/common/build.gradle +++ b/proc/common/build.gradle @@ -15,6 +15,7 @@ dependencies { implementation project(':annotations') implementation project(':algo-common') + implementation project(':algorithm-metrics-api') implementation project(':config-api') implementation project(':core') implementation project(':core-write') diff --git a/proc/common/src/main/java/org/neo4j/gds/BaseProc.java b/proc/common/src/main/java/org/neo4j/gds/BaseProc.java index 3ccb80ce63..2320155cb2 100644 --- a/proc/common/src/main/java/org/neo4j/gds/BaseProc.java +++ b/proc/common/src/main/java/org/neo4j/gds/BaseProc.java @@ -19,6 +19,7 @@ */ package org.neo4j.gds; +import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService; import org.neo4j.gds.api.DatabaseId; import org.neo4j.gds.compat.GraphDatabaseApiProxy; import org.neo4j.gds.config.BaseConfig; @@ -75,6 +76,9 @@ public abstract class BaseProc { @Context public Username username = Username.EMPTY_USERNAME; + @Context + public AlgorithmMetricsService algorithmMetricsService; + protected String username() { return username.username(); } @@ -157,6 +161,7 @@ public ExecutionContext executionContext() { .algorithmMetaDataSetter(new TransactionAlgorithmMetaDataSetter(transaction)) .nodeLookup(new TransactionNodeLookup(transaction)) .isGdsAdmin(transactionContext().isGdsAdmin()) + .algorithmMetricsService(algorithmMetricsService) .build(); } diff --git a/procedures/extension/build.gradle b/procedures/extension/build.gradle index e3ed8a1c61..f51938d5e7 100644 --- a/procedures/extension/build.gradle +++ b/procedures/extension/build.gradle @@ -18,6 +18,7 @@ dependencies { compileOnly(group: 'org.neo4j', name: 'neo4j-logging', version: ver.'neo4j') { transitive = false } // the necessary GDS things for the extension to construct the application + implementation project(':algorithm-metrics-api') implementation project(':config-api') implementation project(':core') implementation project(':core-utils') diff --git a/procedures/extension/src/main/java/org/neo4j/gds/extension/OpenGraphDataScienceExtension.java b/procedures/extension/src/main/java/org/neo4j/gds/extension/OpenGraphDataScienceExtension.java index a6e8cf9508..9d38c84ddd 100644 --- a/procedures/extension/src/main/java/org/neo4j/gds/extension/OpenGraphDataScienceExtension.java +++ b/procedures/extension/src/main/java/org/neo4j/gds/extension/OpenGraphDataScienceExtension.java @@ -21,6 +21,8 @@ import org.neo4j.annotations.service.ServiceProvider; import org.neo4j.configuration.Config; +import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService; +import org.neo4j.gds.algorithms.metrics.PassthroughAlgorithmMetricRegistrar; import org.neo4j.gds.applications.graphstorecatalog.CatalogBusinessFacade; import org.neo4j.gds.core.write.NativeExportBuildersProvider; import org.neo4j.gds.procedures.GraphDataScience; @@ -64,10 +66,20 @@ public Lifecycle newInstance(ExtensionContext extensionContext, Dependencies dep // we have no extra checks to do in OpenGDS Optional> businessFacadeDecorator = Optional.empty(); + var algorithmMetricsService = new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar()); + extensionBuilder .withComponent( GraphDataScience.class, - () -> extensionBuilder.gdsProvider(exporterBuildersProviderService, businessFacadeDecorator) + () -> extensionBuilder.gdsProvider( + exporterBuildersProviderService, + businessFacadeDecorator, + algorithmMetricsService + ) + ) + .withComponent( + AlgorithmMetricsService.class, + () -> ctx -> algorithmMetricsService ) .registerExtension(); diff --git a/procedures/integration/src/main/java/org/neo4j/gds/procedures/integration/ExtensionBuilder.java b/procedures/integration/src/main/java/org/neo4j/gds/procedures/integration/ExtensionBuilder.java index 9e288c0752..511560b2c0 100644 --- a/procedures/integration/src/main/java/org/neo4j/gds/procedures/integration/ExtensionBuilder.java +++ b/procedures/integration/src/main/java/org/neo4j/gds/procedures/integration/ExtensionBuilder.java @@ -20,6 +20,7 @@ package org.neo4j.gds.procedures.integration; import org.neo4j.function.ThrowingFunction; +import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService; import org.neo4j.gds.applications.graphstorecatalog.CatalogBusinessFacade; import org.neo4j.gds.core.loading.GraphStoreCatalogService; import org.neo4j.gds.core.utils.progress.ProgressFeatureSettings; @@ -187,10 +188,12 @@ public void registerExtension() { * * @param exporterBuildersProviderService The catalog of writers * @param businessFacadeDecorator Any checks added across requests + * @param algorithmMetricsService */ public ThrowingFunction gdsProvider( ExporterBuildersProviderService exporterBuildersProviderService, - Optional> businessFacadeDecorator + Optional> businessFacadeDecorator, + AlgorithmMetricsService algorithmMetricsService ) { var catalogFacadeProvider = createCatalogFacadeProvider( exporterBuildersProviderService, From 759750d75d6984c85047fb025416c75a33de5247 Mon Sep 17 00:00:00 2001 From: Veselin Nikolov Date: Tue, 14 Nov 2023 11:14:08 +0000 Subject: [PATCH 6/8] Conditionally register algorithm metrics Co-authored-by: Ioannis Panagiotas --- .../integration/CommunityProcedureProvider.java | 8 +++++--- .../gds/procedures/integration/ExtensionBuilder.java | 9 ++++++--- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/procedures/integration/src/main/java/org/neo4j/gds/procedures/integration/CommunityProcedureProvider.java b/procedures/integration/src/main/java/org/neo4j/gds/procedures/integration/CommunityProcedureProvider.java index 1c1bc9f4fd..d6a87f6573 100644 --- a/procedures/integration/src/main/java/org/neo4j/gds/procedures/integration/CommunityProcedureProvider.java +++ b/procedures/integration/src/main/java/org/neo4j/gds/procedures/integration/CommunityProcedureProvider.java @@ -31,7 +31,6 @@ import org.neo4j.gds.algorithms.community.MutateNodePropertyService; import org.neo4j.gds.algorithms.community.WriteNodePropertyService; import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService; -import org.neo4j.gds.algorithms.metrics.PassthroughAlgorithmMetricRegistrar; import org.neo4j.gds.api.DatabaseId; import org.neo4j.gds.api.GraphLoaderContext; import org.neo4j.gds.api.ImmutableGraphLoaderContext; @@ -75,6 +74,7 @@ public class CommunityProcedureProvider { private final TerminationFlagService terminationFlagService; private final UserLogServices userLogServices; private final UserAccessor userAccessor; + private final AlgorithmMetricsService algorithmMetricsService; public CommunityProcedureProvider( Log log, @@ -87,7 +87,8 @@ public CommunityProcedureProvider( TaskRegistryFactoryService taskRegistryFactoryService, TerminationFlagService terminationFlagService, UserLogServices userLogServices, - UserAccessor userAccessor + UserAccessor userAccessor, + AlgorithmMetricsService algorithmMetricsService ) { this.log = log; this.graphStoreCatalogService = graphStoreCatalogService; @@ -101,6 +102,7 @@ public CommunityProcedureProvider( this.terminationFlagService = terminationFlagService; this.userLogServices = userLogServices; this.userAccessor = userAccessor; + this.algorithmMetricsService = algorithmMetricsService; } public CommunityProcedureFacade createCommunityProcedureFacade(Context context) throws ProcedureException { @@ -129,7 +131,7 @@ public CommunityProcedureFacade createCommunityProcedureFacade(Context context) taskRegistryFactory, userLogRegistryFactory, algorithmMemoryValidationService, - new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar()), + algorithmMetricsService, log ); diff --git a/procedures/integration/src/main/java/org/neo4j/gds/procedures/integration/ExtensionBuilder.java b/procedures/integration/src/main/java/org/neo4j/gds/procedures/integration/ExtensionBuilder.java index 511560b2c0..1d65744e2e 100644 --- a/procedures/integration/src/main/java/org/neo4j/gds/procedures/integration/ExtensionBuilder.java +++ b/procedures/integration/src/main/java/org/neo4j/gds/procedures/integration/ExtensionBuilder.java @@ -200,7 +200,7 @@ public ThrowingFunction gdsProvid businessFacadeDecorator ); - var communityProcedureProvider = createCommunityProcedureProvider(exporterBuildersProviderService); + var communityProcedureProvider = createCommunityProcedureProvider(exporterBuildersProviderService, algorithmMetricsService); return new GraphDataScienceProvider(log, catalogFacadeProvider, communityProcedureProvider); } @@ -226,7 +226,9 @@ private CatalogFacadeProvider createCatalogFacadeProvider( ); } - private CommunityProcedureProvider createCommunityProcedureProvider(ExporterBuildersProviderService exporterBuildersProviderService) { + private CommunityProcedureProvider createCommunityProcedureProvider(ExporterBuildersProviderService exporterBuildersProviderService, + AlgorithmMetricsService algorithmMetricsService + ) { var algorithmMetaDataSetterService = new AlgorithmMetaDataSetterService(); return new CommunityProcedureProvider( @@ -240,7 +242,8 @@ private CommunityProcedureProvider createCommunityProcedureProvider(ExporterBuil taskRegistryFactoryService, terminationFlagService, userLogServices, - userAccessor + userAccessor, + algorithmMetricsService ); } } From 79aabee5edb66e2bcb80bb353725598aba62b5bc Mon Sep 17 00:00:00 2001 From: Veselin Nikolov Date: Tue, 14 Nov 2023 12:55:31 +0000 Subject: [PATCH 7/8] Expose algorithm-metrics-api as transitive of proc-common Co-authored-by: Ioannis Panagiotas --- .../MemoryEstimationExecutorTest.java | 3 +++ .../LinkPredictionTrainingPipelineTest.java | 5 ++++ proc/common/build.gradle | 2 +- ...PropertyComputationResultConsumerTest.java | 3 +++ ...opertiesComputationResultConsumerTest.java | 3 +++ .../neo4j/gds/WriteProcCancellationTest.java | 3 +++ ...fsStreamComputationResultConsumerTest.java | 1 + .../neo4j/gds/pregel/proc/PregelProcTest.java | 11 ++++++++ .../java/org/neo4j/gds/ProcedureRunner.java | 10 ++++++-- .../src/main/java/org/neo4j/gds/BaseTest.java | 25 +++++++++++++++++++ 10 files changed, 63 insertions(+), 3 deletions(-) diff --git a/executor/src/test/java/org/neo4j/gds/executor/MemoryEstimationExecutorTest.java b/executor/src/test/java/org/neo4j/gds/executor/MemoryEstimationExecutorTest.java index cf01f65e04..78a6d3cf77 100644 --- a/executor/src/test/java/org/neo4j/gds/executor/MemoryEstimationExecutorTest.java +++ b/executor/src/test/java/org/neo4j/gds/executor/MemoryEstimationExecutorTest.java @@ -27,6 +27,8 @@ import org.neo4j.gds.NodeProjections; import org.neo4j.gds.ProcedureCallContextReturnColumns; import org.neo4j.gds.RelationshipProjections; +import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService; +import org.neo4j.gds.algorithms.metrics.PassthroughAlgorithmMetricRegistrar; import org.neo4j.gds.api.AlgorithmMetaDataSetter; import org.neo4j.gds.api.CloseableResourceRegistry; import org.neo4j.gds.api.DatabaseId; @@ -85,6 +87,7 @@ void setup() throws Exception { .nodeLookup(NodeLookup.EMPTY) .modelCatalog(ModelCatalog.EMPTY) .isGdsAdmin(false) + .algorithmMetricsService(new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar())) .build(); memoryEstimationExecutor = new MemoryEstimationExecutor<>( diff --git a/pipeline/src/test/java/org/neo4j/gds/ml/pipeline/linkPipeline/LinkPredictionTrainingPipelineTest.java b/pipeline/src/test/java/org/neo4j/gds/ml/pipeline/linkPipeline/LinkPredictionTrainingPipelineTest.java index 47212cef13..ac5ba0c6f0 100644 --- a/pipeline/src/test/java/org/neo4j/gds/ml/pipeline/linkPipeline/LinkPredictionTrainingPipelineTest.java +++ b/pipeline/src/test/java/org/neo4j/gds/ml/pipeline/linkPipeline/LinkPredictionTrainingPipelineTest.java @@ -24,6 +24,8 @@ import org.junit.jupiter.api.Test; import org.neo4j.gds.NodeLabel; import org.neo4j.gds.RelationshipType; +import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService; +import org.neo4j.gds.algorithms.metrics.PassthroughAlgorithmMetricRegistrar; import org.neo4j.gds.api.AlgorithmMetaDataSetter; import org.neo4j.gds.api.CloseableResourceRegistry; import org.neo4j.gds.api.DatabaseId; @@ -195,6 +197,7 @@ void deriveRelationshipWeightProperty() { .taskRegistryFactory(EmptyTaskRegistryFactory.INSTANCE) .userLogRegistryFactory(EmptyUserLogRegistryFactory.INSTANCE) .isGdsAdmin(false) + .algorithmMetricsService(new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar())) .build(); var pipeline = new LinkPredictionTrainingPipeline(); @@ -239,6 +242,7 @@ void deriveRelationshipWeightPropertyFromTrainedModel() { .taskRegistryFactory(EmptyTaskRegistryFactory.INSTANCE) .userLogRegistryFactory(EmptyUserLogRegistryFactory.INSTANCE) .isGdsAdmin(false) + .algorithmMetricsService(new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar())) .build(); var pipeline = new LinkPredictionTrainingPipeline(); @@ -283,6 +287,7 @@ void notDerivePropertyFromUnweightedTrainedModel() { .taskRegistryFactory(EmptyTaskRegistryFactory.INSTANCE) .userLogRegistryFactory(EmptyUserLogRegistryFactory.INSTANCE) .isGdsAdmin(false) + .algorithmMetricsService(new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar())) .build(); var pipeline = new LinkPredictionTrainingPipeline(); diff --git a/proc/common/build.gradle b/proc/common/build.gradle index 5c5c5cbef2..ca5bf82261 100644 --- a/proc/common/build.gradle +++ b/proc/common/build.gradle @@ -11,11 +11,11 @@ dependencies { annotationProcessor group: 'org.immutables', name: 'value', version: ver.'immutables' api(project(':algo')) + api project(':algorithm-metrics-api') api(project(':model-catalog-api')) implementation project(':annotations') implementation project(':algo-common') - implementation project(':algorithm-metrics-api') implementation project(':config-api') implementation project(':core') implementation project(':core-write') diff --git a/proc/common/src/test/java/org/neo4j/gds/MutatePropertyComputationResultConsumerTest.java b/proc/common/src/test/java/org/neo4j/gds/MutatePropertyComputationResultConsumerTest.java index 0451453ea1..917d0b4bc6 100644 --- a/proc/common/src/test/java/org/neo4j/gds/MutatePropertyComputationResultConsumerTest.java +++ b/proc/common/src/test/java/org/neo4j/gds/MutatePropertyComputationResultConsumerTest.java @@ -22,6 +22,8 @@ import org.jetbrains.annotations.NotNull; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService; +import org.neo4j.gds.algorithms.metrics.PassthroughAlgorithmMetricRegistrar; import org.neo4j.gds.api.AlgorithmMetaDataSetter; import org.neo4j.gds.api.CSRGraph; import org.neo4j.gds.api.CloseableResourceRegistry; @@ -88,6 +90,7 @@ class MutatePropertyComputationResultConsumerTest { .nodeLookup(NodeLookup.EMPTY) .modelCatalog(ModelCatalog.EMPTY) .isGdsAdmin(false) + .algorithmMetricsService(new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar())) .build(); @BeforeEach diff --git a/proc/common/src/test/java/org/neo4j/gds/WriteNodePropertiesComputationResultConsumerTest.java b/proc/common/src/test/java/org/neo4j/gds/WriteNodePropertiesComputationResultConsumerTest.java index c3bd97e4b4..e83307e4d1 100644 --- a/proc/common/src/test/java/org/neo4j/gds/WriteNodePropertiesComputationResultConsumerTest.java +++ b/proc/common/src/test/java/org/neo4j/gds/WriteNodePropertiesComputationResultConsumerTest.java @@ -20,6 +20,8 @@ package org.neo4j.gds; import org.junit.jupiter.api.Test; +import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService; +import org.neo4j.gds.algorithms.metrics.PassthroughAlgorithmMetricRegistrar; import org.neo4j.gds.api.AlgorithmMetaDataSetter; import org.neo4j.gds.api.CloseableResourceRegistry; import org.neo4j.gds.api.DatabaseId; @@ -96,6 +98,7 @@ class WriteNodePropertiesComputationResultConsumerTest extends BaseTest { .modelCatalog(ModelCatalog.EMPTY) .isGdsAdmin(false) .nodePropertyExporterBuilder(new NativeNodePropertiesExporterBuilder(EmptyTransactionContext.INSTANCE)) + .algorithmMetricsService(new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar())) .build(); @Test diff --git a/proc/common/src/test/java/org/neo4j/gds/WriteProcCancellationTest.java b/proc/common/src/test/java/org/neo4j/gds/WriteProcCancellationTest.java index 745b4d08de..e70cb2ebfd 100644 --- a/proc/common/src/test/java/org/neo4j/gds/WriteProcCancellationTest.java +++ b/proc/common/src/test/java/org/neo4j/gds/WriteProcCancellationTest.java @@ -20,6 +20,8 @@ package org.neo4j.gds; import org.junit.jupiter.api.Test; +import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService; +import org.neo4j.gds.algorithms.metrics.PassthroughAlgorithmMetricRegistrar; import org.neo4j.gds.api.AlgorithmMetaDataSetter; import org.neo4j.gds.api.CloseableResourceRegistry; import org.neo4j.gds.api.DatabaseId; @@ -124,6 +126,7 @@ public long nodeCount() { .modelCatalog(ModelCatalog.EMPTY) .isGdsAdmin(false) .nodePropertyExporterBuilder(new NativeNodePropertiesExporterBuilder(DatabaseTransactionContext.of(db, tx))) + .algorithmMetricsService(new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar())) .build(); assertThatThrownBy(() -> resultConsumer.consume(computationResult, executionContext)) diff --git a/proc/path-finding/src/test/java/org/neo4j/gds/paths/traverse/DfsStreamComputationResultConsumerTest.java b/proc/path-finding/src/test/java/org/neo4j/gds/paths/traverse/DfsStreamComputationResultConsumerTest.java index b0e2376b65..90e31c1454 100644 --- a/proc/path-finding/src/test/java/org/neo4j/gds/paths/traverse/DfsStreamComputationResultConsumerTest.java +++ b/proc/path-finding/src/test/java/org/neo4j/gds/paths/traverse/DfsStreamComputationResultConsumerTest.java @@ -55,6 +55,7 @@ class DfsStreamComputationResultConsumerTest { void shouldNotComputePath() { when(graphMock.toOriginalNodeId(anyLong())).then(returnsFirstArg()); + when(computationResultMock.graph()).thenReturn(graphMock); when(computationResultMock.result()).thenReturn(Optional.of(HugeLongArray.of(1L, 2L))); diff --git a/proc/pregel/src/test/java/org/neo4j/gds/pregel/proc/PregelProcTest.java b/proc/pregel/src/test/java/org/neo4j/gds/pregel/proc/PregelProcTest.java index 845d8e023d..4d542cde22 100644 --- a/proc/pregel/src/test/java/org/neo4j/gds/pregel/proc/PregelProcTest.java +++ b/proc/pregel/src/test/java/org/neo4j/gds/pregel/proc/PregelProcTest.java @@ -29,6 +29,8 @@ import org.neo4j.gds.GdsCypher; import org.neo4j.gds.GraphAlgorithmFactory; import org.neo4j.gds.TestTaskStore; +import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService; +import org.neo4j.gds.algorithms.metrics.PassthroughAlgorithmMetricRegistrar; import org.neo4j.gds.api.Graph; import org.neo4j.gds.api.nodeproperties.ValueType; import org.neo4j.gds.assertj.ConditionFactory; @@ -209,6 +211,9 @@ void cleanupTaskRegistryWhenTheAlgorithmFailsInStreamMode() { proc.procedureTransaction = transactions.tx(); proc.log = NullLog.getInstance(); proc.callContext = ProcedureCallContext.EMPTY; + + proc.algorithmMetricsService = new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar()); + Map config = Map.of( "maxIterations", 20, "throwInCompute", true @@ -234,6 +239,9 @@ void cleanupTaskRegistryWhenTheAlgorithmFailsInWriteMode() { proc.procedureTransaction = transactions.tx(); proc.log = NullLog.getInstance(); proc.callContext = ProcedureCallContext.EMPTY; + + proc.algorithmMetricsService = new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar()); + Map config = Map.of( "maxIterations", 20, "throwInCompute", true @@ -258,6 +266,9 @@ void cleanupTaskRegistryWhenTheAlgorithmFailsInMutateMode() { proc.procedureTransaction = transactions.tx(); proc.log = NullLog.getInstance(); proc.callContext = ProcedureCallContext.EMPTY; + + proc.algorithmMetricsService = new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar()); + Map config = Map.of( "maxIterations", 20, "throwInCompute", true diff --git a/proc/test/src/main/java/org/neo4j/gds/ProcedureRunner.java b/proc/test/src/main/java/org/neo4j/gds/ProcedureRunner.java index 001f570242..0444c99bcc 100644 --- a/proc/test/src/main/java/org/neo4j/gds/ProcedureRunner.java +++ b/proc/test/src/main/java/org/neo4j/gds/ProcedureRunner.java @@ -19,6 +19,8 @@ */ package org.neo4j.gds; +import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService; +import org.neo4j.gds.algorithms.metrics.PassthroughAlgorithmMetricRegistrar; import org.neo4j.gds.compat.GraphDatabaseApiProxy; import org.neo4j.gds.core.Username; import org.neo4j.gds.core.utils.progress.TaskRegistryFactory; @@ -43,7 +45,8 @@ public static

P instantiateProcedure( TaskRegistryFactory taskRegistryFactory, UserLogRegistryFactory userLogRegistryFactory, Transaction tx, - Username username + Username username, + AlgorithmMetricsService algorithmMetricsService ) { P proc; try { @@ -61,6 +64,8 @@ public static

P instantiateProcedure( proc.userLogRegistryFactory = userLogRegistryFactory; proc.username = username; + proc.algorithmMetricsService = algorithmMetricsService; + return proc; } @@ -82,7 +87,8 @@ public static

P applyOnProcedure( taskRegistryFactory, EmptyUserLogRegistryFactory.INSTANCE, tx, - username + username, + new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar()) ); func.accept(proc); return proc; diff --git a/test-utils/src/main/java/org/neo4j/gds/BaseTest.java b/test-utils/src/main/java/org/neo4j/gds/BaseTest.java index afb596fb32..93baaf60b9 100644 --- a/test-utils/src/main/java/org/neo4j/gds/BaseTest.java +++ b/test-utils/src/main/java/org/neo4j/gds/BaseTest.java @@ -23,6 +23,8 @@ import org.assertj.core.api.Assertions; import org.intellij.lang.annotations.Language; import org.junit.jupiter.api.Timeout; +import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService; +import org.neo4j.gds.algorithms.metrics.PassthroughAlgorithmMetricRegistrar; import org.neo4j.gds.compat.Neo4jProxy; import org.neo4j.gds.compat.TestLog; import org.neo4j.gds.core.Settings; @@ -35,6 +37,11 @@ import org.neo4j.graphdb.Node; import org.neo4j.graphdb.Result; import org.neo4j.graphdb.Transaction; +import org.neo4j.kernel.api.procedure.GlobalProcedures; +import org.neo4j.kernel.extension.ExtensionFactory; +import org.neo4j.kernel.extension.context.ExtensionContext; +import org.neo4j.kernel.lifecycle.Lifecycle; +import org.neo4j.kernel.lifecycle.LifecycleAdapter; import org.neo4j.test.TestDatabaseManagementServiceBuilder; import org.neo4j.test.extension.ExtensionCallback; import org.neo4j.test.extension.ImpermanentDbmsExtension; @@ -86,6 +93,24 @@ protected void configuration(TestDatabaseManagementServiceBuilder builder) { builder.setConfigRaw(Map.of("unsupported.dbms.debug.trace_cursors", "true")); testLog = Neo4jProxy.testLog(); builder.setUserLogProvider(new TestLogProvider(testLog)); + + // Hacky as hell but will have to do until we make this BaseTest obsolete + builder.addExtension(new ExtensionFactory("AlgorithmMetricsServiceExtensionFactory") { + @Override + public Lifecycle newInstance(ExtensionContext context, Dependencies dependencies) { + dependencies.globalProcedures().registerComponent( + AlgorithmMetricsService.class, + ctx -> new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar()), + false + ); + return new LifecycleAdapter(); + } + + }); + } + + interface Dependencies { + GlobalProcedures globalProcedures(); } protected long clearDb() { From 3141b300d1af2e31d9938afc20884feeada416de Mon Sep 17 00:00:00 2001 From: Veselin Nikolov Date: Wed, 15 Nov 2023 11:58:09 +0000 Subject: [PATCH 8/8] Address review comments Co-authored-by: Lasse Westh-Nielsen Co-authored-by: Ioannis Panagiotas --- .../community/BasicAlgorithmRunner.java | 11 ++- .../community/BasicAlgorithmRunnerTest.java | 80 ++++--------------- 2 files changed, 23 insertions(+), 68 deletions(-) diff --git a/algo/src/main/java/org/neo4j/gds/algorithms/community/BasicAlgorithmRunner.java b/algo/src/main/java/org/neo4j/gds/algorithms/community/BasicAlgorithmRunner.java index 85f1ea5878..eb4f13fbf1 100644 --- a/algo/src/main/java/org/neo4j/gds/algorithms/community/BasicAlgorithmRunner.java +++ b/algo/src/main/java/org/neo4j/gds/algorithms/community/BasicAlgorithmRunner.java @@ -115,12 +115,15 @@ public , R, C extends AlgoBaseConfig> AlgorithmComputatio ); // run the algorithm - var algorithmMetric = algorithmMetricsService.create(algorithmFactory.taskName()); + var algorithmResult = runAlgorithm(algorithm, algorithmFactory.taskName()); + return AlgorithmComputationResult.of(algorithmResult, graph, graphStore, algorithm.getTerminationFlag()); + } + + R runAlgorithm(Algorithm algorithm, String algorithmName) { + var algorithmMetric = algorithmMetricsService.create(algorithmName); try(algorithmMetric) { algorithmMetric.start(); - var algorithmResult = algorithm.compute(); - - return AlgorithmComputationResult.of(algorithmResult, graph, graphStore, algorithm.getTerminationFlag()); + return algorithm.compute(); } catch (Exception e) { log.warn("Computation failed", e); algorithm.getProgressTracker().endSubTaskWithFailure(); diff --git a/algo/src/test/java/org/neo4j/gds/algorithms/community/BasicAlgorithmRunnerTest.java b/algo/src/test/java/org/neo4j/gds/algorithms/community/BasicAlgorithmRunnerTest.java index 1c4206391f..ed673080fa 100644 --- a/algo/src/test/java/org/neo4j/gds/algorithms/community/BasicAlgorithmRunnerTest.java +++ b/algo/src/test/java/org/neo4j/gds/algorithms/community/BasicAlgorithmRunnerTest.java @@ -19,28 +19,14 @@ */ package org.neo4j.gds.algorithms.community; -import org.apache.commons.lang3.tuple.Pair; import org.junit.jupiter.api.Test; import org.neo4j.gds.Algorithm; -import org.neo4j.gds.GraphAlgorithmFactory; -import org.neo4j.gds.algorithms.AlgorithmMemoryValidationService; import org.neo4j.gds.algorithms.metrics.AlgorithmMetric; import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService; -import org.neo4j.gds.api.DatabaseId; -import org.neo4j.gds.api.Graph; -import org.neo4j.gds.api.GraphStore; -import org.neo4j.gds.api.User; import org.neo4j.gds.compat.Neo4jProxy; -import org.neo4j.gds.config.AlgoBaseConfig; -import org.neo4j.gds.core.loading.GraphStoreCatalogService; -import org.neo4j.gds.core.utils.progress.TaskRegistryFactory; -import org.neo4j.gds.core.utils.warnings.EmptyUserLogRegistryFactory; import org.neo4j.gds.logging.Log; -import java.util.Optional; - import static org.assertj.core.api.Assertions.assertThatException; -import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.RETURNS_DEEP_STUBS; import static org.mockito.Mockito.mock; @@ -53,43 +39,24 @@ class BasicAlgorithmRunnerTest { @Test void shouldRegisterAlgorithmMetricCountForSuccess() { - var graphMock = mock(Graph.class); - when(graphMock.isEmpty()).thenReturn(false); - - var graphStoreCatalogServiceMock = mock(GraphStoreCatalogService.class); - when(graphStoreCatalogServiceMock.getGraphWithGraphStore(any(), any(), any(), any(), any())) - .thenReturn(Pair.of(graphMock, mock(GraphStore.class))); - var algorithmMetricMock = mock(AlgorithmMetric.class); var algorithmMetricsServiceMock = mock(AlgorithmMetricsService.class); when(algorithmMetricsServiceMock.create(anyString())).thenReturn(algorithmMetricMock); - var logMock = mock(Log.class); - when(logMock.getNeo4jLog()).thenReturn(Neo4jProxy.testLog()); - var runner = new BasicAlgorithmRunner( - graphStoreCatalogServiceMock, - TaskRegistryFactory.empty(), - EmptyUserLogRegistryFactory.INSTANCE, - mock(AlgorithmMemoryValidationService.class), + null, + null, + null, + null, algorithmMetricsServiceMock, - logMock + null ); - var algorithmMock = mock(Algorithm.class, RETURNS_DEEP_STUBS); + var algorithmMock = mock(Algorithm.class); when(algorithmMock.compute()).thenReturn("WooHoo"); - var algorithmFactoryMock = mock(GraphAlgorithmFactory.class); - when(algorithmFactoryMock.taskName()).thenReturn("TestingMetrics"); - when(algorithmFactoryMock.build(any(), any(), any(), any(), any())).thenReturn(algorithmMock); - - runner.run( - "foo", - mock(AlgoBaseConfig.class), - Optional.empty(), - algorithmFactoryMock, - mock(User.class), - DatabaseId.EMPTY - ); + + + runner.runAlgorithm(algorithmMock, "TestingMetrics"); verify(algorithmMetricsServiceMock, times(1)).create("TestingMetrics"); verify(algorithmMetricMock, times(1)).start(); @@ -104,13 +71,6 @@ void shouldRegisterAlgorithmMetricCountForSuccess() { @Test void shouldRegisterAlgorithmMetricCountForFailure() { - var graphMock = mock(Graph.class); - when(graphMock.isEmpty()).thenReturn(false); - - var graphStoreCatalogServiceMock = mock(GraphStoreCatalogService.class); - when(graphStoreCatalogServiceMock.getGraphWithGraphStore(any(), any(), any(), any(), any())) - .thenReturn(Pair.of(graphMock, mock(GraphStore.class))); - var algorithmMetricMock = mock(AlgorithmMetric.class); var algorithmMetricsServiceMock = mock(AlgorithmMetricsService.class); when(algorithmMetricsServiceMock.create(anyString())).thenReturn(algorithmMetricMock); @@ -119,10 +79,10 @@ void shouldRegisterAlgorithmMetricCountForFailure() { when(logMock.getNeo4jLog()).thenReturn(Neo4jProxy.testLog()); var runner = new BasicAlgorithmRunner( - graphStoreCatalogServiceMock, - TaskRegistryFactory.empty(), - EmptyUserLogRegistryFactory.INSTANCE, - mock(AlgorithmMemoryValidationService.class), + null, + null, + null, + null, algorithmMetricsServiceMock, logMock ); @@ -130,18 +90,10 @@ void shouldRegisterAlgorithmMetricCountForFailure() { var algorithmMock = mock(Algorithm.class, RETURNS_DEEP_STUBS); when(algorithmMock.compute()).thenThrow(new RuntimeException("Ooops")); - var algorithmFactoryMock = mock(GraphAlgorithmFactory.class); - when(algorithmFactoryMock.taskName()).thenReturn("TestingMetrics"); - when(algorithmFactoryMock.build(any(), any(), any(), any(), any())).thenReturn(algorithmMock); - assertThatException().isThrownBy( - () -> runner.run( - "foo", - mock(AlgoBaseConfig.class), - Optional.empty(), - algorithmFactoryMock, - mock(User.class), - DatabaseId.EMPTY + () -> runner.runAlgorithm( + algorithmMock, + "TestingMetrics" ) ).withMessage("Ooops");