Skip to content

Commit

Permalink
Merge pull request #31 from TieuLongPhan/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
klausweinbauer authored May 17, 2024
2 parents 4e68b15 + de5d431 commit d965fca
Show file tree
Hide file tree
Showing 15 changed files with 1,111 additions and 611 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
*.pdf
tmp/
out/
cache/

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
1,124 changes: 562 additions & 562 deletions Data/Validation_set/validation_set.csv

Large diffs are not rendered by default.

30 changes: 30 additions & 0 deletions Test/SynUtils/test_batching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest

from synrbl.SynUtils.batching import DataLoader, Dataset


@pytest.mark.parametrize(
"data,batch_size,exp_results",
[
([1], 1, [[1]]),
([1, 2], 1, [[1], [2]]),
([1, 2, 3, 4, 5], 2, [[1, 2], [3, 4], [5]]),
],
)
def test_data_loader(data, batch_size, exp_results):
loader = DataLoader(iter(data), batch_size=batch_size)
for data, exp_data in zip(loader, exp_results):
assert len(exp_data) == len(data)
for a, b in zip(data, exp_data):
print(a, b)
assert b == a
with pytest.raises(StopIteration):
next(loader)


def test_dataset_init_from_list():
data = ["A", "B", "C"]
dataset = Dataset(data)
assert "A" == next(dataset)
assert "B" == next(dataset)
assert "C" == next(dataset)
25 changes: 24 additions & 1 deletion Test/SynUtils/test_chem_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import pytest
import unittest
import itertools
from synrbl.SynUtils.chem_utils import remove_atom_mapping, normalize_smiles
from synrbl.SynUtils.chem_utils import (
remove_atom_mapping,
normalize_smiles,
count_atoms,
)


class TestRemoveAtomMapping(unittest.TestCase):
Expand Down Expand Up @@ -83,3 +88,21 @@ def test_edge_case_1(self):
smiles = "F[Sb@OH12](F)(F)(F)(F)F"
result = normalize_smiles(smiles)
self.assertEqual("F[Sb](F)(F)(F)(F)F", result)

def test_ordering_of_aromatic_compounds(self):
smiles = "[HH].c1ccccc1"
result = normalize_smiles(smiles)
self.assertEqual("c1ccccc1.[HH]", result)

def test_ordering_1(self):
smiles = "[HH].C=O"
result = normalize_smiles(smiles)
self.assertEqual("C=O.[HH]", result)


