diff --git a/ehrapy/tools/_sa.py b/ehrapy/tools/_sa.py index fed63b9e..241e5dee 100644 --- a/ehrapy/tools/_sa.py +++ b/ehrapy/tools/_sa.py @@ -3,7 +3,6 @@ import warnings from typing import TYPE_CHECKING, Literal -import numpy as np # This package is implicitly used import pandas as pd import statsmodels.api as sm import statsmodels.formula.api as smf @@ -23,6 +22,7 @@ if TYPE_CHECKING: from collections.abc import Iterable + import numpy as np from anndata import AnnData from statsmodels.genmod.generalized_linear_model import GLMResultsWrapper @@ -347,9 +347,7 @@ def anova_glm(result_1: GLMResultsWrapper, result_2: GLMResultsWrapper, formula_ return dataframe -def _regression_model( - model_class, adata: AnnData, duration_col: str, event_col: str, entry_col: str = None, accept_zero_duration=True -): +def _regression_model_data_frame_preparation(adata: AnnData, duration_col: str, accept_zero_duration=True): """Convenience function for regression models.""" df = anndata_to_df(adata) df = df.dropna() @@ -357,13 +355,35 @@ def _regression_model( if not accept_zero_duration: df.loc[df[duration_col] == 0, duration_col] += 1e-5 - model = model_class() - model.fit(df, duration_col, event_col, entry_col=entry_col) - - return model + return df -def cox_ph(adata: AnnData, duration_col: str, event_col: str, entry_col: str = None) -> CoxPHFitter: +def cox_ph( + adata: AnnData, + duration_col: str, + *, + inplace: bool = True, + key_added_prefix: str | None = None, + alpha: float = 0.05, + label: str | None = None, + baseline_estimation_method: Literal["breslow", "spline", "piecewise"] = "breslow", + penalizer: float | np.ndarray = 0.0, + l1_ratio: float = 0.0, + strata: list[str] | str | None = None, + n_baseline_knots: int = 4, + knots: list[float] | None = None, + breakpoints: list[float] | None = None, + event_col: str = None, + weights_col: str | None = None, + cluster_col: str | None = None, + entry_col: str = None, + robust: bool = False, + formula: str = None, + batch_mode: bool = None, + show_progress: bool = False, + initial_point: np.ndarray | None = None, + fit_options: dict | None = None, +) -> CoxPHFitter: """Fit the Cox’s proportional hazard for the survival function. The Cox proportional hazards model (CoxPH) examines the relationship between the survival time of subjects and one or more predictor variables. @@ -376,7 +396,26 @@ def cox_ph(adata: AnnData, duration_col: str, event_col: str, entry_col: str = N duration_col: The name of the column in the AnnData objects that contains the subjects’ lifetimes. event_col: The name of the column in anndata that contains the subjects’ death observation. If left as None, assume all individuals are uncensored. + inplace: Whether to modify the AnnData object in place. + alpha: The alpha value in the confidence intervals. + label: A string to name the column of the estimate. + baseline_estimation_method: The method used to estimate the baseline hazard. Options are 'breslow', 'spline', and 'piecewise'. + penalizer: Attach a penalty to the size of the coefficients during regression. This improves stability of the estimates and controls for high correlation between covariates. + l1_ratio: Specify what ratio to assign to a L1 vs L2 penalty. Same as scikit-learn. See penalizer above. + strata: specify a list of columns to use in stratification. This is useful if a categorical covariate does not obey the proportional hazard assumption. This is used similar to the strata expression in R. See http://courses.washington.edu/b515/l17.pdf. + n_baseline_knots: Used when baseline_estimation_method="spline". Set the number of knots (interior & exterior) in the baseline hazard, which will be placed evenly along the time axis. Should be at least 2. Royston et. al, the authors of this model, suggest 4 to start, but any values between 2 and 8 are reasonable. If you need to customize the timestamps used to calculate the curve, use the knots parameter instead. + knots: When baseline_estimation_method="spline", this allows customizing the points in the time axis for the baseline hazard curve. To use evenly-spaced points in time, the n_baseline_knots parameter can be employed instead. + breakpoints: Used when baseline_estimation_method="piecewise". Set the positions of the baseline hazard breakpoints. + event_col: he name of the column in DataFrame that contains the subjects’ death observation. If left as None, assume all individuals are uncensored. + weights_col: The name of the column in DataFrame that contains the weights for each subject. + cluster_col: The name of the column in DataFrame that contains the cluster variable. Using this forces the sandwich estimator (robust variance estimator) to be used. entry_col: Column denoting when a subject entered the study, i.e. left-truncation. + robust: Compute the robust errors using the Huber sandwich estimator, aka Wei-Lin estimate. This does not handle ties, so if there are high number of ties, results may significantly differ. + formula: an Wilkinson formula, like in R and statsmodels, for the right-hand-side. If left as None, all columns not assigned as durations, weights, etc. are used. Uses the library Formulaic for parsing. + batch_mode: enabling batch_mode can be faster for datasets with a large number of ties. If left as None, lifelines will choose the best option. + show_progress: since the fitter is iterative, show convergence diagnostics. Useful if convergence is failing. + initial_point: set the starting point for the iterative solver. + fit_options: Additional keyword arguments to pass into the estimator. Returns: Fitted CoxPHFitter. @@ -388,10 +427,80 @@ def cox_ph(adata: AnnData, duration_col: str, event_col: str, entry_col: str = N >>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0) >>> cph = ep.tl.cox_ph(adata, "mort_day_censored", "censor_flg") """ - return _regression_model(CoxPHFitter, adata, duration_col, event_col, entry_col) + df = _regression_model_data_frame_preparation(adata, duration_col) + cox_ph = CoxPHFitter( + alpha=alpha, + label=label, + strata=strata, + baseline_estimation_method=baseline_estimation_method, + penalizer=penalizer, + l1_ratio=l1_ratio, + n_baseline_knots=n_baseline_knots, + knots=knots, + breakpoints=breakpoints, + ) + cox_ph.fit( + df, + duration_col=duration_col, + event_col=event_col, + entry_col=entry_col, + robust=robust, + initial_point=initial_point, + weights_col=weights_col, + cluster_col=cluster_col, + batch_mode=batch_mode, + formula=formula, + fit_options=fit_options, + show_progress=show_progress, + ) + + # Add the results to the AnnData object + if inplace: + if key_added_prefix is None: + key_added_prefix = "" + else: + key_added_prefix = key_added_prefix + "_" + + cox_ph_summary = cox_ph.summary + print(cox_ph_summary) + + full_results = pd.DataFrame(index=adata.var.index) + + # Populate with CoxPH summary data + for key in cox_ph_summary.columns: + full_results[key_added_prefix + key] = cox_ph_summary[key] + + # Add a boolean column indicating rows populated by this function + full_results[key_added_prefix + "cox_ph_populated"] = full_results.notna().any(axis=1) + + # Assign results back to adata.var + for col in full_results.columns: + adata.var[col] = full_results[col] + + return cox_ph -def weibull_aft(adata: AnnData, duration_col: str, event_col: str, entry_col: str = None) -> WeibullAFTFitter: +def weibull_aft( + adata: AnnData, + duration_col: str, + *, + inplace: bool = True, + key_added_prefix: str | None = None, + alpha: float = 0.05, + fit_intercept: bool = True, + penalizer: float | np.ndarray = 0.0, + l1_ratio: float = 0.0, + model_ancillary: bool = True, + event_col: str | None = None, + ancillary: bool | pd.DataFrame | None = None, + show_progress: bool = False, + weights_col: str | None = None, + robust: bool = False, + initial_point=None, + entry_col: str | None = None, + formula: str | None = None, + fit_options: dict | None = None, +) -> WeibullAFTFitter: """Fit the Weibull accelerated failure time regression for the survival function. The Weibull Accelerated Failure Time (AFT) survival regression model is a statistical method used to analyze time-to-event data, @@ -417,10 +526,22 @@ def weibull_aft(adata: AnnData, duration_col: str, event_col: str, entry_col: st >>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0) >>> aft = ep.tl.weibull_aft(adata, "mort_day_censored", "censor_flg") """ + return _regression_model(WeibullAFTFitter, adata, duration_col, event_col, entry_col, accept_zero_duration=False) -def log_logistic_aft(adata: AnnData, duration_col: str, event_col: str, entry_col: str = None) -> LogLogisticAFTFitter: +def log_logistic_aft( + adata: AnnData, + duration_col: str, + *, + alpha: float = 0.05, + fit_intercept: bool = True, + penalizer: float | np.ndarray = 0.0, + l1_ratio: float = 0.0, + model_ancillary: bool = False, + event_col: str = None, + entry_col: str = None, +) -> LogLogisticAFTFitter: """Fit the log logistic accelerated failure time regression for the survival function. The Log-Logistic Accelerated Failure Time (AFT) survival regression model is a powerful statistical tool employed in the analysis of time-to-event data. This model operates under the assumption that the logarithm of survival time adheres to a log-logistic distribution, offering a flexible framework for understanding the impact of covariates on survival times. @@ -450,6 +571,22 @@ def log_logistic_aft(adata: AnnData, duration_col: str, event_col: str, entry_co ) +def _regression_model( + model_class, adata: AnnData, duration_col: str, event_col: str, entry_col: str = None, accept_zero_duration=True +): + """Convenience function for regression models.""" + df = anndata_to_df(adata) + df = df.dropna() + + if not accept_zero_duration: + df.loc[df[duration_col] == 0, duration_col] += 1e-5 + + model = model_class() + model.fit(df, duration_col, event_col, entry_col=entry_col) + + return model + + def _univariate_model( adata: AnnData, duration_col: str,