diff --git a/synrbl/SynCmd/cmd_benchmark.py b/synrbl/SynCmd/cmd_benchmark.py index 23d5c9f..d99aed5 100644 --- a/synrbl/SynCmd/cmd_benchmark.py +++ b/synrbl/SynCmd/cmd_benchmark.py @@ -1,6 +1,7 @@ import argparse import logging import pandas as pd +import rdkit.Chem.rdChemReactions as rdChemReactions from synrbl import Balancer from synrbl.SynUtils import normalize_smiles, wc_similarity @@ -50,29 +51,55 @@ def _sr(v, c): ) -def check_columns(reactions, reaction_col, result_col): +def check_columns(reactions, reaction_col, result_col, passthrouh_columns): if len(reactions) == 0: raise ValueError("No reactions found in input.") cols = reactions[0].keys() if reaction_col not in cols: raise KeyError("No column '{}' found in input.".format(reaction_col)) + if not isinstance(reactions[0][reaction_col], str): + raise TypeError( + "Reaction column '{}' must be of type string not '{}'.".format( + reaction_col, type(reactions[0][reaction_col]) + ) + ) + mol = None + try: + mol = rdChemReactions.ReactionFromSmarts( + reactions[0][reaction_col], useSmiles=True + ) + except Exception: + pass + if mol is None: + raise ValueError( + "Value '{}...' in reaction column '{}' is not a valid SMILES.".format( + reactions[0][reaction_col][0:30], reaction_col + ) + ) if result_col not in cols: raise KeyError("No column '{}' found in input.".format(result_col)) + for c in passthrouh_columns: + if c not in reactions[0].keys(): + raise KeyError("Column '{}' not found.".format(c)) def run(args): + columns = args.columns if isinstance(args.columns, list) else [args.columns] + synrbl_cols = [c for c in columns if c in ["mcs"]] + passthrough_cols = [c for c in columns if c not in synrbl_cols] input_reactions = pd.read_csv(args.inputfile).to_dict("records") logger.info( "Run benchmark on {} containing {} reactions.".format( args.inputfile, len(input_reactions) ) ) - check_columns(input_reactions, args.col, args.result_col) + check_columns(input_reactions, args.col, args.result_col, passthrough_cols) stats = {} synrbl = Balancer( reaction_col=args.col, confidence_threshold=args.min_confidence, n_jobs=args.p ) + synrbl.columns.extend(synrbl_cols) rbl_reactions = synrbl.rebalance(input_reactions, output_dict=True, stats=stats) rb_correct = 0 @@ -100,7 +127,9 @@ def run(args): if args.o is not None: for in_r, out_r in zip(input_reactions, rbl_reactions): - out_r[args.result_col] = in_r[args.result_col] + for c in columns + [args.result_col]: + if c in in_r.keys(): + out_r[c] = in_r[c] df = pd.DataFrame(rbl_reactions) df.to_csv(args.o) @@ -117,9 +146,13 @@ def __str__(self): return "[{}, {}]".format(self.start, self.end) +def list_of_strings(arg): + return arg.split(",") + + def configure_argparser(argparser: argparse._SubParsersAction): default_similarity_method = "pathway" - default_similarity_threshold = 0.85 + default_similarity_threshold = 1 default_p = -1 default_col = "reaction" default_result_col = "expected_reaction" @@ -156,6 +189,13 @@ def configure_argparser(argparser: argparse._SubParsersAction): help="The reactions column name for in the expected output. " + "(Default: {})".format(default_result_col), ) + test_parser.add_argument( + "--columns", + default=[], + type=list_of_strings, + help="A comma separated list of columns from the input that should " + + "be added to the output. (e.g.: col1,col2,col3)", + ) test_parser.add_argument( "--min-confidence", type=float, diff --git a/synrbl/SynMCSImputer/SubStructure/mcs_graph_detector.py b/synrbl/SynMCSImputer/SubStructure/mcs_graph_detector.py index 8c7d6cc..b82340c 100644 --- a/synrbl/SynMCSImputer/SubStructure/mcs_graph_detector.py +++ b/synrbl/SynMCSImputer/SubStructure/mcs_graph_detector.py @@ -201,7 +201,7 @@ def fit( if method == "MCIS": params = rdFMCS.MCSParameters() - params.Timeout = Timeout + params.Timeout = 1 # Timeout params.BondCompareParameters.RingMatchesRingOnly = RingMatchesRingOnly params.BondCompareParameters.CompleteRingsOnly = CompleteRingsOnly if ignore_bond_order: @@ -213,7 +213,7 @@ def fit( params = rdRascalMCES.RascalOptions() params.singleLargestFrag = False params.returnEmptyMCES = True - params.timeout = Timeout + params.timeout = 1 # Timeout params.similarityThreshold = similarityThreshold else: