Skip to content

Commit

Permalink
update topology_classifier tests
Browse files Browse the repository at this point in the history
  • Loading branch information
franciscoeacosta committed Sep 3, 2024
1 parent cc51247 commit 299e65d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 53 deletions.
4 changes: 4 additions & 0 deletions neurometry/estimators/topology/topology_classifier.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

import numpy as np
import torch
from gtda.diagrams import PersistenceEntropy
from gtda.homology import VietorisRipsPersistence, WeightedRipsPersistence
from sklearn.base import BaseEstimator, ClassifierMixin
Expand Down Expand Up @@ -141,6 +142,9 @@ def fit(self, X, y=None):
Returns the instance itself.
"""

if not isinstance(X, np.ndarray | torch.Tensor):
raise ValueError(f"Expected array-like input for X, but got {type(X).__name__}.")

ref_point_clouds, ref_labels = self._generate_ref_data(X)
self.ref_labels = ref_labels
if self.reduce_dim:
Expand Down
66 changes: 13 additions & 53 deletions tests/neurometry/estimators/topology/test_topology_classifier.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os

import numpy as np
import pytest
from sklearn.exceptions import NotFittedError

Expand All @@ -17,7 +16,7 @@ class BaseTopologyTest:
num_points = 700
encoding_dim = 10
fano_factor = 0.1
num_samples = 100
num_samples = 200
homology_dimensions = (0, 1)

@classmethod
Expand All @@ -30,24 +29,10 @@ def setup_class(cls):
reduce_dim=True,
)

cls.null_data = cls.generate_null_data()
cls.circle_data = cls.generate_circle_data()
cls.sphere_data = cls.generate_sphere_data()
cls.torus_data = cls.generate_torus_data()

@classmethod
def generate_null_data(cls):
rng = np.random.default_rng(seed=0)
task_points = rng.normal(0, 1, (cls.num_points, cls.encoding_dim))
noisy_points, _ = synthetic.synthetic_neural_manifold(
points=task_points,
encoding_dim=cls.encoding_dim,
nonlinearity="sigmoid",
scales=5 * gs.random.rand(cls.encoding_dim),
fano_factor=cls.fano_factor,
)
return noisy_points

@classmethod
def generate_circle_data(cls):
task_points = synthetic.hypersphere(1, cls.num_points)
Expand Down Expand Up @@ -86,22 +71,22 @@ def generate_torus_data(cls):


class TestTopologyClassifier(BaseTopologyTest):
def test_fit(self):
"""Test that the fit method runs without errors."""
"""Unit tests for the topology classifier."""

def test_invalid_input(self):
"""Test classifier with invalid input data."""
invalid_data = "invalid_input"
with pytest.raises(ValueError):
self.classifier.fit(invalid_data)

def test_fit_and_predict_circle(self):
"""Test that the fit method runs without errors and predict on circle data."""
try:
self.classifier.fit(self.null_data)
self.classifier.fit(self.circle_data)
except Exception as e:
pytest.fail(f"Fit method raised an exception: {e}")

def test_predict_null(self):
"""Test prediction on null (noise) data."""
self.classifier.fit(self.null_data)
prediction = self.classifier.predict(self.null_data)
assert prediction[0] == 0, "Prediction for null data should be 0 (null)"

def test_predict_circle(self):
"""Test prediction on circle data."""
self.classifier.fit(self.circle_data)
# If fit is successful, test the prediction
prediction = self.classifier.predict(self.circle_data)
assert prediction[0] == 1, "Prediction for circle data should be 1 (circle)"

Expand All @@ -117,25 +102,6 @@ def test_predict_torus(self):
prediction = self.classifier.predict(self.torus_data)
assert prediction[0] == 3, "Prediction for torus data should be 3 (torus)"

def test_consistent_predictions(self):
"""Test that predictions are consistent across multiple calls."""
self.classifier.fit(self.circle_data)
prediction1 = self.classifier.predict(self.circle_data)
prediction2 = self.classifier.predict(self.circle_data)
assert prediction1[0] == prediction2[0], "Predictions should be consistent across calls"

def test_reduce_dim_false(self):
"""Test classifier with reduce_dim set to False."""
classifier = TopologyClassifier(
num_samples=self.num_samples,
fano_factor=self.fano_factor,
homology_dimensions=self.homology_dimensions,
reduce_dim=False,
)
classifier.fit(self.sphere_data)
prediction = classifier.predict(self.sphere_data)
assert prediction[0] == 2, "Prediction for sphere data should be 2 (sphere) with reduce_dim=False"

def test_not_fitted_error(self):
"""Test that predicting before fitting raises an error."""
classifier = TopologyClassifier(
Expand All @@ -147,9 +113,3 @@ def test_not_fitted_error(self):
with pytest.raises(NotFittedError):
classifier.predict(self.circle_data)

def test_invalid_input(self):
"""Test classifier with invalid input data."""
self.classifier.fit(self.null_data)
invalid_data = "invalid_input"
with pytest.raises(ValueError):
self.classifier.predict(invalid_data)

0 comments on commit 299e65d

Please sign in to comment.