Skip to content

Commit

Permalink
Merge branch 'nhood_connectivity' of github.com:LukasHats/cellcharter…
Browse files Browse the repository at this point in the history
… into nhood_connectivity
  • Loading branch information
marcovarrone committed Jan 12, 2025
2 parents 4a86a80 + 63f5171 commit 231b612
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 10 deletions.
32 changes: 23 additions & 9 deletions src/cellcharter/pl/_shape.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import warnings
from itertools import combinations
from pathlib import Path
from typing import Union

Expand All @@ -10,12 +9,12 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.sparse as sps
import seaborn as sns
import spatialdata as sd
import spatialdata_plot # noqa: F401
from anndata import AnnData
from squidpy._docs import d
import scipy.sparse as sps

from ._utils import adjust_box_widths

Expand Down Expand Up @@ -250,7 +249,7 @@ def plot_shapes(data, x, y, hue, hue_order, figsize, title: str | None = None) -
dodge=True,
hue_order=hue_order,
)

if len(data[hue].unique()) > 1:
handles, labels = ax.get_legend_handles_labels()
if len(handles) > 1:
Expand All @@ -267,6 +266,7 @@ def plot_shapes(data, x, y, hue, hue_order, figsize, title: str | None = None) -
plt.title(title)
plt.show()


@d.dedent
def shape_metrics(
adata: AnnData,
Expand Down Expand Up @@ -311,7 +311,7 @@ def shape_metrics(
metrics = [metrics]
elif isinstance(metrics, tuple):
metrics = list(metrics)

if cluster_id is not None and not isinstance(cluster_id, list) and not isinstance(cluster_id, np.ndarray):
cluster_id = [cluster_id]

Expand All @@ -330,12 +330,8 @@ def shape_metrics(
if cluster_key is not None:
keys.append(cluster_key)

metrics_df = (adata.obs[[component_key] + keys]
.drop_duplicates()
.dropna()
.set_index(component_key))
metrics_df = adata.obs[[component_key] + keys].drop_duplicates().dropna().set_index(component_key)


for metric in metrics:
metrics_df[metric] = metrics_df.index.map(adata.uns[f"shape_{component_key}"][metric])

Expand All @@ -353,9 +349,27 @@ def shape_metrics(

if cluster_key is not None:
plot_shapes(metrics_melted, "metric", "value", cluster_key, cluster_id, figsize, f'Spatial domains: {", ".join([str(cluster) for cluster in cluster_id])} by domain')
plot_shapes(
metrics_melted,
"metric",
"value",
cluster_key,
cluster_id,
figsize,
f'Spatial domains: {", ".join([str(cluster) for cluster in cluster_id])}',
)

if condition_key is not None:
plot_shapes(metrics_melted, "metric", "value", condition_key, condition_groups, figsize, f'Spatial domains: {", ".join([str(cluster) for cluster in cluster_id])} by condition')
plot_shapes(
metrics_melted,
"metric",
"value",
condition_key,
condition_groups,
figsize,
f'Spatial domains: {", ".join([str(cluster) for cluster in cluster_id])}',
)
else:
for metric in metrics:
plot_shapes(
Expand Down
9 changes: 8 additions & 1 deletion src/cellcharter/tl/_shape.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import warnings
from collections import deque
from concurrent.futures import ProcessPoolExecutor, as_completed

Expand All @@ -15,7 +16,7 @@
from shapely.ops import polygonize, unary_union
from skimage.morphology import skeletonize
from squidpy._docs import d
import warnings


def _alpha_shape(coords, alpha):
"""
Expand Down Expand Up @@ -233,6 +234,7 @@ def _rasterize(boundary, height=1000):
poly = shapely.affinity.scale(poly, scale_factor, scale_factor, origin=(0, 0, 0))
return features.rasterize([poly], out_shape=(height, int(height * (maxx - minx) / (maxy - miny)))), scale_factor


def linearity(
adata: AnnData,
cluster_key: str = "component",
Expand All @@ -255,6 +257,7 @@ def linearity(
copy=copy,
)


@d.dedent
def linearity_metric(
adata: AnnData,
Expand Down Expand Up @@ -333,6 +336,7 @@ def elongation(
copy=copy,
)


@d.dedent
def elongation_metric(
adata: AnnData,
Expand Down Expand Up @@ -410,6 +414,7 @@ def curl(
copy=copy,
)


@d.dedent
def curl_metric(
adata: AnnData,
Expand Down Expand Up @@ -445,6 +450,7 @@ def curl_metric(
return curl_score
adata.uns[f"shape_{cluster_key}"][out_key] = curl_score


def purity_metric(
adata: AnnData,
cluster_key: str = "component",
Expand All @@ -467,6 +473,7 @@ def purity_metric(
copy=copy,
)


@d.dedent
def purity(
adata: AnnData,
Expand Down

0 comments on commit 231b612

Please sign in to comment.