Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Calculation of feature importances in a supervised setting #677

Merged
merged 14 commits into from
Apr 7, 2024
Prev Previous commit
Next Next commit
Harmonize plotting function
  • Loading branch information
Lilly-May committed Apr 3, 2024
commit b268006e6fd977eb7c0d30aca67748117b6db06e
40 changes: 33 additions & 7 deletions ehrapy/plot/supervised/_feature_importances.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,56 @@
from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from anndata import AnnData
from matplotlib.axes import Axes


def feature_importances(adata: AnnData, key: str = "feature_importances", n_features: int = 10):
"""
Plot features with greates absolute importances as a barplot.
def feature_importances(
adata: AnnData,
key: str = "feature_importances",
n_features: int = 10,
ax: Axes | None = None,
show: bool = True,
save: str | None = None,
**kwargs,
) -> Axes | None:
"""Plot features with greates absolute importances as a barplot.

Args:
adata: :class:`~anndata.AnnData` object storing the data. A key in adata.var should contain the feature
importances, calculated beforehand.
key: The key in adata.var to use for feature importances. Defaults to 'feature_importances'.
n_features: The number of features to plot. Defaults to 10.
ax: A matplotlib axes object to plot on. If `None`, a new figure will be created. Defaults to `None`.
show: If `True`, show the figure. If `False`, return the axes object. Defaults to `True`.
save: Path to save the figure. If `None`, the figure will not be saved. Defaults to `None`.
**kwargs: Additional arguments passed to `seaborn.barplot`.

Returns:
None
If `show == False` a `matplotlib.axes.Axes` object, else `None`.
"""
if key not in adata.var.keys():
raise ValueError(f"Key {key} not found in adata.var.")
raise ValueError(
f"Key {key} not found in adata.var. Make sure to calculate feature importances first with ep.tl.feature_importances."
)

df = pd.DataFrame({"importance": adata.var[key]}, index=adata.var_names)
df["absolute_importance"] = df["importance"].abs()
df = df.sort_values("absolute_importance", ascending=False)
sns.barplot(x=df["importance"][:n_features], y=df.index[:n_features], orient="h")

if ax is None:
fig, ax = plt.subplots()
sns.barplot(x=df["importance"][:n_features], y=df.index[:n_features], orient="h", ax=ax, **kwargs)
plt.ylabel("Feature")
plt.xlabel("Importance")
plt.tight_layout()
plt.show()

if save:
plt.savefig(save, bbox_inches="tight")
if show:
plt.show()
return None
else:
return ax
6 changes: 1 addition & 5 deletions ehrapy/tools/supervised/_feature_importances.py
Original file line number Diff line number Diff line change
@@ -24,8 +24,7 @@ def feature_importances(
percent_output: bool = False,
**kwargs,
):
"""
Calculate feature importances for predicting a specified feature in adata.var using a given model.
"""Calculate feature importances for predicting a specified feature in adata.var using a given model.

Args:
adata: :class:`~anndata.AnnData` object storing the data.
@@ -47,9 +46,6 @@ def feature_importances(
percent_output: Set to True to output the feature importances as percentages. Note that information about positive or negative
coefficients for regression models will be lost. Defaults to False.
**kwargs: Additional keyword arguments to pass to the model. See the documentation of the respective model in scikit-learn for details.

Returns:
None
"""
if predicted_feature not in adata.var_names:
raise ValueError(f"Feature {predicted_feature} not found in adata.var.")
Loading