Skip to content

Commit

Permalink
Merge branch 'dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
Klaus Weinbauer committed Apr 29, 2024
2 parents 4110cda + c27b6a4 commit 86c5eed
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 6 deletions.
48 changes: 44 additions & 4 deletions synrbl/SynCmd/cmd_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import logging
import pandas as pd
import rdkit.Chem.rdChemReactions as rdChemReactions

from synrbl import Balancer
from synrbl.SynUtils import normalize_smiles, wc_similarity
Expand Down Expand Up @@ -50,29 +51,55 @@ def _sr(v, c):
)


def check_columns(reactions, reaction_col, result_col):
def check_columns(reactions, reaction_col, result_col, passthrouh_columns):
if len(reactions) == 0:
raise ValueError("No reactions found in input.")
cols = reactions[0].keys()
if reaction_col not in cols:
raise KeyError("No column '{}' found in input.".format(reaction_col))
if not isinstance(reactions[0][reaction_col], str):
raise TypeError(
"Reaction column '{}' must be of type string not '{}'.".format(
reaction_col, type(reactions[0][reaction_col])
)
)
mol = None
try:
mol = rdChemReactions.ReactionFromSmarts(
reactions[0][reaction_col], useSmiles=True
)
except Exception:
pass
if mol is None:
raise ValueError(
"Value '{}...' in reaction column '{}' is not a valid SMILES.".format(
reactions[0][reaction_col][0:30], reaction_col
)
)
if result_col not in cols:
raise KeyError("No column '{}' found in input.".format(result_col))
for c in passthrouh_columns:
if c not in reactions[0].keys():
raise KeyError("Column '{}' not found.".format(c))


def run(args):
columns = args.columns if isinstance(args.columns, list) else [args.columns]
synrbl_cols = [c for c in columns if c in ["mcs"]]
passthrough_cols = [c for c in columns if c not in synrbl_cols]
input_reactions = pd.read_csv(args.inputfile).to_dict("records")
logger.info(
"Run benchmark on {} containing {} reactions.".format(
args.inputfile, len(input_reactions)
)
)
check_columns(input_reactions, args.col, args.result_col)
check_columns(input_reactions, args.col, args.result_col, passthrough_cols)

stats = {}
synrbl = Balancer(
reaction_col=args.col, confidence_threshold=args.min_confidence, n_jobs=args.p
)
synrbl.columns.extend(synrbl_cols)
rbl_reactions = synrbl.rebalance(input_reactions, output_dict=True, stats=stats)

rb_correct = 0
Expand Down Expand Up @@ -100,7 +127,9 @@ def run(args):

if args.o is not None:
for in_r, out_r in zip(input_reactions, rbl_reactions):
out_r[args.result_col] = in_r[args.result_col]
for c in columns + [args.result_col]:
if c in in_r.keys():
out_r[c] = in_r[c]
df = pd.DataFrame(rbl_reactions)
df.to_csv(args.o)

Expand All @@ -117,9 +146,13 @@ def __str__(self):
return "[{}, {}]".format(self.start, self.end)


def list_of_strings(arg):
return arg.split(",")


def configure_argparser(argparser: argparse._SubParsersAction):
default_similarity_method = "pathway"
default_similarity_threshold = 0.85
default_similarity_threshold = 1
default_p = -1
default_col = "reaction"
default_result_col = "expected_reaction"
Expand Down Expand Up @@ -156,6 +189,13 @@ def configure_argparser(argparser: argparse._SubParsersAction):
help="The reactions column name for in the expected output. "
+ "(Default: {})".format(default_result_col),
)
test_parser.add_argument(
"--columns",
default=[],
type=list_of_strings,
help="A comma separated list of columns from the input that should "
+ "be added to the output. (e.g.: col1,col2,col3)",
)
test_parser.add_argument(
"--min-confidence",
type=float,
Expand Down
4 changes: 2 additions & 2 deletions synrbl/SynMCSImputer/SubStructure/mcs_graph_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def fit(

if method == "MCIS":
params = rdFMCS.MCSParameters()
params.Timeout = Timeout
params.Timeout = 1 # Timeout
params.BondCompareParameters.RingMatchesRingOnly = RingMatchesRingOnly
params.BondCompareParameters.CompleteRingsOnly = CompleteRingsOnly
if ignore_bond_order:
Expand All @@ -213,7 +213,7 @@ def fit(
params = rdRascalMCES.RascalOptions()
params.singleLargestFrag = False
params.returnEmptyMCES = True
params.timeout = Timeout
params.timeout = 1 # Timeout
params.similarityThreshold = similarityThreshold

else:
Expand Down

0 comments on commit 86c5eed

Please sign in to comment.