Skip to content

Commit

Permalink
first commit with new neighborhood_connectivity function into _shap.py
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasHats committed Dec 23, 2024
1 parent 0ecc3c0 commit b6cd81b
Showing 1 changed file with 125 additions and 0 deletions.
125 changes: 125 additions & 0 deletions src/cellcharter/pl/_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from anndata import AnnData
from scipy.stats import ttest_ind
from squidpy._docs import d
from typing import Union

from ._utils import adjust_box_widths

Expand Down Expand Up @@ -375,3 +376,127 @@ def shape_metrics(
if title is not None:
plt.title(title, fontdict={"fontsize": fontsize})
plt.show()


@d.dedent
def plot_neighborhood_connectivity(
adata,
cluster_key: str,
component_key: str,
cluster_id: list = None,
condition_key: str = None,
condition_groups: list = None,
library_key: str = "library_id",
figsize: tuple = (12, 6),
title: str = None,
save: Union[str, Path] = None,
dpi: int = 600,
show: bool = True,
**kwargs,
) -> Union[plt.Axes, pd.DataFrame]:
"""
Calculates and plots the fraction of cells from neighborhoods that are associated to
a connected component and thereby gives an idea about neighborhood connectivity
Parameters
--------------------------------------------------------------------------------------
adata : AnnData
AnnData object
cluster_key: str
Key in adata.obs that contains the cellcharter cluster information
component_key: str
Key in adata.obs that contains the connected component information after running
cc.gr.connected_components
condition_groups: list
List of condition groups in condition_key to plot
condition_key: str
Key in adata.obs that contains the condition information if users want to compared
conditions
library_key: str
Key in adata.obs that contains the sample or image ID information
cluster_id: list
List of cluster IDs in cluster_key to plot
figsize: tuple
Figure size
title: str
Title of the plot
show: bool
If True, the plot will be shown and axes are always returned,
if False, the function will return the result dataframe
Additional keyword arguments (**kwargs) will be passed to sns.violinplot()
Returns
--------------------------------------------------------------------------------------
Union[None, pd.DataFrame]
If show=True, returns ax and displays the plot
If show=False, returns the computed DataFrame
Examples
--------------------------------------------------------------------------------------
cc.pl_neighborhood_connectivity(adata,
cluster_key='cellcharter_CN',
component_key='component',
condition_key='disease',
library_key='library_id'
)
"""
# Checking mandatory parameters
if not isinstance(adata, ad.AnnData):
raise ValueError("adata must be an AnnData object")
if not isinstance(cluster_key, str) or cluster_key not in adata.obs:
raise ValueError(f"cluster_key '{cluster_key}' not found in adata.obs")
if not isinstance(component_key, str) or component_key not in adata.obs:
raise ValueError(f"component_key '{component_key}' not found in adata.obs")

# Create a dataframe using the component, condition and cluster keys
components = pd.crosstab(
index=adata.obs[cluster_key], columns=adata.obs[library_key], values=adata.obs[component_key], aggfunc="count"
)

# Create a dataframe with the total number of cells in each cluster per image/library key
totals = pd.crosstab(index=adata.obs[cluster_key], columns=adata.obs[library_key])

# Calculate the fraction of cells in a connected component and map the conditions to the dataframe
result = np.divide(components, totals)
result = result.T.reset_index()
result = result.melt(id_vars=library_key, var_name=cluster_key, value_name="fraction_in_component")
if condition_key:
condition_mapping = adata.obs.groupby(library_key)[condition_key].first().to_dict()
result["condition"] = result[library_key].map(condition_mapping)

# Filter for specific cluster groups if provided
if cluster_id:
result = result[result[cluster_key].isin(cluster_id)]

# Filter for specific condition groups if provided
if condition_key and condition_groups:
result = result[result["condition"].isin(condition_groups)]

del components, totals
# Create the violinplot if show=True
if show:
fig, ax = plt.subplots(figsize=figsize)
plot_kwargs = {"data": result, "x": cluster_key, "y": "fraction_in_component", "ax": ax}

if condition_key:
plot_kwargs["hue"] = "condition"
sns.violinplot(**plot_kwargs, **kwargs)
ax.set_ylabel("Neighborhood Connectivity")
ax.set_xlabel("")
ax.tick_params(axis="x", rotation=90)

if title:
ax.set_title(title)

plt.tight_layout()

if save:
plt.savefig(save, dpi=dpi, bbox_inches="tight")

plt.show()
return ax

# Return the result dataframe if show=False
else:
return result

0 comments on commit b6cd81b

Please sign in to comment.