From 925e4eb0fc3cb32282aaa3e97ed09b3dec3a442b Mon Sep 17 00:00:00 2001 From: Klaus Weinbauer Date: Wed, 22 May 2024 09:32:18 +0200 Subject: [PATCH 01/10] Revert "remove charges in normalize smiles" This reverts commit e928aad1a210c304285e4cb0491058cad5efdccf. --- Test/SynUtils/test_chem_utils.py | 21 --------------------- synrbl/SynUtils/chem_utils.py | 14 -------------- 2 files changed, 35 deletions(-) diff --git a/Test/SynUtils/test_chem_utils.py b/Test/SynUtils/test_chem_utils.py index 896c57e..64d6038 100644 --- a/Test/SynUtils/test_chem_utils.py +++ b/Test/SynUtils/test_chem_utils.py @@ -5,7 +5,6 @@ remove_atom_mapping, normalize_smiles, count_atoms, - remove_charge, ) @@ -85,16 +84,6 @@ def test_remove_stereochemistry(self): result = normalize_smiles(smiles) self.assertEqual("CCC", result) - def test_remove_charge1(self): - smiles = "[N-:13]#[C+:19]" - result = remove_charge(smiles) - self.assertEqual("[N:13]#[C:19]", result) - - def test_remove_charge2(self): - smiles = "[N-]#[C+]" - result = remove_charge(smiles) - self.assertEqual("N#C", result) - def test_edge_case_1(self): smiles = "F[Sb@OH12](F)(F)(F)(F)F" result = normalize_smiles(smiles) @@ -110,16 +99,6 @@ def test_ordering_1(self): result = normalize_smiles(smiles) self.assertEqual("C=O.[HH]", result) - def test_remove_charges1(self): - smiles = "[N-:13]#[C+:19]" - result = normalize_smiles(smiles) - self.assertEqual("C#N", result) - - # def test_remove_charges2(self): - # smiles = "C[N-:13]#[C+:19]" - # result = normalize_smiles(smiles) - # self.assertEqual("CN#C", result) - @pytest.mark.parametrize( "smiles,exp_atom_cnt", [("O=C", 2), ("CO", 2), ("HH", 0), ("c1ccccc1", 6)] diff --git a/synrbl/SynUtils/chem_utils.py b/synrbl/SynUtils/chem_utils.py index 1d08303..cd5386e 100644 --- a/synrbl/SynUtils/chem_utils.py +++ b/synrbl/SynUtils/chem_utils.py @@ -133,12 +133,6 @@ def calculate_net_charge(sublist: list[dict[str, Union[str, int]]]) -> int: return total_charge -def remove_unnecessary_brackets(smiles: str) -> str: - pattern = re.compile(r"\[(?P(B|C|N|O|P|S|F|Cl|Br|I){1,2})\]") - smiles = pattern.sub(r"\g", smiles) - return smiles - - def remove_atom_mapping(smiles: str) -> str: pattern = re.compile(r":\d+") smiles = pattern.sub("", smiles) @@ -153,13 +147,6 @@ def remove_stereo_chemistry(smiles: str) -> str: return smiles -def remove_charge(smiles: str) -> str: - smiles = smiles.replace("+", "") - smiles = smiles.replace("-", "") - smiles = remove_unnecessary_brackets(smiles) - return smiles - - def count_atoms(smiles: str) -> int: pattern = re.compile(r"(B|C|N|O|P|S|F|Cl|Br|I|c|n|o)") return len(pattern.findall(smiles)) @@ -167,7 +154,6 @@ def count_atoms(smiles: str) -> int: def normalize_smiles(smiles: str) -> str: smiles = remove_stereo_chemistry(smiles) - smiles = remove_charge(smiles) if ">>" in smiles: return ">>".join([normalize_smiles(t) for t in smiles.split(">>")]) elif "." in smiles: From cbdc68a4634aeebaf69b56012b51398770a57645 Mon Sep 17 00:00:00 2001 From: Klaus Weinbauer Date: Wed, 22 May 2024 11:39:37 +0200 Subject: [PATCH 02/10] add issue if final reaction is unbalanced --- synrbl/balancing.py | 9 ++++++++- synrbl/postprocess.py | 4 +++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/synrbl/balancing.py b/synrbl/balancing.py index 4f6ad6d..0cded74 100644 --- a/synrbl/balancing.py +++ b/synrbl/balancing.py @@ -1,5 +1,6 @@ import copy import logging +import traceback from synrbl.preprocess import preprocess from synrbl.postprocess import Validator @@ -189,6 +190,11 @@ def __run_pipeline(self, reactions, stats=None): self.__post_process(reactions) self.rb_method.run(reactions) self.mcs_validator.check(reactions, override_unsolved=True) + self.mcs_validator.check( + reactions, + override_unsolved=True, + override_issue_msg="Final reaction is unbalanced.", + ) self.conf_predictor.predict( reactions, stats=stats, threshold=self.confidence_threshold @@ -265,7 +271,8 @@ def __rebalance_batch(self, batch, cache_manager): ) logger.info("Cached new results. (Key: {})".format(cache_key[:8])) except Exception as e: - logger.error("Pipeline execution failed: {}".format(e)) + traceback.print_exc() + logger.error("Pipeline execution failed: {}".format(type(e))) return result, batch_stats diff --git a/synrbl/postprocess.py b/synrbl/postprocess.py index a56ac9b..f365b20 100644 --- a/synrbl/postprocess.py +++ b/synrbl/postprocess.py @@ -25,7 +25,7 @@ def __init__( self.issue_col = issue_col self.n_jobs = n_jobs - def check(self, reactions, override_unsolved=False): + def check(self, reactions, override_unsolved=False, override_issue_msg=None): update_reactants_and_products(reactions, self.reaction_col) decompose = RSMIDecomposer( smiles=None, # type: ignore @@ -69,4 +69,6 @@ def check(self, reactions, override_unsolved=False): reaction[self.solved_method_col] = self.method if override_unsolved and not reaction[self.solved_col]: reaction[self.reaction_col] = reaction["input_reaction"] + if override_issue_msg is not None and reaction[self.issue_col] == "": + reaction[self.issue_col] = override_issue_msg return reactions From fd0bdcd1b0ad1e7b287e654fbda3f140dc3f028e Mon Sep 17 00:00:00 2001 From: Klaus Weinbauer Date: Wed, 22 May 2024 11:39:59 +0200 Subject: [PATCH 03/10] fix --- synrbl/balancing.py | 1 - 1 file changed, 1 deletion(-) diff --git a/synrbl/balancing.py b/synrbl/balancing.py index 0cded74..6cf3081 100644 --- a/synrbl/balancing.py +++ b/synrbl/balancing.py @@ -189,7 +189,6 @@ def __run_pipeline(self, reactions, stats=None): ) self.__post_process(reactions) self.rb_method.run(reactions) - self.mcs_validator.check(reactions, override_unsolved=True) self.mcs_validator.check( reactions, override_unsolved=True, From 4d31842674ba2ea48fb00c57649aa263fa0c44dd Mon Sep 17 00:00:00 2001 From: Klaus Weinbauer Date: Wed, 22 May 2024 13:34:50 +0200 Subject: [PATCH 04/10] fix closed compound passthrough in merge --- Test/SynMCSImputer/test_merge.py | 7 +++++++ synrbl/SynMCSImputer/merge.py | 4 +++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/Test/SynMCSImputer/test_merge.py b/Test/SynMCSImputer/test_merge.py index d247b38..049ce61 100644 --- a/Test/SynMCSImputer/test_merge.py +++ b/Test/SynMCSImputer/test_merge.py @@ -306,3 +306,10 @@ def test_ignore_water_in_passthrough(self): cm = merge.merge(cset) self.assertEqual("remove_water_catalyst", cm.rules[0].name) self.assertEqual("CO", cm.smiles) + + def test_catalyst_passthrough(self): + cset = CompoundSet() + cset.add_compound("C", src_mol="C") + cm = merge.merge(cset) + self.assertEqual(0, len(cm.rules)) + self.assertEqual("C", cm.smiles) diff --git a/synrbl/SynMCSImputer/merge.py b/synrbl/SynMCSImputer/merge.py index c6c42fd..0519315 100644 --- a/synrbl/SynMCSImputer/merge.py +++ b/synrbl/SynMCSImputer/merge.py @@ -102,7 +102,9 @@ def merge(compound_set: CompoundSet) -> Compound: else: comps_with_boundaries.append(c) - if len(comps_with_boundaries) == 1: + if len(comps_with_boundaries) == 0 and len(comps_without_boundaries) > 0: + merged_compound = comps_without_boundaries.pop() + elif len(comps_with_boundaries) == 1: merged_compound = _merge_one_compound(comps_with_boundaries[0]) elif len(comps_with_boundaries) == 2: merged_compound = _merge_two_compounds( From bcf5c5ed5f84c4f13008fec35576e7e6be63c313 Mon Sep 17 00:00:00 2001 From: Klaus Weinbauer Date: Wed, 22 May 2024 13:35:49 +0200 Subject: [PATCH 05/10] improve mcs vis debug --- synrbl/SynMCSImputer/mcs_based_method.py | 5 +-- synrbl/SynVis/vis_debug.py | 39 ++++++++++++++++++------ 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/synrbl/SynMCSImputer/mcs_based_method.py b/synrbl/SynMCSImputer/mcs_based_method.py index 3d16f93..eccd274 100644 --- a/synrbl/SynMCSImputer/mcs_based_method.py +++ b/synrbl/SynMCSImputer/mcs_based_method.py @@ -108,7 +108,7 @@ def __init__( smiles_standardizer=[], ): self.reaction_col = reaction_col - self.output_col = output_col + self.output_col = output_col if isinstance(output_col, list) else [output_col] self.mcs_data_col = mcs_data_col self.issue_col = issue_col self.rules_col = rules_col @@ -134,7 +134,8 @@ def run(self, reactions: list[dict], stats=None): carbon_balance_col=self.carbon_balance_col, smiles_standardizer=self.smiles_standardizer, ) - reaction[self.output_col] = result + for col in self.output_col: + reaction[col] = result reaction[self.rules_col] = rules mcs_solved += 1 except Exception as e: diff --git a/synrbl/SynVis/vis_debug.py b/synrbl/SynVis/vis_debug.py index efda492..2037e18 100644 --- a/synrbl/SynVis/vis_debug.py +++ b/synrbl/SynVis/vis_debug.py @@ -45,7 +45,8 @@ class MCSDebug: def __init__(self): self.fontsize = 9 self.balancer = synrbl.Balancer() - self.balancer.columns.append("mcs") + self.balancer.columns.extend(["mcs", "mcs_based_result"]) + self.balancer.mcs_method.output_col.append("mcs_based_result") self.cairosize = (1600, 900) self.highlight_color = (0.4, 0.9, 0.6, 1) @@ -70,10 +71,18 @@ def plot(self, smiles, verbose=True): ) fig = plt.figure() - gs = fig.add_gridspec(3, len(mols)) - ax1 = fig.add_subplot(gs[0, :]) - axs2 = [fig.add_subplot(gs[1, i]) for i in range(len(mols))] - ax3 = fig.add_subplot(gs[2, :]) + ax_mcs = None + if "mcs_based_result" in result.keys(): + gs = fig.add_gridspec(4, len(mols)) + ax1 = fig.add_subplot(gs[0, :]) + axs2 = [fig.add_subplot(gs[1, i]) for i in range(len(mols))] + ax_mcs = fig.add_subplot(gs[2, :]) + ax_final = fig.add_subplot(gs[3, :]) + else: + gs = fig.add_gridspec(3, len(mols)) + ax1 = fig.add_subplot(gs[0, :]) + axs2 = [fig.add_subplot(gs[1, i]) for i in range(len(mols))] + ax_final = fig.add_subplot(gs[2, :]) rxnvis = RxnVis(cairosize=self.cairosize) img = rxnvis.get_rxn_img(result["input_reaction"]) @@ -118,16 +127,26 @@ def plot(self, smiles, verbose=True): ax.axis("off") ax.set_title(title, fontsize=self.fontsize) + rxnvis = RxnVis(cairosize=self.cairosize) + if ax_mcs is not None: + img = rxnvis.get_rxn_img(result["mcs_based_result"]) + ax_mcs.imshow(img) + ax_mcs.set_title( + "MCS-Based Result\nRules: {}".format(result.get("rules", None)), + fontsize=self.fontsize, + ) + ax_mcs.axis("off") + rxnvis = RxnVis(cairosize=self.cairosize) img = rxnvis.get_rxn_img(result["reaction"]) - ax3.imshow(img) - ax3.set_title( - "Result (Confidence: {:.1%})\nRules: {}\nIssue: {}".format( - result.get("confidence", 0), result.get("rules", None), result["issue"] + ax_final.imshow(img) + ax_final.set_title( + "Result (Confidence: {:.1%})\nIssue: {}".format( + result.get("confidence", 0), result["issue"] ), fontsize=self.fontsize, ) - ax3.axis("off") + ax_final.axis("off") plt.tight_layout() plt.show() From 078ce0b209450812a864c1ef6e996756f4c53cfe Mon Sep 17 00:00:00 2001 From: Klaus Weinbauer Date: Wed, 22 May 2024 14:23:45 +0200 Subject: [PATCH 06/10] fix explicit hydrogens in mol standardizer --- Test/SynChemImputer/test_molecule_standardizer.py | 7 +++++++ synrbl/SynChemImputer/molecule_standardizer.py | 7 +++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/Test/SynChemImputer/test_molecule_standardizer.py b/Test/SynChemImputer/test_molecule_standardizer.py index b51f708..3f0b4ba 100644 --- a/Test/SynChemImputer/test_molecule_standardizer.py +++ b/Test/SynChemImputer/test_molecule_standardizer.py @@ -23,6 +23,13 @@ def test_hemiketal_transformation(self): result, expected, "Hemiketal transformation failed or incorrect" ) + def test_hemiketal_transformation_with_aam(self): + smiles = "[C:1]([O:2])([OH:3])" + atom_indices = [0, 1, 2] + result = MoleculeStandardizer.standardize_hemiketal(smiles, atom_indices) + expected = ["[C:1]=[O:2].[OH2:3]", "[C:1]=[O:3].[OH2:2]"] + self.assertIn(result, expected) + def test_MoleculeStandardizer(self): smiles = "C(O)(O)C=CO" standardizer = MoleculeStandardizer() diff --git a/synrbl/SynChemImputer/molecule_standardizer.py b/synrbl/SynChemImputer/molecule_standardizer.py index d2e6f68..c2cfc38 100644 --- a/synrbl/SynChemImputer/molecule_standardizer.py +++ b/synrbl/SynChemImputer/molecule_standardizer.py @@ -118,25 +118,28 @@ def standardize_hemiketal(smiles: str, atom_indices: List[int]) -> str: """ # Load the molecule from SMILES and create an editable molecule object mol = Chem.MolFromSmiles(smiles) - emol = Chem.EditableMol(mol) # Initialize indices c_idx, o1_idx, o2_idx = None, None, None for i in atom_indices: - atom_symbol = mol.GetAtomWithIdx(i).GetSymbol() + atom = mol.GetAtomWithIdx(i) + atom_symbol = atom.GetSymbol() if atom_symbol == "C": c_idx = i elif atom_symbol == "O": if o1_idx is None: o1_idx = i # Assume the first oxygen encountered is O1 + atom.SetNumExplicitHs(0) else: o2_idx = i # The next oxygen is O2 + atom.SetNumExplicitHs(2) # Check if all indices are assigned if None in [c_idx, o1_idx, o2_idx]: return "Invalid atom indices provided. Please check the input." # Attempt to modify the molecule structure + emol = Chem.EditableMol(mol) try: # Remove existing bonds if they exist emol.RemoveBond(c_idx, o1_idx) From f27e72b3aecec1c9415fe0f10e9b0e1ed0cc1d85 Mon Sep 17 00:00:00 2001 From: Klaus Weinbauer Date: Wed, 22 May 2024 14:58:10 +0200 Subject: [PATCH 07/10] remove aam for all reactions --- synrbl/balancing.py | 2 ++ synrbl/preprocess.py | 14 +++++++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/synrbl/balancing.py b/synrbl/balancing.py index 6cf3081..c737411 100644 --- a/synrbl/balancing.py +++ b/synrbl/balancing.py @@ -52,6 +52,7 @@ def __init__( self.__issue_col = "issue" self.__n_jobs = n_jobs + self.remove_aam = True self.batch_size = batch_size self.cache = cache self.cache_dir = cache_dir @@ -171,6 +172,7 @@ def __run_pipeline(self, reactions, stats=None): self.__id_col, self.__solved_col, self.__input_col, + remove_aam=self.remove_aam, ) rxn_cnt = len(reactions) self.input_validator.check(reactions) diff --git a/synrbl/preprocess.py b/synrbl/preprocess.py index 0898770..f1aafe3 100644 --- a/synrbl/preprocess.py +++ b/synrbl/preprocess.py @@ -1,9 +1,21 @@ import pandas as pd from synrbl.SynProcessor import RSMIProcessing +from synrbl.SynUtils import remove_atom_mapping -def preprocess(reactions, reaction_col, index_col, solved_col, input_col, n_jobs=1): +def preprocess( + reactions, + reaction_col, + index_col, + solved_col, + input_col, + n_jobs=1, + remove_aam=False, +): + if remove_aam: + for r in reactions: + r[reaction_col] = remove_atom_mapping(r[reaction_col]) df = pd.DataFrame(reactions) df[solved_col] = False From 4a7e2e9ccfcce8b59aae05714199f54085e5e16f Mon Sep 17 00:00:00 2001 From: Klaus Weinbauer Date: Wed, 22 May 2024 15:56:10 +0200 Subject: [PATCH 08/10] fix batch_size default init --- synrbl/SynCmd/cmd_run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synrbl/SynCmd/cmd_run.py b/synrbl/SynCmd/cmd_run.py index 829315d..0a13fc2 100644 --- a/synrbl/SynCmd/cmd_run.py +++ b/synrbl/SynCmd/cmd_run.py @@ -137,7 +137,7 @@ def run(args): args.out_columns if isinstance(args.out_columns, list) else [args.out_columns] ) batch_size = None - if len(args.batch_size) > 0: + if args.batch_size is not None and len(args.batch_size) > 0: batch_size = int(args.batch_size) impute( From 058daf28579b2516bc309e095f9fd59d017aedf2 Mon Sep 17 00:00:00 2001 From: Klaus Weinbauer Date: Wed, 22 May 2024 16:07:00 +0200 Subject: [PATCH 09/10] update test cases --- Test/test_balancing.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Test/test_balancing.py b/Test/test_balancing.py index 6ae825d..1ef0fa0 100644 --- a/Test/test_balancing.py +++ b/Test/test_balancing.py @@ -5,7 +5,7 @@ def test_e2e_1(): n = 100 - reactant = "[Br]" + n * "[Si](C)(C)O" + "[Si](C)(C)[Br]" + reactant = "Br" + n * "[Si](C)(C)O" + "[Si](C)(C)Br" product = "O" + n * "[Si](C)(C)O" + "[Si](C)(C)O" reaction = "{}>>{}".format(reactant, product) exp_result = "{}.{}>>{}.{}".format(reactant, "O.O", product, "Br.Br") @@ -26,8 +26,8 @@ def test_e2e_1(): ["CC(=O)C>>CC(O)C", "CC(=O)C.[HH]>>CC(O)C"], [ "CCO.[O]>>CC=O", - "CCO.O=[Cr](Cl)(-[O-])=O.c1cc[nH+]cc1.O>>" - + "CC=O.O.O=[Cr](O)O.c1cc[nH+]cc1.[Cl-]", + "CCO.O.O=[Cr](Cl)(-[O-])=O.c1cc[nH+]cc1.O>>" + + "CC=O.O.O.O=[Cr](O)O.c1cc[nH+]cc1.[Cl-]", ], ], ) From bbd7a1ef1951f89261a2cacb3abc1ba07692201f Mon Sep 17 00:00:00 2001 From: Klaus Weinbauer Date: Wed, 22 May 2024 16:09:50 +0200 Subject: [PATCH 10/10] prepare release v0.0.21 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 012cfdd..ee8b71a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "synrbl" -version = "0.0.20" +version = "0.0.21" authors = [ {name="Tieu Long Phan", email="long.tieu_phan@uni-leipzig.de"}, {name="Klaus Weinbauer", email="klaus@bioinf.uni-leipzig.de"}