diff --git a/algo/src/test/java/org/neo4j/gds/triangle/intersect/HugeIntersectionTest.java b/algo/src/test/java/org/neo4j/gds/triangle/intersect/HugeIntersectionTest.java index 6e87ba6ba5..5694e0c7b0 100644 --- a/algo/src/test/java/org/neo4j/gds/triangle/intersect/HugeIntersectionTest.java +++ b/algo/src/test/java/org/neo4j/gds/triangle/intersect/HugeIntersectionTest.java @@ -19,78 +19,86 @@ */ package org.neo4j.gds.triangle.intersect; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.neo4j.gds.BaseTest; import org.neo4j.gds.Orientation; -import org.neo4j.gds.StoreLoaderBuilder; -import org.neo4j.gds.api.RelationshipIntersect; -import org.neo4j.graphdb.Node; -import org.neo4j.graphdb.RelationshipType; +import org.neo4j.gds.RelationshipType; +import org.neo4j.gds.api.Graph; +import org.neo4j.gds.core.concurrency.DefaultPool; +import org.neo4j.gds.core.loading.construction.GraphFactory; +import org.neo4j.gds.core.loading.construction.RelationshipsBuilder; -import java.util.Arrays; -import java.util.PrimitiveIterator; +import java.util.ArrayList; import java.util.Random; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.neo4j.gds.compat.GraphDatabaseApiProxy.applyInFullAccessTransaction; +import static org.assertj.core.api.Assertions.assertThat; -final class HugeIntersectionTest extends BaseTest { +final class HugeIntersectionTest { private static final int DEGREE = 25; - public static final RelationshipType TYPE = RelationshipType.withName("TYPE"); - private static RelationshipIntersect INTERSECT; - private static long START1; - private static long START2; - private static long[] TARGETS; - - @BeforeEach - void setup() { - Random random = new Random(0L); - long[] neoStarts = new long[2]; - long[] neoTargets = applyInFullAccessTransaction(db, tx -> { - Node start1 = tx.createNode(); - Node start2 = tx.createNode(); - Node start3 = tx.createNode(); - neoStarts[0] = start1.getId(); - neoStarts[1] = start2.getId(); - start1.createRelationshipTo(start2, TYPE); - long[] targets = new long[DEGREE]; - int some = 0; - for (int i = 0; i < DEGREE; i++) { - Node target = tx.createNode(); - start1.createRelationshipTo(target, TYPE); - start3.createRelationshipTo(target, TYPE); - if (random.nextBoolean()) { - start2.createRelationshipTo(target, TYPE); - targets[some++] = target.getId(); - } - } - return Arrays.copyOf(targets, some); - }); - var graph = new StoreLoaderBuilder() - .databaseService(db) - .globalOrientation(Orientation.UNDIRECTED) - .build() - .graph(); + @Test + void intersectWithTargets() { + + ArrayList targets = new ArrayList<>(); + var graph = produceGraph(targets); + var targetIterator = targets.iterator(); - INTERSECT = RelationshipIntersectFactoryLocator.lookup(graph) + var intersect = RelationshipIntersectFactoryLocator.lookup(graph) .orElseThrow(IllegalArgumentException::new) .load(graph, ImmutableRelationshipIntersectConfig.builder().build()); - START1 = graph.toMappedNodeId(neoStarts[0]); - START2 = graph.toMappedNodeId(neoStarts[1]); - TARGETS = Arrays.stream(neoTargets).map(graph::toMappedNodeId).toArray(); - Arrays.sort(TARGETS); - } - @Test - void intersectWithTargets() { - PrimitiveIterator.OfLong targets = Arrays.stream(TARGETS).iterator(); - INTERSECT.intersectAll(START1, (a, b, c) -> { - assertEquals(START1, a); - assertEquals(START2, b); - assertEquals(targets.nextLong(), c); + var start1 = Math.min(graph.toMappedNodeId(DEGREE + 1), graph.toMappedNodeId(DEGREE)); + var start2 = Math.max(graph.toMappedNodeId(DEGREE + 1), graph.toMappedNodeId(DEGREE)); + + + intersect.intersectAll(start2, (a, b, c) -> { + + Long next = targetIterator.next(); + var targetMappedId = graph.toMappedNodeId(next); + assertThat(a).isEqualTo(targetMappedId); + assertThat(b).isEqualTo(start1); + assertThat(c).isEqualTo(start2); + }); + + assertThat(targetIterator.hasNext()).isFalse(); + } + + Graph produceGraph(ArrayList targets) { + Random random = new Random(0); + var nodesBuilder = GraphFactory.initNodesBuilder() + .maxOriginalId(2 + DEGREE) + .concurrency(1) + .build(); + + for (long i = 0; i < 3 + DEGREE; ++i) { + nodesBuilder.addNode(i); + } + + var idMap = nodesBuilder.build().idMap(); + RelationshipsBuilder relationshipsBuilder = GraphFactory.initRelationshipsBuilder() + .nodes(idMap) + .relationshipType(RelationshipType.of("FOO")) + .orientation(Orientation.UNDIRECTED) + .executorService(DefaultPool.INSTANCE) + .build(); + + + long start0 = DEGREE + 2, start1 = DEGREE + 1, start2 = DEGREE; + relationshipsBuilder.add(start1, start2); + for (int targetId = 0; targetId < start2; targetId++) { + + relationshipsBuilder.add(start1, targetId); + relationshipsBuilder.add(start0, targetId); + + if (random.nextBoolean()) { + relationshipsBuilder.add(start2, targetId); + targets.add((long) targetId); + } + } + + var relationships = relationshipsBuilder.build(); + + return GraphFactory.create(idMap, relationships); } }