Skip to content

Commit

Permalink
Improve survival analysis interface (#825)
Browse files Browse the repository at this point in the history
* updated kmf to match method signature

* updated notebook

* updated ehrapy tutorial commit

* updated docu for new method signature

* added outputs to survival analysis

* correctly passing on fitting options

* pull request fixes.

- removed kwargs
- updated documentation

* added legacy suport

* added kmf function legacy support in tests and added new kaplan_meier function in line with new signature

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* updated notebook

* added stacklevel to deprecation warning

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* added deprecation warning in comment

* Update ehrapy/plot/_survival_analysis.py

* Update ehrapy/plot/_survival_analysis.py

* Update ehrapy/plot/_survival_analysis.py

* Update ehrapy/plot/_survival_analysis.py

* Update tests/tools/test_sa.py

* doc adjustments

* change name of kmf plot to kaplan_meier, some adjustments

* introduce keyword only for univariate sa

* correct docstring

* update submodule

* add lifelines intersphinx mappings

* Update ehrapy/tools/_sa.py

* Update ehrapy/tools/_sa.py

* Update ehrapy/tools/_sa.py

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Lukas Heumos <lukas.heumos@posteo.net>
Co-authored-by: Eljas Roellin <65244425+eroell@users.noreply.github.com>
Co-authored-by: eroell <eljas.roellin@ikmail.com>
  • Loading branch information
5 people authored Dec 2, 2024
1 parent ee84d9e commit 861d762
Show file tree
Hide file tree
Showing 9 changed files with 256 additions and 28 deletions.
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
"flax": ("https://flax.readthedocs.io/en/latest/", None),
"jax": ("https://jax.readthedocs.io/en/latest/", None),
"lamin": ("https://lamin.ai/docs", None),
"lifelines": ("https://lifelines.readthedocs.io/en/latest/", None),
}

language = "en"
Expand Down
2 changes: 1 addition & 1 deletion docs/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ and [prettier][prettier-editors].
## Writing tests

```{note}
Remember to first install the package with `pip install -e "[dev,test,docs]"`
Remember to first install the package with `pip install -e ".[dev,test,docs]"`
```

