Skip to content

Commit

Permalink
Merge pull request #21 from TieuLongPhan/main
Browse files Browse the repository at this point in the history
v0.0.11
  • Loading branch information
klausweinbauer authored May 13, 2024
2 parents 5492c3e + 7a1c731 commit c737ada
Show file tree
Hide file tree
Showing 9 changed files with 137 additions and 28 deletions.
4 changes: 2 additions & 2 deletions Test/SynMCSImputer/SubStructure/test_mcs_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,5 @@ def test_timeout():
]
result = ensemble_mcs(data, conditions, n_jobs=2)
assert "timeout" in result[0][0]["issue"]
assert [] == result[0][0]["mcs_results"]
assert [] == result[0][0]["sorted_reactants"]
assert [''] == result[0][0]["mcs_results"]
assert "C[Si](C)(Br)O[Si](C)(C)O[Si]" in result[0][0]["sorted_reactants"][0]
18 changes: 18 additions & 0 deletions Test/test_balancing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from synrbl import Balancer


def test_e2e_1():
n = 100
reactant = "[Br]" + n * "[Si](C)(C)O" + "[Si](C)(C)[Br]"
product = "O" + n * "[Si](C)(C)O" + "[Si](C)(C)O"
reaction = "{}>>{}".format(reactant, product)
exp_result = "{}.{}>>{}.{}".format(reactant, "O.O", product, "Br.Br")

balancer = Balancer()
balancer.rb_method.n_jobs = 1
balancer.rb_validator.n_jobs = 1
balancer.mcs_validator.n_jobs = 1
balancer.input_validator.n_jobs = 1

