Skip to content

Commit

Permalink
Merge pull request #20 from TieuLongPhan/main
Browse files Browse the repository at this point in the history
v0.0.10
  • Loading branch information
klausweinbauer authored May 1, 2024
2 parents 354c5f7 + f958974 commit 5492c3e
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 26 deletions.
29 changes: 27 additions & 2 deletions Test/SynMCSImputer/SubStructure/test_mcs_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pathlib import Path

from unittest.mock import patch
from synrbl.SynMCSImputer.SubStructure.mcs_process import single_mcs
from synrbl.SynMCSImputer.SubStructure.mcs_process import single_mcs_safe, ensemble_mcs


class TestMCSFunctions(unittest.TestCase):
Expand Down Expand Up @@ -32,11 +32,36 @@ def test_single_mcs(self, mock_fit):
# Mocking MCSMissingGraphAnalyzer.fit to return predefined values
mock_fit.return_value = ([], [], [], None)

result = single_mcs(self.sample_reaction_data, id_col="R-id")
result = single_mcs_safe(self.sample_reaction_data, id_col="R-id")
self.assertEqual(result["R-id"], "example_id")
self.assertIsInstance(result["mcs_results"], list)
self.assertIsInstance(result["sorted_reactants"], list)


if __name__ == "__main__":
unittest.main()


def test_timeout():
conditions = [
{
"RingMatchesRingOnly": True,
"CompleteRingsOnly": True,
"method": "MCIS",
"sort": "MCIS",
"ignore_bond_order": True,
"maxNodes": 1000,
}
]
data = [
{
"id": 0,
"carbon_balance_check": "products",
"reactants": "[Br]" + 100 * "[Si](C)(C)O" + "[Si][Br]",
"products": "O" + 100 * "[Si](C)(C)O" + "[Si]O",
}
]
result = ensemble_mcs(data, conditions, n_jobs=2)
assert "timeout" in result[0][0]["issue"]
assert [] == result[0][0]["mcs_results"]
assert [] == result[0][0]["sorted_reactants"]
1 change: 1 addition & 0 deletions Test/test_mcs_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def test_simple_mcs():
_add(data, "COC(C)=O>>OC(C)=O")

results = mcs.find(data)
print(results)

