Skip to content

Commit

Permalink
Merge pull request #34 from TieuLongPhan/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
klausweinbauer authored May 22, 2024
2 parents 2b244d4 + bbd7a1e commit 77964d0
Show file tree
Hide file tree
Showing 14 changed files with 85 additions and 59 deletions.
7 changes: 7 additions & 0 deletions Test/SynChemImputer/test_molecule_standardizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ def test_hemiketal_transformation(self):
result, expected, "Hemiketal transformation failed or incorrect"
)

def test_hemiketal_transformation_with_aam(self):
smiles = "[C:1]([O:2])([OH:3])"
atom_indices = [0, 1, 2]
result = MoleculeStandardizer.standardize_hemiketal(smiles, atom_indices)
expected = ["[C:1]=[O:2].[OH2:3]", "[C:1]=[O:3].[OH2:2]"]
self.assertIn(result, expected)

def test_MoleculeStandardizer(self):
smiles = "C(O)(O)C=CO"
standardizer = MoleculeStandardizer()
Expand Down
7 changes: 7 additions & 0 deletions Test/SynMCSImputer/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,3 +306,10 @@ def test_ignore_water_in_passthrough(self):
cm = merge.merge(cset)
self.assertEqual("remove_water_catalyst", cm.rules[0].name)
self.assertEqual("CO", cm.smiles)

def test_catalyst_passthrough(self):
cset = CompoundSet()
cset.add_compound("C", src_mol="C")
cm = merge.merge(cset)
self.assertEqual(0, len(cm.rules))
self.assertEqual("C", cm.smiles)
21 changes: 0 additions & 21 deletions Test/SynUtils/test_chem_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
remove_atom_mapping,
normalize_smiles,
count_atoms,
remove_charge,
)


Expand Down Expand Up @@ -85,16 +84,6 @@ def test_remove_stereochemistry(self):
result = normalize_smiles(smiles)
self.assertEqual("CCC", result)

def test_remove_charge1(self):
smiles = "[N-:13]#[C+:19]"
result = remove_charge(smiles)
self.assertEqual("[N:13]#[C:19]", result)

def test_remove_charge2(self):
smiles = "[N-]#[C+]"
result = remove_charge(smiles)
self.assertEqual("N#C", result)

def test_edge_case_1(self):
smiles = "F[Sb@OH12](F)(F)(F)(F)F"
result = normalize_smiles(smiles)
Expand All @@ -110,16 +99,6 @@ def test_ordering_1(self):
result = normalize_smiles(smiles)
self.assertEqual("C=O.[HH]", result)

def test_remove_charges1(self):
smiles = "[N-:13]#[C+:19]"
result = normalize_smiles(smiles)
self.assertEqual("C#N", result)

# def test_remove_charges2(self):
# smiles = "C[N-:13]#[C+:19]"
# result = normalize_smiles(smiles)
# self.assertEqual("CN#C", result)


