diff --git a/src/cellcharter/pl/_shape.py b/src/cellcharter/pl/_shape.py index f6a4cd2..2f609ff 100644 --- a/src/cellcharter/pl/_shape.py +++ b/src/cellcharter/pl/_shape.py @@ -1,7 +1,6 @@ from __future__ import annotations import warnings -from itertools import combinations from pathlib import Path from typing import Union @@ -10,12 +9,12 @@ import matplotlib.pyplot as plt import numpy as np import pandas as pd +import scipy.sparse as sps import seaborn as sns import spatialdata as sd import spatialdata_plot # noqa: F401 from anndata import AnnData from squidpy._docs import d -import scipy.sparse as sps from ._utils import adjust_box_widths @@ -250,7 +249,7 @@ def plot_shapes(data, x, y, hue, hue_order, figsize, title: str | None = None) - dodge=True, hue_order=hue_order, ) - + if len(data[hue].unique()) > 1: handles, labels = ax.get_legend_handles_labels() if len(handles) > 1: @@ -267,6 +266,7 @@ def plot_shapes(data, x, y, hue, hue_order, figsize, title: str | None = None) - plt.title(title) plt.show() + @d.dedent def shape_metrics( adata: AnnData, @@ -311,7 +311,7 @@ def shape_metrics( metrics = [metrics] elif isinstance(metrics, tuple): metrics = list(metrics) - + if cluster_id is not None and not isinstance(cluster_id, list) and not isinstance(cluster_id, np.ndarray): cluster_id = [cluster_id] @@ -330,12 +330,8 @@ def shape_metrics( if cluster_key is not None: keys.append(cluster_key) - metrics_df = (adata.obs[[component_key] + keys] - .drop_duplicates() - .dropna() - .set_index(component_key)) + metrics_df = adata.obs[[component_key] + keys].drop_duplicates().dropna().set_index(component_key) - for metric in metrics: metrics_df[metric] = metrics_df.index.map(adata.uns[f"shape_{component_key}"][metric]) @@ -353,9 +349,27 @@ def shape_metrics( if cluster_key is not None: plot_shapes(metrics_melted, "metric", "value", cluster_key, cluster_id, figsize, f'Spatial domains: {", ".join([str(cluster) for cluster in cluster_id])} by domain') + plot_shapes( + metrics_melted, + "metric", + "value", + cluster_key, + cluster_id, + figsize, + f'Spatial domains: {", ".join([str(cluster) for cluster in cluster_id])}', + ) if condition_key is not None: plot_shapes(metrics_melted, "metric", "value", condition_key, condition_groups, figsize, f'Spatial domains: {", ".join([str(cluster) for cluster in cluster_id])} by condition') + plot_shapes( + metrics_melted, + "metric", + "value", + condition_key, + condition_groups, + figsize, + f'Spatial domains: {", ".join([str(cluster) for cluster in cluster_id])}', + ) else: for metric in metrics: plot_shapes( diff --git a/src/cellcharter/tl/_shape.py b/src/cellcharter/tl/_shape.py index 066fa3c..99b7ea3 100644 --- a/src/cellcharter/tl/_shape.py +++ b/src/cellcharter/tl/_shape.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from collections import deque from concurrent.futures import ProcessPoolExecutor, as_completed @@ -15,7 +16,7 @@ from shapely.ops import polygonize, unary_union from skimage.morphology import skeletonize from squidpy._docs import d -import warnings + def _alpha_shape(coords, alpha): """ @@ -233,6 +234,7 @@ def _rasterize(boundary, height=1000): poly = shapely.affinity.scale(poly, scale_factor, scale_factor, origin=(0, 0, 0)) return features.rasterize([poly], out_shape=(height, int(height * (maxx - minx) / (maxy - miny)))), scale_factor + def linearity( adata: AnnData, cluster_key: str = "component", @@ -255,6 +257,7 @@ def linearity( copy=copy, ) + @d.dedent def linearity_metric( adata: AnnData, @@ -333,6 +336,7 @@ def elongation( copy=copy, ) + @d.dedent def elongation_metric( adata: AnnData, @@ -410,6 +414,7 @@ def curl( copy=copy, ) + @d.dedent def curl_metric( adata: AnnData, @@ -445,6 +450,7 @@ def curl_metric( return curl_score adata.uns[f"shape_{cluster_key}"][out_key] = curl_score + def purity_metric( adata: AnnData, cluster_key: str = "component", @@ -467,6 +473,7 @@ def purity_metric( copy=copy, ) + @d.dedent def purity( adata: AnnData,