Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2nd release #9

Merged
merged 4 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ data/
[._]ss[a-gi-z]
[._]sw[a-p]

**/.DS_Store
**/.DS_Store
19 changes: 19 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
repos:
- repo: local
hooks:
- id: unittests
name: run unit tests
entry: python -m unittest
language: system
pass_filenames: false
args: ["discover"]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 24.3.0
hooks:
- id: black
22 changes: 22 additions & 0 deletions CITATION.cff
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
cff-version: 1.2.0
message: "If you use this software, please cite it as below."
title: "RNA3DB: A dataset for training and benchmarking deep learning models for RNA structure prediction"
version: 1.1
authors:
- given-names: "Marcell"
family-names: "Szikszai"
Expand All @@ -15,3 +16,24 @@ authors:
- given-names: "Elena
family-names: Rivas"
url: "https://github.com/marcellszi/rna3db"
doi: "10.1016/j.jmb.2024.168552"
date-released: 2024-04-26
preferred-citation:
type: article
authors:
- given-names: "Marcell"
family-names: "Szikszai"
- given-names: "Marcin"
family-names: Magnus
- given-names: "Siddhant"
family-names: "Sanghi"
- given-names: "Sachin"
family-names: "Kadyan"
- given-names: "Nazim"
family-names: "Bouatta"
- given-names: "Elena"
family-names: Rivas"
doi: "10.1016/j.jmb.2024.168552"
journal: "Journal of Molecular Biology"
title: "RNA3DB: A structurally-dissimilar dataset split for training and benchmarking deep learning models for RNA structure prediction"
year: 2024
23 changes: 19 additions & 4 deletions rna3db/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,16 @@ def main(args):
args.input, args.output, args.tbl_dir, args.structural_e_value_cutoff
)
elif args.command == "split":
split(args.input, args.output, args.train_percentage, args.force_zero_test)
split(
args.input,
args.output,
splits=[
args.train_ratio,
args.valid_ratio,
1 - args.train_ratio - args.valid_ratio,
],
force_zero_last=args.force_zero_test,
)
else:
raise ValueError

