Skip to content

Commit

Permalink
Simplify shape metrics plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
marcovarrone committed Jan 12, 2025
1 parent bd742b0 commit 8f3304a
Showing 1 changed file with 99 additions and 108 deletions.
207 changes: 99 additions & 108 deletions src/cellcharter/pl/_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
import spatialdata as sd
import spatialdata_plot # noqa: F401
from anndata import AnnData
from scipy.stats import ttest_ind
from squidpy._docs import d
import scipy.sparse as sps

from ._utils import adjust_box_widths

Expand Down Expand Up @@ -75,6 +75,7 @@ def boundaries(
component_key: str = "component",
alpha_boundary: float = 0.5,
show_cells: bool = True,
cells_radius: float = None,
save: str | Path | None = None,
) -> None:
"""
Expand All @@ -98,6 +99,9 @@ def boundaries(
-------
%(plotting_returns)s
"""
if show_cells is True and cells_radius is None:
raise ValueError("cells_radius must be provided when show_cells is True")

adata = adata[adata.obs[library_key] == sample].copy()
del adata.raw
clusters = adata.obs[component_key].unique()
Expand All @@ -114,7 +118,7 @@ def boundaries(
adata.obs["region"] = "cells"

xy = adata.obsm["spatial"]
cell_circles = sd.models.ShapesModel.parse(xy, geometry=0, radius=3000, index=adata.obs["instance_id"])
cell_circles = sd.models.ShapesModel.parse(xy, geometry=0, radius=cells_radius, index=adata.obs["instance_id"])

obs = pd.DataFrame(list(boundaries.keys()), columns=[component_key], index=np.arange(len(boundaries)).astype(str))
adata_obs = ad.AnnData(X=pd.DataFrame(index=obs.index, columns=adata.var_names), obs=obs)
Expand All @@ -123,6 +127,10 @@ def boundaries(
adata_obs.obs["instance_id"] = np.arange(len(boundaries))
adata_obs.obs[component_key] = pd.Categorical(adata_obs.obs[component_key])

if sps.issparse(adata.X):
# If the adata is sparse, we need to convert the adata_obs to an empty sparse matrix
adata_obs.X = sps.csr_matrix((len(adata_obs.obs), len(adata.var_names)))

adata = ad.concat((adata, adata_obs), join="outer")

adata.obs["region"] = adata.obs["region"].astype("category")
Expand Down Expand Up @@ -219,15 +227,54 @@ def plot_shape_metrics(
)


def plot_shapes(data, x, y, hue, hue_order, figsize, title: str | None = None) -> None:
fig = plt.figure(figsize=figsize)
ax = sns.boxplot(
data=data,
x=x,
hue=hue,
y=y,
showfliers=False,
hue_order=hue_order,
)
adjust_box_widths(fig, 0.9)

ax = sns.stripplot(
data=data,
x=x,
hue=hue,
y=y,
color="0.08",
size=4,
jitter=0.13,
dodge=True,
hue_order=hue_order,
)

if len(data[hue].unique()) > 1:
handles, labels = ax.get_legend_handles_labels()
plt.legend(
handles[0 : len(data[hue].unique())],
labels[0 : len(data[hue].unique())],
bbox_to_anchor=(1.0, 1.03),
title=hue,
)
else:
if ax.get_legend() is not None:
ax.get_legend().remove()
plt.ylim(-0.05, 1.05)
plt.title(title)
plt.show()

@d.dedent
def shape_metrics(
adata: AnnData,
condition_key: str,
condition_key: str | None = None,
condition_groups: list[str] | None = None,
cluster_key: str | None = None,
cluster_id: list[str] | None = None,
cluster_id: str | list[str] | None = None,
component_key: str = "component",
metrics: str | tuple[str] | list[str] = ("linearity", "curl"),
metrics: str | tuple[str] | list[str] | None = None,
fontsize: str | int = "small",
figsize: tuple[float, float] = (8, 7),
title: str | None = None,
Expand All @@ -249,7 +296,7 @@ def shape_metrics(
component_key
Key in :attr:`anndata.AnnData.obs` where the component labels are stored.
metrics
List of metrics to plot. Available metrics are ``linearity``, ``curl``, ``elongation``, ``purity``.
List of metrics to plot. Available metrics are ``linearity``, ``curl``, ``elongation``, ``purity``. If `None`, all computed metrics are plotted.
figsize
Figure size.
title
Expand All @@ -263,119 +310,63 @@ def shape_metrics(
metrics = [metrics]
elif isinstance(metrics, tuple):
metrics = list(metrics)

if cluster_id is not None and not isinstance(cluster_id, list) and not isinstance(cluster_id, np.ndarray):
cluster_id = [cluster_id]

metrics_df = {metric: adata.uns[f"shape_{component_key}"][metric] for metric in metrics}
metrics_df[condition_key] = (
adata[~adata.obs[condition_key].isna()]
.obs[[component_key, condition_key]]
.drop_duplicates()
.set_index(component_key)
.to_dict()[condition_key]
)
if condition_groups is None and condition_key is not None:
condition_groups = adata.obs[condition_key].cat.categories
else:
if not isinstance(condition_groups, list) and not isinstance(condition_groups, np.ndarray):
condition_groups = [condition_groups]

metrics_df[cluster_key] = (
adata[~adata.obs[condition_key].isna()]
.obs[[component_key, cluster_key]]
if metrics is None:
metrics = [metric for metric in adata.uns[f"shape_{component_key}"].keys() if metric != "boundary"]

keys = []
if condition_key is not None:
keys.append(condition_key)
if cluster_key is not None:
keys.append(cluster_key)

metrics_df = (adata.obs[[component_key] + keys]
.drop_duplicates()
.set_index(component_key)
.to_dict()[cluster_key]
)
.dropna()
.set_index(component_key))


for metric in metrics:
metrics_df[metric] = metrics_df.index.map(adata.uns[f"shape_{component_key}"][metric])

metrics_df = pd.DataFrame(metrics_df)
if cluster_id is not None:
metrics_df = metrics_df[metrics_df[cluster_key].isin(cluster_id)]

metrics_df = pd.melt(
metrics_df[metrics + [condition_key]],
id_vars=[condition_key],
var_name="metric",
)

conditions = (
enumerate(combinations(adata.obs[condition_key].cat.categories, 2))
if condition_groups is None
else [condition_groups]
)

for condition1, condition2 in conditions:
fig = plt.figure(figsize=figsize)
metrics_condition_pair = metrics_df[metrics_df[condition_key].isin([condition1, condition2])]
ax = sns.boxplot(
data=metrics_condition_pair,
x="metric",
hue=condition_key,
y="value",
showfliers=False,
hue_order=[condition1, condition2],
metrics_melted = pd.melt(
metrics_df,
id_vars=keys,
value_vars=metrics,
var_name="metric",
)

ax.tick_params(labelsize=fontsize)
ax.set_xlabel(ax.get_xlabel(), fontsize=fontsize)
ax.tick_params(labelsize=fontsize)
ax.set_ylabel(ax.get_ylabel(), fontsize=fontsize)

adjust_box_widths(fig, 0.9)

ax = sns.stripplot(
data=metrics_condition_pair,
x="metric",
hue=condition_key,
y="value",
color="0.08",
size=4,
jitter=0.13,
dodge=True,
hue_order=condition_groups if condition_groups else None,
)
handles, labels = ax.get_legend_handles_labels()
plt.legend(
handles[0 : len(metrics_condition_pair[condition_key].unique())],
labels[0 : len(metrics_condition_pair[condition_key].unique())],
bbox_to_anchor=(1.24, 1.02),
fontsize=fontsize,
)
metrics_melted[cluster_key] = metrics_melted[cluster_key].cat.remove_unused_categories()

for count, metric in enumerate(["linearity", "curl"]):
pvalue = ttest_ind(
metrics_condition_pair[
(metrics_condition_pair[condition_key] == condition1) & (metrics_condition_pair["metric"] == metric)
]["value"],
metrics_condition_pair[
(metrics_condition_pair[condition_key] == condition2) & (metrics_condition_pair["metric"] == metric)
]["value"],
)[1]
x1, x2 = count, count
y, h, col = (
metrics_condition_pair[(metrics_condition_pair["metric"] == metric)]["value"].max()
+ 0.02
+ 0.05 * count,
0.01,
"k",
if cluster_key is not None:
plot_shapes(metrics_melted, "metric", "value", cluster_key, cluster_id, figsize, f'Spatial domains: {", ".join([str(cluster) for cluster in cluster_id])}')

if condition_key is not None:
plot_shapes(metrics_melted, "metric", "value", condition_key, condition_groups, figsize, f'Spatial domains: {", ".join([str(cluster) for cluster in cluster_id])}')
else:
for metric in metrics:
fig = plt.figure(figsize=figsize)
ax = sns.boxplot(
data=metrics_df,
x=cluster_key,
hue=condition_key,
y=metric,
showfliers=False,
# hue_order=[condition1, condition2],
)
plt.plot([x1 - 0.2, x1 - 0.2, x2 + 0.2, x2 + 0.2], [y, y + h, y + h, y], lw=1.5, c=col)
if pvalue < 0.05:
plt.text(
(x1 + x2) * 0.5,
y + h * 2,
f"p = {pvalue:.2e}",
ha="center",
va="bottom",
color=col,
fontdict={"fontsize": fontsize},
)
else:
plt.text(
(x1 + x2) * 0.5,
y + h * 2,
"ns",
ha="center",
va="bottom",
color=col,
fontdict={"fontsize": fontsize},
)
if title is not None:
plt.title(title, fontdict={"fontsize": fontsize})
plt.show()
plt.show()


@d.dedent
Expand Down

0 comments on commit 8f3304a

Please sign in to comment.