From 742d38ceb16519bbf41b8bbe15c1c46bf7453ebb Mon Sep 17 00:00:00 2001 From: Carl Buchholz <32228189+aGuyLearning@users.noreply.github.com> Date: Wed, 18 Dec 2024 15:14:51 +0100 Subject: [PATCH] log_logistic update --- ehrapy/tools/_sa.py | 74 ++++++++++++++++++++++++++++++++++----------- 1 file changed, 56 insertions(+), 18 deletions(-) diff --git a/ehrapy/tools/_sa.py b/ehrapy/tools/_sa.py index 13b74a8f..395a54bf 100644 --- a/ehrapy/tools/_sa.py +++ b/ehrapy/tools/_sa.py @@ -585,13 +585,22 @@ def log_logistic_aft( adata: AnnData, duration_col: str, *, + inplace: bool = True, + key_added_prefix: str | None = None, alpha: float = 0.05, fit_intercept: bool = True, penalizer: float | np.ndarray = 0.0, l1_ratio: float = 0.0, model_ancillary: bool = False, - event_col: str = None, - entry_col: str = None, + event_col: str | None = None, + ancillary: bool | pd.DataFrame | str | None = None, + show_progress: bool = False, + weights_col: str | None = None, + robust: bool = False, + initial_point=None, + entry_col: str | None = None, + formula: str | None = None, + fit_options: dict | None = None, ) -> LogLogisticAFTFitter: """Fit the log logistic accelerated failure time regression for the survival function. The Log-Logistic Accelerated Failure Time (AFT) survival regression model is a powerful statistical tool employed in the analysis of time-to-event data. @@ -603,9 +612,29 @@ def log_logistic_aft( Args: adata: AnnData object with necessary columns `duration_col` and `event_col`. duration_col: Name of the column in the AnnData objects that contains the subjects’ lifetimes. - event_col: Name of the column in anndata that contains the subjects’ death observation. + inplace: Whether to modify the AnnData object in place. + key_added_prefix: Prefix to add to the column names in the AnnData object. An underscore will be added between the prefix and the column + alpha: The alpha value in the confidence intervals. + alpha: The alpha value in the confidence intervals. + fit_intercept: Whether to fit an intercept term in the model. + penalizer: Attach a penalty to the size of the coefficients during regression. This improves stability of the estimates and controls for high correlation between covariates. + l1_ratio: Specify what ratio to assign to a L1 vs L2 penalty. Same as scikit-learn. See penalizer above. + model_ancillary: set the model instance to always model the ancillary parameter with the supplied Dataframe. This is useful for grid-search optimization. + event_col: Name of the column in anndata that contains the subjects’ death observation. 1 if observed, 0 else (censored). If left as None, assume all individuals are uncensored. + ancillary: Choose to model the ancillary parameters. + If None or False, explicitly do not fit the ancillary parameters using any covariates. + If True, model the ancillary parameters with the same covariates as ``df``. + If DataFrame, provide covariates to model the ancillary parameters. Must be the same row count as ``df``. + If str, should be a formula + show_progress: since the fitter is iterative, show convergence diagnostics. Useful if convergence is failing. + weights_col: The name of the column in DataFrame that contains the weights for each subject. + robust: Compute the robust errors using the Huber sandwich estimator, aka Wei-Lin estimate. This does not handle ties, so if there are high number of ties, results may significantly differ. + initial_point: set the starting point for the iterative solver. entry_col: Column denoting when a subject entered the study, i.e. left-truncation. + formula: Use an R-style formula for modeling the dataset. See formula syntax: https://matthewwardrop.github.io/formulaic/basic/grammar/ + If a formula is not provided, all variables in the dataframe are used (minus those used for other purposes like event_col, etc.) + fit_options: Additional keyword arguments to pass into the estimator. Returns: Fitted LogLogisticAFTFitter. @@ -617,26 +646,35 @@ def log_logistic_aft( >>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0) >>> llf = ep.tl.log_logistic_aft(adata, "mort_day_censored", "censor_flg") """ + df = _regression_model_data_frame_preparation(adata, duration_col, accept_zero_duration=False) - return _regression_model( - LogLogisticAFTFitter, adata, duration_col, event_col, entry_col, accept_zero_duration=False + log_logistic_aft = LogLogisticAFTFitter( + alpha=alpha, + fit_intercept=fit_intercept, + penalizer=penalizer, + l1_ratio=l1_ratio, + model_ancillary=model_ancillary, ) + log_logistic_aft.fit( + df, + duration_col=duration_col, + event_col=event_col, + entry_col=entry_col, + ancillary=ancillary, + show_progress=show_progress, + weights_col=weights_col, + robust=robust, + initial_point=initial_point, + formula=formula, + fit_options=fit_options, + ) -def _regression_model( - model_class, adata: AnnData, duration_col: str, event_col: str, entry_col: str = None, accept_zero_duration=True -): - """Convenience function for regression models.""" - df = anndata_to_df(adata) - df = df.dropna() - - if not accept_zero_duration: - df.loc[df[duration_col] == 0, duration_col] += 1e-5 - - model = model_class() - model.fit(df, duration_col, event_col, entry_col=entry_col) + # Add the results to the AnnData object + if inplace: + _regression_model_populate_adata(adata, log_logistic_aft.summary, key_added_prefix) - return model + return log_logistic_aft def _univariate_model(