Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Coxphfitter #643

Merged
merged 14 commits into from
Jan 23, 2024
14 changes: 13 additions & 1 deletion ehrapy/tools/_sa.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pandas as pd
import statsmodels.api as sm
import statsmodels.formula.api as smf
from lifelines import KaplanMeierFitter
from lifelines import CoxPHFitter, KaplanMeierFitter
from lifelines.statistics import StatisticalResult, logrank_test
from scipy import stats

Expand Down Expand Up @@ -262,3 +262,15 @@ def anova_glm(result_1: GLMResultsWrapper, result_2: GLMResultsWrapper, formula_
}
dataframe = pd.DataFrame(data=table)
return dataframe


def cph(ad: AnnData, duration_col: str, event_col: str, entry_col: str = None) -> KaplanMeierFitter:
fatisati marked this conversation as resolved.
Show resolved Hide resolved
"""Fit the Cox’s proportional hazard for the survival function.
fatisati marked this conversation as resolved.
Show resolved Hide resolved

fatisati marked this conversation as resolved.
Show resolved Hide resolved
See https://lifelines.readthedocs.io/en/latest/fitters/regression/CoxPHFitter.html
"""
df = ad.to_df()
fatisati marked this conversation as resolved.
Show resolved Hide resolved
df = df[[duration_col, event_col, entry_col]]
cph = CoxPHFitter()
cph.fit(df, duration_col, event_col, entry_col=entry_col)
return cph
11 changes: 10 additions & 1 deletion tests/tools/test_sa.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import pytest
import statsmodels
from lifelines import KaplanMeierFitter
from lifelines import CoxPHFitter, KaplanMeierFitter

import ehrapy as ep

Expand Down Expand Up @@ -75,3 +75,12 @@ def test_anova_glm(self):
assert dataframe.shape == (2, 6)
assert dataframe.iloc[1, 4] == 2
assert pytest.approx(dataframe.iloc[1, 5], 0.1) == 0.103185

def test_cph(self):
adata = ep.dt.mimic_2(encoded=False)
adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0)
cph = ep.tl.cph(adata, "mort_day_censored", "censor_flg")

assert isinstance(cph, CoxPHFitter)
assert len(cph.durations) == 1776
assert sum(cph.event_observed) == 497
Loading