From 03c31618994cfb2fc025d9cc0d53490c24f2f460 Mon Sep 17 00:00:00 2001 From: Lilly Date: Wed, 27 Mar 2024 15:32:22 +0100 Subject: [PATCH 01/13] Added feature importances function and tests --- ehrapy/tools/__init__.py | 1 + ehrapy/tools/supervised/__init__.py | 0 .../tools/supervised/_feature_importances.py | 139 ++++++++++++++++++ .../supervised/test_feature_importances.py | 56 +++++++ 4 files changed, 196 insertions(+) create mode 100644 ehrapy/tools/supervised/__init__.py create mode 100644 ehrapy/tools/supervised/_feature_importances.py create mode 100644 tests/tools/supervised/test_feature_importances.py diff --git a/ehrapy/tools/__init__.py b/ehrapy/tools/__init__.py index 18eb61be..511c97d5 100644 --- a/ehrapy/tools/__init__.py +++ b/ehrapy/tools/__init__.py @@ -15,6 +15,7 @@ from ehrapy.tools.causal._dowhy import causal_inference from ehrapy.tools.cohort_tracking._cohort_tracker import CohortTracker from ehrapy.tools.feature_ranking._rank_features_groups import filter_rank_features_groups, rank_features_groups +from ehrapy.tools.supervised._feature_importances import feature_importances try: # pragma: no cover from ehrapy.tools.nlp._medcat import ( diff --git a/ehrapy/tools/supervised/__init__.py b/ehrapy/tools/supervised/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ehrapy/tools/supervised/_feature_importances.py b/ehrapy/tools/supervised/_feature_importances.py new file mode 100644 index 00000000..cbd9e1fb --- /dev/null +++ b/ehrapy/tools/supervised/_feature_importances.py @@ -0,0 +1,139 @@ +from typing import Literal + +import pandas as pd +from anndata import AnnData +from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor +from sklearn.linear_model import LinearRegression, LogisticRegression +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import MinMaxScaler, StandardScaler +from sklearn.svm import SVC, SVR + +from ehrapy import logging as logg + + +def feature_importances( + adata: AnnData, + predicted_feature: str, + prediction_type: Literal["continuous", "categorical"], + model: Literal["regression", "svm", "rf"] = "regression", + input_features: list[str] | Literal["all"] = "all", + layer: str | None = None, + test_split_size: float = 0.2, + key_added: str = "feature_importances", + feature_scaling: Literal["standard", "minmax"] | None = "standard", + **kwargs, +): + """ + Calculate feature importances for a given model and predicted feature. + + Args: + adata: :class:`~anndata.AnnData` object storing the data. + predicted_feature: The feature to predict by the model. + prediction_type: Whether the predicted feature is continuous or categorical. If the data type of the predicted feature + is not correct, conversion will be attempted. + model: The model to use for prediction. Choose between 'regression', 'svm', or 'rf'. Note that multi-class classification + is only possible with 'rf'. Defaults to 'regression'. + input_features: The features in adata.var to use for prediction. Should be a list of feature names. If 'all', all features + in adata.var will be used. Note that non-numeric features will be dropped, so make sure to encode them properly before. + Defaults to 'all'. + layer: The layer in adata.layers to use for prediction. If None, adata.X will be used. Defaults to None. + test_split_size: The size of the test set to used to evaluate the model. Defaults to 0.2. + key_added: The key in adata.var to store the feature importances. Defaults to 'feature_importances'. + feature_transformation: The type of feature transformation to use. Choose between 'standard', 'minmax', 'normalize', or None. + 'standard' uses sklearn's StandardScaler, 'minmax' uses MinMaxScaler, 'normalize' uses Normalizer. Will be fit and transformed + for each feature individually. Defaults to 'standard'. + **kwargs: Additional keyword arguments to pass to the model. See the documentation of the respective model in scikit-learn for details. + + Returns: + None + """ + if predicted_feature not in adata.var_names: + raise ValueError(f"Feature {predicted_feature} not found in adata.var.") + + if input_features != "all": + for feature in input_features: + if feature not in adata.var_names: + raise ValueError(f"Feature {feature} not found in adata.var.") + + if model not in ["regression", "svm", "rf"]: + raise ValueError(f"Model {model} not recognized. Please choose either 'regression', 'svm', or 'rf'.") + + if prediction_type not in ["continuous", "categorical"]: + raise ValueError( + f"Prediction type {prediction_type} not recognized. Please choose either 'continuous' or 'categorical'." + ) + + if layer is not None: + data = adata.layers[layer].to_df() + else: + data = adata.to_df() + + if prediction_type == "continuous": + if pd.api.types.is_categorical_dtype(data[predicted_feature].dtype): + try: + data[predicted_feature] = data[predicted_feature].astype(float) + except ValueError as e: + raise ValueError( + f"Feature {predicted_feature} is not continuous and conversion to float failed. Either change the prediction " + f"type to 'categorical' or change the feature data type to a continuous type." + ) from e + + if model == "regression": + predictor = LinearRegression(**kwargs) + elif model == "svm": + predictor = SVR(kernel="linear", **kwargs) + elif model == "rf": + predictor = RandomForestRegressor(**kwargs) + + elif prediction_type == "categorical": + if not pd.api.types.is_categorical_dtype(data[predicted_feature].dtype): + try: + data[predicted_feature] = data[predicted_feature].astype("category") + except ValueError as e: + raise ValueError( + f"Feature {predicted_feature} is not categorical and conversion to category failed. Either change the prediction " + f"type to 'continuous' or change the feature data type to a categorical type." + ) from e + + if data[predicted_feature].nunique() > 2 and model in ["regression", "svm"]: + raise ValueError( + f"Feature {predicted_feature} has more than two categories. Please choose random forest (rf) as model for multi-class classification." + ) + + if model == "regression": + predictor = LogisticRegression(**kwargs) + elif model == "svm": + predictor = SVC(kernel="linear", **kwargs) + elif model == "rf": + predictor = RandomForestClassifier(**kwargs) + + if input_features == "all": + input_features = list(adata.var_names) + input_features.remove(predicted_feature) + + input_data = data[input_features] + labels = data[predicted_feature] + + for feature in input_data.columns: + try: + input_data[feature] = input_data[feature].astype(float) + + if feature_scaling is not None: + scaler = StandardScaler() if feature_scaling == "standard" else MinMaxScaler() + input_data[feature] = scaler.fit_transform(input_data[[feature]]) + except ValueError: + logg.warning(f"Feature {feature} could not be converted to float. Feature will be dropped.") + input_data.drop(feature, axis=1, inplace=True) + + x_train, x_test, y_train, y_test = train_test_split(input_data, labels, test_size=test_split_size) + + predictor.fit(x_train, y_train) + + if model == "regression" or model == "svm": + feature_importances = pd.Series(predictor.coef_.squeeze(), index=input_data.columns) + else: + feature_importances = pd.Series(predictor.feature_importances_.squeeze(), index=input_data.columns) + + # Reorder feature importances to match adata.var order and save in adata.var + feature_importances = feature_importances.reindex(adata.var_names) + adata.var[key_added] = feature_importances diff --git a/tests/tools/supervised/test_feature_importances.py b/tests/tools/supervised/test_feature_importances.py new file mode 100644 index 00000000..b2560df6 --- /dev/null +++ b/tests/tools/supervised/test_feature_importances.py @@ -0,0 +1,56 @@ +import unittest + +import numpy as np +import pandas as pd +import pytest +from anndata import AnnData + +from ehrapy.tools import feature_importances + + +def test_continuous_prediction(): + target = np.random.rand(1000) + X = np.stack((target, target * 2, [1] * 1000)).T + adata = AnnData(X) + adata.var_names = ["target", "feature1", "feature2"] + + for model in ["regression", "svm", "rf"]: + feature_importances(adata, "target", "continuous", model, "all") + assert "feature_importances" in adata.var + assert adata.var["feature_importances"]["feature1"] > 0 + assert adata.var["feature_importances"]["feature2"] == 0 + assert pd.isna(adata.var["feature_importances"]["target"]) + + +def test_categorical_prediction(): + target = np.random.randint(2, size=1000) + X = np.stack((target, target, [1] * 1000)).T + + adata = AnnData(X) + adata.var_names = ["target", "feature1", "feature2"] + + for model in ["regression", "svm", "rf"]: + feature_importances(adata, "target", "categorical", model, "all") + assert "feature_importances" in adata.var + assert adata.var["feature_importances"]["feature1"] > 0 + assert adata.var["feature_importances"]["feature2"] == 0 + assert pd.isna(adata.var["feature_importances"]["target"]) + + +def test_multiclass_prediction(): + target = np.random.randint(4, size=1000) + X = np.stack((target, target, [1] * 1000)).T + + adata = AnnData(X) + adata.var_names = ["target", "feature1", "feature2"] + + feature_importances(adata, "target", "categorical", "rf", "all") + assert "feature_importances" in adata.var + assert adata.var["feature_importances"]["feature1"] > 0 + assert adata.var["feature_importances"]["feature2"] == 0 + assert pd.isna(adata.var["feature_importances"]["target"]) + + for invalid_model in ["regression", "svm"]: + with pytest.raises(ValueError) as excinfo: + feature_importances(adata, "target", "categorical", invalid_model, "all") + assert str(excinfo.value).startswith("Feature target has more than two categories.") From a0c2551616361b9c4d7fe8e401da407b5879e779 Mon Sep 17 00:00:00 2001 From: Lilly Date: Thu, 28 Mar 2024 11:07:35 +0100 Subject: [PATCH 02/13] Added feature importances plotting function --- ehrapy/plot/supervised/__init__.py | 0 .../plot/supervised/_feature_importances.py | 30 +++++++++++++++++++ 2 files changed, 30 insertions(+) create mode 100644 ehrapy/plot/supervised/__init__.py create mode 100644 ehrapy/plot/supervised/_feature_importances.py diff --git a/ehrapy/plot/supervised/__init__.py b/ehrapy/plot/supervised/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ehrapy/plot/supervised/_feature_importances.py b/ehrapy/plot/supervised/_feature_importances.py new file mode 100644 index 00000000..4770ec18 --- /dev/null +++ b/ehrapy/plot/supervised/_feature_importances.py @@ -0,0 +1,30 @@ +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns +from anndata import AnnData + + +def feature_importances(adata: AnnData, key: str = "feature_importances", n_features: int = 10): + """ + Plot features with greates absolute importances as a barplot. + + Args: + adata: :class:`~anndata.AnnData` object storing the data. A key in adata.var should contain the feature + importances, calculated beforehand. + key: The key in adata.var to use for feature importances. Defaults to 'feature_importances'. + n_features: The number of features to plot. Defaults to 10. + + Returns: + None + """ + if key not in adata.var.keys(): + raise ValueError(f"Key {key} not found in adata.var.") + + df = pd.DataFrame({"importance": adata.var[key]}, index=adata.var_names) + df["absolute_importance"] = df["importance"].abs() + df = df.sort_values("absolute_importance", ascending=False) + sns.barplot(x=df["importance"][:n_features], y=df.index[:n_features], orient="h") + plt.ylabel("Feature") + plt.xlabel("Importance") + plt.tight_layout() + plt.show() From 8c8a5a935052625f880b9944a037395d628ccf8a Mon Sep 17 00:00:00 2001 From: Lilly Date: Thu, 28 Mar 2024 11:08:39 +0100 Subject: [PATCH 03/13] Add function in init --- ehrapy/plot/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ehrapy/plot/__init__.py b/ehrapy/plot/__init__.py index 102c2f26..462c7c6f 100644 --- a/ehrapy/plot/__init__.py +++ b/ehrapy/plot/__init__.py @@ -3,3 +3,4 @@ from ehrapy.plot._survival_analysis import kmf, ols from ehrapy.plot._util import * # noqa: F403 from ehrapy.plot.causal_inference._dowhy import causal_effect +from ehrapy.plot.supervised._feature_importances import feature_importances From bb726f61e644c9598f5341feeba2c1b77c7fb532 Mon Sep 17 00:00:00 2001 From: Lilly Date: Tue, 2 Apr 2024 14:40:57 +0200 Subject: [PATCH 04/13] Added evaluation on test set --- ehrapy/tools/supervised/_feature_importances.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ehrapy/tools/supervised/_feature_importances.py b/ehrapy/tools/supervised/_feature_importances.py index cbd9e1fb..6897a53a 100644 --- a/ehrapy/tools/supervised/_feature_importances.py +++ b/ehrapy/tools/supervised/_feature_importances.py @@ -129,6 +129,10 @@ def feature_importances( predictor.fit(x_train, y_train) + score = predictor.score(x_test, y_test) + evaluation_metric = "R2 score" if prediction_type == "continuous" else "accuracy" + logg.info(f"Training completed. The model achieved an {evaluation_metric} of {score:.2f} on the test set.") + if model == "regression" or model == "svm": feature_importances = pd.Series(predictor.coef_.squeeze(), index=input_data.columns) else: From 0305bf330b7a2429a6e315d0626aff51d0133321 Mon Sep 17 00:00:00 2001 From: Lilly Date: Tue, 2 Apr 2024 14:55:16 +0200 Subject: [PATCH 05/13] Raise error for non-numeric features --- ehrapy/tools/supervised/_feature_importances.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ehrapy/tools/supervised/_feature_importances.py b/ehrapy/tools/supervised/_feature_importances.py index 6897a53a..e10710e9 100644 --- a/ehrapy/tools/supervised/_feature_importances.py +++ b/ehrapy/tools/supervised/_feature_importances.py @@ -121,9 +121,11 @@ def feature_importances( if feature_scaling is not None: scaler = StandardScaler() if feature_scaling == "standard" else MinMaxScaler() input_data[feature] = scaler.fit_transform(input_data[[feature]]) - except ValueError: - logg.warning(f"Feature {feature} could not be converted to float. Feature will be dropped.") - input_data.drop(feature, axis=1, inplace=True) + except ValueError as e: + raise ValueError( + f"Feature {feature} is not numeric. Please encode non-numeric features before calculating " + f"feature importances or drop them from the input_features list." + ) from e x_train, x_test, y_train, y_test = train_test_split(input_data, labels, test_size=test_split_size) From 6cfd94f4b383da2531ff04d056efaa65944b50d7 Mon Sep 17 00:00:00 2001 From: Lilly Date: Tue, 2 Apr 2024 17:17:06 +0200 Subject: [PATCH 06/13] Added output as percentage option --- .../tools/supervised/_feature_importances.py | 30 +++++++++++++------ 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/ehrapy/tools/supervised/_feature_importances.py b/ehrapy/tools/supervised/_feature_importances.py index e10710e9..b57c16e5 100644 --- a/ehrapy/tools/supervised/_feature_importances.py +++ b/ehrapy/tools/supervised/_feature_importances.py @@ -21,27 +21,31 @@ def feature_importances( test_split_size: float = 0.2, key_added: str = "feature_importances", feature_scaling: Literal["standard", "minmax"] | None = "standard", + percent_output: bool = False, **kwargs, ): """ - Calculate feature importances for a given model and predicted feature. + Calculate feature importances for predicting a specified feature in adata.var using a given model. Args: adata: :class:`~anndata.AnnData` object storing the data. - predicted_feature: The feature to predict by the model. + predicted_feature: The feature to predict by the model. Must be present in adata.var_names. prediction_type: Whether the predicted feature is continuous or categorical. If the data type of the predicted feature is not correct, conversion will be attempted. model: The model to use for prediction. Choose between 'regression', 'svm', or 'rf'. Note that multi-class classification is only possible with 'rf'. Defaults to 'regression'. input_features: The features in adata.var to use for prediction. Should be a list of feature names. If 'all', all features - in adata.var will be used. Note that non-numeric features will be dropped, so make sure to encode them properly before. - Defaults to 'all'. + in adata.var will be used. Note that non-numeric input features will cause an error, so make sure to encode them properly + before. Defaults to 'all'. layer: The layer in adata.layers to use for prediction. If None, adata.X will be used. Defaults to None. - test_split_size: The size of the test set to used to evaluate the model. Defaults to 0.2. + test_split_size: The split of data used for testing the model. Should be a float between 0 and 1, representing the proportion. + Defaults to 0.2. key_added: The key in adata.var to store the feature importances. Defaults to 'feature_importances'. - feature_transformation: The type of feature transformation to use. Choose between 'standard', 'minmax', 'normalize', or None. - 'standard' uses sklearn's StandardScaler, 'minmax' uses MinMaxScaler, 'normalize' uses Normalizer. Will be fit and transformed - for each feature individually. Defaults to 'standard'. + feature_scaling: The type of feature scaling to use for the input. Choose between 'standard', 'minmax', or None. + 'standard' uses sklearn's StandardScaler, 'minmax' uses MinMaxScaler. Scaler will be fit and transformed + for each feature individually. Defaults to 'standard'. + percent_output: Set to True to output the feature importances as percentages. Note that information about positive or negative + coefficients for regression models will be lost. Defaults to False. **kwargs: Additional keyword arguments to pass to the model. See the documentation of the respective model in scikit-learn for details. Returns: @@ -63,6 +67,11 @@ def feature_importances( f"Prediction type {prediction_type} not recognized. Please choose either 'continuous' or 'categorical'." ) + if feature_scaling not in ["standard", "minmax", None]: + raise ValueError( + f"Feature scaling type {feature_scaling} not recognized. Please choose either 'standard', 'minmax', or None." + ) + if layer is not None: data = adata.layers[layer].to_df() else: @@ -140,6 +149,9 @@ def feature_importances( else: feature_importances = pd.Series(predictor.feature_importances_.squeeze(), index=input_data.columns) - # Reorder feature importances to match adata.var order and save in adata.var + if percent_output: + feature_importances = feature_importances.abs() / feature_importances.abs().sum() * 100 + + # Reorder feature importances to match adata.var order and save importances in adata.var feature_importances = feature_importances.reindex(adata.var_names) adata.var[key_added] = feature_importances From b268006e6fd977eb7c0d30aca67748117b6db06e Mon Sep 17 00:00:00 2001 From: Lilly Date: Wed, 3 Apr 2024 14:28:55 +0200 Subject: [PATCH 07/13] Harmonize plotting function --- .../plot/supervised/_feature_importances.py | 40 +++++++++++++++---- .../tools/supervised/_feature_importances.py | 6 +-- 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/ehrapy/plot/supervised/_feature_importances.py b/ehrapy/plot/supervised/_feature_importances.py index 4770ec18..2afee742 100644 --- a/ehrapy/plot/supervised/_feature_importances.py +++ b/ehrapy/plot/supervised/_feature_importances.py @@ -1,30 +1,56 @@ +from typing import TYPE_CHECKING + import matplotlib.pyplot as plt import pandas as pd import seaborn as sns from anndata import AnnData +from matplotlib.axes import Axes -def feature_importances(adata: AnnData, key: str = "feature_importances", n_features: int = 10): - """ - Plot features with greates absolute importances as a barplot. +def feature_importances( + adata: AnnData, + key: str = "feature_importances", + n_features: int = 10, + ax: Axes | None = None, + show: bool = True, + save: str | None = None, + **kwargs, +) -> Axes | None: + """Plot features with greates absolute importances as a barplot. Args: adata: :class:`~anndata.AnnData` object storing the data. A key in adata.var should contain the feature importances, calculated beforehand. key: The key in adata.var to use for feature importances. Defaults to 'feature_importances'. n_features: The number of features to plot. Defaults to 10. + ax: A matplotlib axes object to plot on. If `None`, a new figure will be created. Defaults to `None`. + show: If `True`, show the figure. If `False`, return the axes object. Defaults to `True`. + save: Path to save the figure. If `None`, the figure will not be saved. Defaults to `None`. + **kwargs: Additional arguments passed to `seaborn.barplot`. Returns: - None + If `show == False` a `matplotlib.axes.Axes` object, else `None`. """ if key not in adata.var.keys(): - raise ValueError(f"Key {key} not found in adata.var.") + raise ValueError( + f"Key {key} not found in adata.var. Make sure to calculate feature importances first with ep.tl.feature_importances." + ) df = pd.DataFrame({"importance": adata.var[key]}, index=adata.var_names) df["absolute_importance"] = df["importance"].abs() df = df.sort_values("absolute_importance", ascending=False) - sns.barplot(x=df["importance"][:n_features], y=df.index[:n_features], orient="h") + + if ax is None: + fig, ax = plt.subplots() + sns.barplot(x=df["importance"][:n_features], y=df.index[:n_features], orient="h", ax=ax, **kwargs) plt.ylabel("Feature") plt.xlabel("Importance") plt.tight_layout() - plt.show() + + if save: + plt.savefig(save, bbox_inches="tight") + if show: + plt.show() + return None + else: + return ax diff --git a/ehrapy/tools/supervised/_feature_importances.py b/ehrapy/tools/supervised/_feature_importances.py index b57c16e5..e483c981 100644 --- a/ehrapy/tools/supervised/_feature_importances.py +++ b/ehrapy/tools/supervised/_feature_importances.py @@ -24,8 +24,7 @@ def feature_importances( percent_output: bool = False, **kwargs, ): - """ - Calculate feature importances for predicting a specified feature in adata.var using a given model. + """Calculate feature importances for predicting a specified feature in adata.var using a given model. Args: adata: :class:`~anndata.AnnData` object storing the data. @@ -47,9 +46,6 @@ def feature_importances( percent_output: Set to True to output the feature importances as percentages. Note that information about positive or negative coefficients for regression models will be lost. Defaults to False. **kwargs: Additional keyword arguments to pass to the model. See the documentation of the respective model in scikit-learn for details. - - Returns: - None """ if predicted_feature not in adata.var_names: raise ValueError(f"Feature {predicted_feature} not found in adata.var.") From 254873c0046fd22c4c60a1f92dec9752df31b6e3 Mon Sep 17 00:00:00 2001 From: Lilly Date: Wed, 3 Apr 2024 14:52:24 +0200 Subject: [PATCH 08/13] PR Reviews --- ehrapy/tools/supervised/_feature_importances.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/ehrapy/tools/supervised/_feature_importances.py b/ehrapy/tools/supervised/_feature_importances.py index e483c981..88e7cd50 100644 --- a/ehrapy/tools/supervised/_feature_importances.py +++ b/ehrapy/tools/supervised/_feature_importances.py @@ -1,3 +1,4 @@ +from collections.abc import Iterable from typing import Literal import pandas as pd @@ -9,6 +10,7 @@ from sklearn.svm import SVC, SVR from ehrapy import logging as logg +from ehrapy.anndata import anndata_to_df def feature_importances( @@ -16,7 +18,7 @@ def feature_importances( predicted_feature: str, prediction_type: Literal["continuous", "categorical"], model: Literal["regression", "svm", "rf"] = "regression", - input_features: list[str] | Literal["all"] = "all", + input_features: Iterable[str] | Literal["all"] = "all", layer: str | None = None, test_split_size: float = 0.2, key_added: str = "feature_importances", @@ -68,10 +70,7 @@ def feature_importances( f"Feature scaling type {feature_scaling} not recognized. Please choose either 'standard', 'minmax', or None." ) - if layer is not None: - data = adata.layers[layer].to_df() - else: - data = adata.to_df() + data = anndata_to_df(adata, layer=layer) if prediction_type == "continuous": if pd.api.types.is_categorical_dtype(data[predicted_feature].dtype): @@ -138,7 +137,9 @@ def feature_importances( score = predictor.score(x_test, y_test) evaluation_metric = "R2 score" if prediction_type == "continuous" else "accuracy" - logg.info(f"Training completed. The model achieved an {evaluation_metric} of {score:.2f} on the test set.") + logg.info( + f"Training completed. The model achieved an {evaluation_metric} of {score:.2f} on the test set, consisting of {len(y_test)} samples." + ) if model == "regression" or model == "svm": feature_importances = pd.Series(predictor.coef_.squeeze(), index=input_data.columns) From 5475269f6d80f08323cc452bc1e9ffb4d86e7e87 Mon Sep 17 00:00:00 2001 From: Lilly Date: Wed, 3 Apr 2024 15:06:49 +0200 Subject: [PATCH 09/13] Return Figure axes when show=False --- ehrapy/plot/supervised/_feature_importances.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ehrapy/plot/supervised/_feature_importances.py b/ehrapy/plot/supervised/_feature_importances.py index 2afee742..733b3991 100644 --- a/ehrapy/plot/supervised/_feature_importances.py +++ b/ehrapy/plot/supervised/_feature_importances.py @@ -42,7 +42,7 @@ def feature_importances( if ax is None: fig, ax = plt.subplots() - sns.barplot(x=df["importance"][:n_features], y=df.index[:n_features], orient="h", ax=ax, **kwargs) + ax = sns.barplot(x=df["importance"][:n_features], y=df.index[:n_features], orient="h", ax=ax, **kwargs) plt.ylabel("Feature") plt.xlabel("Importance") plt.tight_layout() From 1e1acf2cb1ffd019f47dc8cab0da66f21740f064 Mon Sep 17 00:00:00 2001 From: Lilly Date: Wed, 3 Apr 2024 16:31:13 +0200 Subject: [PATCH 10/13] Refactor API --- docs/usage/usage.md | 6 ++++-- .../plot/{supervised => feature_ranking}/__init__.py | 0 .../_feature_importances.py | 2 +- ehrapy/tools/__init__.py | 2 +- .../_feature_importances.py | 4 ++-- ehrapy/tools/supervised/__init__.py | 0 tests/tools/supervised/test_feature_importances.py | 10 +++++----- 7 files changed, 13 insertions(+), 11 deletions(-) rename ehrapy/plot/{supervised => feature_ranking}/__init__.py (100%) rename ehrapy/plot/{supervised => feature_ranking}/_feature_importances.py (98%) rename ehrapy/tools/{supervised => feature_ranking}/_feature_importances.py (99%) delete mode 100644 ehrapy/tools/supervised/__init__.py diff --git a/docs/usage/usage.md b/docs/usage/usage.md index 1a4241da..3d689849 100644 --- a/docs/usage/usage.md +++ b/docs/usage/usage.md @@ -196,7 +196,7 @@ In contrast to a preprocessing function, a tool usually adds an easily interpret tools.paga ``` -### Group comparison +### Feature Ranking ```{eval-rst} .. autosummary:: @@ -205,6 +205,7 @@ In contrast to a preprocessing function, a tool usually adds an easily interpret tools.rank_features_groups tools.filter_rank_features_groups + tools.rank_features_supervised ``` ### Dataset integration @@ -358,7 +359,7 @@ Visualize clusters using one of the embedding methods passing color='leiden'. plot.paga_compare ``` -### Group comparison +### Feature Ranking ```{eval-rst} .. autosummary:: @@ -372,6 +373,7 @@ Visualize clusters using one of the embedding methods passing color='leiden'. plot.rank_features_groups_dotplot plot.rank_features_groups_matrixplot plot.rank_features_groups_tracksplot + plot.rank_features_supervised ``` ### Survival Analysis diff --git a/ehrapy/plot/supervised/__init__.py b/ehrapy/plot/feature_ranking/__init__.py similarity index 100% rename from ehrapy/plot/supervised/__init__.py rename to ehrapy/plot/feature_ranking/__init__.py diff --git a/ehrapy/plot/supervised/_feature_importances.py b/ehrapy/plot/feature_ranking/_feature_importances.py similarity index 98% rename from ehrapy/plot/supervised/_feature_importances.py rename to ehrapy/plot/feature_ranking/_feature_importances.py index 733b3991..6e467a05 100644 --- a/ehrapy/plot/supervised/_feature_importances.py +++ b/ehrapy/plot/feature_ranking/_feature_importances.py @@ -7,7 +7,7 @@ from matplotlib.axes import Axes -def feature_importances( +def rank_features_supervised( adata: AnnData, key: str = "feature_importances", n_features: int = 10, diff --git a/ehrapy/tools/__init__.py b/ehrapy/tools/__init__.py index 511c97d5..d57b936d 100644 --- a/ehrapy/tools/__init__.py +++ b/ehrapy/tools/__init__.py @@ -14,8 +14,8 @@ from ehrapy.tools._scanpy_tl_api import * # noqa: F403 from ehrapy.tools.causal._dowhy import causal_inference from ehrapy.tools.cohort_tracking._cohort_tracker import CohortTracker +from ehrapy.tools.feature_ranking._feature_importances import rank_features_supervised from ehrapy.tools.feature_ranking._rank_features_groups import filter_rank_features_groups, rank_features_groups -from ehrapy.tools.supervised._feature_importances import feature_importances try: # pragma: no cover from ehrapy.tools.nlp._medcat import ( diff --git a/ehrapy/tools/supervised/_feature_importances.py b/ehrapy/tools/feature_ranking/_feature_importances.py similarity index 99% rename from ehrapy/tools/supervised/_feature_importances.py rename to ehrapy/tools/feature_ranking/_feature_importances.py index 88e7cd50..8d4799f3 100644 --- a/ehrapy/tools/supervised/_feature_importances.py +++ b/ehrapy/tools/feature_ranking/_feature_importances.py @@ -13,7 +13,7 @@ from ehrapy.anndata import anndata_to_df -def feature_importances( +def rank_features_supervised( adata: AnnData, predicted_feature: str, prediction_type: Literal["continuous", "categorical"], @@ -26,7 +26,7 @@ def feature_importances( percent_output: bool = False, **kwargs, ): - """Calculate feature importances for predicting a specified feature in adata.var using a given model. + """Calculate feature importances for predicting a specified feature in adata.var. Args: adata: :class:`~anndata.AnnData` object storing the data. diff --git a/ehrapy/tools/supervised/__init__.py b/ehrapy/tools/supervised/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/tools/supervised/test_feature_importances.py b/tests/tools/supervised/test_feature_importances.py index b2560df6..1acc7498 100644 --- a/tests/tools/supervised/test_feature_importances.py +++ b/tests/tools/supervised/test_feature_importances.py @@ -5,7 +5,7 @@ import pytest from anndata import AnnData -from ehrapy.tools import feature_importances +from ehrapy.tools import rank_features_supervised def test_continuous_prediction(): @@ -15,7 +15,7 @@ def test_continuous_prediction(): adata.var_names = ["target", "feature1", "feature2"] for model in ["regression", "svm", "rf"]: - feature_importances(adata, "target", "continuous", model, "all") + rank_features_supervised(adata, "target", "continuous", model, "all") assert "feature_importances" in adata.var assert adata.var["feature_importances"]["feature1"] > 0 assert adata.var["feature_importances"]["feature2"] == 0 @@ -30,7 +30,7 @@ def test_categorical_prediction(): adata.var_names = ["target", "feature1", "feature2"] for model in ["regression", "svm", "rf"]: - feature_importances(adata, "target", "categorical", model, "all") + rank_features_supervised(adata, "target", "categorical", model, "all") assert "feature_importances" in adata.var assert adata.var["feature_importances"]["feature1"] > 0 assert adata.var["feature_importances"]["feature2"] == 0 @@ -44,7 +44,7 @@ def test_multiclass_prediction(): adata = AnnData(X) adata.var_names = ["target", "feature1", "feature2"] - feature_importances(adata, "target", "categorical", "rf", "all") + rank_features_supervised(adata, "target", "categorical", "rf", "all") assert "feature_importances" in adata.var assert adata.var["feature_importances"]["feature1"] > 0 assert adata.var["feature_importances"]["feature2"] == 0 @@ -52,5 +52,5 @@ def test_multiclass_prediction(): for invalid_model in ["regression", "svm"]: with pytest.raises(ValueError) as excinfo: - feature_importances(adata, "target", "categorical", invalid_model, "all") + rank_features_supervised(adata, "target", "categorical", invalid_model, "all") assert str(excinfo.value).startswith("Feature target has more than two categories.") From 201299fcad8fbe2b28b6c25dde416a422df589d1 Mon Sep 17 00:00:00 2001 From: Lilly Date: Wed, 3 Apr 2024 16:46:13 +0200 Subject: [PATCH 11/13] Added docs examples --- .../docstring_previews/feature_importances.png | Bin 0 -> 20235 bytes ehrapy/plot/__init__.py | 2 +- .../feature_ranking/_feature_importances.py | 12 ++++++++++++ .../feature_ranking/_feature_importances.py | 9 +++++++++ 4 files changed, 22 insertions(+), 1 deletion(-) create mode 100644 docs/_static/docstring_previews/feature_importances.png diff --git a/docs/_static/docstring_previews/feature_importances.png b/docs/_static/docstring_previews/feature_importances.png new file mode 100644 index 0000000000000000000000000000000000000000..81cc70cdc0c68afcb7d04793d247731e3f760b2d GIT binary patch literal 20235 zcmb7s1yq*V+V*RY7>o)ADj1;BilEXe76Q_pf;0*uA?@f`sDQ6D0wUclB_I|e-63Jn zAV@d-*GB&{=R0f8Ip616vqli!r}n<@tM2E%%!RY-Xc%ZH6v{gBbEo7el%+Nl%96U@ zR^mI;&)&M?fBaUbRjlOA4XkW0TIy4zE?QZbnOm6|>+Z7Ax4dR-e&sONF)kjCT}D<` z7S{y0xiA0Y2e`~F4Y_yC%)iBlthP9(dW}M%y-5CB5+xdGOrc!3E`I8S!u6-$TkNhs zoS0u2s%PG~`sAsbGJHEN!D{<@t0Yj*tg+ zb>AvNqn^J|Iq7kX}%b0Rz0cS;-c)xQ?Q;wG1cFBs}5Iv z{lNd}LU%Rz$(}iLM`ZHzvfLl9*TiV&=dE13^;w(q3}^k!+0&<4ckMd<#~**V-oGy^ zBO?>kn`u85^Y!c3;WTavB}r?!;9-A0Q|}{}o5FgkBTXAq^gouA)V%s*nO8vJkXDxE z>B`DVh9hRabqSiW37R=39Ur~yV!l;Ni`K>{#{TsgYN-Z~=7-!}qE6geX+PZ)$#?Pf z@}0*ljz>xdJ9pUC9AL*8D-52rZE^WZLX0Cm@TDl2~sHmuQ ze}im}?VvaFwnt~A0***M+I}>z%dw}Yhc>=zAj!1Yo0ERShLGA=_U=WZidkTed_>_;49lJUbDimL4Nu-5d6; z(A{*fHP_YEmB*o+LOEWb7D;+{%*Hr- zOACt%FLve8M#ErMUS5UopE7)yb0=CInsOa7yty=Ec8lFlS>+tX})oOn?mV*R6n}UarkI!Zfm1k5IiH}7^;_qDP zpWM22OGrp)xOqE;qT@uW zof*gxupQ8F8>3L3{a~#rC@Ux^7;Q4*Jx}fLXKbcEKmDLJm4ltVI4zOcX{7WhgV6P) z`T2=_b{?LphqawTE}vdp9(zM8YV^Tlcec%U|CQ?)!u$>!`j0tMD6g{VdUb5uKWwP( zyS0io#;h@Aczj&R(eb6RosyFG&6_vV49gkq92`s+=I8j`l3f;@@r4A{roSO1MC@*i z!^8js7XPMrr_K_;Y^xqBE{WG=-tqR`yQa8eD)X&^AE(OKQ7AMKCwKPmlkyK+wrYKM z{=!_M$aME!>%msdkdP3yj4Oe4ajK!!k!k{wY&AZ>YM_~)Y&I!> z-lD5PzP?+9Cz`Jr#5?sjBr_bn^7N#bn3tblv_YxA@!NYFs9`=@_Tt4p-1VpjF`Viz zq9pxz41|wTCJ{qii70=g z&dxXH&vW<@ZQGX(Th~%^X69TDeJiIKli@%+;bg55OP=be*LG2KXB1}4%sY9TX1X)L zd^mXzmgoZ-{iIlxMwO`i4blrCO~%5^mM=b+o3WJKC;NdL%Tm(@;v7k0G^`v6qn^UFl4y z?3odfjLZt2tpl@%*^3V(+z;BrKV&=hIA_VI;j0_ z(7fb8!*4bRSv9jW!zGgV(H?28Nk39mJuj7WB6Hb~&T4ti-5)=GeDE0wi;OI8Y0*qKtuON#K~w*vrzn<;My8agmD@fzn22@C zx^Lh6v@2+qCQWJ4vvYHNj*}Mpmo7C=Y{h0*M2+pOi;sQwY!_BnK|PR-2pb@*u4{jqG7<&GJbg_+=1=b6j3(eh`}r(E6Kl$4a}#6;_R zYhy_hyVg}!R9ILzIzJ#|rl8O+bRT2@hgy)vg-TE%cju~hKI z`UFjUjC0&>BT{?gI$c|EkLWgUjzIV9Ddx_vX-qZbyY%7rQr|;j)X}+#Hq*-Ir_YOv zSEB!Sbah2~czEn6Ja_J#4E6~rHQ2x6GZTaLnODEH4-6!beY)78NjVr87>JPzKkqhI z_;6b&DKb4(;n|K;mTnH8Q&Lm;EIuC>p6O2|I05y6*M8J!+uD`uwjTVwJ{Ui3uq`Ki^4iCUU_&$@XAwgf(K(|iB$aYjZ) zkX`3CzAMctLqkIoJHFv|D3tp1jSb`so~Onr#mD^l=Wlr^m;~ zX?ILp+u9yoQJa>=8y+5Bgf&LbuB3n-srbT7)$aQB5BRzD>({d>#rbMzYF?{U#kH{G zS^@>E&jS^($v*$nST*uOpyZn{wC*FEwpXua4a%6EJb80=cGi+(WMbmzI=%DM$jGi` z<-uNF;eaD`X(jdbM+=jytE%XJ|6NDTfS18~q*E$9A|l4Brz&eMYGSM~8waD`xbgYp zy=P6zp7NWtzRV)cg!_pQ0TcRs1qtUK=X2Gzwu%(ap#D3&U{y=&7A}5xM$c{L^FT)F ztz!ZL{#u3x1~DeJG5&`Q%Y;8=WoEWd`_c|~WxsdZ8H;6V3S8HeX8Z(wJB0sQ=fO>z zDK}1*nCz2Qh?YBzy;TQ@Az4f?tklzEhw2M|jE~pN&&^!p_>^Ot_%b(_Dcc;qq~r5v z553oKabvITm}$U9;5%PdRyLs_<*0!=RQ3MRc0Rk|OM(2?&H%6y^k7mK$0|-9_Qp+{fNEG$ zgSMa!nJI1Rcco)}j=R+T`u3XMijZ3b5&%euRMAc;Mai5P=_;>BJ*$b3+FlwcsEiue zTN5pxFwVqh;)z-`hE~S5XV1yR{P}F-KmfMRvOxCG$|wM?-ay;dN?b`49<|7@q0`p0 zXU~@U9X^X4VDjlj{OdPwYO}5x;#^KsYr-YAT>kVzD%WA+$V}D}vJgJxwUrtj@N}LD zyUSBHuiL{zX}&8q#bDnX&+X-JFXiX=zEC8nU{_X7PWRV$RzuoLMxI52L=9$sy4#%O z&rTY^q^Z(j_Bwy3;iCOO3U$lsGmNX>k2+k(KiQp5Dqe+^o;~Q!*fv(!*13#A`M`ct zQ{;eZO4xaC4r3e-&FmYaAbxoJi6s1U(uU>#{DoxDjr(GLY#;NSzJ1xg^n>mC0Yy&D zr4;t#Nw>uPUVP)wTzVs4r{p8GQsthDGt(=v4gES)fBVuZmIlwN@CViMhuQl{dtGfq zncw)cy}mDWSbW+KtZ964) zlU5p^84B+I_t@g!qK;D`zuBC1+fm7#Kf6b~DeGFKqobo`dK>8_babQ45{DZz%!B90 zGxABtNKQ^39vZ5|QXa928|$m@s|uGeNLcKOxA!!lN*_mwr82v?xS*^GE6)JznSuyP zC#5Wy7u;MCqmh+Tm#CcxL~1-U{vEa0rTrW8{{8GiLK;V}G+$q;xRi3^W?Y4^TnRuv zBO_y0Qys{`g$oy?GoK5h^eg5&J1NGg$Wj%apSmx!bH8AKq1RH1jEd^0YYSK@7!Wsl zWh^#0SS&eBTtnC)z4r$|ZP#q(@L=ZD&eK58BwfJBC<@TcM4)VHq>Im}VpnaVwl-S1 z(%tnt2}~E4!w-O2k&b=()T#EZu2SdEQ{CO&kC-()^X63lh}+>bGp?wn<{$P{EZMei z`s)Mknlxki+IV&4!Tg0h9v&V^YqRD|S?~+ONQTpV$Qj)ejq#$UNC61#%NvQ6aNa)LCN@$X1JjAO}XefRMiuP~~W7 z$+ngNFC_#sZ+2MX5u>mgs%i}$jMvkraStD^jn~RmLnp7p`O@v&87~`p#u%-O=C|J} zwwu2UVOLH#M{RW&6xP?*H%3u&eegg&%d&&lre8%|4#?r74|n2+59iQ1zviWv1RRyd zX;t8fTYrAFgzU_TmH`x=>G_$#bc-)1PoF+*43>fqdBXgzu+X$2>0)nvVsKr&dL)Wa zjC8PwfYbD-n8OBodKN5@+qZ8ETlh`_x>td6jh4I2%Y(dCxze+be!0uk+n1Q{os*Mu zhn9&;-1X5TzPZE4j!lMi0T#b_qbrvDDZ{+Jqu7TOO0NU5C-LCxQHS*nZTFjw545Oj zX=zahmS=!b#Mq5=Qp0}yu$b!-U8saIFbcjM3xsDs)#0O<=V&J`hiW}mr*1AzMSIgX zH|O56WeeI;bA~q%Q&D-jG?f%kXcV9GJ`vgvbf5rY+VSmMC|GYV5KwPZdhBLSwJ;RU z7-$RBu%oTl#ZbHJQuK@IHf$L3$+fVy_5!`d#f`tZ_!YOqKcRDOv|1YETIMzly+X*z z6{HR9mk#n|&dSR2oA=Xx^8x%+CrKqq5G@vs?OL-$U9|lE>L1@)97>s4cke!dPFekM z+hNIK;GS@>W8kFQfBxD1$vnRnoB6qXtff3m@*(=ZS4^D~aQ z3c{zBtb5yd4vfAwO#F#P{gASdu&`93w$m8OVYUL$p$rySZHm4Gcw!?*7Z7e;vF>7% z7qJHZlS)iYNuk@eE&9wOhK`1aQ};LGSTt+aRAcogeEzJnV#SIctY?N}7CH&0S1l|) zjJ<$1z~9;y80P`dA>uS0g$=}C7tj~e$Taii%a;&RQ3{z{YG_4gpPhcNnIKQzb!(PU zZoDqSanF0PVFjgKY0lhjXlN*2wTj8nqRpWM_ht;&+>xnv1Y0%TtWnO3O|c44B3-X= zef7ciu3#6go(tVPTh~tgFdnG!1p)|tA1xpG@zW=D&1|bfeOdE!gD!nQ3HIO9>ucgv z53k#COa=(H5||R`^-A3}TzTKc0DYj9s)T)d>=qkn94|fERs%7E z$S&A}M!+xx0z&(WAQA^CRGZ_-Bh-_z7TY#fHn#Ir$-_zy=-DHHfc*_}L{qFOcLCj- zzJIt4VnrQx5uaJZ!JQ(GD)aO6G^x*X}uH1tfKbSiaWIx_-JU)VFdb7x*l)LHc z!l-(qbpHIHvV_k?>!6|QXM(L}PDMx8d-~ZKR=NF!5{{M#{e#~A!U_lI-09HW53HqW z+53(oM6b{daI|Jwz02%SA<|-DIcn&@9 zhwo3$zT8eZ{s(F1CcS(TK04_S-0GrU|9 z==svw8kb8Y_f}Bo=zsw{xirZ^SoD;<<*KjzK0u*tKAUu_sFGHsqV`)Mr{*0+*3OP_ zhR`!bWo5?!c~o`)nADdC)z<0}=o3e8tAIuPx5xWWnz=X)4NVoG0E;+!7Iedp6AYx- z0B{$hFp}@xyC-w%)Ttmh>+yco%PrZpCG@0)yX}{r6-Nizv17;Z+)HrUYSg?mvql~- zPIVvAO); zM^L3f>^l1TV%>HM_F_|eGqW%=N1-9cC?^J?f5voo8$sNycE+AmPBkc1NYoMrX5C+S zXyF*0L;e>Th_pm9s0!TO3|54ed|`NGL|#S3m+ZV^ z=t51VV`NMK^Xo>BIe+0o*|nWa&XXZ`)^2;woi};%E-e$mDm9RR2^+>;1r2#G=1#OY zbZswgZ7p<3WebZR!k^4gst681DN?{~g`n8yrL0vsJ=Vu}{f7=wZTDYzT5)f~ULsp0 zB^`Xw%nr?pqtN88?vp1^^bc+&9f|-iRB1x{@}@es^9u@^zW#GHn_SpEQsq27=?J?= z8B;Jbi$nPfMeUcM?voChRhf8BG<+xHySkmv13sx$-0}jD*+LBUrbl- zC+NvA$nMPLrZias;UE$y#;Yl=UcDM;=j}>2wcuCV)~1bPPeZ>TR1R9JZ=DyHrW*Ri z$7?%Xn(KizVLYHiFBn}Odl%>Sr)0))_FaBt8wp5iPN7weQJ zoh}HX4Gk)sswt<|G4NJkx6~%i3qJzLb;{W^ve6{B@#6o8t;}J`9N*K0Fe1 zx&Qn3YcDfW9e6m=XpP(^P`zrQG{zI}1bP851-t0grIh2l4buy*K(+T5b;*Oj5(Q4% zmou1GjmO#E(-SRrcYQ~u+zZ=5E#QVCoGpPu(1|4b60~w-AOg2d?0{52leEed(~{JQN&Y5)>BG#S8_cEzr<0p|7>c6gzyr{@L6Y@RoDran=- zUV=L}46u>_P25<*)1N;)Ne%^yT$!foq))n=;H3S8v)Eb=I8`pyL@c z>%oIE@)44ivr{94uHpuDWOfa>qgzJeBvt2TCX{hwYqPCQQI?K~igr5j3kVQp1Ddxo ziWWV#A2>uTY!qr3Ocx`3ueT*90xoChn@IG5A*uj!=4@A)ll%(i1(wB<0Q!EGacRkg={=`_Vq&*(A$%L(c41?+r0d3 z0xeci4*ZT|H@?&j=GwP#rH3c(OphhYXvVK>vKlHS-(^z##>a zP%>fr8aXC_50whN6*+JMdw+a+^YO!n3T(epMfs2?K|u*%AH7pu!S!TU6Es`mcUUb+ z=MSPlLu`sxj8!6BDKRloWMTGNb)NbQGg|&zQlI8{F zjW}8mg7lhfv2S`mWh9cy#-ac%*iHh z-xX-<>F*x{-_<|;)$3Pp-iYCb)kg+9#}Eq(RZQ*W_2h?VLcy)sw6)uIe5#LEuOAio z$yc7_=jZoGL(+nhH8LTFXZe7g>JvyPHEL5auxwfhG<;%VqAxePvU8LLizZ_AM?*3& zY0gYW9gBhf-PV-vl8*}N1qnrOMl`Sd<3~9HC?IU<1#DzF-!AbT3-t2``6P5>|EGb0 zgQ{0}8N9r_#u`r3AK8BN3e?1KVy#uDl#9$rqCCfF)p%oEQf;+U!S!Ympzue}`GjI!cn&>h`d0KC^aOLlD9M;OxXHxT9UZmeGp;%XW17eI(UC+fD`L&UmV{lnOIuqTrCm~1R(ANu zkBT_e)X{=fOhzC9_VZI+weZH!NY7Dk-nwNy^z}5P0f{Z#S~W23i;_2~p7ZkqYtBpMMwCeExmmBg1%CDX!ItW=ZkIO}qN6o0yCX@a zA~e#r`McPacH;v_U?;AkWmLd|Cd>@%jw&L8s5$+!2zheE{Dx^rS_&8=vFC}+;LQc6j_#lxjScpk0{SihCOZUL zG=feQQT>l}sdRS*0h1!JV^v8oL*`WnYlFBPIzg!_SqF^rK{t0Qx8rX#LD(*YXlRI9 ze10WFkK%CSWwCtHC$t{7300V5=g(IWYK-2~$Z;ER{BvbHYl9Pv!wF3h2pQFzu<0ls z`NKkXU!4IT35;0BB+~r!?rD*$V+RsS)&`9QsD~GOC0RjMY}cvN$$%+K}P^zJDU zlc9-hZ?1ygt*H-r7}gAxhb*Jec;bwnJ$vSGMgY%f1+7^!N3$Ow)YXJ>P zh1tkGd-jyb?ZBB^a@Q9Oz($GadV}qtj*S(E738HCfTBR;r9};P%?InG9E&eE2@SM& z$!FRPA4CbR5spUgrc}Z$O^8VJOibK8W$pK|*glq)e!ou-ly~FolSqetX!FCbthaD! z$N&tEX#CDmNZd&q?}XR(2lsBK$W-7`0)_14fPCt)Aw9jUyH;IJO-&^M2@r!mlxYLb zjZ>j8OmVkjF6ryDa(*XL>*7ryA`G4sFq{v7YaA*X8$Z9Qw&T!kNb8pd7#>UK&0kL_ z3l@RM9Cy->hQ2m$LFcC!dt>Qd6lfrIqN5Z%N7~+j%Y<9-#RE z1D2e}8Z33VP-&zb1spa1*x0BVs}x@a)Y;zK8v}jJKkF0@*?uO-CB4Lm*QBc~a2i4t zvE9KSU_ahduz{>WO4 z*TgVgKbQjisBj1GV0c|D=RL3YKc{*AGZwTMNhr*7oJt|_kE!WtcsUno)~!=iP*4Ct zMJ!`(xR`rUu2#7v)S`#Rbd;1+6nrnM{T1Jv(j}nWD1eofsOelxy3!&57nY5S>jKg@ z-e@$0%v!}QS+>&X^Q%7y>GX?5UtXaZbhG8gl4S}hm&70K6jXwaU7>EC8_uKuG0Lz! zh%BWn(=fDtg5}U-unN?s$_aHEb|%<^uQ2A2)H3NURq2y?fuJwYbcW zFj;q3oPZP}i<=D<%@k(TxLbgGl#K z7^wrZf)L>Auy5ckU-WLplj?iiNB05}In+?W!W29*Ef}argpby#kw%1$OJycPDQRt#~He&&?|4{L%hsRv1c=z<;nZjPPv3Y@Q8TQeOJ@!%&RlspFspqaQglIN# zl8QNg{06~V1-Dia(~nBrQ~(A0-nRVw7=`HPq|*u8kAAr7HR9~Lg*!p_oof*iSTHUk zg59U~lDoKK1(A@u%7g0xlZf7IXK&AjEd$sf3k|iSx0em97kaabQv-^eJD2oCL!&?Z z>Xyu_VukKZ^$;>i&JB?pqw*l(l#C1?EJY~QEaFIuMa)6ai-xY(iwhuLEa08jw@k2f zgXr?=>gv&Kz)AdwU#sJY6nb?maA+?)EFhv%OOk-djT^j^xgRMNStS36DM1{`70Z_s z_({xjk`^Ej7@dJ?*(`Q_!Zn=e*mQQ5+p>%*pGVZ=A7C>ogYs)RPwCJ7ctcB!NVgsQ z!*9!IAw0d=8~zyTVAo`z0IWr9c+_*wLL1zSqY_uAdIp?`CS5%_^v(Lq>lK_TNyp)x z_ckPl8s}i$Mx%UD!?a!I6v?ARs0~J~wR9%ndI0+47!qCn@nzteME@X>u;AceSd--Y zq|C&~hTdr`Lye7-MpuE_2TkDXj$(wOp6DHkw(4@FrmxI_YZR-IWr?WWQ)o`s-`=nL z-kPia%SB-k2NJD7g|Z&&RRlUJ9v$HVWuRHVzHWc8<8$OT$f_gj7x}(YHA8SQAQT=v z%_uMkfvaEtgtSCJk+>XA-fFmA94piBjyqh^FCKk58L%PttKNMl3@d<5K7ynTnZ@D% z6L$)>Qgd^&WC0fd5NdigZdfhC4rASwVpN4JOMSR4J=zP?J_aSeA=v(=tuO7!HK_KD z{`esS<&7k~u*;$miL%EzLgC*e?}VnPtgHJR@x)SlDHdDt#Y3HmKS9V_My?&iCrB@y z_;1NwgLkjr#=nfb(w43NtM?MmmNEfO8fG2r{HvK48F9QV&s;*GI|kA8XE;jfolyx% z8qv^!O#(ANoDeBz1aVva-z8YX+p@up(G*HE;lhjN?8zYDE2Nw*WGXiSLd7IiQec%_Nl-GhWxL6C(|KUde_b+Ci z{G-A3Jf3*;OL8sGwHq;54Dy1(dUIjI2sG5NBe?wg&~IhM|&m^c-t!6go@~_vIeD z4Q-Nf^*&;*psiG*O~fynhvf5+2>yT)2minB&()hO*XE-@DneWF0kb7CAe;^=y(O3#R^`PeLjVzj z0W{A8QNQ723`!XYfH-vMkbqUUhsgY-UZAjp5&=IzGw@*cgL*m)=zq$*8cNbT@7}%Z zAygdj%qu8pzG8$_j9={m$jq+ovh(DZRS0285{rWHaZ&|PSL;w$LD~4Ef`nA1g~vQ0 zcQLoD4nZGsZ^Oyx&~J5hx-rq_qK#rq4FgGKg@i?1a;OJj6D;D$fz}bix?`p`K~o)t znMx8GCg`jr_=_~nuus=Dq*D;T00M09?2G^wXiF_^TbQ>c@R3SHWng>-JQho9>soY8 zX%vy7)>bXRJSsEvQXwT2+8fqiTw9hYBxp2_HbJ38x2*yq0GkQW%K@tEyzX*z?RhF% z2Z5$2nW;5a6Rm2jlnvj&796_6! z*)jO&Oek&Yzbbz28#iYlA4EVJ& zK{y!|Kw_K7*HGz8v?Me%p6=MWlOD1)$pxTzOwY6}BtUdSNXK^$vD#;!zqz*T?S{(= zNjY-Gn$6vau|Tv4LbK}Hae|lub+GRUB_$jjKV;I9o!)3FE-hUT6+I0;AXyzSHob_Y z4IEnx*h{tg@%aP{&f8G1z5Z6&SG^ zahVuI1_Ew+^OMLeQ9nqA8iprP`GKsw@GJqcMrTLLW7RXRG-X@IHZG=i_lX@js^8HB zThSA68o3sK?_4;~a1Uw)2Vd$Hd@R{}sME0J*i24%@pa}49TFdMl$ZT1#FbSYz9gTT z^MFfBa*NBGl3PLkTjqF0?M<9VN8NIqKCDY<5Wb1F)lFccyXbsWd;6t#@81*A)xIa( zXHnRtJlW598cBES*~zOU4hN)51vMdn3<;V z_b1FTV&Wm0p_D&29uQvqfe*p~ik$djSd~KA)9{5B+HcqbOyZx6?Xg#jEGXBFRU~xt zzxLGbon-0kT(Nu-Jtz^y!fLb3Z6VH2LdzD7Mf1AWld zs*pTmG7g5NGlz;H9fj<%j0gR+ja%YE&r?DiZKAk~O1Hpd(gBNe2(vpM>*03o^6%p8o#Q+PsdK{(~auGtEl)PW*Wi*vNpm7-S9EMFxPs9kv{#F{4G04A+csDI}lWc^xqnVI~%` z-Uj4@O*`);>OYx-80)D%yJX1{A~li#BqD4FxVXyF=3s!c!`+F;(?CI`p{1?IW*+-B zdcfrP;|~l1m^ICZx{Fl7sD@J(3X~{ZmDCFht>)Q@TMpiaAoS2N6JHLylB5wq?#-E3 z2^0ej*!?qi0&dWoY`+K$ac1xw395eZ;6dKZfS?zv{25ef;;lpN9VIX(lSVn)%6J{4 zkP6@pvD`c%?rzw7CJJ?e&^riU*3+ZL)#He_$fBf?n}L5eLN_`uE$tHL4Kr^r1qWvSX;c{^ssADoRI~?8DjoV}I9x^$4(@X6?Cflk#YbAD269==OS@}iQKM&k z967Xva^v!@hPhw}c$hcn=DsWOfG#UT)5eB2ue%3@%>FVj15>us6~bqsU(}Q74z&Mx zmE@D)EhHE?*-(h0Sr;oq`)-~@WJ&swoL+IU1hO|UVuD4Wt_d-X;4wwiL&5U|Z6NW? z*fJP}paB=~V=>{An7EhE!ut>lQmEs%~8sl?p*D(Zo zIDZwCK6TmdySgKv2qQu8P1wQ-^p)3ncG58p(5`{zX3Hka(eBC_u0gP4tZBwfZ6s5J z6vS~0-XjO03f7_s@j}7}r)l>4zX#$Zm;mw=ie{*u{ZQkYt!)MoBIynQ2B#qdU&SJn z62}or45VuiL}!;D+ST+U+EiOQ?s4a?)e*-jB628MsZ~(9jeSMg$_{5X-L#Jxhy8sXg^G ziP#j7*gfHF5ZnkmvOU@jLcbI8Wp+&n>Z)+#w@L1hz{x1t&@_`;HezU_T2x{qSf(#v zZ;g^9G1MI;^yD<8qJ(OP(8!Iz>i+tSc-7Q1Nc$20ohaBy=>Z7#Kus>Hod;_smH+}= z9o^jrXldPrY99xdj>s>e9C-MfjSJP&lNs>`r?FaPGKfH?(XM_yh55AeR04d7TtwNV%nL*a!o3)EcbOF<<%1d8VT|L{aD?95_wUKXAKC}4a9hj8`lzS_Aj2)T z0qCr!iIz#Cjo1jREGz`<$V*B-#Qx_)Dhkn1OmPw1zt~o=+x$YdWBoDEC0i|7QTun3 za0t?+K&oaHC)p^1{N6p-LZVXq<85%DNDLhimt9#b#x{c$bqBi4V2%{zY(mk+3E?5K z5GjlRoJ>xmkC}@gwP0`o9jAOm1`15l1&w5f8}?(pep(p9P$!cJC@FC3eDnm-tq^f( zS0U~`4kJONzLm6$W-t5~gL-Y(zpW#RCJ_B&qMuaEc6PJ zgS(6%Eu2v z<>fh|(LYDjOzQLqA^^&YD}-u2c^HBDB)}g0S+v%VxczU&)8taZd{RxGLq_#XXq8;zrU42i6Jy_!@8=z5&Q>`?fuU8X z4gWJjsB8aE1?I*Kv?1n0bN*kes}>^?A3~QppQXp$vV8F0#>@UCKw$8zfRX|>mI>p9 z#9V>>uH}|v=Pk};)S-LQh}_g%g2@?DT#1gK)w&<{5#VI@)o<@N#X&HjJ0&C+cSFEs z{w1l9__%1O&RMP82(m|!Ap6yKWhG?98d;neV5&uBq~jw!NqmB;R)ZN3UmxSH!lTE4 zM62Ybg~MSQVQ{D;Xzbsn#+mddLDN)3ZwC62CPID=RgrLVFhvz=8tMztAg57GNEG-k z|MlB)aU@_K(aTqqu>@?K)D_=|?>N;OMP)7sFpA235^XR*R zAcULe&nXeHADIH8?jQ$==wCCh7#;}S`a5RMm75nLyJ!W|C*=IW`kxtUYNkEc*ZT+S zpJlmq0^>)O?;iAgHTSiv`0IrOtyeQ>vkK?|f+*Vygl#s1>z&0|Oz$6kqL9PTE;|?X z*Ui88bLo}{yhCx0v6wXhIv+RF09*zUH==_E>7?BHG9aV+@JGaOt$Bu2~B4N zRQ7OP+~uB)jBLf&*(&#f*_TmXX_K9|=rLbE5PB#8yY1C7@OWKb27O+&Y?{T=BnUyn zcK-8F%lEx=yiE2cTBFsom9{lmGLtZL0sdOE$x-SqWVGDT>YY5cH1fY0vo#sW_jVTLvIQTnYlY0ap zu)Us9C<-C1Zm3h_-3VmaD{I-ks|r^Q-?aZi6a;$0!i`lSrg5q~wkDrK;++6g0p0{~ zzPho*zSGZ;Kt|*m#nIsxGve4fDoJX67WqNvjQJN+nFW&h>&mdT*b0`0NUD@orR z`~}y(+pLSNh4)JkRBfyZ7Ypo;N(6Bz+|lDe~}el257H3U~oXOOVZ!c@5naoUfzIUX*w*fKzv7vJTkn$cF6 z?L>oVNZS3^SuM{w*^ha^iRErNgNJ!|IAo~=Q9q52Btc**mCB1o)AG_z9xOl2%1&wU zM4NJ6teq{4CLE5)$AGNaHiNBX@+TSt8b77Lpjuzx8>9u|AZYX=?bG34veKo$2g3AZ zkjT=Tn!A1E^O#zY@04l0Iy zKq~y6)2{8A(YDa!pvrm$?$myC7AfT;#HBzFzhm5pbOs*Z$2nI6naB$}F=KHK0tGMU zr&L3ClBp|*iTiYx$tetGey7T2CZ>DwC1>QlZgFQ%?RPj_$XUi@7j}xPn zKePu=mVh1*A-N%0ZwxGu*aIX+*9wF$;bY(hnNBgyNS62`WD{Xs@>%yPKwtP&72fzK zmSFcUL`;-jVxs2HPsc)zFzcy`Q$lJ!9cB+w#^<0jR6Je8I5Bf^FCDVN%@sJc^ows+l4&dO-|Vi1ZM!Gjm_!r6R!pJU z69EL}D+2L|_N_su2f%|5;Ke1x&SCDs2noX>-9O1D#vc^r7~rAJ302v%S&$Rg=h-EKJwu59AC$0dkZc`FwG)y8e`t3l`te@PB{hkFP2Q z`{3hnXx<^ZE+`9msb>_pn zLkQ4EVSS0W^r=a6|g+f$E4u19En%Y`74{ujV$8>uSQ_G>I>g zT2QO3U5f`zgjECp$#5`o2PGRkaJy=9Y|V)PNY(<}ioN2P*-z1O_|8X0I^cMb*X5A+ z!4UB{56L1Wl+W}_MH|WFG2+^CVCHF-9p^Ejk&9I2@aSj_^yGAyTNpm<_{|%Mt_pz1 z#M^7Okl~{3N6cj43o#qR$0tk2ezZFTR}+n)3zAsHY-O5$@n#6%rht$noQi@!`hL54 zH_VBt>PwV;SR>bOlJ^N%wx6@Jv+EgjS;%3HBay$}=B!jAW0AqD7eo`odF+QcK`cI! zT!LgiI##buEFzLaO*g5n!3Z3Ye{nB(zxUyN1`rLAm=Qz#$OGdMz@!A(;hJN>vq>Nq zV^ILUoJ3Z^T*blfF2Br8XzE0MA;T(Y<>#nm$Q_%R9Xc)OB7=Zc6clmdtdd0g-zIZ4 z=lK|JFh}1D2fb5Qijdq!lJZmq6bErRbU~@Uz^6|w*V@*_DwmIDBgG2K5HH5Li5KON zHxbaY4lgVie)?++f`!P)v~i&rr2q+*#T_Sz1$TzTnE^Y?lLJ8zd#oWC!m?r&Ck+t0 z+(50XyF2RQn>!fnIwgbfL+G|?sS{vk-Y6T)>&%h~}J ztdKDwsW`k2h~>I_gzxpW`3#;UVq8`KpkIK;|W zmL-i0l&n{%f^@skhlu$~UY|o=R6>S3!UW&=W5L4xe*)CyfUHF{)H^Jy!Da5Sq6NvY zHnAcIibK&hf=fc+C$KSj=LWF{5n6IF(nNi%$Ji3l`XIDKVFME96>I+QCkmAnV-yZ4 zt5e#ky92b7mfu*q8&ndDT>;dJQ>> import ehrapy as ep + >>> adata = ep.dt.mimic_2(encoded=False) + >>> ep.pp.knn_impute(adata, n_neighbours=5) + >>> input_features = [ + ... feat for feat in adata.var_names if feat not in {"service_unit", "day_icu_intime", "tco2_first"} + ... ] + >>> ep.tl.rank_features_supervised(adata, "tco2_first", "continuous", "rf", input_features=input_features) + >>> ep.pl.rank_features_supervised(adata) + + .. image:: /_static/docstring_previews/feature_importances.png """ if key not in adata.var.keys(): raise ValueError( diff --git a/ehrapy/tools/feature_ranking/_feature_importances.py b/ehrapy/tools/feature_ranking/_feature_importances.py index 8d4799f3..505df92f 100644 --- a/ehrapy/tools/feature_ranking/_feature_importances.py +++ b/ehrapy/tools/feature_ranking/_feature_importances.py @@ -48,6 +48,15 @@ def rank_features_supervised( percent_output: Set to True to output the feature importances as percentages. Note that information about positive or negative coefficients for regression models will be lost. Defaults to False. **kwargs: Additional keyword arguments to pass to the model. See the documentation of the respective model in scikit-learn for details. + + Examples: + >>> import ehrapy as ep + >>> adata = ep.dt.mimic_2(encoded=False) + >>> ep.pp.knn_impute(adata, n_neighbours=5) + >>> input_features = [ + ... feat for feat in adata.var_names if feat not in {"service_unit", "day_icu_intime", "tco2_first"} + ... ] + >>> ep.tl.rank_features_supervised(adata, "tco2_first", "continuous", "rf", input_features=input_features) """ if predicted_feature not in adata.var_names: raise ValueError(f"Feature {predicted_feature} not found in adata.var.") From 0c65ce1ff45e2cfba299bf48e5a1a6079830fdf9 Mon Sep 17 00:00:00 2001 From: Lilly Date: Sun, 7 Apr 2024 18:07:40 +0200 Subject: [PATCH 12/13] Auto prediction_type inference --- .../feature_ranking/_feature_importances.py | 53 +++++++++++++------ 1 file changed, 37 insertions(+), 16 deletions(-) diff --git a/ehrapy/tools/feature_ranking/_feature_importances.py b/ehrapy/tools/feature_ranking/_feature_importances.py index 505df92f..a93da689 100644 --- a/ehrapy/tools/feature_ranking/_feature_importances.py +++ b/ehrapy/tools/feature_ranking/_feature_importances.py @@ -11,12 +11,13 @@ from ehrapy import logging as logg from ehrapy.anndata import anndata_to_df +from ehrapy.anndata._constants import EHRAPY_TYPE_KEY, NON_NUMERIC_ENCODED_TAG, NON_NUMERIC_TAG, NUMERIC_TAG def rank_features_supervised( adata: AnnData, predicted_feature: str, - prediction_type: Literal["continuous", "categorical"], + prediction_type: Literal["continuous", "categorical", "auto"] = "auto", model: Literal["regression", "svm", "rf"] = "regression", input_features: Iterable[str] | Literal["all"] = "all", layer: str | None = None, @@ -32,7 +33,8 @@ def rank_features_supervised( adata: :class:`~anndata.AnnData` object storing the data. predicted_feature: The feature to predict by the model. Must be present in adata.var_names. prediction_type: Whether the predicted feature is continuous or categorical. If the data type of the predicted feature - is not correct, conversion will be attempted. + is not correct, conversion will be attempted. If set to 'auto', the function will try to infer the data type from the data. + Defaults to 'auto'. model: The model to use for prediction. Choose between 'regression', 'svm', or 'rf'. Note that multi-class classification is only possible with 'rf'. Defaults to 'regression'. input_features: The features in adata.var to use for prediction. Should be a list of feature names. If 'all', all features @@ -56,7 +58,9 @@ def rank_features_supervised( >>> input_features = [ ... feat for feat in adata.var_names if feat not in {"service_unit", "day_icu_intime", "tco2_first"} ... ] - >>> ep.tl.rank_features_supervised(adata, "tco2_first", "continuous", "rf", input_features=input_features) + >>> ep.tl.rank_features_supervised( + ... adata, "tco2_first", prediction_type="continuous", model="rf", input_features=input_features + ... ) """ if predicted_feature not in adata.var_names: raise ValueError(f"Feature {predicted_feature} not found in adata.var.") @@ -69,11 +73,6 @@ def rank_features_supervised( if model not in ["regression", "svm", "rf"]: raise ValueError(f"Model {model} not recognized. Please choose either 'regression', 'svm', or 'rf'.") - if prediction_type not in ["continuous", "categorical"]: - raise ValueError( - f"Prediction type {prediction_type} not recognized. Please choose either 'continuous' or 'categorical'." - ) - if feature_scaling not in ["standard", "minmax", None]: raise ValueError( f"Feature scaling type {feature_scaling} not recognized. Please choose either 'standard', 'minmax', or None." @@ -81,7 +80,23 @@ def rank_features_supervised( data = anndata_to_df(adata, layer=layer) - if prediction_type == "continuous": + if prediction_type == "auto": + if EHRAPY_TYPE_KEY in adata.var: + prediction_encoding_type = adata.var[EHRAPY_TYPE_KEY][predicted_feature] + if prediction_encoding_type == NON_NUMERIC_TAG or prediction_encoding_type == NON_NUMERIC_ENCODED_TAG: + prediction_type = "categorical" + else: + prediction_type = "continuous" + else: + if pd.api.types.is_categorical_dtype(data[predicted_feature].dtype): + prediction_type = "categorical" + else: + prediction_type = "continuous" + logg.info( + f"Predicted feature {predicted_feature} was detected as {prediction_type}. If this is incorrect, please specify in the prediction_type argument." + ) + + elif prediction_type == "continuous": if pd.api.types.is_categorical_dtype(data[predicted_feature].dtype): try: data[predicted_feature] = data[predicted_feature].astype(float) @@ -91,13 +106,6 @@ def rank_features_supervised( f"type to 'categorical' or change the feature data type to a continuous type." ) from e - if model == "regression": - predictor = LinearRegression(**kwargs) - elif model == "svm": - predictor = SVR(kernel="linear", **kwargs) - elif model == "rf": - predictor = RandomForestRegressor(**kwargs) - elif prediction_type == "categorical": if not pd.api.types.is_categorical_dtype(data[predicted_feature].dtype): try: @@ -107,7 +115,20 @@ def rank_features_supervised( f"Feature {predicted_feature} is not categorical and conversion to category failed. Either change the prediction " f"type to 'continuous' or change the feature data type to a categorical type." ) from e + else: + raise ValueError( + f"Prediction type {prediction_type} not recognized. Please choose 'continuous', 'categorical', or 'auto'." + ) + + if prediction_type == "continuous": + if model == "regression": + predictor = LinearRegression(**kwargs) + elif model == "svm": + predictor = SVR(kernel="linear", **kwargs) + elif model == "rf": + predictor = RandomForestRegressor(**kwargs) + elif prediction_type == "categorical": if data[predicted_feature].nunique() > 2 and model in ["regression", "svm"]: raise ValueError( f"Feature {predicted_feature} has more than two categories. Please choose random forest (rf) as model for multi-class classification." From 3c0a979b8d48e85086436202bc9cfd8841db5945 Mon Sep 17 00:00:00 2001 From: Lilly May <93096564+Lilly-May@users.noreply.github.com> Date: Sun, 7 Apr 2024 18:09:06 +0200 Subject: [PATCH 13/13] Simplified error message Co-authored-by: Lukas Heumos --- ehrapy/tools/feature_ranking/_feature_importances.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ehrapy/tools/feature_ranking/_feature_importances.py b/ehrapy/tools/feature_ranking/_feature_importances.py index a93da689..79efb779 100644 --- a/ehrapy/tools/feature_ranking/_feature_importances.py +++ b/ehrapy/tools/feature_ranking/_feature_importances.py @@ -131,7 +131,7 @@ def rank_features_supervised( elif prediction_type == "categorical": if data[predicted_feature].nunique() > 2 and model in ["regression", "svm"]: raise ValueError( - f"Feature {predicted_feature} has more than two categories. Please choose random forest (rf) as model for multi-class classification." + f"Feature {predicted_feature} has more than two categories. Please choose 'rf' as model for multi-class classification." ) if model == "regression":