Expand Down Expand Up @@ -246,10 +255,16 @@ def main(args):
split_parser.add_argument("input", type=Path, help="Input JSON file")
split_parser.add_argument("output", type=Path, help="Output JSON file")
split_parser.add_argument(
"--train_percentage",
"--train_ratio",
type=float,
default=0.3,
help="Percentage of data for the train set",
default=0.7,
help="Ratio of data to use for the training set",
)
split_parser.add_argument(
"--valid_ratio",
type=float,
default=0.0,
help="Ratio of the data to use for the validation set",
)
split_parser.add_argument(
"--force_zero_test",
Expand Down
180 changes: 180 additions & 0 deletions rna3db/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ def __getitem__(self, idx):
def __len__(self):
return len(self.residues)

@property
def has_atoms(self):
return any([not res.is_missing for res in self])

def add_residue(self, res: Residue):
"""Add a residue to the chain.

Expand Down Expand Up @@ -341,6 +345,182 @@ def __repr__(self):
f"resolution={self.resolution}, release_date={self.release_date}, structure_method={self.structure_method})"
)

@staticmethod
def _gen_mmcif_loop_str(name: str, headers: Sequence[str], values: Sequence[tuple]):
s = "#\nloop_\n"
for header in headers:
s += f"_{name}.{header}\n"

max_widths = {k: 0 for k in headers}
for V in values:
for k, v in zip(headers, V):
max_widths[k] = max(max_widths[k], len(str(v)))

for V in values:
row = ""
for k, v in zip(headers, V):
row += f"{str(v):<{max_widths[k]}} "
s += row + "\n"

return s

def write_mmcif_chain(self, output_path, author_id):
if not self[author_id].has_atoms:
raise ValueError(
f"Did not find any atoms for chain {author_id}. Did you set `include_atoms=True`?"
)
# extract needed info
entity_poly_seq_data = []
atom_site_data = []
for i, res in enumerate(self[author_id]):
entity_poly_seq_data.append((1, res.index + 1, res.code, "n"))
for idx, (atom_name, atom_coords) in enumerate(res.atoms.items()):
x, y, z = atom_coords
atom_site_data.append(
(
"ATOM",
idx + 1,
atom_name[0],
atom_name,
".",
res.code,
author_id,
"?",
i + 1,
"?",
x,
y,
z,
1.0,
0.0,
"?",
i + 1,
res.code,
author_id,
atom_name,
1,
)
)

# build required strings
header_str = (
f"# generated by rna3db\n"
f"#\n"
f"data_{self.pdb_id}_{author_id}\n"
f"_entry.id {self.pdb_id}_{author_id}\n"
f"_pdbx_database_status.recvd_initial_deposition_date {self.release_date}\n"
f"_exptl.method '{self.structure_method.upper()}'\n"
f"_reflns.d_resolution_high {self.resolution}\n"
f"_entity_poly.pdbx_seq_one_letter_code_can {self[author_id].sequence}\n"
)
struct_asym_str = StructureFile._gen_mmcif_loop_str(
"_struct_asym",
[
"id",
"pdbx_blank_PDB_chainid_flag",
"pdbx_modified",
"entity_id",
"details",
],
[("A", "N", "N", 1, "?")],
)
chem_comp_str = StructureFile._gen_mmcif_loop_str(
"_chem_comp",
[
"id",
"type",
"mon_nstd_flag",
"pdbx_synonyms",
"formula",
"formula_weight",
],
[
(
"A",
"'RNA linking'",
"y",
'"ADENOSINE-5\'-MONOPHOSPHATE"',
"?",
"'C10 H14 N5 O7 P'",
347.221,
),
(
"C",
"'RNA linking'",
"y",
'"CYTIDINE-5\'-MONOPHOSPHATE"',
"?",
"'C9 H14 N3 O8 P'",
323.197,
),
(
"G",
"'RNA linking'",
"y",
'"GUANOSINE-5\'-MONOPHOSPHATE"',
"?",
"'C9 H13 N2 O9 P'",
363.221,
),
(
"U",
"'RNA linking'",
"y",
'"URIDINE-5\'-MONOPHOSPHATE"',
"?",
"'C9 H13 N2 O9 P'",
324.181,
),
("T", "'RNA linking'", "y", '"T"', "?", "''", 0),
("N", "'RNA linking'", "y", '"N"', "?", "''", 0),
],
)
entity_poly_seq_str = StructureFile._gen_mmcif_loop_str(
"entity_poly_seq",
[
"entity_id",
"num",
"mon_id",
"heter",
],
entity_poly_seq_data,
)
atom_site_str = StructureFile._gen_mmcif_loop_str(
"atom_site",
[
"group_PDB",
"id",
"type_symbol",
"label_atom_id",
"label_alt_id",
"label_comp_id",
"label_asym_id",
"label_entity_id",
"label_seq_id",
"pdbx_PDB_ins_code",
"Cartn_x",
"Cartn_y",
"Cartn_z",
"occupancy",
"B_iso_or_equiv",
"pdbx_formal_charge",
"auth_seq_id",
"auth_comp_id",
"auth_asym_id",
"auth_atom_id",
"pdbx_PDB_model_num",
],
atom_site_data,
)

# write to file
with open(output_path, "w") as f:
f.write(header_str)
f.write(struct_asym_str)
f.write(chem_comp_str)
f.write(entity_poly_seq_str)
f.write(atom_site_str)


class mmCIFParser:
def __init__(
Expand Down
79 changes: 51 additions & 28 deletions rna3db/split.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,32 @@
import random

from typing import Sequence

from rna3db.utils import PathLike, read_json, write_json


def find_optimal_components(lengths_dict, capacity):
component_name = list(lengths_dict.keys())
lengths = list(lengths_dict.values())

dp = [0] * (capacity + 1)
trace = [[] for i in range(capacity + 1)]
for i in range(len(lengths)):
for j in range(capacity, lengths[i] - 1, -1):
if dp[j] < dp[j - lengths[i]] + lengths[i]:
dp[j] = dp[j - lengths[i]] + lengths[i]
trace[j] = trace[j - lengths[i]] + [component_name[i]]

return set(trace[capacity])


def split(
input_path: PathLike,
output_path: PathLike,
train_size: float = 0.7,
force_zero_test: bool = True,
splits: Sequence[float] = [0.7, 0.0, 0.3],
split_names: Sequence[str] = ["train_set", "valid_set", "test_set"],
shuffle: bool = False,
force_zero_last: bool = False,
):
"""A function that splits a JSON of components into a train/test set.

Expand All @@ -16,35 +37,37 @@ def split(
Args:
input_path (PathLike): path to JSON containing components
output_path (PathLike): path to output JSON
train_size (float): percentage of data to use as training set
force_zero_test (bool): whether to force component_0 into the test set
"""
if sum(splits) != 1.0:
raise ValueError("Sum of splits must equal 1.0.")

# read json
cluster_json = read_json(input_path)

# count number of repr sequences
total_repr_clusters = sum(len(v) for v in cluster_json.values())

# figure out which components need to go into training set
train_components = set()
train_set_length = 0
i = 1 if force_zero_test else 0
while train_set_length / total_repr_clusters < train_size:
# skip if it's not a real component (should only happen with 0)
if f"component_{i}" not in cluster_json:
i += 1
continue
train_components.add(f"component_{i}")
train_set_length += len(cluster_json[f"component_{i}"].keys())
i += 1

# test_components are just total-train_components
test_components = set(cluster_json.keys()) - train_components

# actually build JSON
output = {"train_set": {}, "test_set": {}}
for k in sorted(train_components):
output["train_set"][k] = cluster_json[k]
for k in sorted(test_components):
output["test_set"][k] = cluster_json[k]
lengths = {k: len(v) for k, v in cluster_json.items()}
total_repr_clusters = sum(lengths.values())

# shuffle if we want to add randomness
if shuffle:
L = list(zip(component_name, lengths))
random.shuffle(L)
component_name, lengths = zip(*L)
component_name, lengths = list(component_name), list(lengths)

output = {k: {} for k in split_names}

if force_zero_last:
output[split_names[-1]]["component_0"] = cluster_json["component_0"]
lengths.pop("component_0")

capacities = [round(total_repr_clusters * ratio) for ratio in splits]
for name, capacity in zip(split_names, capacities):
components = find_optimal_components(lengths, capacity)
for k in sorted(components):
lengths.pop(k)
output[name][k] = cluster_json[k]

assert len(lengths) == 0

write_json(output, output_path)
Loading
Loading