Skip to content

Commit

Permalink
Merge pull request #44 from CSOgroup/sparse-error
Browse files Browse the repository at this point in the history
Raise ValueError on sparse array
  • Loading branch information
marcovarrone authored Jun 7, 2024
2 parents f61f541 + ee06e64 commit 44106ad
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/cellcharter/tl/_gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import anndata as ad
import numpy as np
import pandas as pd
import scipy.sparse as sps
import torch
from lightkit.data import DataLoader, TensorLike, collate_tensor, dataset_from_tensors
from pycave import set_logging_level
Expand Down Expand Up @@ -103,6 +104,11 @@ def fit(self, data: TensorLike) -> GaussianMixture:
----------
The fitted Gaussian mixture.
"""
if sps.issparse(data):
raise ValueError(
"Sparse data is not supported. You may have forgotten to reduce the dimensionality of the data. Otherwise, please convert the data to a dense format."
)

if self.init_strategy == "sklearn":
if self.batch_size is None:
kmeans = KMeans(n_clusters=self.num_components, random_state=self.random_state, n_init=1)
Expand Down
22 changes: 22 additions & 0 deletions tests/tools/test_gmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pytest
import scipy.sparse as sps
import squidpy as sq

import cellcharter as cc


class TestCluster:
@pytest.mark.parametrize("dataset_name", ["mibitof"])
def test_sparse(self, dataset_name: str):
download_dataset = getattr(sq.datasets, dataset_name)
adata = download_dataset()
adata.X = sps.csr_matrix(adata.X)

sq.gr.spatial_neighbors(adata, coord_type="generic", delaunay=True)
cc.gr.remove_long_links(adata)

gmm = cc.tl.Cluster(n_clusters=(10))

# Check if fit raises a ValueError
with pytest.raises(ValueError):
gmm.fit(adata, use_rep=None)

0 comments on commit 44106ad

Please sign in to comment.