diff --git a/src/cellcharter/pl/_shape.py b/src/cellcharter/pl/_shape.py index bb12a5e..73461b9 100644 --- a/src/cellcharter/pl/_shape.py +++ b/src/cellcharter/pl/_shape.py @@ -15,6 +15,7 @@ from anndata import AnnData from scipy.stats import ttest_ind from squidpy._docs import d +from typing import Union from ._utils import adjust_box_widths @@ -375,3 +376,127 @@ def shape_metrics( if title is not None: plt.title(title, fontdict={"fontsize": fontsize}) plt.show() + + +@d.dedent +def plot_neighborhood_connectivity( + adata, + cluster_key: str, + component_key: str, + cluster_id: list = None, + condition_key: str = None, + condition_groups: list = None, + library_key: str = "library_id", + figsize: tuple = (12, 6), + title: str = None, + save: Union[str, Path] = None, + dpi: int = 600, + show: bool = True, + **kwargs, +) -> Union[plt.Axes, pd.DataFrame]: + """ + Calculates and plots the fraction of cells from neighborhoods that are associated to + a connected component and thereby gives an idea about neighborhood connectivity + + Parameters + -------------------------------------------------------------------------------------- + adata : AnnData + AnnData object + cluster_key: str + Key in adata.obs that contains the cellcharter cluster information + component_key: str + Key in adata.obs that contains the connected component information after running + cc.gr.connected_components + condition_groups: list + List of condition groups in condition_key to plot + condition_key: str + Key in adata.obs that contains the condition information if users want to compared + conditions + library_key: str + Key in adata.obs that contains the sample or image ID information + cluster_id: list + List of cluster IDs in cluster_key to plot + figsize: tuple + Figure size + title: str + Title of the plot + show: bool + If True, the plot will be shown and axes are always returned, + if False, the function will return the result dataframe + + Additional keyword arguments (**kwargs) will be passed to sns.violinplot() + + Returns + -------------------------------------------------------------------------------------- + Union[None, pd.DataFrame] + If show=True, returns ax and displays the plot + If show=False, returns the computed DataFrame + + Examples + -------------------------------------------------------------------------------------- + cc.pl_neighborhood_connectivity(adata, + cluster_key='cellcharter_CN', + component_key='component', + condition_key='disease', + library_key='library_id' + ) + """ + # Checking mandatory parameters + if not isinstance(adata, ad.AnnData): + raise ValueError("adata must be an AnnData object") + if not isinstance(cluster_key, str) or cluster_key not in adata.obs: + raise ValueError(f"cluster_key '{cluster_key}' not found in adata.obs") + if not isinstance(component_key, str) or component_key not in adata.obs: + raise ValueError(f"component_key '{component_key}' not found in adata.obs") + + # Create a dataframe using the component, condition and cluster keys + components = pd.crosstab( + index=adata.obs[cluster_key], columns=adata.obs[library_key], values=adata.obs[component_key], aggfunc="count" + ) + + # Create a dataframe with the total number of cells in each cluster per image/library key + totals = pd.crosstab(index=adata.obs[cluster_key], columns=adata.obs[library_key]) + + # Calculate the fraction of cells in a connected component and map the conditions to the dataframe + result = np.divide(components, totals) + result = result.T.reset_index() + result = result.melt(id_vars=library_key, var_name=cluster_key, value_name="fraction_in_component") + if condition_key: + condition_mapping = adata.obs.groupby(library_key)[condition_key].first().to_dict() + result["condition"] = result[library_key].map(condition_mapping) + + # Filter for specific cluster groups if provided + if cluster_id: + result = result[result[cluster_key].isin(cluster_id)] + + # Filter for specific condition groups if provided + if condition_key and condition_groups: + result = result[result["condition"].isin(condition_groups)] + + del components, totals + # Create the violinplot if show=True + if show: + fig, ax = plt.subplots(figsize=figsize) + plot_kwargs = {"data": result, "x": cluster_key, "y": "fraction_in_component", "ax": ax} + + if condition_key: + plot_kwargs["hue"] = "condition" + sns.violinplot(**plot_kwargs, **kwargs) + ax.set_ylabel("Neighborhood Connectivity") + ax.set_xlabel("") + ax.tick_params(axis="x", rotation=90) + + if title: + ax.set_title(title) + + plt.tight_layout() + + if save: + plt.savefig(save, dpi=dpi, bbox_inches="tight") + + plt.show() + return ax + + # Return the result dataframe if show=False + else: + return result