assert 1 == len(results)
result = results[0]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "synrbl"
version = "0.0.9"
version = "0.0.10"
authors = [
{name="Tieu Long Phan", email="long.tieu_phan@uni-leipzig.de"},
{name="Klaus Weinbauer", email="klaus@bioinf.uni-leipzig.de"}
Expand Down
13 changes: 9 additions & 4 deletions synrbl/SynMCSImputer/SubStructure/mcs_graph_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def IterativeMCSReactionPairs(
method="MCIS",
sort="MCIS",
remove_substructure=True,
maxNodes=80,
):
"""
Find the MCS for each reactant fragment with the product, updating the
Expand Down Expand Up @@ -145,7 +146,9 @@ def IterativeMCSReactionPairs(
# Identify the optimal substructure
analyzer = SubstructureAnalyzer()
optimal_substructure = analyzer.identify_optimal_substructure(
parent_mol=current_product, child_mol=mcs_mol
parent_mol=current_product,
child_mol=mcs_mol,
maxNodes=maxNodes,
)
if optimal_substructure:
rw_mol = Chem.RWMol(current_product)
Expand Down Expand Up @@ -174,13 +177,14 @@ def fit(
reaction_dict,
RingMatchesRingOnly=True,
CompleteRingsOnly=True,
Timeout=60,
timeout=1,
similarityThreshold=0.5,
sort="MCIS",
method="MCIS",
remove_substructure=True,
ignore_atom_map=False,
ignore_bond_order=False,
maxNodes=80,
):
"""
Process a reaction dictionary to find MCS, missing parts in reactants
Expand All @@ -201,7 +205,7 @@ def fit(

if method == "MCIS":
params = rdFMCS.MCSParameters()
params.Timeout = 1 # Timeout
params.Timeout = timeout
params.BondCompareParameters.RingMatchesRingOnly = RingMatchesRingOnly
params.BondCompareParameters.CompleteRingsOnly = CompleteRingsOnly
if ignore_bond_order:
Expand All @@ -213,7 +217,7 @@ def fit(
params = rdRascalMCES.RascalOptions()
params.singleLargestFrag = False
params.returnEmptyMCES = True
params.timeout = 1 # Timeout
params.timeout = timeout
params.similarityThreshold = similarityThreshold

else:
Expand Down Expand Up @@ -242,6 +246,7 @@ def fit(
method=method,
sort=sort,
remove_substructure=remove_substructure,
maxNodes=maxNodes,
)

return mcs_list, sorted_parents, reactant_mol_list, product_mol
Expand Down
47 changes: 34 additions & 13 deletions synrbl/SynMCSImputer/SubStructure/mcs_process.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import datetime
import time
import logging
import multiprocessing
import multiprocessing.pool

import rdkit.Chem.rdmolfiles as rdmolfiles

Expand All @@ -14,16 +16,17 @@

def single_mcs(
data_dict,
id_col="id",
mcs_data,
issue_col="issue",
RingMatchesRingOnly=True,
CompleteRingsOnly=True,
Timeout=60,
timeout=1,
sort="MCES",
method="MCES",
similarityThreshold=0.5,
remove_substructure=True,
ignore_bond_order=True,
maxNodes=100,
):
"""
Performs MCS on a single reaction data entry and captures any issues encountered.
Expand All @@ -36,12 +39,6 @@ def single_mcs(
- dict: A dictionary containing MCS results and any sorted reactants encountered.
"""
block_logs = BlockLogs()
mcs_data = {
id_col: data_dict[id_col],
"mcs_results": [],
"sorted_reactants": [],
issue_col: "",
}

try:
analyzer = MCSMissingGraphAnalyzer()
Expand All @@ -52,13 +49,14 @@ def single_mcs(
sort=sort,
method=method,
remove_substructure=remove_substructure,
Timeout=Timeout,
timeout=timeout,
similarityThreshold=similarityThreshold,
ignore_bond_order=ignore_bond_order,
maxNodes=maxNodes
)

if len(reactant_mol_list) != len(sorted_reactants):
mcs_data["issue"] = "Uncertian MCS."
mcs_data[issue_col] = "Uncertian MCS."
else:
mcs_data["mcs_results"] = [rdmolfiles.MolToSmarts(mol) for mol in mcs_list]
mcs_data["sorted_reactants"] = [
Expand All @@ -71,8 +69,31 @@ def single_mcs(
return mcs_data


def single_mcs_safe(data_dict, job_timeout=2, id_col="id", issue_col="issue", **kwargs):
mcs_data = {
id_col: data_dict[id_col],
"mcs_results": [],
"sorted_reactants": [],
issue_col: "",
}
pool = multiprocessing.pool.ThreadPool(1)
async_result = pool.apply_async(
single_mcs,
(
data_dict,
mcs_data,
),
kwargs,
)
try:
return async_result.get(job_timeout)
except multiprocessing.TimeoutError:
mcs_data[issue_col] = "MCS search terminated by timeout."
return mcs_data


def ensemble_mcs(
data, conditions, id_col="id", issue_col="issue", n_jobs=-1, Timeout=60
data, conditions, id_col="id", issue_col="issue", n_jobs=-1, timeout=1
):
condition_results = []
start_time = time.time()
Expand All @@ -81,12 +102,12 @@ def ensemble_mcs(
all_results = [] # Accumulate results for each condition

p_generator = Parallel(n_jobs=n_jobs, verbose=0, return_as="generator")(
delayed(single_mcs)(
delayed(single_mcs_safe)(
data_dict,
id_col=id_col,
issue_col=issue_col,
**condition,
Timeout=Timeout,
timeout=timeout,
)
for data_dict in data
)
Expand Down
21 changes: 16 additions & 5 deletions synrbl/SynMCSImputer/SubStructure/substructure_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,22 +63,33 @@ def sort_substructures_by_fragment_count(
return [pair[0] for pair in paired_list]

def identify_optimal_substructure(
self, parent_mol: Mol, child_mol: Mol
self, parent_mol: Mol, child_mol: Mol, maxNodes: int = 80
) -> Tuple[int, ...]:
"""
Identifies the most relevant substructure within a parent molecule
given a child molecule.
given a child molecule, with a timeout feature for the
substructure matching process. If the primary matching process times out,
a fallback search is attempted with a maximum of one match.
Parameters:
parent_mol (Mol): The parent molecule.
child_mol (Mol): The child molecule whose substructures are to be
analyzed.
child_mol (Mol): The child molecule whose substructures are to be analyzed.
timeout_sec (int): Timeout in seconds for the substructure search process.
Returns:
Tuple[int, ...]: The atom indices of the identified substructure
in the parent molecule.
Returns:
Tuple[int, ...]: The atom indices of the identified substructure in
the parent molecule.
"""
substructures = parent_mol.GetSubstructMatches(child_mol)

if child_mol.GetNumAtoms() <= maxNodes:
substructures = parent_mol.GetSubstructMatches(child_mol)
else:
substructures = parent_mol.GetSubstructMatches(child_mol, maxMatches=1)

if len(substructures) > 1:
fragment_counts = [
self.remove_substructure_atoms(parent_mol, substructure)
Expand Down
2 changes: 1 addition & 1 deletion synrbl/mcs_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ def find(self, reactions):
id_col=self.id_col,
issue_col=self.issue_col,
n_jobs=self.n_jobs,
Timeout=60,
)

largest_conditions = ExtractMCS.get_largest_condition(*condition_results)
print(largest_conditions)

mcs_results = find_graph_dict(largest_conditions)

Expand Down

0 comments on commit 5492c3e

Please sign in to comment.