Skip to content

Commit

Permalink
Remove baseTest from FilteredKnnIdMappingTest
Browse files Browse the repository at this point in the history
  • Loading branch information
IoannisPanagiotas committed Nov 17, 2023
1 parent 26b8548 commit 419380f
Showing 1 changed file with 22 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,42 +20,40 @@
package org.neo4j.gds.similarity.filteredknn;

import org.junit.jupiter.api.Test;
import org.neo4j.gds.BaseTest;
import org.neo4j.gds.PropertyMapping;
import org.neo4j.gds.StoreLoaderBuilder;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.extension.GdlExtension;
import org.neo4j.gds.extension.GdlGraph;
import org.neo4j.gds.extension.Inject;
import org.neo4j.gds.similarity.filtering.NodeFilterSpecFactory;
import org.neo4j.gds.similarity.knn.KnnContext;
import org.neo4j.gds.similarity.knn.KnnNodePropertySpec;

import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.LongStream;

import static org.assertj.core.api.Assertions.assertThat;

public class FilteredKnnIdMappingTest extends BaseTest {
@GdlExtension
public class FilteredKnnIdMappingTest {

@GdlGraph(idOffset = 4242)
private static final String DB_CYPHER =
"CREATE" +
" (a {name: 'a', knn: 1.2})" +
", (b {name: 'b', knn: 1.1})" +
", (c {name: 'c', knn: 2.1})" +
", (d {name: 'd', knn: 3.1})" +
", (e {name: 'e', knn: 4.1})";
" (a { knn: 1.2})" +
", (b { knn: 1.1})" +
", (c { knn: 2.1})" +
", (d { knn: 3.1})" +
", (e { knn: 4.1})";

@Inject
private Graph graph;
@Test
void shouldIdMapTheSourceNodeFilter() {
// Offset the Neo ID space, then get the lowest Neo ID to use for sourceNodeFilter
runQuery("UNWIND range(0, 10) AS foo CREATE ()");
runQuery("MATCH (n) DELETE n");
runQuery(DB_CYPHER);
var lowestNeoId = runQuery("MATCH (n) RETURN id(n) AS id ORDER BY id ASC LIMIT 1", (r) -> (Long) r.next().get("id"));

var graph = new StoreLoaderBuilder()
.databaseService(db)
.nodeProperties(List.of(PropertyMapping.of("knn")))
.build()
.graphStore()
.getUnion();

var lowestOriginalId = LongStream.range(0, graph.nodeCount()).map(graph::toOriginalNodeId).min().orElse(-1);
assertThat(lowestOriginalId).isPositive();

var config = ImmutableFilteredKnnBaseConfig.builder()
.nodeProperties(List.of(new KnnNodePropertySpec("knn")))
Expand All @@ -64,8 +62,9 @@ void shouldIdMapTheSourceNodeFilter() {
.maxIterations(1)
.randomSeed(20L)
.concurrency(1)
.sourceNodeFilter(NodeFilterSpecFactory.create(lowestNeoId))
.sourceNodeFilter(NodeFilterSpecFactory.create(lowestOriginalId))
.build();

var knn = FilteredKnn.createWithoutSeeding(graph, config, KnnContext.empty());

var result = knn.compute();
Expand All @@ -76,6 +75,7 @@ void shouldIdMapTheSourceNodeFilter() {
.map(res -> res.node1)
.map(graph::toOriginalNodeId)
.collect(Collectors.<Long>toSet());
assertThat(sourceNodesInResult).containsExactly(lowestNeoId);
assertThat(sourceNodesInResult).containsExactly(lowestOriginalId);

}
}

0 comments on commit 419380f

Please sign in to comment.