Skip to content

Commit

Permalink
Added tests and small upgrades
Browse files Browse the repository at this point in the history
-Added CandyCrunch tests for GPST000017,GPST000029,GPST000307,GPST000350, sheep_milk, and PMC8950484_CHO
-Added CandyCrumbs test for glycans and glycopeptides
-Improved CandyCrumbs performance by ignoring subgraphs with more termini than allowed cleavages
-Generalized charges in mass_check
-Allowed specific input charges in mass_check
-Impute and supplement_prediction now use the charge in each row
-Removed filter_top_frag_annotations
  • Loading branch information
urbj committed Jan 29, 2025
1 parent f2751f5 commit 57d68bb
Show file tree
Hide file tree
Showing 38 changed files with 14,757 additions and 105 deletions.
3 changes: 3 additions & 0 deletions CandyCrunch/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,9 @@ def generate_atomic_frags(nx_mono, global_mods, special_residues, allowed_X_clea
nx_deg = nx_mono.degree
for i,subg in enumerate(subgraphs):
terminals = get_terminals(nx_deg,subg)
new_terminals = [x for x in terminals if x not in all_other_terminals]
if len(new_terminals)>max_cleavages:
continue
other_terminals = [x for x in subg.nodes if x in all_other_terminals and x not in terminals]
terminals = terminals+other_terminals
inner_mass = sum([mono_attributes[node_dict_basic[m]]['mass'][node_dict_basic[m]] for m in subg.nodes() if m not in terminals])
Expand Down
43 changes: 19 additions & 24 deletions CandyCrunch/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from CandyCrunch.model import (CandyCrunch_CNN, SimpleDataset, transform_mz,
transform_prec, transform_rt)
from analysis import CandyCrumbs
from CandyCrunch.analysis import CandyCrumbs

this_dir, this_filename = os.path.split(__file__)
data_path = os.path.join(this_dir, 'glycans.pkl')
Expand Down Expand Up @@ -285,8 +285,8 @@ def get_topk(dataloader, model, glycans, k = 25, temp = False, temperature = tem
return preds, conf.tolist()


def mass_check(mass, glycan, mode = 'negative', modification = 'reduced', mass_tag = None,
double_thresh = 900, triple_thresh = 1500, quadruple_thresh = 3500, mass_thresh = 0.5):
def mass_check(mass, glycan, mode = 'negative', modification = 'reduced', mass_tag = None, double_thresh = 900,
triple_thresh = 1500, quadruple_thresh = 3500, mass_thresh = 0.5,permitted_charges = [1,2,3,4]):
"""determine whether glycan could explain m/z\n
| Arguments:
| :-
Expand All @@ -298,11 +298,14 @@ def mass_check(mass, glycan, mode = 'negative', modification = 'reduced', mass_t
| double_thresh (float): mass threshold over which to consider doubly-charged ions; default:900
| triple_thresh (float): mass threshold over which to consider triply-charged ions; default:1500
| quadruple_thresh (float): mass threshold over which to consider quadruply-charged ions; default:3500
| mass_thresh (float): maximum allowed mass difference to return True; default:0.5\n
| mass_thresh (float): maximum allowed mass difference to return True; default:0.5
| permitted_charges (list): charges of ions used to check mass against; default:[1,2,3,4]\n
| Returns:
| :-
| Returns True if glycan could explain mass and False if not
"""
threshold_dict = {2:double_thresh, 3:triple_thresh, 4:quadruple_thresh}
greater_charges = [x for x in permitted_charges if x>1]
try:
mz = glycan_to_mass(glycan, sample_prep= modification if modification in ["permethylated", "peracetylated"] else 'underivatized') if isinstance(glycan, str) else glycan
except:
Expand All @@ -313,10 +316,11 @@ def mass_check(mass, glycan, mode = 'negative', modification = 'reduced', mass_t
mz += mass_tag
adduct_list = ['Acetonitrile', 'Acetate', 'Formate', 'HCO3-'] if mode == 'negative' else ['Na+', 'K+', 'NH4+']
og_list = [mz] + [mz + mass_dict.get(adduct, 999) for adduct in adduct_list]
charge_adjustments = [-0.5, -0.66, -0.75] if mode == 'negative' else [0.5, 0.66, 0.75]
thresholds = [double_thresh, triple_thresh, quadruple_thresh]
mz_list = og_list + [
(m / z + charge_adjust) for z, threshold, charge_adjust in zip([2, 3, 4], thresholds, charge_adjustments)
single_list = og_list if 1 in permitted_charges else []
charge_adjustments = [-(1-(1/x)) for x in greater_charges] if mode == 'negative' else [(1-(1/x)) for x in greater_charges]
thresholds = [threshold_dict[x] for x in greater_charges]
mz_list = single_list + [
(m / z + charge_adjust) for z, threshold, charge_adjust in zip(greater_charges, thresholds, charge_adjustments)
for m in og_list if m > threshold
]
return [m for m in mz_list if abs(mass - m) < mass_thresh]
Expand Down Expand Up @@ -567,27 +571,16 @@ def assign_candidate_structures(df_in,df_glycan_in,comp_struct_map,topo_struct_m
return df_in


def filter_top_frag_annotations(ccrumbs_out):
filtered_annotations = []
for k,v in ccrumbs_out.items():
if v:
for ant in [a for a in v['Domon-Costello nomenclatures'] if len(a)<3]:
prefs = [a.split('_')[0][-1] for a in ant]
if 'A' in prefs or 'X' in prefs:
if len(prefs)>1:
continue
filtered_annotations.append(ant)
return filtered_annotations


def assign_annotation_scores_pooled(df_in,multiplier,mass_tag,mass_tolerance):
unq_structs = df_in[df_in['candidate_structure'].notnull()].groupby('candidate_structure').first().reset_index()
for struct,comp in zip(unq_structs.candidate_structure,unq_structs.composition):
try:
if '][GlcNAc(b1-4)]' in struct:
continue
rounded_mass_rows = [[np.round(y,1) for y in x][:15] for x in df_in[df_in['candidate_structure'] == struct].peak_d]
row_charge = max(df_in[df_in['candidate_structure'] == struct].charge)
unq_rounded_masses = set([x for y in rounded_mass_rows for x in y])
cc_out = CandyCrumbs(struct, unq_rounded_masses,mass_tolerance,simplify=False,charge=int(multiplier*abs(row_charge)),disable_global_mods=True,max_cleavages=3,mass_tag=mass_tag)
cc_out = CandyCrumbs(struct, unq_rounded_masses,mass_tolerance,simplify=False,charge=int(multiplier*abs(row_charge)),disable_global_mods=True,disable_X_cross_rings=True,max_cleavages=2,mass_tag=mass_tag)
tester_mass_scores=score_top_frag_masses(cc_out)
row_scores = [sum([tester_mass_scores[x] for x in y])/sum(comp.values()) for y in rounded_mass_rows]
secondary_mass_scores = score_top_frag_masses(cc_out,simple_frags_only=True)
Expand Down Expand Up @@ -886,6 +879,7 @@ def impute(df_out, pred_thresh,mode = 'negative', modification = 'reduced', mass
"""
predictions_list = df_out.predictions.values.tolist()
index_list = df_out.index.tolist()
charge_list = df_out.charge.tolist()
seqs = [p[0][0] for p in predictions_list if p and ("Neu5Ac" in p[0][0] or "Neu5Gc" in p[0][0])]
variants = set(unwrap([get_all_variants(s,'Neu5Ac','Neu5Gc') for s in seqs]))
if glycan_class == "O":
Expand All @@ -894,7 +888,7 @@ def impute(df_out, pred_thresh,mode = 'negative', modification = 'reduced', mass
for i, k in enumerate(predictions_list):
if len(k) < 1:
for v in variants:
if mass_check(index_list[i], v, mode = mode, modification = modification, mass_tag = mass_tag):
if mass_check(index_list[i], v, mode = mode, modification = modification, mass_tag = mass_tag,permitted_charges=[abs(charge_list[i])]):
df_out.iat[i, 0] = [(v, pred_thresh)]
break
return df_out
Expand Down Expand Up @@ -1861,10 +1855,11 @@ def supplement_prediction(df_in, glycan_class, mode = 'negative', modification =
net = evoprune_network(net)
unexplained_idx = [idx for idx, pred in enumerate(df['predictions']) if not pred]
unexplained = df.index[unexplained_idx].tolist()
charges = df.charge[unexplained_idx].tolist()
preds_set = set(preds)
new_nodes = [k for k in net.nodes() if k not in preds_set]
explained_idx = [[unexplained_idx[k] for k, check in enumerate([mass_check(j, node,
modification = modification, mode = mode, mass_tag = mass_tag) for j in unexplained]) if check] for node in new_nodes]
modification = modification, mode = mode, mass_tag = mass_tag,permitted_charges=[abs(c)]) for j,c in zip(unexplained,charges)]) if check] for node in new_nodes]
new_nodes = [(node, idx) for node, idx in zip(new_nodes, explained_idx) if idx]
explained = {k: [] for k in set(unwrap(explained_idx))}
for node, indices in new_nodes:
Expand Down
159 changes: 159 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# conftest.py
from tabulate import tabulate
import numpy as np
from collections import defaultdict
import pytest
import json
import os
from datetime import datetime

class ResultCollector:
def __init__(self):
self.results = defaultdict(list)
self.param_names = None
self.current_dict = None
self.dict_results = defaultdict(lambda: defaultdict(list))
self.log_file = "test_results_log.json"
self.previous_results = self.load_previous_results()

def load_previous_results(self):
if os.path.exists(self.log_file):
with open(self.log_file, 'r') as f:
return json.load(f)
return {}

def save_current_results(self):
# Calculate final averages for each test_dict
final_results = {}
for dict_name, param_results in self.dict_results.items():
averaged_results = {
str(params): np.mean(scores) # Convert tuple to str for JSON serialization
for params, scores in param_results.items()
}
final_results[dict_name] = {
'scores': averaged_results,
'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S")
}

# Save to log file
with open(self.log_file, 'w') as f:
json.dump(final_results, f, indent=4)

def check_performance(self, test_dict_name, param_key, current_score):
"""Check if current score is at least as good as previous best"""
if test_dict_name in self.previous_results:
prev_scores = self.previous_results[test_dict_name]['scores']
param_key_str = str(param_key)
if param_key_str in prev_scores:
prev_score = prev_scores[param_key_str]
if current_score < prev_score:
raise AssertionError(
f"\nPerformance regression detected for {test_dict_name}!"
f"\nPrevious score: {prev_score:.3f}"
f"\nCurrent score: {current_score:.3f}"
f"\nParameters: {param_key}")

def add_result(self, params, score):
test_dict_name = params['test_dict']['name']
param_key = tuple(
params[key] if key != 'test_dict' else params['test_dict']['name']
for key in self.param_names.keys()
)

# Append the individual score to the list of scores for this parameter combination
self.dict_results[test_dict_name][param_key[1:]].append(score)

# If we've switched to a new test_dict, print the results for the previous one
if self.current_dict != test_dict_name and self.current_dict is not None:
self.print_dict_results(self.current_dict)

self.current_dict = test_dict_name

def print_dict_results(self, dict_name):
print(f"\n=== Results for {dict_name} ===")

# Get results for this test_dict
dict_specific_results = self.dict_results[dict_name]

# Generate headers (exclude test_dict since it's the same for all rows)
headers = list(self.param_names.values())[1:] + ['F1 Score']
table_data = []

# Calculate averages and sort by F1 score
averaged_results = {
params: np.mean(scores)
for params, scores in dict_specific_results.items()
}

# Sort by average F1 score and create table rows
for params, avg_f1 in sorted(averaged_results.items(), key=lambda x: x[1], reverse=True):
row = list(params) + [f"{avg_f1:.3f}"]
table_data.append(row)

print(tabulate(table_data, headers=headers, tablefmt='grid'))
print()

def get_table(self):
if not self.dict_results:
return "No results collected!"

# Print final results for the last test_dict
if self.current_dict:
self.print_dict_results(self.current_dict)

self.save_current_results()

# Print overall results
print("\n=== Overall Performance Results ===")
headers = list(self.param_names.values()) + ['F1 Score']
table_data = []

# Combine results from all test_dicts
overall_results = defaultdict(list)
for dict_name, param_results in self.dict_results.items():
for params, scores in param_results.items():
overall_key = (dict_name,) + params
overall_results[overall_key].extend(scores)

# Calculate averages and sort
averaged_overall = {
params: np.mean(scores)
for params, scores in overall_results.items()
}

for params, avg_f1 in sorted(averaged_overall.items(), key=lambda x: x[1], reverse=True):
row = list(params) + [f"{avg_f1:.3f}"]
table_data.append(row)

return tabulate(table_data, headers=headers, tablefmt='grid')

# Rest remains the same
collector = ResultCollector()

@pytest.fixture(scope="session")
def result_collector():
return collector

def pytest_configure(config):
config.collector = collector


def pytest_sessionfinish():
print("\n=== Final Summary ===")

# Print tables for each test_dict
for dict_name in collector.dict_results.keys():
collector.print_dict_results(dict_name)

# Print overall results
print("\n=== Overall Performance Results ===")
print(collector.get_table())
print()

if collector.previous_results:
print("\n=== Comparison with Previous Run ===")
for dict_name in collector.dict_results.keys():
if dict_name in collector.previous_results:
prev_timestamp = collector.previous_results[dict_name]['timestamp']
prev_score = collector.previous_results[dict_name]['scores']
print(f"\n{dict_name} (Previous run: {prev_timestamp}, Score: {prev_score})")
Loading

0 comments on commit 57d68bb

Please sign in to comment.