Skip to content

Commit

Permalink
implemented EM from scratch for pairwise interactions
Browse files Browse the repository at this point in the history
  • Loading branch information
ashuaibi7 committed Dec 12, 2024
1 parent ec8e9af commit fb807c1
Showing 1 changed file with 127 additions and 5 deletions.
132 changes: 127 additions & 5 deletions src/dialect/models/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,15 @@ def verify_bmr_pmf_and_counts_exist(self):
if self.gene_a.counts is None or self.gene_b.counts is None:
raise ValueError("Counts are not defined for one or both genes.")

def verify_taus_are_valid(self, taus):
def verify_taus_are_valid(self, taus, tol=1e-6):
"""
Verify that tau parameters are valid (0 <= tau_i <= 1 and sum(tau) == 1).
:param taus: (list of float) Tau parameters to validate.
:param tol: (float) Tolerance for the sum of tau parameters (default: 1e-1).
:raises ValueError: If any or all tau parameters are invalid.
"""
if not all(0 <= t <= 1 for t in taus) or not np.isclose(sum(taus), 1):
if not all(0 <= t <= 1 for t in taus) or not np.isclose(sum(taus), 1, atol=tol):
logging.info(f"Invalid tau parameters: {taus}")
raise ValueError(
"Invalid tau parameters. Ensure 0 <= tau_i <= 1 and sum(tau) == 1."
Expand Down Expand Up @@ -77,6 +78,41 @@ def verify_pi_values(self, pi_a, pi_b):
# ---------------------------------------------------------------------------- #
# TODO: Add additional metrics (KL, MI, etc.) for further exploration

def compute_joint_probability(self, tau, u, v):
joint_probability = np.array(
[
tau
* self.gene_a.bmr_pmf.get(c_a - u, 0)
* self.gene_b.bmr_pmf.get(c_b - v, 0)
for c_a, c_b in zip(self.gene_a.counts, self.gene_b.counts)
]
)
return joint_probability

def compute_total_probability(self, tau_00, tau_01, tau_10, tau_11):
total_probabilities = np.array(
[
sum(
(
tau_00
* self.gene_a.bmr_pmf.get(c_a, 0)
* self.gene_b.bmr_pmf.get(c_b, 0),
tau_01
* self.gene_a.bmr_pmf.get(c_a, 0)
* self.gene_b.bmr_pmf.get(c_b - 1, 0),
tau_10
* self.gene_a.bmr_pmf.get(c_a - 1, 0)
* self.gene_b.bmr_pmf.get(c_b, 0),
tau_11
* self.gene_a.bmr_pmf.get(c_a - 1, 0)
* self.gene_b.bmr_pmf.get(c_b - 1, 0),
)
)
for c_a, c_b in zip(self.gene_a.counts, self.gene_b.counts)
]
)
return total_probabilities

def compute_log_likelihood(self, taus):
"""
Compute the complete data log-likelihood for the interaction given the parameters \( \tau \).
Expand Down Expand Up @@ -303,6 +339,7 @@ def compute_rho(self, taus):
def estimate_tau_with_optimization_using_scipy(
self, tau_init=[0.25, 0.25, 0.25, 0.25], alpha=1e-13
):
# TODO: tau parameters fail verification due to optimization scheme. try different optimization scheme
"""
Estimate the tau parameters using the SLSQP optimization scheme.
Expand Down Expand Up @@ -336,15 +373,100 @@ def negative_log_likelihood(tau):
logging.info(
f"Estimated tau parameters for interaction {self.name}: tau_00={self.tau_00}, tau_01={self.tau_01}, tau_10={self.tau_10}, tau_11={self.tau_11}"
)
return self.tau_00, self.tau_01, self.tau_10, self.tau_11

# TODO: Implement this method
def estimate_tau_with_em_from_scratch(self):
def estimate_tau_with_em_from_scratch(
self, max_iter=1000, tol=1e-6, tau_init=[0.25, 0.25, 0.25, 0.25]
):
"""
Estimate the tau parameters for interaction using the Expectation-Maximization (EM) algorithm.
This method iteratively updates the tau parameters, \( \tau = (\tau_{00}, \tau_{01}, \tau_{10}, \tau_{11}) \),
to maximize the likelihood of the observed mutation count data for two interacting genes.
**Algorithm Steps**:
1. **E-Step**:
At iteration \( t \), given the estimated driver mutation probabilities
\( \tau^{(t)} = (\tau_{00}^{(t)}, \tau_{01}^{(t)}, \tau_{10}^{(t)}, \tau_{11}^{(t)}) \),
compute the responsibilities \( z_{i,uv}^{(t)} \) for each pair \( (u,v) \in \{0,1\}^2 \)
and sample \( i = 1, \dots, N \) as:
.. math::
z_{i,uv}^{(t)} = \\frac{\\tau_{uv}^{(t)} \\cdot \\mathbb{P}(P_i = c_i - u) \\cdot \\mathbb{P}(P_i' = c_i' - v)}
{\\sum_{(x,y) \\in \\{0,1\\}^2} \\left( \\tau_{xy}^{(t)} \\cdot \\mathbb{P}(P_i = c_i - x) \\cdot \\mathbb{P}(P_i' = c_i' - y) \\right)}
where \( P_i \) and \( P_i' \) represent passenger mutation probabilities for the two genes,
and \( c_i, c_i' \) are the observed mutation counts.
2. **M-Step**:
Given the responsibilities \( \\bm{z}_i^{(t)} = (z_{i,00}^{(t)}, z_{i,01}^{(t)}, z_{i,10}^{(t)}, z_{i,11}^{(t)}) \),
update the tau parameters at iteration \( t+1 \) as:
.. math::
\\tau_{uv}^{(t+1)} = \\frac{1}{N} \\sum_{i=1}^{N} z_{i,uv}^{(t)}
for each pair \( (u,v) \\in \\{0,1\\}^2 \).
**Parameters**:
:param max_iter: (int) Maximum number of iterations for the EM algorithm (default: 1000).
:param tol: (float) Convergence threshold for log-likelihood improvement (default: 1e-6).
:param tau_init: (list of float) Initial guesses for the tau parameters (default: [0.25, 0.25, 0.25, 0.25]).
**Returns**:
:return: (tuple) The estimated values of \( (\\tau_{00}, \\tau_{01}, \\tau_{10}, \\tau_{11}) \).
"""
logging.info("Estimating tau parameters using EM algorithm from scratch.")

self.verify_bmr_pmf_and_counts_exist()

raise NotImplementedError("Method is not yet implemented.")
# TODO: handle nonzero probability counts and cases where counts are not in BMR PMF

tau_00, tau_01, tau_10, tau_11 = tau_init
for it in range(max_iter):
# E-Step: Compute responsibilities
total_probabilities = self.compute_total_probability(
tau_00, tau_01, tau_10, tau_11
) # denominator in E-Step equation
z_i_00 = self.compute_joint_probability(tau_00, 0, 0) / total_probabilities
z_i_01 = self.compute_joint_probability(tau_01, 0, 1) / total_probabilities
z_i_10 = self.compute_joint_probability(tau_10, 1, 0) / total_probabilities
z_i_11 = self.compute_joint_probability(tau_11, 1, 1) / total_probabilities

# M-Step: Update tau parameters
curr_tau_00 = np.mean(z_i_00)
curr_tau_01 = np.mean(z_i_01)
curr_tau_10 = np.mean(z_i_10)
curr_tau_11 = np.mean(z_i_11)

# Check for convergence
prev_log_likelihood = self.compute_log_likelihood(
(tau_00, tau_01, tau_10, tau_11)
)
curr_log_likelihood = self.compute_log_likelihood(
(curr_tau_00, curr_tau_01, curr_tau_10, curr_tau_11)
)
if abs(curr_log_likelihood - prev_log_likelihood) < tol:
break

tau_00, tau_01, tau_10, tau_11 = (
curr_tau_00,
curr_tau_01,
curr_tau_10,
curr_tau_11,
)

self.tau_00, self.tau_01, self.tau_10, self.tau_11 = (
tau_00,
tau_01,
tau_10,
tau_11,
)
logging.info(
f"Estimated tau parameters for interaction {self.name}: tau_00={self.tau_00}, tau_01={self.tau_01}, tau_10={self.tau_10}, tau_11={self.tau_11}"
)

# TODO: Implement below to increase speed relative to from-scratch EM
def estimate_tau_with_em_using_pomegranate(self):
Expand Down

0 comments on commit fb807c1

Please sign in to comment.