@pytest.mark.parametrize(
"smiles,exp_atom_cnt", [("O=C", 2), ("CO", 2), ("HH", 0), ("c1ccccc1", 6)]
Expand Down
6 changes: 3 additions & 3 deletions Test/test_balancing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

def test_e2e_1():
n = 100
reactant = "[Br]" + n * "[Si](C)(C)O" + "[Si](C)(C)[Br]"
reactant = "Br" + n * "[Si](C)(C)O" + "[Si](C)(C)Br"
product = "O" + n * "[Si](C)(C)O" + "[Si](C)(C)O"
reaction = "{}>>{}".format(reactant, product)
exp_result = "{}.{}>>{}.{}".format(reactant, "O.O", product, "Br.Br")
Expand All @@ -26,8 +26,8 @@ def test_e2e_1():
["CC(=O)C>>CC(O)C", "CC(=O)C.[HH]>>CC(O)C"],
[
"CCO.[O]>>CC=O",
"CCO.O=[Cr](Cl)(-[O-])=O.c1cc[nH+]cc1.O>>"
+ "CC=O.O.O=[Cr](O)O.c1cc[nH+]cc1.[Cl-]",
"CCO.O.O=[Cr](Cl)(-[O-])=O.c1cc[nH+]cc1.O>>"
+ "CC=O.O.O.O=[Cr](O)O.c1cc[nH+]cc1.[Cl-]",
],
],
)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "synrbl"
version = "0.0.20"
version = "0.0.21"
authors = [
{name="Tieu Long Phan", email="long.tieu_phan@uni-leipzig.de"},
{name="Klaus Weinbauer", email="klaus@bioinf.uni-leipzig.de"}
Expand Down
7 changes: 5 additions & 2 deletions synrbl/SynChemImputer/molecule_standardizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,25 +118,28 @@ def standardize_hemiketal(smiles: str, atom_indices: List[int]) -> str:
"""
# Load the molecule from SMILES and create an editable molecule object
mol = Chem.MolFromSmiles(smiles)
emol = Chem.EditableMol(mol)

# Initialize indices
c_idx, o1_idx, o2_idx = None, None, None
for i in atom_indices:
atom_symbol = mol.GetAtomWithIdx(i).GetSymbol()
atom = mol.GetAtomWithIdx(i)
atom_symbol = atom.GetSymbol()
if atom_symbol == "C":
c_idx = i
elif atom_symbol == "O":
if o1_idx is None:
o1_idx = i # Assume the first oxygen encountered is O1
atom.SetNumExplicitHs(0)
else:
o2_idx = i # The next oxygen is O2
atom.SetNumExplicitHs(2)

# Check if all indices are assigned
if None in [c_idx, o1_idx, o2_idx]:
return "Invalid atom indices provided. Please check the input."

# Attempt to modify the molecule structure
emol = Chem.EditableMol(mol)
try:
# Remove existing bonds if they exist
emol.RemoveBond(c_idx, o1_idx)
Expand Down
2 changes: 1 addition & 1 deletion synrbl/SynCmd/cmd_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def run(args):
args.out_columns if isinstance(args.out_columns, list) else [args.out_columns]
)
batch_size = None
if len(args.batch_size) > 0:
if args.batch_size is not None and len(args.batch_size) > 0:
batch_size = int(args.batch_size)

impute(
Expand Down
5 changes: 3 additions & 2 deletions synrbl/SynMCSImputer/mcs_based_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(
smiles_standardizer=[],
):
self.reaction_col = reaction_col
self.output_col = output_col
self.output_col = output_col if isinstance(output_col, list) else [output_col]
self.mcs_data_col = mcs_data_col
self.issue_col = issue_col
self.rules_col = rules_col
Expand All @@ -134,7 +134,8 @@ def run(self, reactions: list[dict], stats=None):
carbon_balance_col=self.carbon_balance_col,
smiles_standardizer=self.smiles_standardizer,
)
reaction[self.output_col] = result
for col in self.output_col:
reaction[col] = result
reaction[self.rules_col] = rules
mcs_solved += 1
except Exception as e:
Expand Down
4 changes: 3 additions & 1 deletion synrbl/SynMCSImputer/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ def merge(compound_set: CompoundSet) -> Compound:
else:
comps_with_boundaries.append(c)

if len(comps_with_boundaries) == 1:
if len(comps_with_boundaries) == 0 and len(comps_without_boundaries) > 0:
merged_compound = comps_without_boundaries.pop()
elif len(comps_with_boundaries) == 1:
merged_compound = _merge_one_compound(comps_with_boundaries[0])
elif len(comps_with_boundaries) == 2:
merged_compound = _merge_two_compounds(
Expand Down
14 changes: 0 additions & 14 deletions synrbl/SynUtils/chem_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,6 @@ def calculate_net_charge(sublist: list[dict[str, Union[str, int]]]) -> int:
return total_charge


def remove_unnecessary_brackets(smiles: str) -> str:
pattern = re.compile(r"\[(?P<atom>(B|C|N|O|P|S|F|Cl|Br|I){1,2})\]")
smiles = pattern.sub(r"\g<atom>", smiles)
return smiles


def remove_atom_mapping(smiles: str) -> str:
pattern = re.compile(r":\d+")
smiles = pattern.sub("", smiles)
Expand All @@ -153,21 +147,13 @@ def remove_stereo_chemistry(smiles: str) -> str:
return smiles


def remove_charge(smiles: str) -> str:
smiles = smiles.replace("+", "")
smiles = smiles.replace("-", "")
smiles = remove_unnecessary_brackets(smiles)
return smiles


def count_atoms(smiles: str) -> int:
pattern = re.compile(r"(B|C|N|O|P|S|F|Cl|Br|I|c|n|o)")
return len(pattern.findall(smiles))


def normalize_smiles(smiles: str) -> str:
smiles = remove_stereo_chemistry(smiles)
smiles = remove_charge(smiles)
if ">>" in smiles:
return ">>".join([normalize_smiles(t) for t in smiles.split(">>")])
elif "." in smiles:
Expand Down
39 changes: 29 additions & 10 deletions synrbl/SynVis/vis_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ class MCSDebug:
def __init__(self):
self.fontsize = 9
self.balancer = synrbl.Balancer()
self.balancer.columns.append("mcs")
self.balancer.columns.extend(["mcs", "mcs_based_result"])
self.balancer.mcs_method.output_col.append("mcs_based_result")
self.cairosize = (1600, 900)
self.highlight_color = (0.4, 0.9, 0.6, 1)

Expand All @@ -70,10 +71,18 @@ def plot(self, smiles, verbose=True):
)

fig = plt.figure()
gs = fig.add_gridspec(3, len(mols))
ax1 = fig.add_subplot(gs[0, :])
axs2 = [fig.add_subplot(gs[1, i]) for i in range(len(mols))]
ax3 = fig.add_subplot(gs[2, :])
ax_mcs = None
if "mcs_based_result" in result.keys():
gs = fig.add_gridspec(4, len(mols))
ax1 = fig.add_subplot(gs[0, :])
axs2 = [fig.add_subplot(gs[1, i]) for i in range(len(mols))]
ax_mcs = fig.add_subplot(gs[2, :])
ax_final = fig.add_subplot(gs[3, :])
else:
gs = fig.add_gridspec(3, len(mols))
ax1 = fig.add_subplot(gs[0, :])
axs2 = [fig.add_subplot(gs[1, i]) for i in range(len(mols))]
ax_final = fig.add_subplot(gs[2, :])

rxnvis = RxnVis(cairosize=self.cairosize)
img = rxnvis.get_rxn_img(result["input_reaction"])
Expand Down Expand Up @@ -118,16 +127,26 @@ def plot(self, smiles, verbose=True):
ax.axis("off")
ax.set_title(title, fontsize=self.fontsize)

rxnvis = RxnVis(cairosize=self.cairosize)
if ax_mcs is not None:
img = rxnvis.get_rxn_img(result["mcs_based_result"])
ax_mcs.imshow(img)
ax_mcs.set_title(
"MCS-Based Result\nRules: {}".format(result.get("rules", None)),
fontsize=self.fontsize,
)
ax_mcs.axis("off")

rxnvis = RxnVis(cairosize=self.cairosize)
img = rxnvis.get_rxn_img(result["reaction"])
ax3.imshow(img)
ax3.set_title(
"Result (Confidence: {:.1%})\nRules: {}\nIssue: {}".format(
result.get("confidence", 0), result.get("rules", None), result["issue"]
ax_final.imshow(img)
ax_final.set_title(
"Result (Confidence: {:.1%})\nIssue: {}".format(
result.get("confidence", 0), result["issue"]
),
fontsize=self.fontsize,
)
ax3.axis("off")
ax_final.axis("off")

plt.tight_layout()
plt.show()
12 changes: 10 additions & 2 deletions synrbl/balancing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import logging
import traceback

from synrbl.preprocess import preprocess
from synrbl.postprocess import Validator
Expand Down Expand Up @@ -51,6 +52,7 @@ def __init__(
self.__issue_col = "issue"
self.__n_jobs = n_jobs

self.remove_aam = True
self.batch_size = batch_size
self.cache = cache
self.cache_dir = cache_dir
Expand Down Expand Up @@ -170,6 +172,7 @@ def __run_pipeline(self, reactions, stats=None):
self.__id_col,
self.__solved_col,
self.__input_col,
remove_aam=self.remove_aam,
)
rxn_cnt = len(reactions)
self.input_validator.check(reactions)
Expand All @@ -188,7 +191,11 @@ def __run_pipeline(self, reactions, stats=None):
)
self.__post_process(reactions)
self.rb_method.run(reactions)
self.mcs_validator.check(reactions, override_unsolved=True)
self.mcs_validator.check(
reactions,
override_unsolved=True,
override_issue_msg="Final reaction is unbalanced.",
)

self.conf_predictor.predict(
reactions, stats=stats, threshold=self.confidence_threshold
Expand Down Expand Up @@ -265,7 +272,8 @@ def __rebalance_batch(self, batch, cache_manager):
)
logger.info("Cached new results. (Key: {})".format(cache_key[:8]))
except Exception as e:
logger.error("Pipeline execution failed: {}".format(e))
traceback.print_exc()
logger.error("Pipeline execution failed: {}".format(type(e)))

return result, batch_stats

Expand Down
4 changes: 3 additions & 1 deletion synrbl/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(
self.issue_col = issue_col
self.n_jobs = n_jobs

def check(self, reactions, override_unsolved=False):
def check(self, reactions, override_unsolved=False, override_issue_msg=None):
update_reactants_and_products(reactions, self.reaction_col)
decompose = RSMIDecomposer(
smiles=None, # type: ignore
Expand Down Expand Up @@ -69,4 +69,6 @@ def check(self, reactions, override_unsolved=False):
reaction[self.solved_method_col] = self.method
if override_unsolved and not reaction[self.solved_col]:
reaction[self.reaction_col] = reaction["input_reaction"]
if override_issue_msg is not None and reaction[self.issue_col] == "":
reaction[self.issue_col] = override_issue_msg
return reactions
14 changes: 13 additions & 1 deletion synrbl/preprocess.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
import pandas as pd

from synrbl.SynProcessor import RSMIProcessing
from synrbl.SynUtils import remove_atom_mapping


def preprocess(reactions, reaction_col, index_col, solved_col, input_col, n_jobs=1):
def preprocess(
reactions,
reaction_col,
index_col,
solved_col,
input_col,
n_jobs=1,
remove_aam=False,
):
if remove_aam:
for r in reactions:
r[reaction_col] = remove_atom_mapping(r[reaction_col])
df = pd.DataFrame(reactions)
df[solved_col] = False

Expand Down

0 comments on commit 77964d0

Please sign in to comment.