Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Calculation of feature importances in a supervised setting #677

Merged
merged 14 commits into from
Apr 7, 2024
Prev Previous commit
Next Next commit
Refactor API
  • Loading branch information
Lilly-May committed Apr 3, 2024
commit 1e1acf2cb1ffd019f47dc8cab0da66f21740f064
6 changes: 4 additions & 2 deletions docs/usage/usage.md
Original file line number Diff line number Diff line change
@@ -196,7 +196,7 @@ In contrast to a preprocessing function, a tool usually adds an easily interpret
tools.paga
```

### Group comparison
### Feature Ranking

```{eval-rst}
.. autosummary::
@@ -205,6 +205,7 @@ In contrast to a preprocessing function, a tool usually adds an easily interpret

tools.rank_features_groups
tools.filter_rank_features_groups
tools.rank_features_supervised
```

### Dataset integration
@@ -358,7 +359,7 @@ Visualize clusters using one of the embedding methods passing color='leiden'.
plot.paga_compare
```

### Group comparison
### Feature Ranking

```{eval-rst}
.. autosummary::
@@ -372,6 +373,7 @@ Visualize clusters using one of the embedding methods passing color='leiden'.
plot.rank_features_groups_dotplot
plot.rank_features_groups_matrixplot
plot.rank_features_groups_tracksplot
plot.rank_features_supervised
```

### Survival Analysis
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@
from matplotlib.axes import Axes


def feature_importances(
def rank_features_supervised(
adata: AnnData,
key: str = "feature_importances",
n_features: int = 10,
2 changes: 1 addition & 1 deletion ehrapy/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -14,8 +14,8 @@
from ehrapy.tools._scanpy_tl_api import * # noqa: F403
from ehrapy.tools.causal._dowhy import causal_inference
from ehrapy.tools.cohort_tracking._cohort_tracker import CohortTracker
from ehrapy.tools.feature_ranking._feature_importances import rank_features_supervised
from ehrapy.tools.feature_ranking._rank_features_groups import filter_rank_features_groups, rank_features_groups
from ehrapy.tools.supervised._feature_importances import feature_importances

try: # pragma: no cover
from ehrapy.tools.nlp._medcat import (
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@
from ehrapy.anndata import anndata_to_df


def feature_importances(
def rank_features_supervised(
adata: AnnData,
predicted_feature: str,
prediction_type: Literal["continuous", "categorical"],
@@ -26,7 +26,7 @@ def feature_importances(
percent_output: bool = False,
**kwargs,
):
"""Calculate feature importances for predicting a specified feature in adata.var using a given model.
"""Calculate feature importances for predicting a specified feature in adata.var.

Args:
adata: :class:`~anndata.AnnData` object storing the data.
Empty file.
10 changes: 5 additions & 5 deletions tests/tools/supervised/test_feature_importances.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@
import pytest
from anndata import AnnData

from ehrapy.tools import feature_importances
from ehrapy.tools import rank_features_supervised


def test_continuous_prediction():
@@ -15,7 +15,7 @@ def test_continuous_prediction():
adata.var_names = ["target", "feature1", "feature2"]

for model in ["regression", "svm", "rf"]:
feature_importances(adata, "target", "continuous", model, "all")
rank_features_supervised(adata, "target", "continuous", model, "all")
assert "feature_importances" in adata.var
assert adata.var["feature_importances"]["feature1"] > 0
assert adata.var["feature_importances"]["feature2"] == 0
@@ -30,7 +30,7 @@ def test_categorical_prediction():
adata.var_names = ["target", "feature1", "feature2"]

for model in ["regression", "svm", "rf"]:
feature_importances(adata, "target", "categorical", model, "all")
rank_features_supervised(adata, "target", "categorical", model, "all")
assert "feature_importances" in adata.var
assert adata.var["feature_importances"]["feature1"] > 0
assert adata.var["feature_importances"]["feature2"] == 0
@@ -44,13 +44,13 @@ def test_multiclass_prediction():
adata = AnnData(X)
adata.var_names = ["target", "feature1", "feature2"]

feature_importances(adata, "target", "categorical", "rf", "all")
rank_features_supervised(adata, "target", "categorical", "rf", "all")
assert "feature_importances" in adata.var
assert adata.var["feature_importances"]["feature1"] > 0
assert adata.var["feature_importances"]["feature2"] == 0
assert pd.isna(adata.var["feature_importances"]["target"])

for invalid_model in ["regression", "svm"]:
with pytest.raises(ValueError) as excinfo:
feature_importances(adata, "target", "categorical", invalid_model, "all")
rank_features_supervised(adata, "target", "categorical", invalid_model, "all")
assert str(excinfo.value).startswith("Feature target has more than two categories.")
Loading