Skip to content

Commit

Permalink
added functionality to find set of overlapping samples between two ge…
Browse files Browse the repository at this point in the history
…nes and created a user-friendly analysis script to test for any pair of genes
  • Loading branch information
ashuaibi7 committed Dec 14, 2024
1 parent 50d540b commit a469ddd
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 13 deletions.
88 changes: 88 additions & 0 deletions analysis/interaction_cooccurring_samples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""
This script analyzes co-occurring samples between two genes using user-provided
gene pairs and count matrix data.
"""

# ---------------------------------------------------------------------------- #
# IMPORTS #
# ---------------------------------------------------------------------------- #
import logging
from dialect.utils.identify import load_cnt_mtx_and_bmr_pmfs
from dialect.models.gene import Gene
from dialect.models.interaction import Interaction


# ---------------------------------------------------------------------------- #
# HELPER FUNCTIONS #
# ---------------------------------------------------------------------------- #
def initialize_gene_objects(cnt_df, bmr_dict):
"""
Create a dictionary mapping gene names to Gene objects.
"""
gene_objects = {}
for gene_name in cnt_df.columns:
counts = cnt_df[gene_name].values
bmr_pmf = {i: bmr_dict[gene_name][i] for i in range(len(bmr_dict[gene_name]))}
gene_objects[gene_name] = Gene(
name=gene_name, samples=cnt_df.index, counts=counts, bmr_pmf=bmr_pmf
)
logging.info(f"Initialized {len(gene_objects)} Gene objects.")
return gene_objects


def get_cooccurring_samples(gene_a, gene_b):
"""
Get the set of co-occurring samples for two genes.
"""
interaction = Interaction(gene_a, gene_b)
return interaction.get_set_of_cooccurring_samples()


# ---------------------------------------------------------------------------- #
# MAIN FUNCTION #
# ---------------------------------------------------------------------------- #
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
print("Analyze Co-occurring Samples Between Gene Pairs")

# Prompt for count matrix and BMR PMFs file
cnt_mtx_path = input("Enter the path to the count matrix file: ").strip()
bmr_pmfs_path = input("Enter the path to the BMR PMFs file: ").strip()

cnt_df, bmr_dict = load_cnt_mtx_and_bmr_pmfs(cnt_mtx_path, bmr_pmfs_path)

# Initialize gene objects
gene_objects = initialize_gene_objects(cnt_df, bmr_dict)

print("\nType 'exit' to quit the program at any time.")
while True:
# Prompt for gene A
gene_a_name = input("Enter the name of Gene A: ").strip()
if gene_a_name.lower() == "exit":
print("Exiting the program. Goodbye!")
break
if gene_a_name not in gene_objects:
print(f"Gene '{gene_a_name}' does not exist. Try again.")
continue

# Prompt for gene B
gene_b_name = input("Enter the name of Gene B: ").strip()
if gene_b_name.lower() == "exit":
print("Exiting the program. Goodbye!")
break
if gene_b_name not in gene_objects:
print(f"Gene '{gene_b_name}' does not exist. Try again.")
continue

# Get Gene objects
gene_a = gene_objects[gene_a_name]
gene_b = gene_objects[gene_b_name]

# Get co-occurring samples
try:
cooccurring_samples = get_cooccurring_samples(gene_a, gene_b)
print(
f"Co-occurring samples between {gene_a_name} and {gene_b_name}: {cooccurring_samples}"
)
except Exception as e:
print(f"An error occurred while processing genes: {e}")
46 changes: 33 additions & 13 deletions src/dialect/models/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,36 @@ def __str__(self):
f"Contingency Table:{cm_info}"
)

# ---------------------------------------------------------------------------- #
# UTILITY METHODS #
# ---------------------------------------------------------------------------- #

def compute_contingency_table(self):
"""
Compute the contingency table (confusion matrix) for binarized counts
between gene_a and gene_b.
:return: A 2x2 numpy array representing the contingency table.
"""
gene_a_mutations = (self.gene_a.counts > 0).astype(int)
gene_b_mutations = (self.gene_b.counts > 0).astype(int)
cm = confusion_matrix(gene_a_mutations, gene_b_mutations, labels=[0, 1])
return cm

def get_set_of_cooccurring_samples(self):
"""
Get the list of samples in which both genes have at least one mutation.
:return: (list) List of sample indices where both genes have at least one mutation.
"""
sample_names = self.gene_a.samples
cooccurring_samples = [
sample_names[i]
for i in range(len(sample_names))
if self.gene_a.counts[i] > 0 and self.gene_b.counts[i] > 0
]
return sorted(cooccurring_samples)

# ---------------------------------------------------------------------------- #
# DATA VALIDATION & LOGGING #
# ---------------------------------------------------------------------------- #
Expand Down Expand Up @@ -116,18 +146,6 @@ def verify_pi_values(self, pi_a, pi_b):
# ---------------------------------------------------------------------------- #
# TODO (LOW PRIORITY): Add additional metrics (KL, MI, etc.)

def compute_contingency_table(self):
"""
Compute the contingency table (confusion matrix) for binarized counts
between gene_a and gene_b.
:return: A 2x2 numpy array representing the contingency table.
"""
gene_a_mutations = (self.gene_a.counts > 0).astype(int)
gene_b_mutations = (self.gene_b.counts > 0).astype(int)
cm = confusion_matrix(gene_a_mutations, gene_b_mutations, labels=[0, 1])
return cm

def compute_joint_probability(self, tau, u, v):
joint_probability = np.array(
[
Expand Down Expand Up @@ -195,7 +213,8 @@ def compute_log_likelihood(self, taus):
:raises ValueError: If `bmr_pmf` or `counts` are not defined for either gene, or if `tau` is invalid.
"""

logging.info(f"Computing log likelihood for {self.name}. Taus: {taus}")
# TODO: add verbose option for logging
# logging.info(f"Computing log likelihood for {self.name}. Taus: {taus}")

self.verify_bmr_pmf_and_counts_exist()
self.verify_taus_are_valid(taus)
Expand Down Expand Up @@ -511,6 +530,7 @@ def estimate_tau_with_em_from_scratch(
tau_10,
tau_11,
)
logging.info(" EM algorithm converged after {} iterations.".format(it + 1))
logging.info(
f"Estimated tau parameters for interaction {self.name}: tau_00={self.tau_00}, tau_01={self.tau_01}, tau_10={self.tau_10}, tau_11={self.tau_11}"
)
Expand Down

0 comments on commit a469ddd

Please sign in to comment.