From 79aabee5edb66e2bcb80bb353725598aba62b5bc Mon Sep 17 00:00:00 2001 From: Veselin Nikolov Date: Tue, 14 Nov 2023 12:55:31 +0000 Subject: [PATCH] Expose algorithm-metrics-api as transitive of proc-common Co-authored-by: Ioannis Panagiotas --- .../MemoryEstimationExecutorTest.java | 3 +++ .../LinkPredictionTrainingPipelineTest.java | 5 ++++ proc/common/build.gradle | 2 +- ...PropertyComputationResultConsumerTest.java | 3 +++ ...opertiesComputationResultConsumerTest.java | 3 +++ .../neo4j/gds/WriteProcCancellationTest.java | 3 +++ ...fsStreamComputationResultConsumerTest.java | 1 + .../neo4j/gds/pregel/proc/PregelProcTest.java | 11 ++++++++ .../java/org/neo4j/gds/ProcedureRunner.java | 10 ++++++-- .../src/main/java/org/neo4j/gds/BaseTest.java | 25 +++++++++++++++++++ 10 files changed, 63 insertions(+), 3 deletions(-) diff --git a/executor/src/test/java/org/neo4j/gds/executor/MemoryEstimationExecutorTest.java b/executor/src/test/java/org/neo4j/gds/executor/MemoryEstimationExecutorTest.java index cf01f65e04..78a6d3cf77 100644 --- a/executor/src/test/java/org/neo4j/gds/executor/MemoryEstimationExecutorTest.java +++ b/executor/src/test/java/org/neo4j/gds/executor/MemoryEstimationExecutorTest.java @@ -27,6 +27,8 @@ import org.neo4j.gds.NodeProjections; import org.neo4j.gds.ProcedureCallContextReturnColumns; import org.neo4j.gds.RelationshipProjections; +import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService; +import org.neo4j.gds.algorithms.metrics.PassthroughAlgorithmMetricRegistrar; import org.neo4j.gds.api.AlgorithmMetaDataSetter; import org.neo4j.gds.api.CloseableResourceRegistry; import org.neo4j.gds.api.DatabaseId; @@ -85,6 +87,7 @@ void setup() throws Exception { .nodeLookup(NodeLookup.EMPTY) .modelCatalog(ModelCatalog.EMPTY) .isGdsAdmin(false) + .algorithmMetricsService(new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar())) .build(); memoryEstimationExecutor = new MemoryEstimationExecutor<>( diff --git a/pipeline/src/test/java/org/neo4j/gds/ml/pipeline/linkPipeline/LinkPredictionTrainingPipelineTest.java b/pipeline/src/test/java/org/neo4j/gds/ml/pipeline/linkPipeline/LinkPredictionTrainingPipelineTest.java index 47212cef13..ac5ba0c6f0 100644 --- a/pipeline/src/test/java/org/neo4j/gds/ml/pipeline/linkPipeline/LinkPredictionTrainingPipelineTest.java +++ b/pipeline/src/test/java/org/neo4j/gds/ml/pipeline/linkPipeline/LinkPredictionTrainingPipelineTest.java @@ -24,6 +24,8 @@ import org.junit.jupiter.api.Test; import org.neo4j.gds.NodeLabel; import org.neo4j.gds.RelationshipType; +import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService; +import org.neo4j.gds.algorithms.metrics.PassthroughAlgorithmMetricRegistrar; import org.neo4j.gds.api.AlgorithmMetaDataSetter; import org.neo4j.gds.api.CloseableResourceRegistry; import org.neo4j.gds.api.DatabaseId; @@ -195,6 +197,7 @@ void deriveRelationshipWeightProperty() { .taskRegistryFactory(EmptyTaskRegistryFactory.INSTANCE) .userLogRegistryFactory(EmptyUserLogRegistryFactory.INSTANCE) .isGdsAdmin(false) + .algorithmMetricsService(new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar())) .build(); var pipeline = new LinkPredictionTrainingPipeline(); @@ -239,6 +242,7 @@ void deriveRelationshipWeightPropertyFromTrainedModel() { .taskRegistryFactory(EmptyTaskRegistryFactory.INSTANCE) .userLogRegistryFactory(EmptyUserLogRegistryFactory.INSTANCE) .isGdsAdmin(false) + .algorithmMetricsService(new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar())) .build(); var pipeline = new LinkPredictionTrainingPipeline(); @@ -283,6 +287,7 @@ void notDerivePropertyFromUnweightedTrainedModel() { .taskRegistryFactory(EmptyTaskRegistryFactory.INSTANCE) .userLogRegistryFactory(EmptyUserLogRegistryFactory.INSTANCE) .isGdsAdmin(false) + .algorithmMetricsService(new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar())) .build(); var pipeline = new LinkPredictionTrainingPipeline(); diff --git a/proc/common/build.gradle b/proc/common/build.gradle index 5c5c5cbef2..ca5bf82261 100644 --- a/proc/common/build.gradle +++ b/proc/common/build.gradle @@ -11,11 +11,11 @@ dependencies { annotationProcessor group: 'org.immutables', name: 'value', version: ver.'immutables' api(project(':algo')) + api project(':algorithm-metrics-api') api(project(':model-catalog-api')) implementation project(':annotations') implementation project(':algo-common') - implementation project(':algorithm-metrics-api') implementation project(':config-api') implementation project(':core') implementation project(':core-write') diff --git a/proc/common/src/test/java/org/neo4j/gds/MutatePropertyComputationResultConsumerTest.java b/proc/common/src/test/java/org/neo4j/gds/MutatePropertyComputationResultConsumerTest.java index 0451453ea1..917d0b4bc6 100644 --- a/proc/common/src/test/java/org/neo4j/gds/MutatePropertyComputationResultConsumerTest.java +++ b/proc/common/src/test/java/org/neo4j/gds/MutatePropertyComputationResultConsumerTest.java @@ -22,6 +22,8 @@ import org.jetbrains.annotations.NotNull; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService; +import org.neo4j.gds.algorithms.metrics.PassthroughAlgorithmMetricRegistrar; import org.neo4j.gds.api.AlgorithmMetaDataSetter; import org.neo4j.gds.api.CSRGraph; import org.neo4j.gds.api.CloseableResourceRegistry; @@ -88,6 +90,7 @@ class MutatePropertyComputationResultConsumerTest { .nodeLookup(NodeLookup.EMPTY) .modelCatalog(ModelCatalog.EMPTY) .isGdsAdmin(false) + .algorithmMetricsService(new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar())) .build(); @BeforeEach diff --git a/proc/common/src/test/java/org/neo4j/gds/WriteNodePropertiesComputationResultConsumerTest.java b/proc/common/src/test/java/org/neo4j/gds/WriteNodePropertiesComputationResultConsumerTest.java index c3bd97e4b4..e83307e4d1 100644 --- a/proc/common/src/test/java/org/neo4j/gds/WriteNodePropertiesComputationResultConsumerTest.java +++ b/proc/common/src/test/java/org/neo4j/gds/WriteNodePropertiesComputationResultConsumerTest.java @@ -20,6 +20,8 @@ package org.neo4j.gds; import org.junit.jupiter.api.Test; +import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService; +import org.neo4j.gds.algorithms.metrics.PassthroughAlgorithmMetricRegistrar; import org.neo4j.gds.api.AlgorithmMetaDataSetter; import org.neo4j.gds.api.CloseableResourceRegistry; import org.neo4j.gds.api.DatabaseId; @@ -96,6 +98,7 @@ class WriteNodePropertiesComputationResultConsumerTest extends BaseTest { .modelCatalog(ModelCatalog.EMPTY) .isGdsAdmin(false) .nodePropertyExporterBuilder(new NativeNodePropertiesExporterBuilder(EmptyTransactionContext.INSTANCE)) + .algorithmMetricsService(new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar())) .build(); @Test diff --git a/proc/common/src/test/java/org/neo4j/gds/WriteProcCancellationTest.java b/proc/common/src/test/java/org/neo4j/gds/WriteProcCancellationTest.java index 745b4d08de..e70cb2ebfd 100644 --- a/proc/common/src/test/java/org/neo4j/gds/WriteProcCancellationTest.java +++ b/proc/common/src/test/java/org/neo4j/gds/WriteProcCancellationTest.java @@ -20,6 +20,8 @@ package org.neo4j.gds; import org.junit.jupiter.api.Test; +import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService; +import org.neo4j.gds.algorithms.metrics.PassthroughAlgorithmMetricRegistrar; import org.neo4j.gds.api.AlgorithmMetaDataSetter; import org.neo4j.gds.api.CloseableResourceRegistry; import org.neo4j.gds.api.DatabaseId; @@ -124,6 +126,7 @@ public long nodeCount() { .modelCatalog(ModelCatalog.EMPTY) .isGdsAdmin(false) .nodePropertyExporterBuilder(new NativeNodePropertiesExporterBuilder(DatabaseTransactionContext.of(db, tx))) + .algorithmMetricsService(new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar())) .build(); assertThatThrownBy(() -> resultConsumer.consume(computationResult, executionContext)) diff --git a/proc/path-finding/src/test/java/org/neo4j/gds/paths/traverse/DfsStreamComputationResultConsumerTest.java b/proc/path-finding/src/test/java/org/neo4j/gds/paths/traverse/DfsStreamComputationResultConsumerTest.java index b0e2376b65..90e31c1454 100644 --- a/proc/path-finding/src/test/java/org/neo4j/gds/paths/traverse/DfsStreamComputationResultConsumerTest.java +++ b/proc/path-finding/src/test/java/org/neo4j/gds/paths/traverse/DfsStreamComputationResultConsumerTest.java @@ -55,6 +55,7 @@ class DfsStreamComputationResultConsumerTest { void shouldNotComputePath() { when(graphMock.toOriginalNodeId(anyLong())).then(returnsFirstArg()); + when(computationResultMock.graph()).thenReturn(graphMock); when(computationResultMock.result()).thenReturn(Optional.of(HugeLongArray.of(1L, 2L))); diff --git a/proc/pregel/src/test/java/org/neo4j/gds/pregel/proc/PregelProcTest.java b/proc/pregel/src/test/java/org/neo4j/gds/pregel/proc/PregelProcTest.java index 845d8e023d..4d542cde22 100644 --- a/proc/pregel/src/test/java/org/neo4j/gds/pregel/proc/PregelProcTest.java +++ b/proc/pregel/src/test/java/org/neo4j/gds/pregel/proc/PregelProcTest.java @@ -29,6 +29,8 @@ import org.neo4j.gds.GdsCypher; import org.neo4j.gds.GraphAlgorithmFactory; import org.neo4j.gds.TestTaskStore; +import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService; +import org.neo4j.gds.algorithms.metrics.PassthroughAlgorithmMetricRegistrar; import org.neo4j.gds.api.Graph; import org.neo4j.gds.api.nodeproperties.ValueType; import org.neo4j.gds.assertj.ConditionFactory; @@ -209,6 +211,9 @@ void cleanupTaskRegistryWhenTheAlgorithmFailsInStreamMode() { proc.procedureTransaction = transactions.tx(); proc.log = NullLog.getInstance(); proc.callContext = ProcedureCallContext.EMPTY; + + proc.algorithmMetricsService = new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar()); + Map config = Map.of( "maxIterations", 20, "throwInCompute", true @@ -234,6 +239,9 @@ void cleanupTaskRegistryWhenTheAlgorithmFailsInWriteMode() { proc.procedureTransaction = transactions.tx(); proc.log = NullLog.getInstance(); proc.callContext = ProcedureCallContext.EMPTY; + + proc.algorithmMetricsService = new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar()); + Map config = Map.of( "maxIterations", 20, "throwInCompute", true @@ -258,6 +266,9 @@ void cleanupTaskRegistryWhenTheAlgorithmFailsInMutateMode() { proc.procedureTransaction = transactions.tx(); proc.log = NullLog.getInstance(); proc.callContext = ProcedureCallContext.EMPTY; + + proc.algorithmMetricsService = new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar()); + Map config = Map.of( "maxIterations", 20, "throwInCompute", true diff --git a/proc/test/src/main/java/org/neo4j/gds/ProcedureRunner.java b/proc/test/src/main/java/org/neo4j/gds/ProcedureRunner.java index 001f570242..0444c99bcc 100644 --- a/proc/test/src/main/java/org/neo4j/gds/ProcedureRunner.java +++ b/proc/test/src/main/java/org/neo4j/gds/ProcedureRunner.java @@ -19,6 +19,8 @@ */ package org.neo4j.gds; +import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService; +import org.neo4j.gds.algorithms.metrics.PassthroughAlgorithmMetricRegistrar; import org.neo4j.gds.compat.GraphDatabaseApiProxy; import org.neo4j.gds.core.Username; import org.neo4j.gds.core.utils.progress.TaskRegistryFactory; @@ -43,7 +45,8 @@ public static

P instantiateProcedure( TaskRegistryFactory taskRegistryFactory, UserLogRegistryFactory userLogRegistryFactory, Transaction tx, - Username username + Username username, + AlgorithmMetricsService algorithmMetricsService ) { P proc; try { @@ -61,6 +64,8 @@ public static

P instantiateProcedure( proc.userLogRegistryFactory = userLogRegistryFactory; proc.username = username; + proc.algorithmMetricsService = algorithmMetricsService; + return proc; } @@ -82,7 +87,8 @@ public static

P applyOnProcedure( taskRegistryFactory, EmptyUserLogRegistryFactory.INSTANCE, tx, - username + username, + new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar()) ); func.accept(proc); return proc; diff --git a/test-utils/src/main/java/org/neo4j/gds/BaseTest.java b/test-utils/src/main/java/org/neo4j/gds/BaseTest.java index afb596fb32..93baaf60b9 100644 --- a/test-utils/src/main/java/org/neo4j/gds/BaseTest.java +++ b/test-utils/src/main/java/org/neo4j/gds/BaseTest.java @@ -23,6 +23,8 @@ import org.assertj.core.api.Assertions; import org.intellij.lang.annotations.Language; import org.junit.jupiter.api.Timeout; +import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService; +import org.neo4j.gds.algorithms.metrics.PassthroughAlgorithmMetricRegistrar; import org.neo4j.gds.compat.Neo4jProxy; import org.neo4j.gds.compat.TestLog; import org.neo4j.gds.core.Settings; @@ -35,6 +37,11 @@ import org.neo4j.graphdb.Node; import org.neo4j.graphdb.Result; import org.neo4j.graphdb.Transaction; +import org.neo4j.kernel.api.procedure.GlobalProcedures; +import org.neo4j.kernel.extension.ExtensionFactory; +import org.neo4j.kernel.extension.context.ExtensionContext; +import org.neo4j.kernel.lifecycle.Lifecycle; +import org.neo4j.kernel.lifecycle.LifecycleAdapter; import org.neo4j.test.TestDatabaseManagementServiceBuilder; import org.neo4j.test.extension.ExtensionCallback; import org.neo4j.test.extension.ImpermanentDbmsExtension; @@ -86,6 +93,24 @@ protected void configuration(TestDatabaseManagementServiceBuilder builder) { builder.setConfigRaw(Map.of("unsupported.dbms.debug.trace_cursors", "true")); testLog = Neo4jProxy.testLog(); builder.setUserLogProvider(new TestLogProvider(testLog)); + + // Hacky as hell but will have to do until we make this BaseTest obsolete + builder.addExtension(new ExtensionFactory("AlgorithmMetricsServiceExtensionFactory") { + @Override + public Lifecycle newInstance(ExtensionContext context, Dependencies dependencies) { + dependencies.globalProcedures().registerComponent( + AlgorithmMetricsService.class, + ctx -> new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar()), + false + ); + return new LifecycleAdapter(); + } + + }); + } + + interface Dependencies { + GlobalProcedures globalProcedures(); } protected long clearDb() {