diff --git a/ehrapy/_settings.py b/ehrapy/_settings.py index 3547af0a..f733c059 100644 --- a/ehrapy/_settings.py +++ b/ehrapy/_settings.py @@ -53,7 +53,7 @@ def __init__( figdir: str | Path = "./figures/", cache_compression: str | None = "lzf", max_memory=15, - n_jobs: int = 1, + n_jobs: int = -1, logfile: str | Path | None = None, categories_to_ignore: Iterable[str] = ("N/A", "dontknow", "no_gate", "?"), _frameon: bool = True, diff --git a/ehrapy/core/_tool_available.py b/ehrapy/_utils_available.py similarity index 79% rename from ehrapy/core/_tool_available.py rename to ehrapy/_utils_available.py index 75153b41..7b116681 100644 --- a/ehrapy/core/_tool_available.py +++ b/ehrapy/_utils_available.py @@ -4,7 +4,7 @@ from subprocess import PIPE, Popen -def _check_module_importable(package: str) -> bool: # pragma: no cover +def _check_module_importable(package: str) -> bool: """Checks whether a module is installed and can be loaded. Args: @@ -19,7 +19,7 @@ def _check_module_importable(package: str) -> bool: # pragma: no cover return module_available -def _shell_command_accessible(command: list[str]) -> bool: # pragma: no cover +def _shell_command_accessible(command: list[str]) -> bool: """Checks whether the provided command is accessible in the current shell. Args: @@ -29,7 +29,7 @@ def _shell_command_accessible(command: list[str]) -> bool: # pragma: no cover True if the command is accessible, False otherwise. """ command_accessible = Popen(command, stdout=PIPE, stderr=PIPE, universal_newlines=True, shell=True) - (commmand_stdout, command_stderr) = command_accessible.communicate() + command_accessible.communicate() if command_accessible.returncode != 0: return False diff --git a/ehrapy/_doc_util.py b/ehrapy/_utils_doc.py similarity index 100% rename from ehrapy/_doc_util.py rename to ehrapy/_utils_doc.py diff --git a/ehrapy/_utils_rendering.py b/ehrapy/_utils_rendering.py new file mode 100644 index 00000000..43596c54 --- /dev/null +++ b/ehrapy/_utils_rendering.py @@ -0,0 +1,21 @@ +import functools + +from rich.progress import Progress, SpinnerColumn + + +def spinner(message: str = "Running task"): + def wrap(func): + @functools.wraps(func) + def wrapped_f(*args, **kwargs): + with Progress( + "[progress.description]{task.description}", + SpinnerColumn(), + refresh_per_second=1500, + ) as progress: + progress.add_task(f"[blue]{message}", total=1) + result = func(*args, **kwargs) + return result + + return wrapped_f + + return wrap diff --git a/ehrapy/anndata/anndata_ext.py b/ehrapy/anndata/anndata_ext.py index 36bdd6a5..a82721d3 100644 --- a/ehrapy/anndata/anndata_ext.py +++ b/ehrapy/anndata/anndata_ext.py @@ -3,7 +3,7 @@ import random from collections import OrderedDict from string import ascii_letters -from typing import TYPE_CHECKING, NamedTuple +from typing import TYPE_CHECKING, Any, NamedTuple import numpy as np import pandas as pd @@ -303,7 +303,7 @@ def move_to_x(adata: AnnData, to_x: list[str] | str) -> AnnData: return new_adata -def _get_column_indices(adata: AnnData, col_names: str | Iterable[str]) -> list[int]: +def get_column_indices(adata: AnnData, col_names: str | Iterable[str]) -> list[int]: """Fetches the column indices in X for a given list of column names Args: @@ -383,7 +383,7 @@ def set_numeric_vars( if copy: adata = adata.copy() - vars_idx = _get_column_indices(adata, vars) + vars_idx = get_column_indices(adata, vars) adata.X[:, vars_idx] = values @@ -663,3 +663,49 @@ def get_rank_features_df( class NotEncodedError(AssertionError): pass + + +def _are_ndarrays_equal(arr1: np.ndarray, arr2: np.ndarray) -> np.bool_: + """Check if two arrays are equal member-wise. + + Note: Two NaN are considered equal. + + Args: + arr1: First array to compare + arr2: Second array to compare + + Returns: + True if the two arrays are equal member-wise + """ + return np.all(np.equal(arr1, arr2, dtype=object) | ((arr1 != arr1) & (arr2 != arr2))) + + +def _is_val_missing(data: np.ndarray) -> np.ndarray[Any, np.dtype[np.bool_]]: + """Check if values in a AnnData matrix are missing. + + Args: + data: The AnnData matrix to check + + Returns: + An array of bool representing the missingness of the original data, with the same shape + """ + return np.isin(data, [None, ""]) | (data != data) + + +def _to_dense_matrix(adata: AnnData, layer: str | None = None) -> np.ndarray: # pragma: no cover + """Extract a layer from an AnnData object and convert it to a dense matrix if required. + + Args: + adata: The AnnData where to extract the layer from. + layer: Name of the layer to extract. If omitted, X is considered. + + Returns: + The layer as a dense matrix. If a conversion was required, this function returns a copy of the original layer, + othersize this function returns a reference. + """ + from scipy.sparse import issparse + + if layer is None: + return adata.X.toarray() if issparse(adata.X) else adata.X + else: + return adata.layers[layer].toarray() if issparse(adata.layers[layer]) else adata.layers[layer] diff --git a/ehrapy/plot/_scanpy_pl_api.py b/ehrapy/plot/_scanpy_pl_api.py index d09ca8ed..a1d032a2 100644 --- a/ehrapy/plot/_scanpy_pl_api.py +++ b/ehrapy/plot/_scanpy_pl_api.py @@ -9,7 +9,7 @@ import scanpy as sc from scanpy.plotting import DotPlot, MatrixPlot, StackedViolin -from ehrapy._doc_util import ( +from ehrapy._utils_doc import ( _doc_params, doc_adata_color_etc, doc_common_groupby_plot_args, diff --git a/ehrapy/preprocessing/_imputation.py b/ehrapy/preprocessing/_imputation.py index 6ff60a2d..60facfeb 100644 --- a/ehrapy/preprocessing/_imputation.py +++ b/ehrapy/preprocessing/_imputation.py @@ -7,22 +7,20 @@ import numpy as np import pandas as pd from lamin_utils import logger -from rich import print -from rich.progress import Progress, SpinnerColumn -from sklearn.experimental import enable_iterative_imputer # required to enable IterativeImputer (experimental feature) +from sklearn.experimental import enable_iterative_imputer # noinspection PyUnresolvedReference from sklearn.impute import SimpleImputer -from sklearn.preprocessing import OrdinalEncoder from ehrapy import settings +from ehrapy._utils_available import _check_module_importable +from ehrapy._utils_rendering import spinner from ehrapy.anndata import check_feature_types -from ehrapy.anndata._constants import CATEGORICAL_TAG, FEATURE_TYPE_KEY -from ehrapy.anndata.anndata_ext import _get_column_indices -from ehrapy.core._tool_available import _check_module_importable +from ehrapy.anndata.anndata_ext import get_column_indices if TYPE_CHECKING: from anndata import AnnData +@spinner("Performing explicit impute") def explicit_impute( adata: AnnData, replacement: (str | int) | (dict[str, str | int]), @@ -30,7 +28,7 @@ def explicit_impute( impute_empty_strings: bool = True, warning_threshold: int = 70, copy: bool = False, -) -> AnnData: +) -> AnnData | None: """Replaces all missing values in all columns or a subset of columns specified by the user with the passed replacement value. There are two scenarios to cover: @@ -47,7 +45,7 @@ def explicit_impute( Returns: If copy is True, a modified copy of the original AnnData object with imputed X. - If copy is False, the original AnnData object is modified in place. + If copy is False, the original AnnData object is modified in place, and None is returned. Examples: Replace all missing values in adata with the value 0: @@ -56,7 +54,7 @@ def explicit_impute( >>> adata = ep.dt.mimic_2(encoded=True) >>> ep.pp.explicit_impute(adata, replacement=0) """ - if copy: # pragma: no cover + if copy: adata = adata.copy() if isinstance(replacement, int) or isinstance(replacement, str): @@ -64,32 +62,25 @@ def explicit_impute( else: _warn_imputation_threshold(adata, var_names=replacement.keys(), threshold=warning_threshold) # type: ignore - with Progress( - "[progress.description]{task.description}", - SpinnerColumn(), - refresh_per_second=1500, - ) as progress: - progress.add_task("[blue]Running explicit imputation", total=1) - # 1: Replace all missing values with the specified value - if isinstance(replacement, (int, str)): - _replace_explicit(adata.X, replacement, impute_empty_strings) - - # 2: Replace all missing values in a subset of columns with a specified value per column or a default value, when the column is not explicitly named - elif isinstance(replacement, dict): - for idx, column_name in enumerate(adata.var_names): - imputation_value = _extract_impute_value(replacement, column_name) - # only replace if an explicit value got passed or could be extracted from replacement - if imputation_value: - _replace_explicit(adata.X[:, idx : idx + 1], imputation_value, impute_empty_strings) - else: - logger.warning(f"No replace value passed and found for var [not bold green]{column_name}.") - else: - raise ValueError( # pragma: no cover - f"Type {type(replacement)} is not a valid datatype for replacement parameter. Either use int, str or a dict!" - ) + # 1: Replace all missing values with the specified value + if isinstance(replacement, (int, str)): + _replace_explicit(adata.X, replacement, impute_empty_strings) + + # 2: Replace all missing values in a subset of columns with a specified value per column or a default value, when the column is not explicitly named + elif isinstance(replacement, dict): + for idx, column_name in enumerate(adata.var_names): + imputation_value = _extract_impute_value(replacement, column_name) + # only replace if an explicit value got passed or could be extracted from replacement + if imputation_value: + _replace_explicit(adata.X[:, idx : idx + 1], imputation_value, impute_empty_strings) + else: + logger.warning(f"No replace value passed and found for var [not bold green]{column_name}.") + else: + raise ValueError( # pragma: no cover + f"Type {type(replacement)} is not a valid datatype for replacement parameter. Either use int, str or a dict!" + ) - if copy: - return adata + return adata if copy else None def _replace_explicit(arr: np.ndarray, replacement: str | int, impute_empty_strings: bool) -> None: @@ -119,6 +110,7 @@ def _extract_impute_value(replacement: dict[str, str | int], column_name: str) - return None +@spinner("Performing simple impute") def simple_impute( adata: AnnData, var_names: Iterable[str] | None = None, @@ -126,9 +118,12 @@ def simple_impute( strategy: Literal["mean", "median", "most_frequent"] = "mean", copy: bool = False, warning_threshold: int = 70, -) -> AnnData: +) -> AnnData | None: """Impute missing values in numerical data using mean/median/most frequent imputation. + If required and using mean or median strategy, the data needs to be properly encoded as this imputation requires + numerical data only. + Args: adata: The annotated data matrix to impute missing values on. var_names: A list of column names to apply imputation on (if None, impute all columns). @@ -137,13 +132,8 @@ def simple_impute( copy:Whether to return a copy of `adata` or modify it inplace. Returns: - An updated AnnData object with imputed values. - - Raises: - ValueError: - If the selected imputation strategy is not applicable to the data. - ValueError: - If an unknown imputation strategy is provided. + If copy is True, a modified copy of the original AnnData object with imputed X. + If copy is False, the original AnnData object is modified in place, and None is returned. Examples: >>> import ehrapy as ep @@ -155,43 +145,35 @@ def simple_impute( _warn_imputation_threshold(adata, var_names, threshold=warning_threshold) - with Progress( - "[progress.description]{task.description}", - SpinnerColumn(), - refresh_per_second=1500, - ) as progress: - progress.add_task(f"[blue]Running simple imputation with {strategy}", total=1) - # Imputation using median and mean strategy works with numerical data only - if strategy in {"median", "mean"}: - try: - _simple_impute(adata, var_names, strategy) - except ValueError: - raise ValueError( - f"Can only impute numerical data using {strategy} strategy. Try to restrict imputation" - "to certain columns using var_names parameter or use a different mode." - ) from None - # most_frequent imputation works with non-numerical data as well - elif strategy == "most_frequent": + if strategy in {"median", "mean"}: + try: _simple_impute(adata, var_names, strategy) - # unknown simple imputation strategy - else: - raise ValueError( # pragma: no cover - f"Unknown impute strategy {strategy} for simple Imputation. Choose any of mean, median or most_frequent." + except ValueError: + raise ValueError( + f"Can only impute numerical data using {strategy} strategy. Try to restrict imputation " + "to certain columns using var_names parameter or use a different mode." ) from None + # most_frequent imputation works with non-numerical data as well + elif strategy == "most_frequent": + _simple_impute(adata, var_names, strategy) + else: + raise ValueError( + f"Unknown impute strategy {strategy} for simple Imputation. Choose any of mean, median or most_frequent." + ) from None - if copy: - return adata + return adata if copy else None def _simple_impute(adata: AnnData, var_names: Iterable[str] | None, strategy: str) -> None: imputer = SimpleImputer(strategy=strategy) - if isinstance(var_names, Iterable): - column_indices = _get_column_indices(adata, var_names) + if isinstance(var_names, Iterable) and all(isinstance(item, str) for item in var_names): + column_indices = get_column_indices(adata, var_names) adata.X[::, column_indices] = imputer.fit_transform(adata.X[::, column_indices]) else: adata.X = imputer.fit_transform(adata.X) +@spinner("Performing KNN impute") @check_feature_types def knn_impute( adata: AnnData, @@ -206,9 +188,7 @@ def knn_impute( ) -> AnnData: """Imputes missing values in the input AnnData object using K-nearest neighbor imputation. - When using KNN Imputation with mixed data (non-numerical and numerical), encoding using ordinal encoding is required - since KNN Imputation can only work on numerical data. The encoding itself is just a utility and will be undone once - imputation ran successfully. + If required, the data needs to be properly encoded as this imputation requires numerical data only. .. warning:: Currently, both `n_neighbours` and `n_neighbors` are accepted as parameters for the number of neighbors. @@ -234,10 +214,8 @@ def knn_impute( kwargs: Gathering keyword arguments of earlier ehrapy versions for backwards compatibility. It is encouraged to use the here listed, current arguments. Returns: - An updated AnnData object with imputed values. - - Raises: - ValueError: If the input data matrix contains only categorical (non-numeric) values. + If copy is True, a modified copy of the original AnnData object with imputed X. + If copy is False, the original AnnData object is modified in place, and None is returned. Examples: >>> import ehrapy as ep @@ -274,40 +252,26 @@ def knn_impute( from sklearnex import patch_sklearn, unpatch_sklearn patch_sklearn() + try: - with Progress( - "[progress.description]{task.description}", - SpinnerColumn(), - refresh_per_second=1500, - ) as progress: - progress.add_task("[blue]Running KNN imputation", total=1) - # numerical only data needs no encoding since KNN Imputation can be applied directly - if np.issubdtype(adata.X.dtype, np.number): - _knn_impute(adata, var_names, n_neighbors, backend=backend, **backend_kwargs) - else: - # ordinal encoding is used since non-numerical data can not be imputed using KNN Imputation - enc = OrdinalEncoder() - column_indices = adata.var[FEATURE_TYPE_KEY] == CATEGORICAL_TAG - adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices]) - # impute the data using KNN imputation - _knn_impute(adata, var_names, n_neighbors, backend=backend, **backend_kwargs) - # imputing on encoded columns might result in float numbers; those can not be decoded - # cast them to int to ensure they can be decoded - adata.X[::, column_indices] = np.rint(adata.X[::, column_indices]).astype(int) - # knn imputer transforms X dtype to numerical (encoded), but object is needed for decoding - adata.X = adata.X.astype("object") - # decode ordinal encoding to obtain imputed original data - adata.X[::, column_indices] = enc.inverse_transform(adata.X[::, column_indices]) + if np.issubdtype(adata.X.dtype, np.number): + _knn_impute(adata, var_names, n_neighbors, backend=backend, **backend_kwargs) + else: + # Raise exception since non-numerical data can not be imputed using KNN Imputation + raise ValueError( + "Can only impute numerical data. Try to restrict imputation to certain columns using " + "var_names parameter or perform an encoding of your data." + ) + except ValueError as e: if "Data matrix has wrong shape" in str(e): logger.error("Check that your matrix does not contain any NaN only columns!") - raise + raise if _check_module_importable("sklearnex"): # pragma: no cover unpatch_sklearn() - if copy: - return adata + return adata if copy else None def _knn_impute( @@ -326,8 +290,8 @@ def _knn_impute( imputer = FaissImputer(n_neighbors=n_neighbors, **kwargs) - if isinstance(var_names, Iterable): - column_indices = _get_column_indices(adata, var_names) + if isinstance(var_names, Iterable) and all(isinstance(item, str) for item in var_names): + column_indices = get_column_indices(adata, var_names) adata.X[::, column_indices] = imputer.fit_transform(adata.X[::, column_indices]) # this is required since X dtype has to be numerical in order to correctly round floats adata.X = adata.X.astype("float64") @@ -335,17 +299,18 @@ def _knn_impute( adata.X = imputer.fit_transform(adata.X) +@spinner("Performing miss-forest impute") def miss_forest_impute( adata: AnnData, - var_names: dict[str, list[str]] | list[str] | None = None, + var_names: Iterable[str] | None = None, *, num_initial_strategy: Literal["mean", "median", "most_frequent", "constant"] = "mean", max_iter: int = 3, - n_estimators=100, + n_estimators: int = 100, random_state: int = 0, warning_threshold: int = 70, copy: bool = False, -) -> AnnData: +) -> AnnData | None: """Impute data using the MissForest strategy. This function uses the MissForest strategy to impute missing values in the data matrix of an AnnData object. @@ -353,12 +318,12 @@ def miss_forest_impute( and using the trained model to predict the missing values. See https://academic.oup.com/bioinformatics/article/28/1/112/219101. - This requires the computation of which columns in X contain numerical only (including NaNs) and which contain non-numerical data. + + If required, the data needs to be properly encoded as this imputation requires numerical data only. Args: adata: The AnnData object to use MissForest Imputation on. - var_names: List of columns to impute or a dict with two keys ('numerical' and 'non_numerical') indicating which var - contain mixed data and which numerical data only. + var_names: Iterable of columns to impute num_initial_strategy: The initial strategy to replace all missing numerical values with. max_iter: The maximum number of iterations if the stop criterion has not been met yet. n_estimators: The number of trees to fit for every missing variable. Has a big effect on the run time. @@ -368,21 +333,20 @@ def miss_forest_impute( copy: Whether to return a copy or act in place. Returns: - The imputed (but unencoded) AnnData object. + If copy is True, a modified copy of the original AnnData object with imputed X. + If copy is False, the original AnnData object is modified in place, and None is returned. Examples: >>> import ehrapy as ep >>> adata = ep.dt.mimic_2(encoded=True) >>> ep.pp.miss_forest_impute(adata) """ - if copy: # pragma: no cover + if copy: adata = adata.copy() if var_names is None: _warn_imputation_threshold(adata, list(adata.var_names), threshold=warning_threshold) - elif isinstance(var_names, dict): - _warn_imputation_threshold(adata, var_names.keys(), threshold=warning_threshold) # type: ignore - elif isinstance(var_names, list): + elif isinstance(var_names, Iterable) and all(isinstance(item, str) for item in var_names): _warn_imputation_threshold(adata, var_names, threshold=warning_threshold) if _check_module_importable("sklearnex"): # pragma: no cover @@ -394,74 +358,49 @@ def miss_forest_impute( from sklearn.impute import IterativeImputer try: - with Progress( - "[progress.description]{task.description}", - SpinnerColumn(), - refresh_per_second=1500, - ) as progress: - progress.add_task("[blue]Running MissForest imputation", total=1) - - if settings.n_jobs == 1: # pragma: no cover - logger.warning("The number of jobs is only 1. To decrease the runtime set ep.settings.n_jobs=-1.") - - imp_num = IterativeImputer( - estimator=ExtraTreesRegressor(n_estimators=n_estimators, n_jobs=settings.n_jobs), - initial_strategy=num_initial_strategy, - max_iter=max_iter, - random_state=random_state, - ) - # initial strategy here will not be parametrized since only most_frequent will be applied to non numerical data - imp_cat = IterativeImputer( - estimator=RandomForestClassifier(n_estimators=n_estimators, n_jobs=settings.n_jobs), - initial_strategy="most_frequent", - max_iter=max_iter, - random_state=random_state, + imp_num = IterativeImputer( + estimator=ExtraTreesRegressor(n_estimators=n_estimators, n_jobs=settings.n_jobs), + initial_strategy=num_initial_strategy, + max_iter=max_iter, + random_state=random_state, + ) + # initial strategy here will not be parametrized since only most_frequent will be applied to non numerical data + IterativeImputer( + estimator=RandomForestClassifier(n_estimators=n_estimators, n_jobs=settings.n_jobs), + initial_strategy="most_frequent", + max_iter=max_iter, + random_state=random_state, + ) + + if isinstance(var_names, Iterable) and all(isinstance(item, str) for item in var_names): # type: ignore + num_indices = get_column_indices(adata, var_names) + else: + num_indices = get_column_indices(adata, adata.var_names) + + if set(num_indices).issubset(_get_non_numerical_column_indices(adata.X)): + raise ValueError( + "Can only impute numerical data. Try to restrict imputation to certain columns using " + "var_names parameter." ) - if isinstance(var_names, list): - var_indices = _get_column_indices(adata, var_names) # type: ignore - adata.X[::, var_indices] = imp_num.fit_transform(adata.X[::, var_indices]) - elif isinstance(var_names, dict) or var_names is None: - if var_names: - try: - non_num_vars = var_names["non_numerical"] - num_vars = var_names["numerical"] - except KeyError: # pragma: no cover - raise ValueError( - "One or both of your keys provided for var_names are unknown. Only " - "numerical and non_numerical are available!" - ) from None - non_num_indices = _get_column_indices(adata, non_num_vars) - num_indices = _get_column_indices(adata, num_vars) - - # infer non numerical and numerical indices automatically - else: - non_num_indices_set = _get_non_numerical_column_indices(adata.X) - num_indices = [idx for idx in range(adata.X.shape[1]) if idx not in non_num_indices_set] - non_num_indices = list(non_num_indices_set) - - # encode all non numerical columns - if non_num_indices: - enc = OrdinalEncoder() - adata.X[::, non_num_indices] = enc.fit_transform(adata.X[::, non_num_indices]) - # this step is the most expensive one and might extremely slow down the impute process - if num_indices: - adata.X[::, num_indices] = imp_num.fit_transform(adata.X[::, num_indices]) - if non_num_indices: - adata.X[::, non_num_indices] = imp_cat.fit_transform(adata.X[::, non_num_indices]) - adata.X[::, non_num_indices] = enc.inverse_transform(adata.X[::, non_num_indices]) + # this step is the most expensive one and might extremely slow down the impute process + if num_indices: + adata.X[::, num_indices] = imp_num.fit_transform(adata.X[::, num_indices]) + else: + raise ValueError("Cannot find any feature to perform imputation") + except ValueError as e: if "Data matrix has wrong shape" in str(e): logger.error("Check that your matrix does not contain any NaN only columns!") - raise + raise if _check_module_importable("sklearnex"): # pragma: no cover unpatch_sklearn() - if copy: - return adata + return adata if copy else None +@spinner("Performing mice-forest impute") @check_feature_types def mice_forest_impute( adata: AnnData, @@ -475,12 +414,14 @@ def mice_forest_impute( variable_parameters: dict | None = None, verbose: bool = False, copy: bool = False, -) -> AnnData: +) -> AnnData | None: """Impute data using the miceforest. See https://github.com/AnotherSamWilson/miceforest Fast, memory efficient Multiple Imputation by Chained Equations (MICE) with lightgbm. + If required, the data needs to be properly encoded as this imputation requires numerical data only. + Args: adata: The AnnData object containing the data to impute. var_names: A list of variable names to impute. If None, impute all variables. @@ -497,7 +438,8 @@ def mice_forest_impute( copy: Whether to return a copy of the AnnData object or modify it in-place. Returns: - The imputed AnnData object. + If copy is True, a modified copy of the original AnnData object with imputed X. + If copy is False, the original AnnData object is modified in place, and None is returned. Examples: >>> import ehrapy as ep @@ -509,49 +451,31 @@ def mice_forest_impute( adata = adata.copy() _warn_imputation_threshold(adata, var_names, threshold=warning_threshold) + try: - with Progress( - "[progress.description]{task.description}", - SpinnerColumn(), - refresh_per_second=1500, - ) as progress: - progress.add_task("[blue]Running miceforest", total=1) - if np.issubdtype(adata.X.dtype, np.number): - _miceforest_impute( - adata, - var_names, - save_all_iterations_data, - random_state, - inplace, - iterations, - variable_parameters, - verbose, - ) - else: - # ordinal encoding is used since non-numerical data can not be imputed using miceforest - enc = OrdinalEncoder() - column_indices = adata.var[FEATURE_TYPE_KEY] == CATEGORICAL_TAG - adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices]) - # impute the data using miceforest - _miceforest_impute( - adata, - var_names, - save_all_iterations_data, - random_state, - inplace, - iterations, - variable_parameters, - verbose, - ) - adata.X = adata.X.astype("object") - # decode ordinal encoding to obtain imputed original data - adata.X[::, column_indices] = enc.inverse_transform(adata.X[::, column_indices]) + if np.issubdtype(adata.X.dtype, np.number): + _miceforest_impute( + adata, + var_names, + save_all_iterations_data, + random_state, + inplace, + iterations, + variable_parameters, + verbose, + ) + else: + raise ValueError( + "Can only impute numerical data. Try to restrict imputation to certain columns using " + "var_names parameter." + ) + except ValueError as e: if "Data matrix has wrong shape" in str(e): logger.warning("Check that your matrix does not contain any NaN only columns!") - raise + raise - return adata + return adata if copy else None def _miceforest_impute( @@ -562,8 +486,8 @@ def _miceforest_impute( data_df = pd.DataFrame(adata.X, columns=adata.var_names, index=adata.obs_names) data_df = data_df.apply(pd.to_numeric, errors="coerce") - if isinstance(var_names, Iterable): - column_indices = _get_column_indices(adata, var_names) + if isinstance(var_names, Iterable) and all(isinstance(item, str) for item in var_names): + column_indices = get_column_indices(adata, var_names) selected_columns = data_df.iloc[:, column_indices] selected_columns = selected_columns.reset_index(drop=True) @@ -616,27 +540,21 @@ def _warn_imputation_threshold(adata: AnnData, var_names: Iterable[str] | None, return var_name_to_pct -def _get_non_numerical_column_indices(X: np.ndarray) -> set: +def _get_non_numerical_column_indices(arr: np.ndarray) -> set: """Return indices of columns, that contain at least one non-numerical value that is not "Nan".""" - def _is_float_or_nan(val): # pragma: no cover + def _is_float_or_nan(val) -> bool: # pragma: no cover """Check whether a given item is a float or np.nan""" try: - float(val) - except ValueError: - if val is np.nan: - return True + _ = float(val) + return not isinstance(val, bool) + except (ValueError, TypeError): return False - else: - if not isinstance(val, bool): - return True - else: - return False - is_numeric_numpy = np.vectorize(_is_float_or_nan, otypes=[bool]) - mask = np.apply_along_axis(is_numeric_numpy, 0, X) + def _is_float_or_nan_row(row) -> list[bool]: # pragma: no cover + return [_is_float_or_nan(val) for val in row] + mask = np.apply_along_axis(_is_float_or_nan_row, 0, arr) _, column_indices = np.where(~mask) - non_num_indices = set(column_indices) - return non_num_indices + return set(column_indices) diff --git a/ehrapy/preprocessing/_normalization.py b/ehrapy/preprocessing/_normalization.py index 6f3d5e1a..4541cef3 100644 --- a/ehrapy/preprocessing/_normalization.py +++ b/ehrapy/preprocessing/_normalization.py @@ -13,8 +13,8 @@ daskml_pp = None from ehrapy.anndata.anndata_ext import ( - _get_column_indices, assert_numeric_vars, + get_column_indices, get_numeric_vars, set_numeric_vars, ) @@ -48,7 +48,7 @@ def _scale_func_group( adata = _prep_adata_norm(adata, copy) - var_idx = _get_column_indices(adata, vars) + var_idx = get_column_indices(adata, vars) var_values = np.take(adata.X, var_idx, axis=1) if group_key is None: @@ -379,7 +379,7 @@ def log_norm( "or offset negative values with ep.pp.offset_negative_values()." ) - var_idx = _get_column_indices(adata, vars) + var_idx = get_column_indices(adata, vars) var_values = np.take(adata.X, var_idx, axis=1) if offset == 1: diff --git a/tests/anndata/test_anndata_ext.py b/tests/anndata/test_anndata_ext.py index ab3b971a..aca4ca77 100644 --- a/tests/anndata/test_anndata_ext.py +++ b/tests/anndata/test_anndata_ext.py @@ -11,7 +11,9 @@ from ehrapy.anndata._constants import CATEGORICAL_TAG, FEATURE_TYPE_KEY, NUMERIC_TAG from ehrapy.anndata.anndata_ext import ( NotEncodedError, + _are_ndarrays_equal, _assert_encoded, + _is_val_missing, anndata_to_df, assert_numeric_vars, delete_from_obs, @@ -500,3 +502,17 @@ def test_set_numeric_vars(adata_strings_encoded): with pytest.raises(NotEncodedError, match=r"not yet been encoded"): set_numeric_vars(adata_strings, values) + + +def test_are_ndarrays_equal(impute_num_adata): + impute_num_adata_copy = impute_num_adata.copy() + assert _are_ndarrays_equal(impute_num_adata.X, impute_num_adata_copy.X) + impute_num_adata_copy.X[0, 0] = 42.0 + assert not _are_ndarrays_equal(impute_num_adata.X, impute_num_adata_copy.X) + + +def test_is_val_missing(impute_num_adata): + assert np.array_equal( + _is_val_missing(impute_num_adata.X), + np.array([[False, False, True], [False, False, False], [True, False, False], [False, False, True]]), + ) diff --git a/tests/preprocessing/test_imputation.py b/tests/preprocessing/test_imputation.py index d35d6055..21379ef0 100644 --- a/tests/preprocessing/test_imputation.py +++ b/tests/preprocessing/test_imputation.py @@ -1,11 +1,14 @@ import os import warnings +from collections.abc import Iterable from pathlib import Path import numpy as np import pytest +from anndata import AnnData from sklearn.exceptions import ConvergenceWarning +from ehrapy.anndata.anndata_ext import _are_ndarrays_equal, _is_val_missing, _to_dense_matrix from ehrapy.preprocessing._imputation import ( _warn_imputation_threshold, explicit_impute, @@ -20,17 +23,127 @@ _TEST_PATH = f"{TEST_DATA_PATH}/imputation" +def _base_check_imputation( + adata_before_imputation: AnnData, + adata_after_imputation: AnnData, + before_imputation_layer: str | None = None, + after_imputation_layer: str | None = None, + imputed_var_names: Iterable[str] | None = None, +): + """Provides a base check for all imputations: + + - Imputation doesn't leave any NaN behind + - Imputation doesn't modify anything in non-imputated columns (if the imputation on a subset was requested) + - Imputation doesn't modify any data that wasn't NaN + + Args: + adata_before_imputation: AnnData before imputation + adata_after_imputation: AnnData after imputation + before_imputation_layer: Layer to consider in the original ``AnnData``, ``X`` if not specified + after_imputation_layer: Layer to consider in the imputated ``AnnData``, ``X`` if not specified + imputed_var_names: Names of the features that were imputated, will consider all of them if not specified + + Raises: + AssertionError: If any of the checks fail. + """ + + layer_before = _to_dense_matrix(adata_before_imputation, before_imputation_layer) + layer_after = _to_dense_matrix(adata_after_imputation, after_imputation_layer) + + if layer_before.shape != layer_after.shape: + raise AssertionError("The shapes of the two layers do not match") + + var_indices = ( + np.arange(layer_before.shape[1]) + if imputed_var_names is None + else [ + adata_before_imputation.var_names.get_loc(var_name) + for var_name in imputed_var_names + if var_name in imputed_var_names + ] + ) + + before_nan_mask = _is_val_missing(layer_before) + imputed_mask = np.zeros(layer_before.shape[1], dtype=bool) + imputed_mask[var_indices] = True + + # Ensure no NaN remains in the imputed columns of layer_after + if np.any(before_nan_mask[:, imputed_mask] & _is_val_missing(layer_after[:, imputed_mask])): + raise AssertionError("NaN found in imputed columns of layer_after.") + + # Ensure unchanged values outside imputed columns + unchanged_mask = ~imputed_mask + if not _are_ndarrays_equal(layer_before[:, unchanged_mask], layer_after[:, unchanged_mask]): + raise AssertionError("Values outside imputed columns were modified.") + + # Ensure imputation does not alter non-NaN values in the imputed columns + imputed_non_nan_mask = (~before_nan_mask) & imputed_mask + if not _are_ndarrays_equal(layer_before[imputed_non_nan_mask], layer_after[imputed_non_nan_mask]): + raise AssertionError("Non-NaN values in imputed columns were modified.") + + # If reaching here: all checks passed + return + + +def test_base_check_imputation_incompatible_shapes(impute_num_adata): + adata_imputed = knn_impute(impute_num_adata, copy=True) + with pytest.raises(AssertionError): + _base_check_imputation(impute_num_adata, adata_imputed[1:, :]) + with pytest.raises(AssertionError): + _base_check_imputation(impute_num_adata, adata_imputed[:, 1:]) + + +def test_base_check_imputation_nan_detected_after_complete_imputation(impute_num_adata): + adata_imputed = knn_impute(impute_num_adata, copy=True) + adata_imputed.X[0, 2] = np.nan + with pytest.raises(AssertionError): + _base_check_imputation(impute_num_adata, adata_imputed) + + +def test_base_check_imputation_nan_detected_after_partial_imputation(impute_num_adata): + var_names = ("col2", "col3") + adata_imputed = knn_impute(impute_num_adata, var_names=var_names, copy=True) + adata_imputed.X[0, 2] = np.nan + with pytest.raises(AssertionError): + _base_check_imputation(impute_num_adata, adata_imputed, imputed_var_names=var_names) + + +def test_base_check_imputation_nan_ignored_if_not_in_imputed_column(impute_num_adata): + var_names = ("col2", "col3") + adata_imputed = knn_impute(impute_num_adata, var_names=var_names, copy=True) + # col1 has a NaN at row 2, should get ignored + _base_check_imputation(impute_num_adata, adata_imputed, imputed_var_names=var_names) + + +def test_base_check_imputation_change_detected_in_non_imputed_column(impute_num_adata): + var_names = ("col2", "col3") + adata_imputed = knn_impute(impute_num_adata, var_names=var_names, copy=True) + # col1 has a NaN at row 2, let's simulate it has been imputed by mistake + adata_imputed.X[2, 0] = 42.0 + with pytest.raises(AssertionError): + _base_check_imputation(impute_num_adata, adata_imputed, imputed_var_names=var_names) + + +def test_base_check_imputation_change_detected_in_imputed_column(impute_num_adata): + adata_imputed = knn_impute(impute_num_adata, copy=True) + # col3 didn't have a NaN at row 1, let's simulate it has been modified by mistake + adata_imputed.X[1, 2] = 42.0 + with pytest.raises(AssertionError): + _base_check_imputation(impute_num_adata, adata_imputed) + + def test_mean_impute_no_copy(impute_num_adata): + adata_not_imputed = impute_num_adata.copy() simple_impute(impute_num_adata) - assert not np.isnan(impute_num_adata.X).any() + _base_check_imputation(adata_not_imputed, impute_num_adata) def test_mean_impute_copy(impute_num_adata): adata_imputed = simple_impute(impute_num_adata, copy=True) assert id(impute_num_adata) != id(adata_imputed) - assert not np.isnan(adata_imputed.X).any() + _base_check_imputation(impute_num_adata, adata_imputed) def test_mean_impute_throws_error_non_numerical(impute_adata): @@ -39,23 +152,25 @@ def test_mean_impute_throws_error_non_numerical(impute_adata): def test_mean_impute_subset(impute_adata): - adata_imputed = simple_impute(impute_adata, var_names=["intcol", "indexcol"], copy=True) + var_names = ("intcol", "indexcol") + adata_imputed = simple_impute(impute_adata, var_names=var_names, copy=True) - assert not np.all([item != item for item in adata_imputed.X[::, 1:2]]) + _base_check_imputation(impute_adata, adata_imputed, imputed_var_names=var_names) assert np.any([item != item for item in adata_imputed.X[::, 3:4]]) def test_median_impute_no_copy(impute_num_adata): + adata_not_imputed = impute_num_adata.copy() simple_impute(impute_num_adata, strategy="median") - assert not np.isnan(impute_num_adata.X).any() + _base_check_imputation(adata_not_imputed, impute_num_adata) -def test_median_impute_copy(impute_num_adata, impute_adata): +def test_median_impute_copy(impute_num_adata): adata_imputed = simple_impute(impute_num_adata, strategy="median", copy=True) - assert id(impute_adata) != id(adata_imputed) - assert not np.isnan(adata_imputed.X).any() + _base_check_imputation(impute_num_adata, adata_imputed) + assert id(impute_num_adata) != id(adata_imputed) def test_median_impute_throws_error_non_numerical(impute_adata): @@ -64,156 +179,137 @@ def test_median_impute_throws_error_non_numerical(impute_adata): def test_median_impute_subset(impute_adata): - adata_imputed = simple_impute(impute_adata, var_names=["intcol", "indexcol"], strategy="median", copy=True) + var_names = ("intcol", "indexcol") + adata_imputed = simple_impute(impute_adata, var_names=var_names, strategy="median", copy=True) - assert not np.all([item != item for item in adata_imputed.X[::, 1:2]]) - assert np.any([item != item for item in adata_imputed.X[::, 3:4]]) + _base_check_imputation(impute_adata, adata_imputed, imputed_var_names=var_names) def test_most_frequent_impute_no_copy(impute_adata): + adata_not_imputed = impute_adata.copy() simple_impute(impute_adata, strategy="most_frequent") - assert not (np.all([item != item for item in impute_adata.X])) + _base_check_imputation(adata_not_imputed, impute_adata) def test_most_frequent_impute_copy(impute_adata): adata_imputed = simple_impute(impute_adata, strategy="most_frequent", copy=True) + _base_check_imputation(impute_adata, adata_imputed) assert id(impute_adata) != id(adata_imputed) - assert not (np.all([item != item for item in adata_imputed.X])) + + +def test_unknown_simple_imputation_strategy(impute_adata): + with pytest.raises(ValueError): + simple_impute(impute_adata, strategy="invalid_strategy", copy=True) # type: ignore def test_most_frequent_impute_subset(impute_adata): - adata_imputed = simple_impute(impute_adata, var_names=["intcol", "strcol"], strategy="most_frequent", copy=True) + var_names = ("intcol", "strcol") + adata_imputed = simple_impute(impute_adata, var_names=var_names, strategy="most_frequent", copy=True) - assert not (np.all([item != item for item in adata_imputed.X[::, 1:3]])) + _base_check_imputation(impute_adata, adata_imputed, imputed_var_names=var_names) def test_knn_impute_check_backend(impute_num_adata): - knn_impute(impute_num_adata, backend="faiss") - knn_impute(impute_num_adata, backend="scikit-learn") + knn_impute(impute_num_adata, backend="faiss", copy=True) + knn_impute(impute_num_adata, backend="scikit-learn", copy=True) with pytest.raises( ValueError, match="Unknown backend 'invalid_backend' for KNN imputation. Choose between 'scikit-learn' and 'faiss'.", ): - knn_impute(impute_num_adata, backend="invalid_backend") + knn_impute(impute_num_adata, backend="invalid_backend") # type: ignore def test_knn_impute_no_copy(impute_num_adata): + adata_not_imputed = impute_num_adata.copy() knn_impute(impute_num_adata) - assert not (np.all([item != item for item in impute_num_adata.X])) + _base_check_imputation(adata_not_imputed, impute_num_adata) def test_knn_impute_copy(impute_num_adata): adata_imputed = knn_impute(impute_num_adata, n_neighbors=3, copy=True) + _base_check_imputation(impute_num_adata, adata_imputed) assert id(impute_num_adata) != id(adata_imputed) - assert not (np.all([item != item for item in adata_imputed.X])) def test_knn_impute_non_numerical_data(impute_adata): - adata_imputed = knn_impute(impute_adata, n_neighbors=3, copy=True) - - assert not (np.all([item != item for item in adata_imputed.X])) + with pytest.raises(ValueError): + knn_impute(impute_adata, n_neighbors=3, copy=True) def test_knn_impute_numerical_data(impute_num_adata): adata_imputed = knn_impute(impute_num_adata, copy=True) - assert not (np.all([item != item for item in adata_imputed.X])) - - -def test_knn_impute_list_str(impute_adata): - adata_imputed = knn_impute(impute_adata, var_names=["intcol", "strcol", "boolcol"], copy=True) - - assert not (np.all([item != item for item in adata_imputed.X])) + _base_check_imputation(impute_num_adata, adata_imputed) def test_missforest_impute_non_numerical_data(impute_adata): - adata_imputed = miss_forest_impute(impute_adata, copy=True) - - assert not (np.all([item != item for item in adata_imputed.X])) + with pytest.raises(ValueError): + miss_forest_impute(impute_adata, copy=True) def test_missforest_impute_numerical_data(impute_num_adata): warnings.filterwarnings("ignore", category=ConvergenceWarning) adata_imputed = miss_forest_impute(impute_num_adata, copy=True) - assert not (np.all([item != item for item in adata_imputed.X])) + _base_check_imputation(impute_num_adata, adata_imputed) def test_missforest_impute_subset(impute_num_adata): - adata_imputed = miss_forest_impute( - impute_num_adata, var_names={"non_numerical": ["intcol"], "numerical": ["strcol"]}, copy=True - ) - - assert not (np.all([item != item for item in adata_imputed.X])) - - -def test_missforest_impute_list_str(impute_num_adata): - warnings.filterwarnings("ignore", category=ConvergenceWarning) - adata_imputed = miss_forest_impute(impute_num_adata, var_names=["col1", "col2", "col3"], copy=True) - - assert not (np.all([item != item for item in adata_imputed.X])) - - -def test_missforest_impute_dict(impute_adata): warnings.filterwarnings("ignore", category=ConvergenceWarning) - adata_imputed = miss_forest_impute( - impute_adata, var_names={"numerical": ["intcol", "datetime"], "non_numerical": ["strcol", "boolcol"]}, copy=True - ) + var_names = ("col2", "col3") + adata_imputed = miss_forest_impute(impute_num_adata, var_names=var_names, copy=True) - assert not (np.all([item != item for item in adata_imputed.X])) + _base_check_imputation(impute_num_adata, adata_imputed, imputed_var_names=var_names) @pytest.mark.skipif(os.name == "Darwin", reason="miceforest Imputation not supported by MacOS.") def test_miceforest_impute_no_copy(impute_iris_adata): - adata_imputed = mice_forest_impute(impute_iris_adata) + adata_not_imputed = impute_iris_adata.copy() + mice_forest_impute(impute_iris_adata) - assert id(impute_iris_adata) == id(adata_imputed) + _base_check_imputation(adata_not_imputed, impute_iris_adata) @pytest.mark.skipif(os.name == "Darwin", reason="miceforest Imputation not supported by MacOS.") def test_miceforest_impute_copy(impute_iris_adata): adata_imputed = mice_forest_impute(impute_iris_adata, copy=True) + _base_check_imputation(impute_iris_adata, adata_imputed) assert id(impute_iris_adata) != id(adata_imputed) @pytest.mark.skipif(os.name == "Darwin", reason="miceforest Imputation not supported by MacOS.") def test_miceforest_impute_non_numerical_data(impute_titanic_adata): - adata_imputed = mice_forest_impute(impute_titanic_adata) - - assert not (np.all([item != item for item in adata_imputed.X])) + with pytest.raises(ValueError): + mice_forest_impute(impute_titanic_adata) @pytest.mark.skipif(os.name == "Darwin", reason="miceforest Imputation not supported by MacOS.") def test_miceforest_impute_numerical_data(impute_iris_adata): - adata_imputed = mice_forest_impute(impute_iris_adata) - - assert not (np.all([item != item for item in adata_imputed.X])) - - -@pytest.mark.skipif(os.name == "Darwin", reason="miceforest Imputation not supported by MacOS.") -def test_miceforest_impute_list_str(impute_titanic_adata): - adata_imputed = mice_forest_impute(impute_titanic_adata, var_names=["Cabin", "Age"]) + adata_not_imputed = impute_iris_adata.copy() + mice_forest_impute(impute_iris_adata) - assert not (np.all([item != item for item in adata_imputed.X])) + _base_check_imputation(adata_not_imputed, impute_iris_adata) def test_explicit_impute_all(impute_num_adata): warnings.filterwarnings("ignore", category=FutureWarning) adata_imputed = explicit_impute(impute_num_adata, replacement=1011, copy=True) - assert (adata_imputed.X == 1011).sum() == 3 + _base_check_imputation(impute_num_adata, adata_imputed) + assert np.sum([adata_imputed.X == 1011]) == 3 def test_explicit_impute_subset(impute_adata): adata_imputed = explicit_impute(impute_adata, replacement={"strcol": "REPLACED", "intcol": 1011}, copy=True) - assert (adata_imputed.X == 1011).sum() == 1 - assert (adata_imputed.X == "REPLACED").sum() == 1 + _base_check_imputation(impute_adata, adata_imputed, imputed_var_names=("strcol", "intcol")) + assert np.sum([adata_imputed.X == 1011]) == 1 + assert np.sum([adata_imputed.X == "REPLACED"]) == 1 def test_warning(impute_num_adata): diff --git a/tests/core/_test_tool_available.py b/tests/utils/test_utils_available.py similarity index 90% rename from tests/core/_test_tool_available.py rename to tests/utils/test_utils_available.py index ebcd0f88..7e8044a5 100644 --- a/tests/core/_test_tool_available.py +++ b/tests/utils/test_utils_available.py @@ -1,6 +1,4 @@ -import pytest - -from ehrapy.core._tool_available import _check_module_importable, _shell_command_accessible +from ehrapy._utils_available import _check_module_importable, _shell_command_accessible def test_check_module_importable_true():