Skip to content

Commit

Permalink
Add ValueError checker for msms scoring method and return method when…
Browse files Browse the repository at this point in the history
… calling the function for logging
  • Loading branch information
bkieft-usa committed Jan 31, 2025
1 parent efeb08f commit 5d6fbb8
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 20 deletions.
2 changes: 1 addition & 1 deletion metatlas/plots/dill2plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def __init__(self,
"""
logger.debug("Initializing new instance of %s.", self.__class__.__name__)
self.data = data
self.msms_hits = sp.sort_msms_hits(msms_hits)
self.msms_hits, _ = sp.sort_msms_hits(msms_hits)
self.color_me = or_default(color_me, [('black', '')])
self.compound_idx = compound_idx
self.width = width
Expand Down
15 changes: 6 additions & 9 deletions metatlas/tools/fastanalysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def make_stats_table(workflow_name: str = "JGI-HILIC", input_fname: Optional[Pat
msms_hits_df = msms_hits.copy()
msms_hits_df.reset_index(inplace=True)

#msms_hits_sorted_list = []
msms_hits_sorted_list = []
for compound_idx, compound_name in enumerate(compound_names):
ref_rt_peak = dataset[0][compound_idx]['identification'].rt_references[0].rt_peak
ref_mz = dataset[0][compound_idx]['identification'].mz_references[0].mz
Expand All @@ -136,8 +136,8 @@ def make_stats_table(workflow_name: str = "JGI-HILIC", input_fname: Optional[Pat
& ((abs(msms_hits_df['measured_precursor_mz'].values.astype(float) - mz_theoretical)/mz_theoretical) \
<= cid.mz_references[0].mz_tolerance*1e-6)]

comp_msms_hits = sp.sort_msms_hits(comp_msms_hits)
#msms_hits_sorted_list.append(comp_msms_hits)
comp_msms_hits, sorting_method = sp.sort_msms_hits(comp_msms_hits)
msms_hits_sorted_list.append(comp_msms_hits)
file_idxs, scores, msv_sample_list, msv_ref_list, rt_list = [], [], [], [], []
if len(comp_msms_hits) > 0 and not np.isnan(np.concatenate(comp_msms_hits['msv_ref_aligned'].values, axis=1)).all():
file_idxs = [file_names.index(f) for f in comp_msms_hits['file_name'] if f in file_names]
Expand Down Expand Up @@ -428,12 +428,9 @@ def make_stats_table(workflow_name: str = "JGI-HILIC", input_fname: Optional[Pat
dfs['msms_score'].iat[compound_idx, file_idx] = rows.loc[rows['score'].astype(float).idxmax()]['score']
dfs['num_frag_matches'].iat[compound_idx, file_idx] = rows.loc[rows['score'].astype(float).idxmax()]['num_matches']

# method = 'numeric_hierarchy'
# msms_hits_sorted_df = pd.concat(msms_hits_sorted_list)
# msms_hits_sorted_df.to_csv(f"/out/msms_hits_sorted_{method}.csv")
# best_hits = msms_hits_sorted_df.groupby(['inchi_key', 'adduct']).first().reset_index()
# best_hits.to_csv(f"/out/best_hits_{method}.csv")
# best_hits.to_json(f"/out/best_hits_{method}.json")
logger.info(f"Finished processing all compounds and sorted MSMS scores by method '{sorting_method}'.")
#msms_hits_sorted_df = pd.concat(msms_hits_sorted_list)
#msms_hits_sorted_df.to_csv(f"/out/msms_hits_sorted_by_{sorting_method}.csv") # Use this for examining hits with different scoring methods

passing['msms_score'] = (np.nan_to_num(dfs['msms_score'].values) >= min_msms_score).astype(float)
passing['num_frag_matches'] = (np.nan_to_num(dfs['num_frag_matches'].values) >= min_num_frag_matches).astype(float)
Expand Down
16 changes: 6 additions & 10 deletions metatlas/tools/spectralprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1579,11 +1579,15 @@ def next2pow(x):
return np.array([MA[cutoff_idx], ptA[cutoff_idx]]), contributions


def sort_msms_hits(msms_hits: pd.DataFrame, sorting_method: str = 'score') -> pd.DataFrame:
def sort_msms_hits(msms_hits: pd.DataFrame, sorting_method: str = 'score') -> tuple[pd.DataFrame, str]:
"""
Takes an msms hits dataframe and returns a sorted version of it based on the sorting method. Typically
this function is called while iterating over compounds, so each dataframe input will likely be for a single compound.
"""
allowed_methods = ['score', 'sums', 'weighted', 'numeric_hierarchy', 'quantile_hierarchy']
if sorting_method not in allowed_methods:
raise ValueError(f"Invalid sorting method: {sorting_method}. Allowed methods are: {allowed_methods}")

if sorting_method == "score":
sorted_msms_hits = msms_hits.sort_values('score', ascending=False)

Expand All @@ -1595,8 +1599,6 @@ def sort_msms_hits(msms_hits: pd.DataFrame, sorting_method: str = 'score') -> pd
(sorted_msms_hits['score'])
)
sorted_msms_hits = sorted_msms_hits.sort_values('summed_ratios_and_score', ascending=False)
# droppable_columns = ['summed_ratios_and_score', 'data_frags', 'ref_frags']
# sorted_msms_hits.drop(columns=droppable_columns, inplace=True)

elif sorting_method == "weighted":
sorted_msms_hits = msms_hits.copy()
Expand All @@ -1610,25 +1612,19 @@ def sort_msms_hits(msms_hits: pd.DataFrame, sorting_method: str = 'score') -> pd
((sorted_msms_hits['num_matches'] / sorted_msms_hits['ref_frags']) * weights['match_to_ref_frag_ratio'])
)
sorted_msms_hits = sorted_msms_hits.sort_values('weighted_score', ascending=False)
# droppable_columns = ['weighted_score', 'data_frags', 'ref_frags']
# sorted_msms_hits.drop(columns=droppable_columns, inplace=True)

elif sorting_method == "numeric_hierarchy":
sorted_msms_hits = msms_hits.copy()
bins = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.97, 0.99, 1]
labels = ['0-0.1', '0.1-0.2', '0.2-0.3', '0.3-0.4', '0.4-0.5', '0.5-0.6', '0.6-0.7', '0.7-0.8', '0.8-0.9', '0.9-0.95', '0.95-0.97', '0.97-0.99', '0.99-1']
sorted_msms_hits['score_bin'] = pd.cut(sorted_msms_hits['score'], bins=bins, labels=labels, right=False)
sorted_msms_hits = sorted_msms_hits.sort_values(by=['score_bin', 'num_matches', 'score'], ascending=[False, False, False])
# droppable_columns = ['score_bin', 'data_frags', 'ref_frags']
# sorted_msms_hits.drop(columns=droppable_columns, inplace=True)

elif sorting_method == "quantile_hierarchy":
sorted_msms_hits = msms_hits.copy()
sorted_msms_hits = sorted_msms_hits.dropna(subset=['score'])
sorted_msms_hits['score'] += np.random.normal(0, 1e-8, size=len(sorted_msms_hits)) # Add small noise to handle duplicates
sorted_msms_hits['score_bin'] = pd.qcut(sorted_msms_hits['score'], duplicates='drop', q=5)
sorted_msms_hits = sorted_msms_hits.sort_values(by=['score_bin', 'num_matches', 'score'], ascending=[False, False, False])
# droppable_columns = ['score_bin', 'data_frags', 'ref_frags']
# sorted_msms_hits.drop(columns=droppable_columns, inplace=True)

return sorted_msms_hits
return sorted_msms_hits, sorting_method

0 comments on commit 5d6fbb8

Please sign in to comment.