This package uses the [pytest][] for automated testing. Please [write tests][scanpy-test-docs] for every function added
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/notebooks
4 changes: 2 additions & 2 deletions docs/usage/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ In contrast to a preprocessing function, a tool usually adds an easily interpret
tools.ols
tools.glm
tools.kmf
tools.kaplan_meier
tools.test_kmf_logrank
tools.test_nested_f_statistic
tools.cox_ph
Expand Down Expand Up @@ -368,7 +368,7 @@ Methods that extract and visualize tool-specific annotation in an AnnData object
:nosignatures:
plot.ols
plot.kmf
plot.kaplan_meier
```

### Causal Inference
Expand Down
2 changes: 1 addition & 1 deletion ehrapy/plot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
from ehrapy.plot._colormaps import * # noqa: F403
from ehrapy.plot._missingno_pl_api import * # noqa: F403
from ehrapy.plot._scanpy_pl_api import * # noqa: F403
from ehrapy.plot._survival_analysis import kmf, ols
from ehrapy.plot._survival_analysis import kaplan_meier, kmf, ols
from ehrapy.plot.causal_inference._dowhy import causal_effect
from ehrapy.plot.feature_ranking._feature_importances import rank_features_supervised
70 changes: 57 additions & 13 deletions ehrapy/plot/_survival_analysis.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
Expand Down Expand Up @@ -38,7 +39,7 @@ def ols(
ax: Axes | None = None,
title: str | None = None,
**kwds,
):
) -> Axes | None:
"""Plots an Ordinary Least Squares (OLS) Model result, scatter plot, and line plot.
Args:
Expand Down Expand Up @@ -134,6 +135,8 @@ def ols(

if not show:
return ax
else:
return None


def kmf(
Expand All @@ -152,7 +155,48 @@ def kmf(
figsize: tuple[float, float] | None = None,
show: bool | None = None,
title: str | None = None,
):
) -> Axes | None:
warnings.warn(
"This function is deprecated and will be removed in the next release. Use `ep.pl.kaplan_meier` instead.",
DeprecationWarning,
stacklevel=2,
)
return kaplan_meier(
kmfs=kmfs,
ci_alpha=ci_alpha,
ci_force_lines=ci_force_lines,
ci_show=ci_show,
ci_legend=ci_legend,
at_risk_counts=at_risk_counts,
color=color,
grid=grid,
xlim=xlim,
ylim=ylim,
xlabel=xlabel,
ylabel=ylabel,
figsize=figsize,
show=show,
title=title,
)


def kaplan_meier(
kmfs: Sequence[KaplanMeierFitter],
ci_alpha: list[float] | None = None,
ci_force_lines: list[Boolean] | None = None,
ci_show: list[Boolean] | None = None,
ci_legend: list[Boolean] | None = None,
at_risk_counts: list[Boolean] | None = None,
color: list[str] | None | None = None,
grid: Boolean | None = False,
xlim: tuple[float, float] | None = None,
ylim: tuple[float, float] | None = None,
xlabel: str | None = None,
ylabel: str | None = None,
figsize: tuple[float, float] | None = None,
show: bool | None = None,
title: str | None = None,
) -> Axes | None:
"""Plots a pretty figure of the Fitted KaplanMeierFitter model
See https://lifelines.readthedocs.io/en/latest/fitters/univariate/KaplanMeierFitter.html
Expand Down Expand Up @@ -186,23 +230,21 @@ def kmf(
# So we need to flip `censor_fl` when pass `censor_fl` to KaplanMeierFitter
>>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0)
>>> kmf = ep.tl.kmf(adata[:, ["mort_day_censored"]].X, adata[:, ["censor_flg"]].X)
>>> ep.pl.kmf(
>>> kmf = ep.tl.kaplan_meier(adata, "mort_day_censored", "censor_flg")
>>> ep.pl.kaplan_meier(
... [kmf], color=["r"], xlim=[0, 700], ylim=[0, 1], xlabel="Days", ylabel="Proportion Survived", show=True
... )
.. image:: /_static/docstring_previews/kmf_plot_1.png
>>> T = adata[:, ["mort_day_censored"]].X
>>> E = adata[:, ["censor_flg"]].X
>>> groups = adata[:, ["service_unit"]].X
>>> ix1 = groups == "FICU"
>>> ix2 = groups == "MICU"
>>> ix3 = groups == "SICU"
>>> kmf_1 = ep.tl.kmf(T[ix1], E[ix1], label="FICU")
>>> kmf_2 = ep.tl.kmf(T[ix2], E[ix2], label="MICU")
>>> kmf_3 = ep.tl.kmf(T[ix3], E[ix3], label="SICU")
>>> ep.pl.kmf([kmf_1, kmf_2, kmf_3], ci_show=[False,False,False], color=['k','r', 'g'],
>>> adata_ficu = adata[groups == "FICU"]
>>> adata_micu = adata[groups == "MICU"]
>>> adata_sicu = adata[groups == "SICU"]
>>> kmf_1 = ep.tl.kaplan_meier(adata_ficu, "mort_day_censored", "censor_flg", label="FICU")
>>> kmf_2 = ep.tl.kaplan_meier(adata_micu, "mort_day_censored", "censor_flg", label="MICU")
>>> kmf_3 = ep.tl.kaplan_meier(adata_sicu, "mort_day_censored", "censor_flg", label="SICU")
>>> ep.pl.kaplan_meier([kmf_1, kmf_2, kmf_3], ci_show=[False,False,False], color=['k','r', 'g'],
>>> xlim=[0, 750], ylim=[0, 1], xlabel="Days", ylabel="Proportion Survived")
.. image:: /_static/docstring_previews/kmf_plot_2.png
Expand Down Expand Up @@ -251,3 +293,5 @@ def kmf(

if not show:
return ax
else:
return None
2 changes: 2 additions & 0 deletions ehrapy/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
anova_glm,
cox_ph,
glm,
kaplan_meier,
kmf,
log_logistic_aft,
nelson_aalen,
Expand Down Expand Up @@ -31,6 +32,7 @@
"cox_ph",
"glm",
"kmf",
"kaplan_meier",
"log_logistic_aft",
"nelson_aalen",
"ols",
Expand Down
Loading

0 comments on commit 861d762

Please sign in to comment.