result = balancer.rebalance(reaction)
assert exp_result == result[0]
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "synrbl"
version = "0.0.10"
version = "0.0.11"
authors = [
{name="Tieu Long Phan", email="long.tieu_phan@uni-leipzig.de"},
{name="Klaus Weinbauer", email="klaus@bioinf.uni-leipzig.de"}
Expand Down
89 changes: 80 additions & 9 deletions synrbl/SynMCSImputer/MissingGraph/find_graph_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from rdkit.rdBase import BlockLogs
from joblib import Parallel, delayed
from typing import List
import multiprocessing
import multiprocessing.pool
from synrbl.SynMCSImputer.MissingGraph.find_missing_graphs import FindMissingGraphs
from synrbl.SynMCSImputer.MissingGraph.uncertainty_graph import GraphMissingUncertainty

Expand Down Expand Up @@ -86,23 +88,89 @@ def find_single_graph_parallel(mcs_mol_list, sorted_reactants_mol_list, n_jobs=4
- 'issue' (str): Any issues encountered during processing.
"""

def process_single_pair(reactant_mol, mcs_mol):
# def process_single_pair(reactant_mol, mcs_mol):
# try:
# block = BlockLogs()
# (
# mols,
# boundary_atoms_products,
# nearest_neighbor_products,
# ) = FindMissingGraphs.find_missing_parts_pairs(reactant_mol, mcs_mol)
# del block
# return {
# "smiles": [
# Chem.MolToSmiles(mol) if mol is not None else None for mol in mols
# ],
# "boundary_atoms_products": boundary_atoms_products,
# "nearest_neighbor_products": nearest_neighbor_products,
# "issue": "",
# }
# except Exception as e:
# return {
# "smiles": [],
# "boundary_atoms_products": [],
# "nearest_neighbor_products": [],
# "issue": str(e),
# }

# def process_single_pair_safe(reactant_mol, mcs_mol, job_timeout=5):
# pool = multiprocessing.Pool(1)
# async_result = pool.apply_async(
# process_single_pair,
# (
# reactant_mol,
# mcs_mol,
# ),
# )
# try:
# return async_result.get(job_timeout)
# except multiprocessing.TimeoutError:
# return {
# "smiles": [],
# "boundary_atoms_products": [],
# "nearest_neighbor_products": [],
# "issue": "Find Missing Graph terminated by timeout.",
# }
# finally:
# pool.terminate() # Terminate the pool to release resources
def process_single_pair(reactant_mol, mcs_mol, job_timeout=2):
try:
block = BlockLogs()
(
mols,
boundary_atoms_products,
nearest_neighbor_products,
) = FindMissingGraphs.find_missing_parts_pairs(reactant_mol, mcs_mol)
pool = multiprocessing.Pool(1)
async_result = pool.apply_async(
FindMissingGraphs.find_missing_parts_pairs,
(
reactant_mol,
mcs_mol,
),
)
result = async_result.get(job_timeout)
pool.terminate() # Terminate the pool to release resources
del block
return {
"smiles": [
Chem.MolToSmiles(mol) if mol is not None else None for mol in mols
Chem.MolToSmiles(mol) if mol is not None else None
for mol in result[0]
],
"boundary_atoms_products": boundary_atoms_products,
"nearest_neighbor_products": nearest_neighbor_products,
"boundary_atoms_products": result[1],
"nearest_neighbor_products": result[2],
"issue": "",
}
except multiprocessing.TimeoutError:
pool.terminate() # Terminate the pool in case of timeout

result = FindMissingGraphs.find_missing_parts_pairs(
reactant_mol, mcs_mol, False
)
return {
"smiles": [
Chem.MolToSmiles(mol) if mol is not None else None
for mol in result[0]
],
"boundary_atoms_products": result[1],
"nearest_neighbor_products": result[2],
"issue": "Find Missing Graph terminated by timeout",
}
except Exception as e:
return {
"smiles": [],
Expand All @@ -122,6 +190,9 @@ def find_graph_dict(mcs_dict, n_jobs: int = 4):
"""
Function to find missing graphs for a given MCS dictionary.
"""
if len(mcs_dict) == 0:
return []

msc_df = pd.DataFrame(mcs_dict)

mcs_results = msc_df["mcs_results"].to_list()
Expand Down
15 changes: 10 additions & 5 deletions synrbl/SynMCSImputer/MissingGraph/find_missing_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def __init__(self):

@staticmethod
def find_missing_parts_pairs(
mol_list: List[Chem.Mol], mcs_list: Optional[List[Chem.Mol]] = None
mol_list: List[Chem.Mol],
mcs_list: Optional[List[Chem.Mol]] = None,
substructure_optimize: bool = True,
) -> Tuple[
Optional[List[Chem.Mol]], List[List[Dict[str, int]]], List[List[Dict[str, int]]]
]:
Expand Down Expand Up @@ -93,10 +95,13 @@ def find_missing_parts_pairs(
if not substructure_match:
substructure_match = mol.GetSubstructMatch(mcs_mol)
else:
analyzer = SubstructureAnalyzer()
substructure_match = analyzer.identify_optimal_substructure(
parent_mol=mol, child_mol=mcs_mol
)
if substructure_optimize:
analyzer = SubstructureAnalyzer()
substructure_match = analyzer.identify_optimal_substructure(
parent_mol=mol, child_mol=mcs_mol
)
else:
substructure_match = mol.GetSubstructMatch(mcs_mol)

if substructure_match:
atoms_to_remove.update(substructure_match)
Expand Down
25 changes: 18 additions & 7 deletions synrbl/SynMCSImputer/SubStructure/mcs_graph_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def IterativeMCSReactionPairs(
method="MCIS",
sort="MCIS",
remove_substructure=True,
maxNodes=80,
maxNodes=200,
substructure_optimize=True,
):
"""
Find the MCS for each reactant fragment with the product, updating the
Expand Down Expand Up @@ -144,12 +145,20 @@ def IterativeMCSReactionPairs(
# Conditional substructure removal
if remove_substructure:
# Identify the optimal substructure
analyzer = SubstructureAnalyzer()
optimal_substructure = analyzer.identify_optimal_substructure(
parent_mol=current_product,
child_mol=mcs_mol,
maxNodes=maxNodes,
)
if substructure_optimize:
analyzer = SubstructureAnalyzer()
optimal_substructure = (
analyzer.identify_optimal_substructure(
parent_mol=current_product,
child_mol=mcs_mol,
maxNodes=maxNodes,
)
)
else:
optimal_substructure = current_product.GetSubstructMatch(
mcs_mol
)

if optimal_substructure:
rw_mol = Chem.RWMol(current_product)
# Remove atoms in descending order of their indices
Expand Down Expand Up @@ -185,6 +194,7 @@ def fit(
ignore_atom_map=False,
ignore_bond_order=False,
maxNodes=80,
substructure_optimize=True,
):
"""
Process a reaction dictionary to find MCS, missing parts in reactants
Expand Down Expand Up @@ -247,6 +257,7 @@ def fit(
sort=sort,
remove_substructure=remove_substructure,
maxNodes=maxNodes,
substructure_optimize=substructure_optimize,
)

return mcs_list, sorted_parents, reactant_mol_list, product_mol
Expand Down
9 changes: 7 additions & 2 deletions synrbl/SynMCSImputer/SubStructure/mcs_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def single_mcs(
similarityThreshold=0.5,
remove_substructure=True,
ignore_bond_order=True,
maxNodes=100,
maxNodes=200,
substructure_optimize: bool = True,
):
"""
Performs MCS on a single reaction data entry and captures any issues encountered.
Expand All @@ -52,7 +53,8 @@ def single_mcs(
timeout=timeout,
similarityThreshold=similarityThreshold,
ignore_bond_order=ignore_bond_order,
maxNodes=maxNodes
maxNodes=maxNodes,
substructure_optimize=substructure_optimize,
)

if len(reactant_mol_list) != len(sorted_reactants):
Expand Down Expand Up @@ -88,6 +90,9 @@ def single_mcs_safe(data_dict, job_timeout=2, id_col="id", issue_col="issue", **
try:
return async_result.get(job_timeout)
except multiprocessing.TimeoutError:
mcs_data = single_mcs(
data_dict=data_dict, mcs_data=mcs_data, substructure_optimize=False
)
mcs_data[issue_col] = "MCS search terminated by timeout."
return mcs_data

Expand Down
2 changes: 1 addition & 1 deletion synrbl/SynMCSImputer/SubStructure/substructure_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def sort_substructures_by_fragment_count(
return [pair[0] for pair in paired_list]

def identify_optimal_substructure(
self, parent_mol: Mol, child_mol: Mol, maxNodes: int = 80
self, parent_mol: Mol, child_mol: Mol, maxNodes: int = 200
) -> Tuple[int, ...]:
"""
Identifies the most relevant substructure within a parent molecule
Expand Down
1 change: 0 additions & 1 deletion synrbl/mcs_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def find(self, reactions):
)

largest_conditions = ExtractMCS.get_largest_condition(*condition_results)
print(largest_conditions)

mcs_results = find_graph_dict(largest_conditions)

Expand Down

0 comments on commit c737ada

Please sign in to comment.