diff --git a/src/cellcharter/tl/_gmm.py b/src/cellcharter/tl/_gmm.py index 06eb050..783c0c0 100644 --- a/src/cellcharter/tl/_gmm.py +++ b/src/cellcharter/tl/_gmm.py @@ -6,6 +6,7 @@ import anndata as ad import numpy as np import pandas as pd +import scipy.sparse as sps import torch from lightkit.data import DataLoader, TensorLike, collate_tensor, dataset_from_tensors from pycave import set_logging_level @@ -103,6 +104,11 @@ def fit(self, data: TensorLike) -> GaussianMixture: ---------- The fitted Gaussian mixture. """ + if sps.issparse(data): + raise ValueError( + "Sparse data is not supported. You may have forgotten to reduce the dimensionality of the data. Otherwise, please convert the data to a dense format." + ) + if self.init_strategy == "sklearn": if self.batch_size is None: kmeans = KMeans(n_clusters=self.num_components, random_state=self.random_state, n_init=1) diff --git a/tests/tools/test_gmm.py b/tests/tools/test_gmm.py new file mode 100644 index 0000000..df7664f --- /dev/null +++ b/tests/tools/test_gmm.py @@ -0,0 +1,22 @@ +import pytest +import scipy.sparse as sps +import squidpy as sq + +import cellcharter as cc + + +class TestCluster: + @pytest.mark.parametrize("dataset_name", ["mibitof"]) + def test_sparse(self, dataset_name: str): + download_dataset = getattr(sq.datasets, dataset_name) + adata = download_dataset() + adata.X = sps.csr_matrix(adata.X) + + sq.gr.spatial_neighbors(adata, coord_type="generic", delaunay=True) + cc.gr.remove_long_links(adata) + + gmm = cc.tl.Cluster(n_clusters=(10)) + + # Check if fit raises a ValueError + with pytest.raises(ValueError): + gmm.fit(adata, use_rep=None)