Skip to content

Commit

Permalink
add method to compute confusion matrix between binarized counts and m…
Browse files Browse the repository at this point in the history
…odify todo statement comments
  • Loading branch information
ashuaibi7 committed Dec 13, 2024
1 parent a883c74 commit eecb81b
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions src/dialect/models/interaction.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import numpy as np
from scipy.optimize import minimize
from sklearn.metrics import confusion_matrix

from dialect.models.gene import Gene

Expand Down Expand Up @@ -76,7 +77,19 @@ def verify_pi_values(self, pi_a, pi_b):
# ---------------------------------------------------------------------------- #
# Likelihood & Metric Evaluation #
# ---------------------------------------------------------------------------- #
# TODO: Add additional metrics (KL, MI, etc.) for further exploration
# TODO (LOW PRIORITY): Add additional metrics (KL, MI, etc.)

def compute_contingency_table(self):
"""
Compute the contingency table (confusion matrix) for binarized counts
between gene_a and gene_b.
:return: A 2x2 numpy array representing the contingency table.
"""
gene_a_mutations = (self.gene_a.counts > 0).astype(int)
gene_b_mutations = (self.gene_b.counts > 0).astype(int)
cm = confusion_matrix(gene_a_mutations, gene_b_mutations, labels=[0, 1])
return cm

def compute_joint_probability(self, tau, u, v):
joint_probability = np.array(
Expand Down Expand Up @@ -153,6 +166,7 @@ def compute_log_likelihood(self, taus):
a_counts, b_counts = self.gene_a.counts, self.gene_b.counts
a_bmr_pmf, b_bmr_pmf = self.gene_a.bmr_pmf, self.gene_b.bmr_pmf
tau_00, tau_01, tau_10, tau_11 = taus
# TODO: Moddify all passenger key access to not default to 0
log_likelihood = sum(
np.log(
a_bmr_pmf.get(c_a, 0) * b_bmr_pmf.get(c_b, 0) * tau_00
Expand Down Expand Up @@ -374,7 +388,6 @@ def negative_log_likelihood(tau):
f"Estimated tau parameters for interaction {self.name}: tau_00={self.tau_00}, tau_01={self.tau_01}, tau_10={self.tau_10}, tau_11={self.tau_11}"
)

# TODO: Implement this method
def estimate_tau_with_em_from_scratch(
self, max_iter=1000, tol=1e-6, tau_init=[0.25, 0.25, 0.25, 0.25]
):
Expand Down

0 comments on commit eecb81b

Please sign in to comment.