From 18901fa99280dbdc6e5a0fa7486a8cdab384e154 Mon Sep 17 00:00:00 2001 From: ashuaibi7 Date: Sat, 23 Nov 2024 08:17:55 -0500 Subject: [PATCH] implemented optimization with scipy for interaction class --- src/dialect/models/interaction.py | 39 +++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/src/dialect/models/interaction.py b/src/dialect/models/interaction.py index 8209c2c..3393f94 100644 --- a/src/dialect/models/interaction.py +++ b/src/dialect/models/interaction.py @@ -240,3 +240,42 @@ def compute_rho(self): ) logging.info(f"Computed rho for interaction {self.name}: {rho}") return rho + + # ---------------------------------------------------------------------------- # + # Parameter Estimation Methods # + # ---------------------------------------------------------------------------- # + def estimate_tau_with_optimization_using_scipy( + self, tau_init=[0.25, 0.25, 0.25, 0.25], alpha=1e-13 + ): + """ + Estimate the tau parameters using the L-BFGS-B optimization scheme. + + :param tau_init (list): Initial guesses for the tau parameters (default: [0.25, 0.25, 0.25, 0.25]). + :param alpha (float): Small value to avoid edge cases at 0 or 1 (default: 1e-13). + :return (tuple): The optimized values of (tau_00, tau_01, tau_10, tau_11). + """ + logging.info(f"Estimating tau params for {self.name} using L-BFGS-B.") + + def negative_log_likelihood(tau): + return -self.compute_log_likelihood(tau) + + bounds = 4 * [(alpha, 1 - alpha)] + constraints = {"type": "eq", "fun": lambda tau: sum(tau) - 1} + result = minimize( + negative_log_likelihood, + x0=tau_init, + bounds=bounds, + constraints=constraints, + method="L-BFGS-B", + ) + if not result.success: + logging.warning( + f"Optimization failed for interaction {self.name}: {result.message}" + ) + raise ValueError(f"Optimization failed: {result.message}") + + self.tau_00, self.tau_01, self.tau_10, self.tau_11 = result.x + logging.info( + f"Estimated tau parameters for interaction {self.name}: tau_00={self.tau_00}, tau_01={self.tau_01}, tau_10={self.tau_10}, tau_11={self.tau_11}" + ) + return self.tau_00, self.tau_01, self.tau_10, self.tau_11