@pytest.mark.parametrize(
"smiles,exp_atom_cnt", [("O=C", 2), ("CO", 2), ("HH", 0), ("c1ccccc1", 6)]
)
def test_count_atoms(smiles, exp_atom_cnt):
atom_cnt = count_atoms(smiles)
assert exp_atom_cnt == atom_cnt
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "synrbl"
version = "0.0.18"
version = "0.0.19"
authors = [
{name="Tieu Long Phan", email="long.tieu_phan@uni-leipzig.de"},
{name="Klaus Weinbauer", email="klaus@bioinf.uni-leipzig.de"}
Expand Down
108 changes: 85 additions & 23 deletions synrbl/SynCmd/cmd_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import pandas as pd
import rdkit.Chem.rdChemReactions as rdChemReactions

from argparse import RawTextHelpFormatter

from synrbl import Balancer

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -97,12 +99,20 @@ def impute(
passthrough_cols,
min_confidence,
n_jobs=-1,
cache=False,
cache_dir=None,
batch_size=None,
):
input_reactions = pd.read_csv(src_file).to_dict("records")
check_columns(input_reactions, reaction_col, required_cols=passthrough_cols)

synrbl = Balancer(
reaction_col=reaction_col, confidence_threshold=min_confidence, n_jobs=n_jobs
reaction_col=reaction_col,
confidence_threshold=min_confidence,
n_jobs=n_jobs,
cache=cache,
cache_dir=cache_dir,
batch_size=batch_size,
)
stats = {}
rbl_reactions = synrbl.rebalance(input_reactions, output_dict=True, stats=stats)
Expand All @@ -126,6 +136,9 @@ def run(args):
passthrough_cols = (
args.out_columns if isinstance(args.out_columns, list) else [args.out_columns]
)
batch_size = None
if len(args.batch_size) > 0:
batch_size = int(args.batch_size)

impute(
args.inputfile,
Expand All @@ -134,6 +147,9 @@ def run(args):
passthrough_cols=passthrough_cols,
min_confidence=args.min_confidence,
n_jobs=args.p,
cache=args.cache,
cache_dir=args.cache_dir,
batch_size=batch_size,
)


Expand All @@ -154,45 +170,91 @@ def list_of_strings(arg):


def configure_argparser(argparser: argparse._SubParsersAction):
test_parser = argparser.add_parser(
"run", description="Try to rebalance chemical reactions."
default_cache_dir = "./cache"
default_min_confidence = 0
default_num_processes = -1
default_reaction_col = "reaction"

run_parser = argparser.add_parser(
"run",
description="Try to rebalance chemical reactions.",
formatter_class=RawTextHelpFormatter,
)

test_parser.add_argument(
run_parser.add_argument(
"inputfile",
help="Path to file containing reaction SMILES. "
+ " The file should be in csv format and the reaction SMILES column "
+ "can be specified with the --col parameter.",
help=(
"Path to file containing reaction SMILES. \n"
+ "Possible file formats are:\n"
+ " - csv: SMILES are expected to be in column '{}'\n".format(
default_reaction_col
)
+ " The column name can be changed with the --col\n"
+ " argument.\n"
+ " - json: The expected format is a list of objects:\n"
+ " [{...},{...}]\n"
+ " The reaction SMILES is expected as value of\n"
+ " '{}'. The reaction key can be changed with\n".format(
default_reaction_col
)
+ " the --col argument."
),
)
test_parser.add_argument("-o", default=None, help="Path to output file.")
test_parser.add_argument(
run_parser.add_argument("-o", default=None, help="Path to output file.")
run_parser.add_argument(
"-p",
default=-1,
default=default_num_processes,
type=int,
help="The number of parallel process. (Default -1 => # of processors)",
help=("The number of parallel process. (Default {} => # of processors)").format(
default_num_processes
),
)
test_parser.add_argument(
run_parser.add_argument(
"-b",
"--batch-size",
default=None,
help=("Number of reactions that should be processed at once."),
)
run_parser.add_argument(
"--col",
default="reaction",
help="The reaction SMILES column name for in the input .csv file. "
+ "(Default: 'reaction')",
default=default_reaction_col,
help=(
"The reaction SMILES column name (csv) or key (json) in the \n"
+ "input file. (Default: '{}')"
).format(default_reaction_col),
)
test_parser.add_argument(
run_parser.add_argument(
"--out-columns",
default=[],
type=list_of_strings,
help="A comma separated list of columns from the input that should "
+ "be added to the output. (e.g.: col1,col2,col3)",
help="A comma separated list of columns/keys from the input \n"
+ "that should be added to the output. (e.g.: col1,col2,col3)",
)
test_parser.add_argument(
run_parser.add_argument(
"--min-confidence",
type=float,
default=0,
default=default_min_confidence,
choices=[Range(0.0, 1.0)],
help=(
"Set a confidence threshold for the results "
+ "from the MCS-based method. (Default: 0)"
"Set a confidence threshold for the results from the\n"
+ "MCS-based method. (Default: {})"
).format(default_min_confidence),
)
run_parser.add_argument(
"--cache",
action="store_true",
help=(
"Flag to cache intermediate results. Use this together \n"
+ "with --batch-size."
),
)
run_parser.add_argument(
"--cache-dir",
default=default_cache_dir,
help=(
"The directory that is used to store intermediate results. \n"
+ "(Default: {})".format(default_cache_dir)
),
)

test_parser.set_defaults(func=run)
run_parser.set_defaults(func=run)
1 change: 1 addition & 0 deletions synrbl/SynUtils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@
normalize_smiles,
wc_similarity,
)
from .batching import Dataset, DataLoader
116 changes: 116 additions & 0 deletions synrbl/SynUtils/batching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import os
import csv
import json
import hashlib


def csv_reader(file):
with open(file, "r") as f:
reader = csv.reader(f)
header = next(reader)
for line in reader:
result_data = {}
for k, v in zip(header, line):
result_data[k] = v
yield result_data


def json_reader(file):
with open(file, "r") as f:
json_data = json.load(f)
if not isinstance(json_data, list):
raise ValueError(
"Top level json object sould be a list. "
+ r"Expected json structure: [{...},{...}]"
)
for json_entry in json_data:
yield json_entry


class Dataset:
def __init__(self, source):
if isinstance(source, list):
self.__data_reader = iter(source)
elif isinstance(source, str):
file_type = os.path.splitext(source)[1].replace(".", "").lower()
if file_type == "csv":
self.__data_reader = csv_reader(source)
elif file_type == "json":
self.__data_reader = json_reader(source)
else:
raise ValueError(
"File type '{}' is not supported as dataset source.".format(
file_type
)
)
else:
raise ValueError(
"'{}' is not a valid source for a dataset. "
+ "Use a file or list of data instead."
)

def __next__(self):
return next(self.__data_reader)

def __iter__(self):
return self.__data_reader


class DataLoader:
def __init__(self, data, batch_size=1):
self.__data = data
self.batch_size = batch_size
self.__iter_stopped = False

def __next__(self):
if not self.__iter_stopped:
return_data = []
for _ in range(self.batch_size):
try:
data_item = next(self.__data)
return_data.append(data_item)
except StopIteration:
self.__iter_stopped = True
break
return return_data
else:
raise StopIteration

def __iter__(self):
return self


class CacheManager:
def __init__(self, cache_dir="./cache", cache_ext="cache"):
self.__cache_dir = cache_dir
self.__cache_ext = cache_ext
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
self.__cache_refs = {}
for _, _, files in os.walk(cache_dir):
for file in files:
file_name, file_ext = os.path.splitext(file)
cache_key = os.path.basename(file_name)
file_path = os.path.join(os.path.abspath(cache_dir), file)
if file_ext.replace(".", "") == cache_ext.lower():
self.__cache_refs[cache_key] = file_path

def get_hash_key(self, data) -> str:
dhash = hashlib.sha256()
dhash.update(json.dumps(data, sort_keys=True).encode())
return dhash.hexdigest()

def is_cached(self, key) -> bool:
return key in self.__cache_refs.keys()

def load_cache(self, key):
file = self.__cache_refs[key]
with open(file, "r") as f:
data = json.load(f)
return data

def write_cache(self, key, data):
file = os.path.join(self.__cache_dir, "{}.{}".format(key, self.__cache_ext))
with open(file, "w") as f:
json.dump(data, f)
return key
13 changes: 7 additions & 6 deletions synrbl/SynUtils/chem_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,19 +147,20 @@ def remove_stereo_chemistry(smiles: str) -> str:
return smiles


def count_atoms(smiles: str) -> int:
pattern = re.compile(r"(B|C|N|O|P|S|F|Cl|Br|I|c|n|o)")
return len(pattern.findall(smiles))


def normalize_smiles(smiles: str) -> str:
smiles = remove_stereo_chemistry(smiles)
if ">>" in smiles:
return ">>".join([normalize_smiles(t) for t in smiles.split(">>")])
elif "." in smiles:
token = sorted(
smiles.split("."),
key=lambda x: (sum(1 for c in x if c.isupper()), sum(ord(c) for c in x)),
reverse=True,
)
token = smiles.split(".")
token = [normalize_smiles(t) for t in token]
token.sort(
key=lambda x: (sum(1 for c in x if c.isupper()), sum(ord(c) for c in x)),
key=lambda x: (count_atoms(x), sum(ord(c) for c in x)),
reverse=True,
)
return ".".join(token)
Expand Down
Loading

0 comments on commit d965fca

Please sign in to comment.