Skip to content

Commit

Permalink
Refactor (#629)
Browse files Browse the repository at this point in the history
Signed-off-by: zethson <lukas.heumos@posteo.net>
  • Loading branch information
Zethson authored Dec 18, 2023
1 parent 923e1b3 commit 204a3a7
Show file tree
Hide file tree
Showing 9 changed files with 206 additions and 315 deletions.
1 change: 0 additions & 1 deletion ehrapy/plot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from ehrapy.plot._missingno_pl_api import * # noqa: F403
from ehrapy.plot._qc import qc_metrics
from ehrapy.plot._scanpy_pl_api import * # noqa: F403
from ehrapy.plot._survival_analysis import kmf, ols
from ehrapy.plot._util import * # noqa: F403
Expand Down
58 changes: 0 additions & 58 deletions ehrapy/plot/_qc.py

This file was deleted.

5 changes: 1 addition & 4 deletions ehrapy/preprocessing/_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,6 @@ def undo_encoding(
else:
raise ValueError(f"Cannot decode object of type {type(data)}. Can only decode AnnData objects!")

return None


def _encode(
adata: AnnData,
Expand Down Expand Up @@ -787,8 +785,7 @@ def _reorder_encodings(adata: AnnData, new_encodings: dict[str, list[list[str]]
# if encoding mode is
if not encoded_categoricals_with_mode:
del adata.uns["encoding_to_var"][encode_mode]
logg.info("Re-encoded the AnnData object.")
# return updated encodings

return _update_new_encode_modes(new_encodings, adata.uns["encoding_to_var"])


Expand Down
44 changes: 13 additions & 31 deletions ehrapy/preprocessing/_quality_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from rich import print
from thefuzz import process

from ehrapy import logging as logg

if TYPE_CHECKING:
from collections.abc import Collection

Expand All @@ -36,44 +34,32 @@ def qc_metrics(
Observation level metrics include:
`missing_values_abs`
Absolute amount of missing values.
`missing_values_pct`
Relative amount of missing values in percent.
- `missing_values_abs`: Absolute amount of missing values.
- `missing_values_pct`: Relative amount of missing values in percent.
Feature level metrics include:
`missing_values_abs`
Absolute amount of missing values.
`missing_values_pct`
Relative amount of missing values in percent.
`mean`
Mean value of the features.
`median`
Median value of the features.
`std`
Standard deviation of the features.
`min`
Minimum value of the features.
`max`
Maximum value of the features.
- `missing_values_abs`: Absolute amount of missing values.
- `missing_values_pct`: Relative amount of missing values in percent.
- `mean`: Mean value of the features.
- `median`: Median value of the features.
- `std`: Standard deviation of the features.
- `min`: Minimum value of the features.
- `max`: Maximum value of the features.
Examples:
>>> import ehrapy as ep
>>> import seaborn as sns
>>> import matplotlib.pyplot as plt
>>> adata = ep.dt.mimic_2(encoded=True)
>>> ep.pp.qc_metrics(adata)
>>> sns.displot(adata.obs["missing_values_abs"])
>>> plt.show()
"""
obs_metrics = _obs_qc_metrics(adata, layer, qc_vars)
var_metrics = _var_qc_metrics(adata, layer)

if inplace:
adata.obs[obs_metrics.columns] = obs_metrics
adata.var[var_metrics.columns] = var_metrics
logg.info("Added the calculated metrics to AnnData's `obs` and `var`.")

return obs_metrics, var_metrics

Expand All @@ -91,10 +77,8 @@ def _missing_values(
Returns:
Absolute or relative amount of missing values.
"""
# Absolute number of missing values
if shape is None:
return pd.isnull(arr).sum()
# Relative number of missing values in percent
else:
n_rows, n_cols = shape
if df_type == "obs":
Expand Down Expand Up @@ -256,7 +240,7 @@ def qc_lab_measurements(
If you want to specify your own table as a Pandas DataFrame please examine the existing default table.
Ethnicity and age columns can be added.
https://github.com/theislab/ehrapy/ehrapy/preprocessing/laboratory_reference_tables/laposata.tsv
https://github.com/theislab/ehrapy/blob/main/ehrapy/preprocessing/laboratory_reference_tables/laposata.tsv
Args:
adata: Annotated data matrix.
Expand All @@ -267,13 +251,13 @@ def qc_lab_measurements(
threshold: Minimum required matching confidence score of the fuzzysearch.
0 = no matches, 100 = all must match. Defaults to 20.
age_col: Column containing age values.
age_range: The inclusive age-range to filter for. e.g. 5-99
age_range: The inclusive age-range to filter for such as 5-99.
sex_col: Column containing sex values. Column must contain 'U', 'M' or 'F'.
sex: Sex to filter the reference values for. Use U for unisex which uses male values when male and female conflict.
Defaults to 'U|M'
Defaults to 'U|M'.
ethnicity_col: Column containing ethnicity values.
ethnicity: Ethnicity to filter for.
copy: Whether to return a copy. Defaults to False .
copy: Whether to return a copy. Defaults to False.
verbose: Whether to have verbose stdout. Notifies user of matched columns and value ranges.
Returns:
Expand Down Expand Up @@ -323,7 +307,6 @@ def qc_lab_measurements(
f"ethnicity columns and their values."
)

# Fetch reference values
try:
if age_col:
min_age, max_age = age_range.split("-")
Expand All @@ -344,7 +327,6 @@ def qc_lab_measurements(
except TypeError:
print(f"[bold yellow]Unable to find specified reference values for {measurement}.")

# Check whether the measurements are inside the reference ranges
check = reference_values[reference_column].values
check_str: str = np.array2string(check)
check_str = check_str.replace("[", "").replace("]", "").replace("'", "")
Expand Down
2 changes: 1 addition & 1 deletion ehrapy/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from ehrapy.tools._sa import anova_glm, glm, kmf, ols, test_kmf_logrank, test_nested_f_statistic
from ehrapy.tools._scanpy_tl_api import * # noqa: F403
from ehrapy.tools.causal._dowhy import causal_inference
from ehrapy.tools.feature_ranking._rank_features_groups import rank_features_groups
from ehrapy.tools.feature_ranking._rank_features_groups import filter_rank_features_groups, rank_features_groups

try: # pragma: no cover
from ehrapy.tools.nlp._medcat import (
Expand Down
49 changes: 0 additions & 49 deletions ehrapy/tools/_scanpy_tl_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,52 +679,3 @@ def ingest(
inplace=inplace,
**kwargs,
)


def filter_rank_features_groups(
adata: AnnData,
key="rank_features_groups",
groupby=None,
key_added="rank_features_groups_filtered",
min_in_group_fraction=0.25,
min_fold_change=1,
max_out_group_fraction=0.5,
) -> None: # pragma: no cover
"""Filters out features based on fold change and fraction of features containing the feature within and outside the `groupby` categories.
See :func:`~ehrapy.tl.rank_features_groups`.
Results are stored in `adata.uns[key_added]`
(default: 'rank_genes_groups_filtered').
To preserve the original structure of adata.uns['rank_genes_groups'],
filtered genes are set to `NaN`.
Args:
adata: Annotated data matrix.
key: Key previously added by :func:`~ehrapy.tl.rank_features_groups`
groupby: The key of the observations grouping to consider.
key_added: The key in `adata.uns` information is saved to.
min_in_group_fraction: Minimum in group fraction (default: 0.25).
min_fold_change: Miniumum fold change (default: 1).
max_out_group_fraction: Maximum out group fraction (default: 0.5).
Returns:
Same output as :func:`ehrapy.tl.rank_features_groups` but with filtered feature names set to `nan`
Examples:
>>> import ehrapy as ep
>>> adata = ep.dt.mimic_2(encoded=True)
>>> ep.tl.rank_features_groups(adata, "service_unit")
>>> ep.pl.rank_features_groups(adata)
"""
return sc.tl.filter_rank_genes_groups(
adata=adata,
key=key,
groupby=groupby,
use_raw=False,
key_added=key_added,
min_in_group_fraction=min_in_group_fraction,
min_fold_change=min_fold_change,
max_out_group_fraction=max_out_group_fraction,
)
59 changes: 55 additions & 4 deletions ehrapy/tools/feature_ranking/_rank_features_groups.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from collections.abc import Iterable
from typing import TYPE_CHECKING, Literal, Optional, Union
from typing import TYPE_CHECKING, Literal

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -346,8 +346,11 @@ def rank_features_groups(
Used only for statistical tests (e.g. doesn't work for "logreg" `num_cols_method`)
tie_correct: Use tie correction for `'wilcoxon'` scores. Used only for `'wilcoxon'`.
layer: Key from `adata.layers` whose value will be used to perform tests on.
field_to_rank: Set to `layer` to rank variables in `adata.X` or `adata.layers[layer]` (default), `obs` to rank `adata.obs`, or `layer_and_obs` to rank both. Layer needs to be None if this is not 'layer'.
columns_to_rank: Subset of columns to rank. If 'all', all columns are used. If a dictionary, it must have keys 'var_names' and/or 'obs_names' and values must be iterables of strings. E.g. {'var_names': ['glucose'], 'obs_names': ['age', 'height']}.
field_to_rank: Set to `layer` to rank variables in `adata.X` or `adata.layers[layer]` (default), `obs` to rank `adata.obs`, or `layer_and_obs` to rank both.
Layer needs to be None if this is not 'layer'.
columns_to_rank: Subset of columns to rank. If 'all', all columns are used.
If a dictionary, it must have keys 'var_names' and/or 'obs_names' and values must be iterables of strings
such as {'var_names': ['glucose'], 'obs_names': ['age', 'height']}.
**kwds: Are passed to test methods. Currently this affects only parameters that
are passed to :class:`sklearn.linear_model.LogisticRegression`.
For instance, you can pass `penalty='l1'` to try to come up with a
Expand Down Expand Up @@ -568,7 +571,6 @@ def rank_features_groups(
adata_orig.uns[key_added] = adata.uns[key_added]
adata = adata_orig

# Adjust p values
if "pvals" in adata.uns[key_added]:
adata.uns[key_added]["pvals_adj"] = _adjust_pvalues(
adata.uns[key_added]["pvals"], corr_method=correction_method
Expand All @@ -581,3 +583,52 @@ def rank_features_groups(
_sort_features(adata, key_added)

return adata if copy else None


def filter_rank_features_groups(
adata: AnnData,
key="rank_features_groups",
groupby=None,
key_added="rank_features_groups_filtered",
min_in_group_fraction=0.25,
min_fold_change=1,
max_out_group_fraction=0.5,
) -> None: # pragma: no cover
"""Filters out features based on fold change and fraction of features containing the feature within and outside the `groupby` categories.
See :func:`~ehrapy.tl.rank_features_groups`.
Results are stored in `adata.uns[key_added]`
(default: 'rank_genes_groups_filtered').
To preserve the original structure of adata.uns['rank_genes_groups'],
filtered genes are set to `NaN`.
Args:
adata: Annotated data matrix.
key: Key previously added by :func:`~ehrapy.tl.rank_features_groups`
groupby: The key of the observations grouping to consider.
key_added: The key in `adata.uns` information is saved to.
min_in_group_fraction: Minimum in group fraction (default: 0.25).
min_fold_change: Miniumum fold change (default: 1).
max_out_group_fraction: Maximum out group fraction (default: 0.5).
Returns:
Same output as :func:`ehrapy.tl.rank_features_groups` but with filtered feature names set to `nan`
Examples:
>>> import ehrapy as ep
>>> adata = ep.dt.mimic_2(encoded=True)
>>> ep.tl.rank_features_groups(adata, "service_unit")
>>> ep.pl.rank_features_groups(adata)
"""
return sc.tl.filter_rank_genes_groups(
adata=adata,
key=key,
groupby=groupby,
use_raw=False,
key_added=key_added,
min_in_group_fraction=min_in_group_fraction,
min_fold_change=min_fold_change,
max_out_group_fraction=max_out_group_fraction,
)
Loading

0 comments on commit 204a3a7

Please sign in to comment.