Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Coxphfitter #643

Merged
merged 14 commits into from
Jan 23, 2024
Merged
1 change: 1 addition & 0 deletions docs/usage/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ In contrast to a preprocessing function, a tool usually adds an easily interpret
tools.kmf
tools.test_kmf_logrank
tools.test_nested_f_statistic
tools.cox_ph
```

### Causal Inference
Expand Down
2 changes: 1 addition & 1 deletion ehrapy/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ehrapy.tools._sa import anova_glm, glm, kmf, ols, test_kmf_logrank, test_nested_f_statistic
from ehrapy.tools._sa import anova_glm, cox_ph, glm, kmf, ols, test_kmf_logrank, test_nested_f_statistic
from ehrapy.tools._scanpy_tl_api import * # noqa: F403
from ehrapy.tools.causal._dowhy import causal_inference
from ehrapy.tools.feature_ranking._rank_features_groups import filter_rank_features_groups, rank_features_groups
Expand Down
68 changes: 53 additions & 15 deletions ehrapy/tools/_sa.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
import pandas as pd
import statsmodels.api as sm
import statsmodels.formula.api as smf
from lifelines import KaplanMeierFitter
from lifelines import CoxPHFitter, KaplanMeierFitter
from lifelines.statistics import StatisticalResult, logrank_test
from scipy import stats

from ehrapy.anndata import anndata_to_df

if TYPE_CHECKING:
from collections.abc import Iterable

Expand All @@ -26,13 +28,14 @@ def ols(
"""Create a Ordinary Least Squares (OLS) Model from a formula and AnnData.

See https://www.statsmodels.org/stable/generated/statsmodels.formula.api.ols.html#statsmodels.formula.api.ols
Internally use the statsmodel to create a OLS Model from a formula and dataframe.

Args:
adata: The AnnData object for the OLS model.
var_names: A list of var names indicating which columns are for the OLS model.
formula: The formula specifying the model.
missing: Available options are 'none', 'drop', and 'raise'. If 'none', no nan checking is done. If 'drop', any observations with nans are dropped. If 'raise', an error is raised. Default is 'none'.
missing: Available options are 'none', 'drop', and 'raise'.
If 'none', no nan checking is done. If 'drop', any observations with nans are dropped.
If 'raise', an error is raised. Defaults to 'none'.

Returns:
The OLS model instance.
Expand Down Expand Up @@ -64,7 +67,6 @@ def glm(
"""Create a Generalized Linear Model (GLM) from a formula, a distribution, and AnnData.

See https://www.statsmodels.org/stable/generated/statsmodels.formula.api.glm.html#statsmodels.formula.api.glm
Internally use the statsmodel to create a GLM Model from a formula, a distribution, and dataframe.

Args:
adata: The AnnData object for the GLM model.
Expand All @@ -74,7 +76,7 @@ def glm(
Defaults to 'Gaussian'.
missing: Available options are 'none', 'drop', and 'raise'. If 'none', no nan checking is done.
If 'drop', any observations with nans are dropped. If 'raise', an error is raised (default: 'none').
ascontinus: A list of var names indicating which columns are continuous rather than categorical.
as_continuous: A list of var names indicating which columns are continuous rather than categorical.
The corresponding columns will be set as type float.

Returns:
Expand All @@ -86,7 +88,7 @@ def glm(
>>> formula = 'day_28_flg ~ age'
>>> var_names = ['day_28_flg', 'age']
>>> family = 'Binomial'
>>> glm = ep.tl.glmglm(adata, var_names, formula, family, missing = 'drop', ascontinus = ['age'])
>>> glm = ep.tl.glm(adata, var_names, formula, family, missing = 'drop', ascontinus = ['age'])
"""
family_dict = {
"Gaussian": sm.families.Gaussian(),
Expand Down Expand Up @@ -120,15 +122,18 @@ def kmf(
) -> KaplanMeierFitter:
"""Fit the Kaplan-Meier estimate for the survival function.

See https://lifelines.readthedocs.io/en/latest/fitters/univariate/KaplanMeierFitter.html#module-lifelines.fitters.kaplan_meier_fitter
Class for fitting the Kaplan-Meier estimate for the survival function.
The Kaplan–Meier estimator, also known as the product limit estimator, is a non-parametric statistic used to estimate the survival function from lifetime data.
In medical research, it is often used to measure the fraction of patients living for a certain amount of time after treatment.

See https://en.wikipedia.org/wiki/Kaplan%E2%80%93Meier_estimator
https://lifelines.readthedocs.io/en/latest/fitters/univariate/KaplanMeierFitter.html#module-lifelines.fitters.kaplan_meier_fitter

Args:
durations: length n -- duration (relative to subject's birth) the subject was alive for.
event_observed: True if the the death was observed, False if the event was lost (right-censored). Defaults to all True if event_observed==None.
event_observed: True if the death was observed, False if the event was lost (right-censored). Defaults to all True if event_observed==None.
timeline: return the best estimate at the values in timelines (positively increasing)
entry: Relative time when a subject entered the study. This is useful for left-truncated (not left-censored) observations.
If None, all members of the population entered study when they were "born".
If None, all members of the population entered study when they were "born".
label: A string to name the column of the estimate.
alpha: The alpha value in the confidence intervals. Overrides the initializing alpha for this call to fit only.
ci_labels: Add custom column names to the generated confidence intervals as a length-2 list: [<lower-bound name>, <upper-bound name>] (default: <label>_lower_<1-alpha/2>).
Expand All @@ -143,9 +148,7 @@ def kmf(
Examples:
>>> import ehrapy as ep
>>> adata = ep.dt.mimic_2(encoded=False)
>>> # Because in MIMIC-II database, `censor_fl` is censored or death (binary: 0 = death, 1 = censored).
>>> # While in KaplanMeierFitter, `event_observed` is True if the the death was observed, False if the event was lost (right-censored).
>>> # So we need to flip `censor_fl` when pass `censor_fl` to KaplanMeierFitter
>>> # Flip 'censor_fl' because 0 = death and 1 = censored
>>> adata[:, ['censor_flg']].X = np.where(adata[:, ['censor_flg']].X == 0, 1, 0)
>>> kmf = ep.tl.kmf(adata[:, ['mort_day_censored']].X, adata[:, ['censor_flg']].X)
"""
Expand Down Expand Up @@ -184,12 +187,12 @@ def test_kmf_logrank(
) -> StatisticalResult:
"""Calculates the p-value for the logrank test comparing the survival functions of two groups.

See https://lifelines.readthedocs.io/en/latest/lifelines.statistics.html

Measures and reports on whether two intensity processes are different.
That is, given two event series, determines whether the data generating processes are statistically different.
The test-statistic is chi-squared under the null hypothesis.

See https://lifelines.readthedocs.io/en/latest/lifelines.statistics.html

Args:
kmf_A: The first KaplanMeierFitter object containing the durations and events.
kmf_B: The second KaplanMeierFitter object containing the durations and events.
Expand Down Expand Up @@ -262,3 +265,38 @@ def anova_glm(result_1: GLMResultsWrapper, result_2: GLMResultsWrapper, formula_
}
dataframe = pd.DataFrame(data=table)
return dataframe


def cox_ph(adata: AnnData, duration_col: str, event_col: str, entry_col: str = None) -> CoxPHFitter:
"""Fit the Cox’s proportional hazard for the survival function.
fatisati marked this conversation as resolved.
Show resolved Hide resolved

fatisati marked this conversation as resolved.
Show resolved Hide resolved
The Cox proportional hazards model (CoxPH) examines the relationship between the survival time of subjects and one or more predictor variables.
It models the hazard rate as a product of a baseline hazard function and an exponential function of the predictors, assuming proportional hazards over time.

See https://lifelines.readthedocs.io/en/latest/fitters/regression/CoxPHFitter.html

Args:
adata: adata: AnnData object with necessary columns `duration_col` and `event_col`.
duration_col: the name of the column in the AnnData objects that contains the subjects’ lifetimes.
event_col: the name of the column in anndata that contains the subjects’ death observation. If left as None, assume all individuals are uncensored.
entry_col: a column denoting when a subject entered the study, i.e. left-truncation.

Returns:
Fitted CoxPHFitter

Examples:
>>> import ehrapy as ep
>>> adata = ep.dt.mimic_2(encoded=False)
>>> # Flip 'censor_fl' because 0 = death and 1 = censored
>>> adata[:, ['censor_flg']].X = np.where(adata[:, ['censor_flg']].X == 0, 1, 0)
>>> cph = ep.tl.cox_ph(adata, "mort_day_censored", "censor_flg")
"""
df = anndata_to_df(adata)
keys = [duration_col, event_col]
if entry_col:
keys.append(entry_col)
df = df[keys]
cph = CoxPHFitter()
cph.fit(df, duration_col, event_col, entry_col=entry_col)

return cph
11 changes: 10 additions & 1 deletion tests/tools/test_sa.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import pytest
import statsmodels
from lifelines import KaplanMeierFitter
from lifelines import CoxPHFitter, KaplanMeierFitter

import ehrapy as ep

Expand Down Expand Up @@ -75,3 +75,12 @@ def test_anova_glm(self):
assert dataframe.shape == (2, 6)
assert dataframe.iloc[1, 4] == 2
assert pytest.approx(dataframe.iloc[1, 5], 0.1) == 0.103185

def test_cox_ph(self):
adata = ep.dt.mimic_2(encoded=False)
adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0)
cph = ep.tl.cox_ph(adata, "mort_day_censored", "censor_flg")

assert isinstance(cph, CoxPHFitter)
assert len(cph.durations) == 1776
assert sum(cph.event_observed) == 497
Loading