Skip to content

Commit

Permalink
cox_ph add all arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
aGuyLearning committed Dec 18, 2024
1 parent 419f2b2 commit b1d36b8
Showing 1 changed file with 149 additions and 12 deletions.
161 changes: 149 additions & 12 deletions ehrapy/tools/_sa.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import warnings
from typing import TYPE_CHECKING, Literal

import numpy as np # This package is implicitly used
import pandas as pd
import statsmodels.api as sm
import statsmodels.formula.api as smf
Expand All @@ -23,6 +22,7 @@
if TYPE_CHECKING:
from collections.abc import Iterable

import numpy as np
from anndata import AnnData
from statsmodels.genmod.generalized_linear_model import GLMResultsWrapper

Expand Down Expand Up @@ -347,23 +347,43 @@ def anova_glm(result_1: GLMResultsWrapper, result_2: GLMResultsWrapper, formula_
return dataframe


def _regression_model(
model_class, adata: AnnData, duration_col: str, event_col: str, entry_col: str = None, accept_zero_duration=True
):
def _regression_model_data_frame_preparation(adata: AnnData, duration_col: str, accept_zero_duration=True):
"""Convenience function for regression models."""
df = anndata_to_df(adata)
df = df.dropna()

if not accept_zero_duration:
df.loc[df[duration_col] == 0, duration_col] += 1e-5

model = model_class()
model.fit(df, duration_col, event_col, entry_col=entry_col)

return model
return df


def cox_ph(adata: AnnData, duration_col: str, event_col: str, entry_col: str = None) -> CoxPHFitter:
def cox_ph(
adata: AnnData,
duration_col: str,
*,
inplace: bool = True,
key_added_prefix: str | None = None,
alpha: float = 0.05,
label: str | None = None,
baseline_estimation_method: Literal["breslow", "spline", "piecewise"] = "breslow",
penalizer: float | np.ndarray = 0.0,
l1_ratio: float = 0.0,
strata: list[str] | str | None = None,
n_baseline_knots: int = 4,
knots: list[float] | None = None,
breakpoints: list[float] | None = None,
event_col: str = None,
weights_col: str | None = None,
cluster_col: str | None = None,
entry_col: str = None,
robust: bool = False,
formula: str = None,
batch_mode: bool = None,
show_progress: bool = False,
initial_point: np.ndarray | None = None,
fit_options: dict | None = None,
) -> CoxPHFitter:
"""Fit the Cox’s proportional hazard for the survival function.
The Cox proportional hazards model (CoxPH) examines the relationship between the survival time of subjects and one or more predictor variables.
Expand All @@ -376,7 +396,26 @@ def cox_ph(adata: AnnData, duration_col: str, event_col: str, entry_col: str = N
duration_col: The name of the column in the AnnData objects that contains the subjects’ lifetimes.
event_col: The name of the column in anndata that contains the subjects’ death observation.
If left as None, assume all individuals are uncensored.
inplace: Whether to modify the AnnData object in place.
alpha: The alpha value in the confidence intervals.
label: A string to name the column of the estimate.
baseline_estimation_method: The method used to estimate the baseline hazard. Options are 'breslow', 'spline', and 'piecewise'.
penalizer: Attach a penalty to the size of the coefficients during regression. This improves stability of the estimates and controls for high correlation between covariates.
l1_ratio: Specify what ratio to assign to a L1 vs L2 penalty. Same as scikit-learn. See penalizer above.
strata: specify a list of columns to use in stratification. This is useful if a categorical covariate does not obey the proportional hazard assumption. This is used similar to the strata expression in R. See http://courses.washington.edu/b515/l17.pdf.
n_baseline_knots: Used when baseline_estimation_method="spline". Set the number of knots (interior & exterior) in the baseline hazard, which will be placed evenly along the time axis. Should be at least 2. Royston et. al, the authors of this model, suggest 4 to start, but any values between 2 and 8 are reasonable. If you need to customize the timestamps used to calculate the curve, use the knots parameter instead.
knots: When baseline_estimation_method="spline", this allows customizing the points in the time axis for the baseline hazard curve. To use evenly-spaced points in time, the n_baseline_knots parameter can be employed instead.
breakpoints: Used when baseline_estimation_method="piecewise". Set the positions of the baseline hazard breakpoints.
event_col: he name of the column in DataFrame that contains the subjects’ death observation. If left as None, assume all individuals are uncensored.
weights_col: The name of the column in DataFrame that contains the weights for each subject.
cluster_col: The name of the column in DataFrame that contains the cluster variable. Using this forces the sandwich estimator (robust variance estimator) to be used.
entry_col: Column denoting when a subject entered the study, i.e. left-truncation.
robust: Compute the robust errors using the Huber sandwich estimator, aka Wei-Lin estimate. This does not handle ties, so if there are high number of ties, results may significantly differ.
formula: an Wilkinson formula, like in R and statsmodels, for the right-hand-side. If left as None, all columns not assigned as durations, weights, etc. are used. Uses the library Formulaic for parsing.
batch_mode: enabling batch_mode can be faster for datasets with a large number of ties. If left as None, lifelines will choose the best option.
show_progress: since the fitter is iterative, show convergence diagnostics. Useful if convergence is failing.
initial_point: set the starting point for the iterative solver.
fit_options: Additional keyword arguments to pass into the estimator.
Returns:
Fitted CoxPHFitter.
Expand All @@ -388,10 +427,80 @@ def cox_ph(adata: AnnData, duration_col: str, event_col: str, entry_col: str = N
>>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0)
>>> cph = ep.tl.cox_ph(adata, "mort_day_censored", "censor_flg")
"""
return _regression_model(CoxPHFitter, adata, duration_col, event_col, entry_col)
df = _regression_model_data_frame_preparation(adata, duration_col)
cox_ph = CoxPHFitter(
alpha=alpha,
label=label,
strata=strata,
baseline_estimation_method=baseline_estimation_method,
penalizer=penalizer,
l1_ratio=l1_ratio,
n_baseline_knots=n_baseline_knots,
knots=knots,
breakpoints=breakpoints,
)
cox_ph.fit(
df,
duration_col=duration_col,
event_col=event_col,
entry_col=entry_col,
robust=robust,
initial_point=initial_point,
weights_col=weights_col,
cluster_col=cluster_col,
batch_mode=batch_mode,
formula=formula,
fit_options=fit_options,
show_progress=show_progress,
)

# Add the results to the AnnData object
if inplace:
if key_added_prefix is None:
key_added_prefix = ""
else:
key_added_prefix = key_added_prefix + "_"

cox_ph_summary = cox_ph.summary
print(cox_ph_summary)

full_results = pd.DataFrame(index=adata.var.index)

# Populate with CoxPH summary data
for key in cox_ph_summary.columns:
full_results[key_added_prefix + key] = cox_ph_summary[key]

# Add a boolean column indicating rows populated by this function
full_results[key_added_prefix + "cox_ph_populated"] = full_results.notna().any(axis=1)

# Assign results back to adata.var
for col in full_results.columns:
adata.var[col] = full_results[col]

return cox_ph


def weibull_aft(adata: AnnData, duration_col: str, event_col: str, entry_col: str = None) -> WeibullAFTFitter:
def weibull_aft(
adata: AnnData,
duration_col: str,
*,
inplace: bool = True,
key_added_prefix: str | None = None,
alpha: float = 0.05,
fit_intercept: bool = True,
penalizer: float | np.ndarray = 0.0,
l1_ratio: float = 0.0,
model_ancillary: bool = True,
event_col: str | None = None,
ancillary: bool | pd.DataFrame | None = None,
show_progress: bool = False,
weights_col: str | None = None,
robust: bool = False,
initial_point=None,
entry_col: str | None = None,
formula: str | None = None,
fit_options: dict | None = None,
) -> WeibullAFTFitter:
"""Fit the Weibull accelerated failure time regression for the survival function.
The Weibull Accelerated Failure Time (AFT) survival regression model is a statistical method used to analyze time-to-event data,
Expand All @@ -417,10 +526,22 @@ def weibull_aft(adata: AnnData, duration_col: str, event_col: str, entry_col: st
>>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0)
>>> aft = ep.tl.weibull_aft(adata, "mort_day_censored", "censor_flg")
"""

return _regression_model(WeibullAFTFitter, adata, duration_col, event_col, entry_col, accept_zero_duration=False)


def log_logistic_aft(adata: AnnData, duration_col: str, event_col: str, entry_col: str = None) -> LogLogisticAFTFitter:
def log_logistic_aft(
adata: AnnData,
duration_col: str,
*,
alpha: float = 0.05,
fit_intercept: bool = True,
penalizer: float | np.ndarray = 0.0,
l1_ratio: float = 0.0,
model_ancillary: bool = False,
event_col: str = None,
entry_col: str = None,
) -> LogLogisticAFTFitter:
"""Fit the log logistic accelerated failure time regression for the survival function.
The Log-Logistic Accelerated Failure Time (AFT) survival regression model is a powerful statistical tool employed in the analysis of time-to-event data.
This model operates under the assumption that the logarithm of survival time adheres to a log-logistic distribution, offering a flexible framework for understanding the impact of covariates on survival times.
Expand Down Expand Up @@ -450,6 +571,22 @@ def log_logistic_aft(adata: AnnData, duration_col: str, event_col: str, entry_co
)


def _regression_model(
model_class, adata: AnnData, duration_col: str, event_col: str, entry_col: str = None, accept_zero_duration=True
):
"""Convenience function for regression models."""
df = anndata_to_df(adata)
df = df.dropna()

if not accept_zero_duration:
df.loc[df[duration_col] == 0, duration_col] += 1e-5

model = model_class()
model.fit(df, duration_col, event_col, entry_col=entry_col)

return model


def _univariate_model(
adata: AnnData,
duration_col: str,
Expand Down

0 comments on commit b1d36b8

Please sign in to comment.