Skip to content

Commit

Permalink
Remove baseTest from Node2VecTest
Browse files Browse the repository at this point in the history
  • Loading branch information
IoannisPanagiotas committed Nov 17, 2023
1 parent f7f9768 commit 79cb574
Showing 1 changed file with 29 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import org.assertj.core.api.junit.jupiter.SoftAssertionsExtension;
import org.assertj.core.data.Offset;
import org.assertj.core.data.Percentage;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
Expand All @@ -32,14 +31,12 @@
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.MethodSource;
import org.neo4j.gds.BaseTest;
import org.neo4j.gds.NodeLabel;
import org.neo4j.gds.Orientation;
import org.neo4j.gds.PropertyMapping;
import org.neo4j.gds.RelationshipType;
import org.neo4j.gds.StoreLoaderBuilder;
import org.neo4j.gds.TestProgressTracker;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.gds.collections.ha.HugeObjectArray;
import org.neo4j.gds.collections.hsa.HugeSparseLongArray;
Expand All @@ -55,6 +52,9 @@
import org.neo4j.gds.core.utils.progress.EmptyTaskRegistryFactory;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.shuffle.ShuffleUtil;
import org.neo4j.gds.extension.GdlExtension;
import org.neo4j.gds.extension.GdlGraph;
import org.neo4j.gds.extension.Inject;
import org.neo4j.gds.gdl.GdlFactory;
import org.neo4j.gds.ml.core.tensor.FloatVector;

Expand All @@ -69,11 +69,13 @@
import static org.neo4j.gds.assertj.Extractors.removingThreadId;

@ExtendWith(SoftAssertionsExtension.class)
class Node2VecTest extends BaseTest {
@GdlExtension
class Node2VecTest {

private static final List<Long> NO_SOURCE_NODES = List.of();
private static final Optional<Long> NO_RANDOM_SEED = Optional.empty();

@GdlGraph
private static final String DB_CYPHER =
"CREATE" +
" (a:Node1)" +
Expand All @@ -88,21 +90,19 @@ class Node2VecTest extends BaseTest {
", (b)-[:REL {prop: 1.0}]->(c)" +
", (c)-[:REL {prop: 1.0}]->(b)";

@BeforeEach
void setUp() {
runQuery(DB_CYPHER);
}
@Inject
private Graph graph;

@Inject
private GraphStore graphStore;

@ParameterizedTest(name = "{0}")
@MethodSource("graphs")
void embeddingsShouldHaveTheConfiguredDimension(String msg, Iterable<String> nodeLabels) {
Graph graph = new StoreLoaderBuilder()
.databaseService(db)
.nodeLabels(nodeLabels)
.build()
.graph();
void embeddingsShouldHaveTheConfiguredDimension(String msg, List<NodeLabel> nodeLabels) {

var currentGraph = graphStore.getGraph(nodeLabels);
int embeddingDimension = 128;

var trainParameters = new TrainParameters(
0.025,
0.0001,
Expand All @@ -112,8 +112,9 @@ void embeddingsShouldHaveTheConfiguredDimension(String msg, Iterable<String> nod
embeddingDimension,
EmbeddingInitializer.NORMALIZED
);

HugeObjectArray<FloatVector> node2Vec = new Node2Vec(
graph,
currentGraph,
4,
NO_SOURCE_NODES,
NO_RANDOM_SEED,
Expand All @@ -123,7 +124,7 @@ void embeddingsShouldHaveTheConfiguredDimension(String msg, Iterable<String> nod
ProgressTracker.NULL_TRACKER
).compute().embeddings();

graph.forEachNode(node -> {
currentGraph.forEachNode(node -> {
assertEquals(embeddingDimension, node2Vec.get(node).data().length);
return true;
}
Expand All @@ -136,19 +137,19 @@ void embeddingsShouldHaveTheConfiguredDimension(String msg, Iterable<String> nod
"false,3"
})
void shouldLogProgress(boolean relationshipWeights, int expectedProgresses) {
var storeLoaderBuilder = new StoreLoaderBuilder()
.databaseService(db);
Graph currentGraph;
if (relationshipWeights) {
storeLoaderBuilder.addRelationshipProperty(PropertyMapping.of("prop"));
currentGraph = graph;
} else {
currentGraph = graphStore.getGraph(RelationshipType.of("REL"), Optional.empty());
}
Graph graph = storeLoaderBuilder.build().graph();

int embeddingDimension = 128;
Node2VecStreamConfig config = ImmutableNode2VecStreamConfig
.builder()
.embeddingDimension(embeddingDimension)
.build();
var progressTask = new Node2VecAlgorithmFactory<>().progressTask(graph, config);
var progressTask = new Node2VecAlgorithmFactory<>().progressTask(currentGraph, config);

var walkParameters = new WalkParameters(10, 80, 1.0, 1.0, 0.001, 0.75);
var trainParameters = new TrainParameters(
Expand All @@ -163,7 +164,7 @@ void shouldLogProgress(boolean relationshipWeights, int expectedProgresses) {
var log = Neo4jProxy.testLog();
var progressTracker = new TestProgressTracker(progressTask, log, 4, EmptyTaskRegistryFactory.INSTANCE);
new Node2Vec(
graph,
currentGraph,
4,
NO_SOURCE_NODES,
NO_RANDOM_SEED,
Expand Down Expand Up @@ -226,13 +227,13 @@ void shouldEstimateMemory() {

@Test
void failOnNegativeWeights() {
var graph = GdlFactory.of("CREATE (a)-[:REL {weight: -1}]->(b)").build().getUnion();
var negativeGraph = GdlFactory.of("CREATE (a)-[:REL {weight: -1}]->(b)").build().getUnion();

var walkParameters = new WalkParameters(10, 80, 1.0, 1.0, 0.001, 0.75);
var trainParameters = new TrainParameters(0.025, 0.0001, 1, 1, 1, 128, EmbeddingInitializer.NORMALIZED);

var node2Vec = new Node2Vec(
graph,
negativeGraph,
4,
NO_SOURCE_NODES,
NO_RANDOM_SEED,
Expand All @@ -253,7 +254,7 @@ void failOnNegativeWeights() {
@Disabled("The order of the randomWalks + its usage in the training is not deterministic yet.")
@Test
void randomSeed(SoftAssertions softly) {
Graph graph = new StoreLoaderBuilder().databaseService(db).build().graph();


int embeddingDimension = 2;
var walkParameters = new WalkParameters(1, 20, 1.0, 1.0, 0.001, 0.75);
Expand Down Expand Up @@ -288,8 +289,8 @@ void randomSeed(SoftAssertions softly) {

static Stream<Arguments> graphs() {
return Stream.of(
Arguments.of("All Labels", List.of()),
Arguments.of("Non Consecutive Original IDs", List.of("Node2", "Isolated"))
Arguments.of("All Labels", List.of(NodeLabel.of("Node1"), NodeLabel.of("Node2"), NodeLabel.of("Isolated"))),
Arguments.of("Non Consecutive Original IDs", List.of(NodeLabel.of("Node2"), NodeLabel.of("Isolated")))
);
}

Expand Down

0 comments on commit 79cb574

Please sign in to comment.