Skip to content

Commit

Permalink
log_logistic update
Browse files Browse the repository at this point in the history
  • Loading branch information
aGuyLearning committed Dec 18, 2024
1 parent 22d190a commit 742d38c
Showing 1 changed file with 56 additions and 18 deletions.
74 changes: 56 additions & 18 deletions ehrapy/tools/_sa.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,13 +585,22 @@ def log_logistic_aft(
adata: AnnData,
duration_col: str,
*,
inplace: bool = True,
key_added_prefix: str | None = None,
alpha: float = 0.05,
fit_intercept: bool = True,
penalizer: float | np.ndarray = 0.0,
l1_ratio: float = 0.0,
model_ancillary: bool = False,
event_col: str = None,
entry_col: str = None,
event_col: str | None = None,
ancillary: bool | pd.DataFrame | str | None = None,
show_progress: bool = False,
weights_col: str | None = None,
robust: bool = False,
initial_point=None,
entry_col: str | None = None,
formula: str | None = None,
fit_options: dict | None = None,
) -> LogLogisticAFTFitter:
"""Fit the log logistic accelerated failure time regression for the survival function.
The Log-Logistic Accelerated Failure Time (AFT) survival regression model is a powerful statistical tool employed in the analysis of time-to-event data.
Expand All @@ -603,9 +612,29 @@ def log_logistic_aft(
Args:
adata: AnnData object with necessary columns `duration_col` and `event_col`.
duration_col: Name of the column in the AnnData objects that contains the subjects’ lifetimes.
event_col: Name of the column in anndata that contains the subjects’ death observation.
inplace: Whether to modify the AnnData object in place.
key_added_prefix: Prefix to add to the column names in the AnnData object. An underscore will be added between the prefix and the column
alpha: The alpha value in the confidence intervals.
alpha: The alpha value in the confidence intervals.
fit_intercept: Whether to fit an intercept term in the model.
penalizer: Attach a penalty to the size of the coefficients during regression. This improves stability of the estimates and controls for high correlation between covariates.
l1_ratio: Specify what ratio to assign to a L1 vs L2 penalty. Same as scikit-learn. See penalizer above.
model_ancillary: set the model instance to always model the ancillary parameter with the supplied Dataframe. This is useful for grid-search optimization.
event_col: Name of the column in anndata that contains the subjects’ death observation. 1 if observed, 0 else (censored).
If left as None, assume all individuals are uncensored.
ancillary: Choose to model the ancillary parameters.
If None or False, explicitly do not fit the ancillary parameters using any covariates.
If True, model the ancillary parameters with the same covariates as ``df``.
If DataFrame, provide covariates to model the ancillary parameters. Must be the same row count as ``df``.
If str, should be a formula
show_progress: since the fitter is iterative, show convergence diagnostics. Useful if convergence is failing.
weights_col: The name of the column in DataFrame that contains the weights for each subject.
robust: Compute the robust errors using the Huber sandwich estimator, aka Wei-Lin estimate. This does not handle ties, so if there are high number of ties, results may significantly differ.
initial_point: set the starting point for the iterative solver.
entry_col: Column denoting when a subject entered the study, i.e. left-truncation.
formula: Use an R-style formula for modeling the dataset. See formula syntax: https://matthewwardrop.github.io/formulaic/basic/grammar/
If a formula is not provided, all variables in the dataframe are used (minus those used for other purposes like event_col, etc.)
fit_options: Additional keyword arguments to pass into the estimator.
Returns:
Fitted LogLogisticAFTFitter.
Expand All @@ -617,26 +646,35 @@ def log_logistic_aft(
>>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0)
>>> llf = ep.tl.log_logistic_aft(adata, "mort_day_censored", "censor_flg")
"""
df = _regression_model_data_frame_preparation(adata, duration_col, accept_zero_duration=False)

return _regression_model(
LogLogisticAFTFitter, adata, duration_col, event_col, entry_col, accept_zero_duration=False
log_logistic_aft = LogLogisticAFTFitter(
alpha=alpha,
fit_intercept=fit_intercept,
penalizer=penalizer,
l1_ratio=l1_ratio,
model_ancillary=model_ancillary,
)

log_logistic_aft.fit(
df,
duration_col=duration_col,
event_col=event_col,
entry_col=entry_col,
ancillary=ancillary,
show_progress=show_progress,
weights_col=weights_col,
robust=robust,
initial_point=initial_point,
formula=formula,
fit_options=fit_options,
)

def _regression_model(
model_class, adata: AnnData, duration_col: str, event_col: str, entry_col: str = None, accept_zero_duration=True
):
"""Convenience function for regression models."""
df = anndata_to_df(adata)
df = df.dropna()

if not accept_zero_duration:
df.loc[df[duration_col] == 0, duration_col] += 1e-5

model = model_class()
model.fit(df, duration_col, event_col, entry_col=entry_col)
# Add the results to the AnnData object
if inplace:
_regression_model_populate_adata(adata, log_logistic_aft.summary, key_added_prefix)

return model
return log_logistic_aft


def _univariate_model(
Expand Down

0 comments on commit 742d38c

Please sign in to comment.