Skip to content

Commit

Permalink
Merge pull request #29 from TieuLongPhan/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
klausweinbauer authored May 16, 2024
2 parents d6ae8a2 + cf33476 commit 597f7f3
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 98 deletions.
17 changes: 5 additions & 12 deletions Test/SynMCSImputer/MissingGraph/test_find_graph_dict.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import unittest
from rdkit import Chem
from synrbl.SynMCSImputer.MissingGraph.find_graph_dict import (
find_single_graph,
find_single_graph_parallel,
convert_smiles_to_mols,
smiles_to_mol_parallel,
Expand All @@ -11,20 +10,14 @@
class TestFindGraphFunctions(unittest.TestCase):
def setUp(self):
# Example molecules for testing
self.mcs_mol_list = [Chem.MolFromSmiles("CC"), Chem.MolFromSmiles("C")]
self.mcs_mol_list = [[Chem.MolFromSmiles("CC"), Chem.MolFromSmiles("C")]]
self.sorted_reactants_mol_list = [
Chem.MolFromSmiles("CCO"),
Chem.MolFromSmiles("CO"),
[
Chem.MolFromSmiles("CCO"),
Chem.MolFromSmiles("CO"),
]
]

def test_find_single_graph(self):
result = find_single_graph(self.mcs_mol_list, self.sorted_reactants_mol_list)
self.assertIsInstance(result, dict)
self.assertIn("smiles", result)
self.assertIn("boundary_atoms_products", result)
self.assertIn("nearest_neighbor_products", result)
self.assertIn("issue", result)

def test_find_single_graph_parallel(self):
result = find_single_graph_parallel(
self.mcs_mol_list, self.sorted_reactants_mol_list, n_jobs=2
Expand Down
1 change: 0 additions & 1 deletion Test/test_mcs_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def test_simple_mcs():
_add(data, "COC(C)=O>>OC(C)=O")

results = mcs.find(data)
print(results)

assert 1 == len(results)
result = results[0]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "synrbl"
version = "0.0.16"
version = "0.0.17"
authors = [
{name="Tieu Long Phan", email="long.tieu_phan@uni-leipzig.de"},
{name="Klaus Weinbauer", email="klaus@bioinf.uni-leipzig.de"}
Expand Down
40 changes: 24 additions & 16 deletions synrbl/SynCmd/cmd_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,20 @@ def _sr(r):
rb_s = stats["rb_solved"]
rb_a = stats["rb_applied"]
mcs_a = stats["mcs_applied"]
mcs_cth = stats["confident_cnt"]
total_solved = rb_s + mcs_cth
mcs_s = stats["confident_cnt"]
total_solved = rb_s + mcs_s
total_correct = rb_correct + mcs_correct
output_stats["total_solved"] = total_solved
output_stats["total_correct"] = total_correct
rb_suc = _r(rb_s, rb_a)
mcs_suc = _r(mcs_cth, mcs_a)
suc = _r(rb_s + mcs_cth, rxn_cnt)
mcs_suc = _r(mcs_s, mcs_a)
suc = _r(rb_s + mcs_s, rxn_cnt)
output_stats["rb_suc"] = rb_suc
output_stats["mcs_suc"] = mcs_suc
output_stats["success"] = suc
rb_acc = _r(rb_correct, rb_s)
mcs_acc = _r(mcs_correct, mcs_cth)
acc = _r(rb_correct + mcs_correct, rb_s + mcs_cth)
mcs_acc = _r(mcs_correct, mcs_s)
acc = _r(rb_correct + mcs_correct, rb_s + mcs_s)
output_stats["rb_acc"] = rb_acc
output_stats["mcs_acc"] = mcs_acc
output_stats["accuracy"] = acc
Expand All @@ -49,7 +49,7 @@ def _sr(r):
logger.info("-" * len(header))

logger.info(line_fmt.format("Input", str(rb_a), str(mcs_a), str(rxn_cnt)))
logger.info(line_fmt.format("Solved", str(rb_s), str(mcs_cth), str(total_solved)))
logger.info(line_fmt.format("Solved", str(rb_s), str(mcs_s), str(total_solved)))
logger.info(
line_fmt.format(
"Correct",
Expand Down Expand Up @@ -127,24 +127,32 @@ def run(args):
for i, entry in enumerate(dataset):
if not entry["solved"]:
continue

if (
entry["solved_by"] == "mcs-based"
and entry["confidence"] >= args.min_confidence
):
mcs_cth += 1

exp = entry[args.target_col]
if pd.isna(exp):
logger.warning(
"Missing expected reaction ({}) in line {}.".format(args.target_col, i)
)
continue

exp_reaction = normalize_smiles(exp)
act_reaction = normalize_smiles(entry[args.col])
if entry["confidence"] >= args.min_confidence:
mcs_cth += 1
if (
wc_similarity(exp_reaction, act_reaction, args.similarity_method)
>= args.similarity_threshold
):
if entry["solved_by"] == "rule-based":
rb_correct += 1
elif entry["solved_by"] == "mcs-based":
wc_sim = wc_similarity(exp_reaction, act_reaction, args.similarity_method)
if entry["solved_by"] == "mcs-based":
if (
entry["confidence"] >= args.min_confidence
and wc_sim >= args.similarity_threshold
):
mcs_correct += 1
elif entry["solved_by"] == "rule-based":
if wc_sim >= args.similarity_threshold:
rb_correct += 1

stats["confident_cnt"] = mcs_cth
output_result(stats, rb_correct, mcs_correct, file=args.o)
Expand Down
38 changes: 38 additions & 0 deletions synrbl/SynCmd/cmd_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import argparse
import logging
import pandas as pd
import rdkit.Chem.rdChemReactions as rdChemReactions

from synrbl import Balancer

Expand Down Expand Up @@ -53,6 +54,42 @@ def _sr(v, c):
)


def check_columns(reactions, reaction_col, required_cols=[]):
if len(reactions) == 0:
raise ValueError("No reactions found in input.")
cols = reactions[0].keys()
if reaction_col not in cols:
raise KeyError("No column '{}' found in input.".format(reaction_col))
if not isinstance(reactions[0][reaction_col], str):
raise TypeError(
"Reaction column '{}' must be of type string not '{}'.".format(
reaction_col, type(reactions[0][reaction_col])
)
)

mol = None
try:
mol = rdChemReactions.ReactionFromSmarts(
reactions[0][reaction_col], useSmiles=True
)
except Exception:
pass
if mol is None:
raise ValueError(
"Value '{}...' in reaction column '{}' is not a valid SMILES.".format(
reactions[0][reaction_col][0:30], reaction_col
)
)
for c in required_cols:
if c not in reactions[0].keys():
raise KeyError(
(
"Required column '{}' not found. The input to benchamrk "
+ "should be the output from a rebalancing run."
).format(c)
)


def impute(
src_file,
output_file,
Expand All @@ -62,6 +99,7 @@ def impute(
n_jobs=-1,
):
input_reactions = pd.read_csv(src_file).to_dict("records")
check_columns(input_reactions, reaction_col, required_cols=passthrough_cols)

synrbl = Balancer(
reaction_col=reaction_col, confidence_threshold=min_confidence, n_jobs=n_jobs
Expand Down
69 changes: 1 addition & 68 deletions synrbl/SynMCSImputer/MissingGraph/find_graph_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,58 +9,6 @@
from synrbl.SynMCSImputer.MissingGraph.uncertainty_graph import GraphMissingUncertainty


def find_single_graph(mcs_mol_list, sorted_reactants_mol_list):
"""
Find missing parts, boundary atoms, and nearest neighbors for a list of
reactant molecules using a corresponding list of MCS (Maximum Common
Substructure) molecules.
Parameters:
- mcs_mol_list (list of rdkit.Chem.Mol): List of RDKit molecule objects
representing the MCS, corresponding to each molecule in
sorted_reactants_mol_list.
- sorted_reactants_mol_list (list of rdkit.Chem.Mol): The list of RDKit
molecule objects to analyze.
Returns:
- Dictionary containing:
- 'smiles' (list of list of str): SMILES representations of the missing
parts for each molecule.
- 'boundary_atoms_products' (list of list of dict): Lists of boundary atoms
for each molecule.
- 'nearest_neighbor_products' (list of list of dict): Lists of nearest
neighbors for each molecule.
- 'issue' (list): Any issues encountered during processing.
"""
missing_results = {
"smiles": [],
"boundary_atoms_products": [],
"nearest_neighbor_products": [],
"issue": [],
}
for i in zip(sorted_reactants_mol_list, mcs_mol_list):
try:
(
mols,
boundary_atoms_products,
nearest_neighbor_products,
) = FindMissingGraphs.find_missing_parts_pairs(i[0], i[1])
missing_results["smiles"].append([Chem.MolToSmiles(mol) for mol in mols])
missing_results["boundary_atoms_products"].append(boundary_atoms_products)
missing_results["nearest_neighbor_products"].append(
nearest_neighbor_products
)
missing_results["issue"].append([])
except Exception as e:
missing_results["smiles"].append([])
missing_results["boundary_atoms_products"].append([])
missing_results["nearest_neighbor_products"].append([])
missing_results["issue"].append(
"FindMissingGraphs.find_missing_parts() failed:" + str(e)
)
return missing_results


def find_single_graph_parallel(mcs_mol_list, sorted_reactants_mol_list, n_jobs=4):
"""
Find missing parts, boundary atoms, and nearest neighbors for a list of
Expand Down Expand Up @@ -109,26 +57,11 @@ def process_single_pair(reactant_mol, mcs_mol, job_timeout=2):
"issue": "",
}
except multiprocessing.TimeoutError:
pool.terminate() # Terminate the pool in case of timeout

result = FindMissingGraphs.find_missing_parts_pairs(
reactant_mol, mcs_mol, False
)
return {
"smiles": [
Chem.MolToSmiles(mol) if mol is not None else None
for mol in result[0]
],
"boundary_atoms_products": result[1],
"nearest_neighbor_products": result[2],
"issue": "Find Missing Graph terminated by timeout",
}
except Exception as e:
return {
"smiles": [],
"boundary_atoms_products": [],
"nearest_neighbor_products": [],
"issue": str(e),
"issue": "Find Missing Graph terminated by timeout",
}

results = Parallel(n_jobs=n_jobs, verbose=0)(
Expand Down

0 comments on commit 597f7f3

Please sign in to comment.