diff --git a/docs/_static/docstring_previews/feature_importances.png b/docs/_static/docstring_previews/feature_importances.png new file mode 100644 index 00000000..81cc70cd Binary files /dev/null and b/docs/_static/docstring_previews/feature_importances.png differ diff --git a/docs/usage/usage.md b/docs/usage/usage.md index 1a4241da..3d689849 100644 --- a/docs/usage/usage.md +++ b/docs/usage/usage.md @@ -196,7 +196,7 @@ In contrast to a preprocessing function, a tool usually adds an easily interpret tools.paga ``` -### Group comparison +### Feature Ranking ```{eval-rst} .. autosummary:: @@ -205,6 +205,7 @@ In contrast to a preprocessing function, a tool usually adds an easily interpret tools.rank_features_groups tools.filter_rank_features_groups + tools.rank_features_supervised ``` ### Dataset integration @@ -358,7 +359,7 @@ Visualize clusters using one of the embedding methods passing color='leiden'. plot.paga_compare ``` -### Group comparison +### Feature Ranking ```{eval-rst} .. autosummary:: @@ -372,6 +373,7 @@ Visualize clusters using one of the embedding methods passing color='leiden'. plot.rank_features_groups_dotplot plot.rank_features_groups_matrixplot plot.rank_features_groups_tracksplot + plot.rank_features_supervised ``` ### Survival Analysis diff --git a/ehrapy/plot/__init__.py b/ehrapy/plot/__init__.py index 102c2f26..f20c6220 100644 --- a/ehrapy/plot/__init__.py +++ b/ehrapy/plot/__init__.py @@ -3,3 +3,4 @@ from ehrapy.plot._survival_analysis import kmf, ols from ehrapy.plot._util import * # noqa: F403 from ehrapy.plot.causal_inference._dowhy import causal_effect +from ehrapy.plot.feature_ranking._feature_importances import rank_features_supervised diff --git a/ehrapy/plot/feature_ranking/__init__.py b/ehrapy/plot/feature_ranking/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ehrapy/plot/feature_ranking/_feature_importances.py b/ehrapy/plot/feature_ranking/_feature_importances.py new file mode 100644 index 00000000..88b48ce3 --- /dev/null +++ b/ehrapy/plot/feature_ranking/_feature_importances.py @@ -0,0 +1,68 @@ +from typing import TYPE_CHECKING + +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns +from anndata import AnnData +from matplotlib.axes import Axes + + +def rank_features_supervised( + adata: AnnData, + key: str = "feature_importances", + n_features: int = 10, + ax: Axes | None = None, + show: bool = True, + save: str | None = None, + **kwargs, +) -> Axes | None: + """Plot features with greates absolute importances as a barplot. + + Args: + adata: :class:`~anndata.AnnData` object storing the data. A key in adata.var should contain the feature + importances, calculated beforehand. + key: The key in adata.var to use for feature importances. Defaults to 'feature_importances'. + n_features: The number of features to plot. Defaults to 10. + ax: A matplotlib axes object to plot on. If `None`, a new figure will be created. Defaults to `None`. + show: If `True`, show the figure. If `False`, return the axes object. Defaults to `True`. + save: Path to save the figure. If `None`, the figure will not be saved. Defaults to `None`. + **kwargs: Additional arguments passed to `seaborn.barplot`. + + Returns: + If `show == False` a `matplotlib.axes.Axes` object, else `None`. + + Examples: + >>> import ehrapy as ep + >>> adata = ep.dt.mimic_2(encoded=False) + >>> ep.pp.knn_impute(adata, n_neighbours=5) + >>> input_features = [ + ... feat for feat in adata.var_names if feat not in {"service_unit", "day_icu_intime", "tco2_first"} + ... ] + >>> ep.tl.rank_features_supervised(adata, "tco2_first", "continuous", "rf", input_features=input_features) + >>> ep.pl.rank_features_supervised(adata) + + .. image:: /_static/docstring_previews/feature_importances.png + """ + if key not in adata.var.keys(): + raise ValueError( + f"Key {key} not found in adata.var. Make sure to calculate feature importances first with ep.tl.feature_importances." + ) + + df = pd.DataFrame({"importance": adata.var[key]}, index=adata.var_names) + df["absolute_importance"] = df["importance"].abs() + df = df.sort_values("absolute_importance", ascending=False) + + if ax is None: + fig, ax = plt.subplots() + ax = sns.barplot(x=df["importance"][:n_features], y=df.index[:n_features], orient="h", ax=ax, **kwargs) + plt.ylabel("Feature") + plt.xlabel("Importance") + plt.tight_layout() + + if save: + plt.savefig(save, bbox_inches="tight") + if show: + plt.show() + return None + else: + return ax diff --git a/ehrapy/tools/__init__.py b/ehrapy/tools/__init__.py index 18eb61be..d57b936d 100644 --- a/ehrapy/tools/__init__.py +++ b/ehrapy/tools/__init__.py @@ -14,6 +14,7 @@ from ehrapy.tools._scanpy_tl_api import * # noqa: F403 from ehrapy.tools.causal._dowhy import causal_inference from ehrapy.tools.cohort_tracking._cohort_tracker import CohortTracker +from ehrapy.tools.feature_ranking._feature_importances import rank_features_supervised from ehrapy.tools.feature_ranking._rank_features_groups import filter_rank_features_groups, rank_features_groups try: # pragma: no cover diff --git a/ehrapy/tools/feature_ranking/_feature_importances.py b/ehrapy/tools/feature_ranking/_feature_importances.py new file mode 100644 index 00000000..79efb779 --- /dev/null +++ b/ehrapy/tools/feature_ranking/_feature_importances.py @@ -0,0 +1,184 @@ +from collections.abc import Iterable +from typing import Literal + +import pandas as pd +from anndata import AnnData +from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor +from sklearn.linear_model import LinearRegression, LogisticRegression +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import MinMaxScaler, StandardScaler +from sklearn.svm import SVC, SVR + +from ehrapy import logging as logg +from ehrapy.anndata import anndata_to_df +from ehrapy.anndata._constants import EHRAPY_TYPE_KEY, NON_NUMERIC_ENCODED_TAG, NON_NUMERIC_TAG, NUMERIC_TAG + + +def rank_features_supervised( + adata: AnnData, + predicted_feature: str, + prediction_type: Literal["continuous", "categorical", "auto"] = "auto", + model: Literal["regression", "svm", "rf"] = "regression", + input_features: Iterable[str] | Literal["all"] = "all", + layer: str | None = None, + test_split_size: float = 0.2, + key_added: str = "feature_importances", + feature_scaling: Literal["standard", "minmax"] | None = "standard", + percent_output: bool = False, + **kwargs, +): + """Calculate feature importances for predicting a specified feature in adata.var. + + Args: + adata: :class:`~anndata.AnnData` object storing the data. + predicted_feature: The feature to predict by the model. Must be present in adata.var_names. + prediction_type: Whether the predicted feature is continuous or categorical. If the data type of the predicted feature + is not correct, conversion will be attempted. If set to 'auto', the function will try to infer the data type from the data. + Defaults to 'auto'. + model: The model to use for prediction. Choose between 'regression', 'svm', or 'rf'. Note that multi-class classification + is only possible with 'rf'. Defaults to 'regression'. + input_features: The features in adata.var to use for prediction. Should be a list of feature names. If 'all', all features + in adata.var will be used. Note that non-numeric input features will cause an error, so make sure to encode them properly + before. Defaults to 'all'. + layer: The layer in adata.layers to use for prediction. If None, adata.X will be used. Defaults to None. + test_split_size: The split of data used for testing the model. Should be a float between 0 and 1, representing the proportion. + Defaults to 0.2. + key_added: The key in adata.var to store the feature importances. Defaults to 'feature_importances'. + feature_scaling: The type of feature scaling to use for the input. Choose between 'standard', 'minmax', or None. + 'standard' uses sklearn's StandardScaler, 'minmax' uses MinMaxScaler. Scaler will be fit and transformed + for each feature individually. Defaults to 'standard'. + percent_output: Set to True to output the feature importances as percentages. Note that information about positive or negative + coefficients for regression models will be lost. Defaults to False. + **kwargs: Additional keyword arguments to pass to the model. See the documentation of the respective model in scikit-learn for details. + + Examples: + >>> import ehrapy as ep + >>> adata = ep.dt.mimic_2(encoded=False) + >>> ep.pp.knn_impute(adata, n_neighbours=5) + >>> input_features = [ + ... feat for feat in adata.var_names if feat not in {"service_unit", "day_icu_intime", "tco2_first"} + ... ] + >>> ep.tl.rank_features_supervised( + ... adata, "tco2_first", prediction_type="continuous", model="rf", input_features=input_features + ... ) + """ + if predicted_feature not in adata.var_names: + raise ValueError(f"Feature {predicted_feature} not found in adata.var.") + + if input_features != "all": + for feature in input_features: + if feature not in adata.var_names: + raise ValueError(f"Feature {feature} not found in adata.var.") + + if model not in ["regression", "svm", "rf"]: + raise ValueError(f"Model {model} not recognized. Please choose either 'regression', 'svm', or 'rf'.") + + if feature_scaling not in ["standard", "minmax", None]: + raise ValueError( + f"Feature scaling type {feature_scaling} not recognized. Please choose either 'standard', 'minmax', or None." + ) + + data = anndata_to_df(adata, layer=layer) + + if prediction_type == "auto": + if EHRAPY_TYPE_KEY in adata.var: + prediction_encoding_type = adata.var[EHRAPY_TYPE_KEY][predicted_feature] + if prediction_encoding_type == NON_NUMERIC_TAG or prediction_encoding_type == NON_NUMERIC_ENCODED_TAG: + prediction_type = "categorical" + else: + prediction_type = "continuous" + else: + if pd.api.types.is_categorical_dtype(data[predicted_feature].dtype): + prediction_type = "categorical" + else: + prediction_type = "continuous" + logg.info( + f"Predicted feature {predicted_feature} was detected as {prediction_type}. If this is incorrect, please specify in the prediction_type argument." + ) + + elif prediction_type == "continuous": + if pd.api.types.is_categorical_dtype(data[predicted_feature].dtype): + try: + data[predicted_feature] = data[predicted_feature].astype(float) + except ValueError as e: + raise ValueError( + f"Feature {predicted_feature} is not continuous and conversion to float failed. Either change the prediction " + f"type to 'categorical' or change the feature data type to a continuous type." + ) from e + + elif prediction_type == "categorical": + if not pd.api.types.is_categorical_dtype(data[predicted_feature].dtype): + try: + data[predicted_feature] = data[predicted_feature].astype("category") + except ValueError as e: + raise ValueError( + f"Feature {predicted_feature} is not categorical and conversion to category failed. Either change the prediction " + f"type to 'continuous' or change the feature data type to a categorical type." + ) from e + else: + raise ValueError( + f"Prediction type {prediction_type} not recognized. Please choose 'continuous', 'categorical', or 'auto'." + ) + + if prediction_type == "continuous": + if model == "regression": + predictor = LinearRegression(**kwargs) + elif model == "svm": + predictor = SVR(kernel="linear", **kwargs) + elif model == "rf": + predictor = RandomForestRegressor(**kwargs) + + elif prediction_type == "categorical": + if data[predicted_feature].nunique() > 2 and model in ["regression", "svm"]: + raise ValueError( + f"Feature {predicted_feature} has more than two categories. Please choose 'rf' as model for multi-class classification." + ) + + if model == "regression": + predictor = LogisticRegression(**kwargs) + elif model == "svm": + predictor = SVC(kernel="linear", **kwargs) + elif model == "rf": + predictor = RandomForestClassifier(**kwargs) + + if input_features == "all": + input_features = list(adata.var_names) + input_features.remove(predicted_feature) + + input_data = data[input_features] + labels = data[predicted_feature] + + for feature in input_data.columns: + try: + input_data[feature] = input_data[feature].astype(float) + + if feature_scaling is not None: + scaler = StandardScaler() if feature_scaling == "standard" else MinMaxScaler() + input_data[feature] = scaler.fit_transform(input_data[[feature]]) + except ValueError as e: + raise ValueError( + f"Feature {feature} is not numeric. Please encode non-numeric features before calculating " + f"feature importances or drop them from the input_features list." + ) from e + + x_train, x_test, y_train, y_test = train_test_split(input_data, labels, test_size=test_split_size) + + predictor.fit(x_train, y_train) + + score = predictor.score(x_test, y_test) + evaluation_metric = "R2 score" if prediction_type == "continuous" else "accuracy" + logg.info( + f"Training completed. The model achieved an {evaluation_metric} of {score:.2f} on the test set, consisting of {len(y_test)} samples." + ) + + if model == "regression" or model == "svm": + feature_importances = pd.Series(predictor.coef_.squeeze(), index=input_data.columns) + else: + feature_importances = pd.Series(predictor.feature_importances_.squeeze(), index=input_data.columns) + + if percent_output: + feature_importances = feature_importances.abs() / feature_importances.abs().sum() * 100 + + # Reorder feature importances to match adata.var order and save importances in adata.var + feature_importances = feature_importances.reindex(adata.var_names) + adata.var[key_added] = feature_importances diff --git a/tests/tools/supervised/test_feature_importances.py b/tests/tools/supervised/test_feature_importances.py new file mode 100644 index 00000000..1acc7498 --- /dev/null +++ b/tests/tools/supervised/test_feature_importances.py @@ -0,0 +1,56 @@ +import unittest + +import numpy as np +import pandas as pd +import pytest +from anndata import AnnData + +from ehrapy.tools import rank_features_supervised + + +def test_continuous_prediction(): + target = np.random.rand(1000) + X = np.stack((target, target * 2, [1] * 1000)).T + adata = AnnData(X) + adata.var_names = ["target", "feature1", "feature2"] + + for model in ["regression", "svm", "rf"]: + rank_features_supervised(adata, "target", "continuous", model, "all") + assert "feature_importances" in adata.var + assert adata.var["feature_importances"]["feature1"] > 0 + assert adata.var["feature_importances"]["feature2"] == 0 + assert pd.isna(adata.var["feature_importances"]["target"]) + + +def test_categorical_prediction(): + target = np.random.randint(2, size=1000) + X = np.stack((target, target, [1] * 1000)).T + + adata = AnnData(X) + adata.var_names = ["target", "feature1", "feature2"] + + for model in ["regression", "svm", "rf"]: + rank_features_supervised(adata, "target", "categorical", model, "all") + assert "feature_importances" in adata.var + assert adata.var["feature_importances"]["feature1"] > 0 + assert adata.var["feature_importances"]["feature2"] == 0 + assert pd.isna(adata.var["feature_importances"]["target"]) + + +def test_multiclass_prediction(): + target = np.random.randint(4, size=1000) + X = np.stack((target, target, [1] * 1000)).T + + adata = AnnData(X) + adata.var_names = ["target", "feature1", "feature2"] + + rank_features_supervised(adata, "target", "categorical", "rf", "all") + assert "feature_importances" in adata.var + assert adata.var["feature_importances"]["feature1"] > 0 + assert adata.var["feature_importances"]["feature2"] == 0 + assert pd.isna(adata.var["feature_importances"]["target"]) + + for invalid_model in ["regression", "svm"]: + with pytest.raises(ValueError) as excinfo: + rank_features_supervised(adata, "target", "categorical", invalid_model, "all") + assert str(excinfo.value).startswith("Feature target has more than two categories.")