diff --git a/Test/SynMCSImputer/SubStructure/test_mcs_process.py b/Test/SynMCSImputer/SubStructure/test_mcs_process.py index c71f297..5bca0dd 100644 --- a/Test/SynMCSImputer/SubStructure/test_mcs_process.py +++ b/Test/SynMCSImputer/SubStructure/test_mcs_process.py @@ -63,5 +63,5 @@ def test_timeout(): ] result = ensemble_mcs(data, conditions, n_jobs=2) assert "timeout" in result[0][0]["issue"] - assert [] == result[0][0]["mcs_results"] - assert [] == result[0][0]["sorted_reactants"] + assert [''] == result[0][0]["mcs_results"] + assert "C[Si](C)(Br)O[Si](C)(C)O[Si]" in result[0][0]["sorted_reactants"][0] diff --git a/Test/test_balancing.py b/Test/test_balancing.py new file mode 100644 index 0000000..12a8e23 --- /dev/null +++ b/Test/test_balancing.py @@ -0,0 +1,18 @@ +from synrbl import Balancer + + +def test_e2e_1(): + n = 100 + reactant = "[Br]" + n * "[Si](C)(C)O" + "[Si](C)(C)[Br]" + product = "O" + n * "[Si](C)(C)O" + "[Si](C)(C)O" + reaction = "{}>>{}".format(reactant, product) + exp_result = "{}.{}>>{}.{}".format(reactant, "O.O", product, "Br.Br") + + balancer = Balancer() + balancer.rb_method.n_jobs = 1 + balancer.rb_validator.n_jobs = 1 + balancer.mcs_validator.n_jobs = 1 + balancer.input_validator.n_jobs = 1 + + result = balancer.rebalance(reaction) + assert exp_result == result[0] diff --git a/pyproject.toml b/pyproject.toml index 715deee..07d902c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "synrbl" -version = "0.0.10" +version = "0.0.11" authors = [ {name="Tieu Long Phan", email="long.tieu_phan@uni-leipzig.de"}, {name="Klaus Weinbauer", email="klaus@bioinf.uni-leipzig.de"} diff --git a/synrbl/SynMCSImputer/MissingGraph/find_graph_dict.py b/synrbl/SynMCSImputer/MissingGraph/find_graph_dict.py index d655d4f..4fe31e0 100644 --- a/synrbl/SynMCSImputer/MissingGraph/find_graph_dict.py +++ b/synrbl/SynMCSImputer/MissingGraph/find_graph_dict.py @@ -4,6 +4,8 @@ from rdkit.rdBase import BlockLogs from joblib import Parallel, delayed from typing import List +import multiprocessing +import multiprocessing.pool from synrbl.SynMCSImputer.MissingGraph.find_missing_graphs import FindMissingGraphs from synrbl.SynMCSImputer.MissingGraph.uncertainty_graph import GraphMissingUncertainty @@ -86,23 +88,89 @@ def find_single_graph_parallel(mcs_mol_list, sorted_reactants_mol_list, n_jobs=4 - 'issue' (str): Any issues encountered during processing. """ - def process_single_pair(reactant_mol, mcs_mol): + # def process_single_pair(reactant_mol, mcs_mol): + # try: + # block = BlockLogs() + # ( + # mols, + # boundary_atoms_products, + # nearest_neighbor_products, + # ) = FindMissingGraphs.find_missing_parts_pairs(reactant_mol, mcs_mol) + # del block + # return { + # "smiles": [ + # Chem.MolToSmiles(mol) if mol is not None else None for mol in mols + # ], + # "boundary_atoms_products": boundary_atoms_products, + # "nearest_neighbor_products": nearest_neighbor_products, + # "issue": "", + # } + # except Exception as e: + # return { + # "smiles": [], + # "boundary_atoms_products": [], + # "nearest_neighbor_products": [], + # "issue": str(e), + # } + + # def process_single_pair_safe(reactant_mol, mcs_mol, job_timeout=5): + # pool = multiprocessing.Pool(1) + # async_result = pool.apply_async( + # process_single_pair, + # ( + # reactant_mol, + # mcs_mol, + # ), + # ) + # try: + # return async_result.get(job_timeout) + # except multiprocessing.TimeoutError: + # return { + # "smiles": [], + # "boundary_atoms_products": [], + # "nearest_neighbor_products": [], + # "issue": "Find Missing Graph terminated by timeout.", + # } + # finally: + # pool.terminate() # Terminate the pool to release resources + def process_single_pair(reactant_mol, mcs_mol, job_timeout=2): try: block = BlockLogs() - ( - mols, - boundary_atoms_products, - nearest_neighbor_products, - ) = FindMissingGraphs.find_missing_parts_pairs(reactant_mol, mcs_mol) + pool = multiprocessing.Pool(1) + async_result = pool.apply_async( + FindMissingGraphs.find_missing_parts_pairs, + ( + reactant_mol, + mcs_mol, + ), + ) + result = async_result.get(job_timeout) + pool.terminate() # Terminate the pool to release resources del block return { "smiles": [ - Chem.MolToSmiles(mol) if mol is not None else None for mol in mols + Chem.MolToSmiles(mol) if mol is not None else None + for mol in result[0] ], - "boundary_atoms_products": boundary_atoms_products, - "nearest_neighbor_products": nearest_neighbor_products, + "boundary_atoms_products": result[1], + "nearest_neighbor_products": result[2], "issue": "", } + except multiprocessing.TimeoutError: + pool.terminate() # Terminate the pool in case of timeout + + result = FindMissingGraphs.find_missing_parts_pairs( + reactant_mol, mcs_mol, False + ) + return { + "smiles": [ + Chem.MolToSmiles(mol) if mol is not None else None + for mol in result[0] + ], + "boundary_atoms_products": result[1], + "nearest_neighbor_products": result[2], + "issue": "Find Missing Graph terminated by timeout", + } except Exception as e: return { "smiles": [], @@ -122,6 +190,9 @@ def find_graph_dict(mcs_dict, n_jobs: int = 4): """ Function to find missing graphs for a given MCS dictionary. """ + if len(mcs_dict) == 0: + return [] + msc_df = pd.DataFrame(mcs_dict) mcs_results = msc_df["mcs_results"].to_list() diff --git a/synrbl/SynMCSImputer/MissingGraph/find_missing_graphs.py b/synrbl/SynMCSImputer/MissingGraph/find_missing_graphs.py index b4429e9..4d36d92 100644 --- a/synrbl/SynMCSImputer/MissingGraph/find_missing_graphs.py +++ b/synrbl/SynMCSImputer/MissingGraph/find_missing_graphs.py @@ -41,7 +41,9 @@ def __init__(self): @staticmethod def find_missing_parts_pairs( - mol_list: List[Chem.Mol], mcs_list: Optional[List[Chem.Mol]] = None + mol_list: List[Chem.Mol], + mcs_list: Optional[List[Chem.Mol]] = None, + substructure_optimize: bool = True, ) -> Tuple[ Optional[List[Chem.Mol]], List[List[Dict[str, int]]], List[List[Dict[str, int]]] ]: @@ -93,10 +95,13 @@ def find_missing_parts_pairs( if not substructure_match: substructure_match = mol.GetSubstructMatch(mcs_mol) else: - analyzer = SubstructureAnalyzer() - substructure_match = analyzer.identify_optimal_substructure( - parent_mol=mol, child_mol=mcs_mol - ) + if substructure_optimize: + analyzer = SubstructureAnalyzer() + substructure_match = analyzer.identify_optimal_substructure( + parent_mol=mol, child_mol=mcs_mol + ) + else: + substructure_match = mol.GetSubstructMatch(mcs_mol) if substructure_match: atoms_to_remove.update(substructure_match) diff --git a/synrbl/SynMCSImputer/SubStructure/mcs_graph_detector.py b/synrbl/SynMCSImputer/SubStructure/mcs_graph_detector.py index b028e4b..9337c94 100644 --- a/synrbl/SynMCSImputer/SubStructure/mcs_graph_detector.py +++ b/synrbl/SynMCSImputer/SubStructure/mcs_graph_detector.py @@ -50,7 +50,8 @@ def IterativeMCSReactionPairs( method="MCIS", sort="MCIS", remove_substructure=True, - maxNodes=80, + maxNodes=200, + substructure_optimize=True, ): """ Find the MCS for each reactant fragment with the product, updating the @@ -144,12 +145,20 @@ def IterativeMCSReactionPairs( # Conditional substructure removal if remove_substructure: # Identify the optimal substructure - analyzer = SubstructureAnalyzer() - optimal_substructure = analyzer.identify_optimal_substructure( - parent_mol=current_product, - child_mol=mcs_mol, - maxNodes=maxNodes, - ) + if substructure_optimize: + analyzer = SubstructureAnalyzer() + optimal_substructure = ( + analyzer.identify_optimal_substructure( + parent_mol=current_product, + child_mol=mcs_mol, + maxNodes=maxNodes, + ) + ) + else: + optimal_substructure = current_product.GetSubstructMatch( + mcs_mol + ) + if optimal_substructure: rw_mol = Chem.RWMol(current_product) # Remove atoms in descending order of their indices @@ -185,6 +194,7 @@ def fit( ignore_atom_map=False, ignore_bond_order=False, maxNodes=80, + substructure_optimize=True, ): """ Process a reaction dictionary to find MCS, missing parts in reactants @@ -247,6 +257,7 @@ def fit( sort=sort, remove_substructure=remove_substructure, maxNodes=maxNodes, + substructure_optimize=substructure_optimize, ) return mcs_list, sorted_parents, reactant_mol_list, product_mol diff --git a/synrbl/SynMCSImputer/SubStructure/mcs_process.py b/synrbl/SynMCSImputer/SubStructure/mcs_process.py index 2f2d2c1..a97332f 100644 --- a/synrbl/SynMCSImputer/SubStructure/mcs_process.py +++ b/synrbl/SynMCSImputer/SubStructure/mcs_process.py @@ -26,7 +26,8 @@ def single_mcs( similarityThreshold=0.5, remove_substructure=True, ignore_bond_order=True, - maxNodes=100, + maxNodes=200, + substructure_optimize: bool = True, ): """ Performs MCS on a single reaction data entry and captures any issues encountered. @@ -52,7 +53,8 @@ def single_mcs( timeout=timeout, similarityThreshold=similarityThreshold, ignore_bond_order=ignore_bond_order, - maxNodes=maxNodes + maxNodes=maxNodes, + substructure_optimize=substructure_optimize, ) if len(reactant_mol_list) != len(sorted_reactants): @@ -88,6 +90,9 @@ def single_mcs_safe(data_dict, job_timeout=2, id_col="id", issue_col="issue", ** try: return async_result.get(job_timeout) except multiprocessing.TimeoutError: + mcs_data = single_mcs( + data_dict=data_dict, mcs_data=mcs_data, substructure_optimize=False + ) mcs_data[issue_col] = "MCS search terminated by timeout." return mcs_data diff --git a/synrbl/SynMCSImputer/SubStructure/substructure_analyzer.py b/synrbl/SynMCSImputer/SubStructure/substructure_analyzer.py index 7aae4f6..e28c3a0 100644 --- a/synrbl/SynMCSImputer/SubStructure/substructure_analyzer.py +++ b/synrbl/SynMCSImputer/SubStructure/substructure_analyzer.py @@ -63,7 +63,7 @@ def sort_substructures_by_fragment_count( return [pair[0] for pair in paired_list] def identify_optimal_substructure( - self, parent_mol: Mol, child_mol: Mol, maxNodes: int = 80 + self, parent_mol: Mol, child_mol: Mol, maxNodes: int = 200 ) -> Tuple[int, ...]: """ Identifies the most relevant substructure within a parent molecule diff --git a/synrbl/mcs_search.py b/synrbl/mcs_search.py index 1a0bf97..6a036a3 100644 --- a/synrbl/mcs_search.py +++ b/synrbl/mcs_search.py @@ -83,7 +83,6 @@ def find(self, reactions): ) largest_conditions = ExtractMCS.get_largest_condition(*condition_results) - print(largest_conditions) mcs_results = find_graph_dict(largest_conditions)