Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 12, 2025
1 parent 8f3304a commit 63f5171
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 12 deletions.
34 changes: 23 additions & 11 deletions src/cellcharter/pl/_shape.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import warnings
from itertools import combinations
from pathlib import Path
from typing import Union

Expand All @@ -10,12 +9,12 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.sparse as sps
import seaborn as sns
import spatialdata as sd
import spatialdata_plot # noqa: F401
from anndata import AnnData
from squidpy._docs import d
import scipy.sparse as sps

from ._utils import adjust_box_widths

Expand Down Expand Up @@ -250,7 +249,7 @@ def plot_shapes(data, x, y, hue, hue_order, figsize, title: str | None = None) -
dodge=True,
hue_order=hue_order,
)

if len(data[hue].unique()) > 1:
handles, labels = ax.get_legend_handles_labels()
plt.legend(
Expand All @@ -266,6 +265,7 @@ def plot_shapes(data, x, y, hue, hue_order, figsize, title: str | None = None) -
plt.title(title)
plt.show()


@d.dedent
def shape_metrics(
adata: AnnData,
Expand Down Expand Up @@ -310,7 +310,7 @@ def shape_metrics(
metrics = [metrics]
elif isinstance(metrics, tuple):
metrics = list(metrics)

if cluster_id is not None and not isinstance(cluster_id, list) and not isinstance(cluster_id, np.ndarray):
cluster_id = [cluster_id]

Expand All @@ -329,12 +329,8 @@ def shape_metrics(
if cluster_key is not None:
keys.append(cluster_key)

metrics_df = (adata.obs[[component_key] + keys]
.drop_duplicates()
.dropna()
.set_index(component_key))
metrics_df = adata.obs[[component_key] + keys].drop_duplicates().dropna().set_index(component_key)


for metric in metrics:
metrics_df[metric] = metrics_df.index.map(adata.uns[f"shape_{component_key}"][metric])

Expand All @@ -351,10 +347,26 @@ def shape_metrics(
metrics_melted[cluster_key] = metrics_melted[cluster_key].cat.remove_unused_categories()

if cluster_key is not None:
plot_shapes(metrics_melted, "metric", "value", cluster_key, cluster_id, figsize, f'Spatial domains: {", ".join([str(cluster) for cluster in cluster_id])}')
plot_shapes(
metrics_melted,
"metric",
"value",
cluster_key,
cluster_id,
figsize,
f'Spatial domains: {", ".join([str(cluster) for cluster in cluster_id])}',
)

if condition_key is not None:
plot_shapes(metrics_melted, "metric", "value", condition_key, condition_groups, figsize, f'Spatial domains: {", ".join([str(cluster) for cluster in cluster_id])}')
plot_shapes(
metrics_melted,
"metric",
"value",
condition_key,
condition_groups,
figsize,
f'Spatial domains: {", ".join([str(cluster) for cluster in cluster_id])}',
)
else:
for metric in metrics:
fig = plt.figure(figsize=figsize)
Expand Down
9 changes: 8 additions & 1 deletion src/cellcharter/tl/_shape.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import warnings
from collections import deque
from concurrent.futures import ProcessPoolExecutor, as_completed

Expand All @@ -15,7 +16,7 @@
from shapely.ops import polygonize, unary_union
from skimage.morphology import skeletonize
from squidpy._docs import d
import warnings


def _alpha_shape(coords, alpha):
"""
Expand Down Expand Up @@ -233,6 +234,7 @@ def _rasterize(boundary, height=1000):
poly = shapely.affinity.scale(poly, scale_factor, scale_factor, origin=(0, 0, 0))
return features.rasterize([poly], out_shape=(height, int(height * (maxx - minx) / (maxy - miny)))), scale_factor


def linearity(
adata: AnnData,
cluster_key: str = "component",
Expand All @@ -255,6 +257,7 @@ def linearity(
copy=copy,
)


@d.dedent
def linearity_metric(
adata: AnnData,
Expand Down Expand Up @@ -333,6 +336,7 @@ def elongation(
copy=copy,
)


@d.dedent
def elongation_metric(
adata: AnnData,
Expand Down Expand Up @@ -410,6 +414,7 @@ def curl(
copy=copy,
)


@d.dedent
def curl_metric(
adata: AnnData,
Expand Down Expand Up @@ -445,6 +450,7 @@ def curl_metric(
return curl_score
adata.uns[f"shape_{cluster_key}"][out_key] = curl_score


def purity_metric(
adata: AnnData,
cluster_key: str = "component",
Expand All @@ -467,6 +473,7 @@ def purity_metric(
copy=copy,
)


@d.dedent
def purity(
adata: AnnData,
Expand Down

0 comments on commit 63f5171

Please sign in to comment.