Skip to content

Commit

Permalink
created postprocessing script to unify top ranking pairs identificati…
Browse files Browse the repository at this point in the history
…on by method based on ME or CO
  • Loading branch information
ashuaibi7 committed Jan 16, 2025
1 parent 3af068a commit d5d344b
Showing 1 changed file with 157 additions and 0 deletions.
157 changes: 157 additions & 0 deletions src/dialect/utils/postprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import pandas as pd

# ---------------------------------------------------------------------------- #
# CONSTANTS #
# ---------------------------------------------------------------------------- #
MIN_DRIVER_COUNT = 10
PVALUE_THRESHOLD = 1.0

ME_COLUMN_MAP = {
"DIALECT": "Rho",
"DISCOVER": "Discover ME P-Val",
"Fisher's Exact Test": "Fisher's ME P-Val",
"MEGSA": "MEGSA S-Score (LRT)",
"WeSME": "WeSME P-Val",
}

CO_COLUMN_MAP = {
"DIALECT": "Rho",
"DISCOVER": "Discover CO P-Val",
"Fisher's Exact Test": "Fisher's CO P-Val",
"MEGSA": None,
"WeSME": "WeSCO P-Val",
}


# ---------------------------------------------------------------------------- #
# HELPER FUNCTIONS #
# ---------------------------------------------------------------------------- #
def get_sort_column(method: str, meco: str) -> str:
"""
Returns the column name to sort on, depending on method and whether
we're doing ME or CO. Returns None if method doesn't apply to the
chosen meco (e.g. MEGSA for co-occurrence).
"""
if meco == "ME":
return ME_COLUMN_MAP.get(method, None)
else:
return CO_COLUMN_MAP.get(method, None)


def filter_by_method(
top_ranking_pairs: pd.DataFrame, method: str, meco: str, num_samples: int
) -> pd.DataFrame:
"""
Applies method-specific filters to the top_ranking_pairs DataFrame,
depending on whether we're seeking ME or CO.
Returns the filtered DataFrame or None if not applicable.
"""
if method == "MEGSA" and meco == "CO":
return None

if method == "DIALECT":
epsilon = MIN_DRIVER_COUNT / num_samples
# only keep pairs w/ both driver marginals > epsilon
top_ranking_pairs = top_ranking_pairs[
(top_ranking_pairs["Tau_1X"] > epsilon) & (top_ranking_pairs["Tau_X1"] > epsilon)
]
if meco == "ME":
top_ranking_pairs = top_ranking_pairs[top_ranking_pairs["Rho"] < 0]
else:
top_ranking_pairs = top_ranking_pairs[top_ranking_pairs["Rho"] > 0]

elif method == "MEGSA":
# For ME, keep only S-Score > 0
top_ranking_pairs = top_ranking_pairs[top_ranking_pairs["MEGSA S-Score (LRT)"] > 0]

elif method == "DISCOVER":
if meco == "ME":
top_ranking_pairs = top_ranking_pairs[
top_ranking_pairs["Discover ME P-Val"] < PVALUE_THRESHOLD
]
else:
top_ranking_pairs = top_ranking_pairs[
top_ranking_pairs["Discover CO P-Val"] < PVALUE_THRESHOLD
]

elif method == "Fisher's Exact Test":
if meco == "ME":
top_ranking_pairs = top_ranking_pairs[
top_ranking_pairs["Fisher's ME P-Val"] < PVALUE_THRESHOLD
]
else:
top_ranking_pairs = top_ranking_pairs[
top_ranking_pairs["Fisher's CO P-Val"] < PVALUE_THRESHOLD
]

elif method == "WeSME":
if meco == "ME":
top_ranking_pairs = top_ranking_pairs[
top_ranking_pairs["WeSME P-Val"] < PVALUE_THRESHOLD
]
else:
top_ranking_pairs = top_ranking_pairs[
top_ranking_pairs["WeSCO P-Val"] < PVALUE_THRESHOLD
]
top_ranking_pairs = top_ranking_pairs.rename(columns={"WeSCO P-Val": "WeSCO"})

return top_ranking_pairs


def get_top_ranked_pairs_by_method(
results_df,
method,
meco,
num_pairs,
num_samples,
):
"""
Given a results_df containing all methods' results, a method name,
and whether we are looking at ME or CO, returns the top num_pairs
after applying the appropriate filters/sorting.
Returns None if not applicable (e.g. MEGSA + CO).
"""
sort_col = get_sort_column(method, meco)
if sort_col is None:
return None

if method == "DIALECT":
# sort rho ascending for ME and descending for CO
# negative rho values indicate mutual exclusivity
ascending = meco == "ME"
elif method == "MEGSA":
# MEGSA uses LRT scores, which you sort descending
ascending = False
else:
# all other methods have p-values that you sort ascending
ascending = True

top_ranking_pairs = results_df.sort_values(by=sort_col, ascending=ascending)
top_ranking_pairs = filter_by_method(top_ranking_pairs, method, meco, num_samples)
if top_ranking_pairs is None or top_ranking_pairs.empty:
return None
top_ranking_pairs = top_ranking_pairs.head(num_pairs)
return top_ranking_pairs


def generate_top_ranking_tables(
results_df: pd.DataFrame, meco: str, num_pairs: int, num_samples: int
):
"""
Generates a dictionary of top-ranked dataframes for each method w/ ME or CO
return: dict { method_name : DataFrame or None }
"""
methods = ["DIALECT", "DISCOVER", "Fisher's Exact Test", "MEGSA", "WeSME"]

tables = {}
for method in methods:
top_df = get_top_ranked_pairs_by_method(
results_df=results_df,
method=method,
meco=meco,
num_pairs=num_pairs,
num_samples=num_samples,
)
tables[method] = top_df

return tables

0 comments on commit d5d344b

Please sign in to comment.