-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathassemble.py
118 lines (102 loc) · 4.63 KB
/
assemble.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import get_molecules
import json
from sevenn_runner import SevenNetCalculator
import math
import numpy as np
def create_single_chain():
sevennet_0_cal = SevenNetCalculator("7net-0", device="auto") # 7net-0, SevenNet-0, 7net-0_22May2024, 7net-0_11July2024 ...
print(f"running on device {sevennet_0_cal.device}")
smiles="NCC(=O)NCCCCCC(=O)"
atomic_nums, coords_log, last_non_hydrogen_idx_on_main_chain = get_molecules.grow_two_molecules(sevennet_0_cal, smiles)
relax_batches = [{
"atomic_nums": atomic_nums.tolist(),
"relax_len": len(coords_log),
}]
num_monomers = 5
for i in range(num_monomers - 2): # -2 since we already have the first 2 monomers
atomic_nums, coords_log, last_non_hydrogen_idx_on_main_chain = get_molecules.grow_on_chain(sevennet_0_cal, relax_batches[-1]["atomic_nums"], coords_log, last_non_hydrogen_idx_on_main_chain, smiles)
relax_batches.append({
"atomic_nums": atomic_nums.tolist(),
"relax_len": len(coords_log),
})
coords_log = coords_log + get_molecules.relax(sevennet_0_cal, relax_batches[-1]["atomic_nums"], coords_log[-1], max_steps=100)
relax_batches.append({
"atomic_nums": relax_batches[-1]["atomic_nums"],
"relax_len": len(coords_log),
})
relaxation = {
"frames": []
}
curr_idx = 0
for relax_batch in relax_batches:
atomic_nums = relax_batch["atomic_nums"]
relax_len = relax_batch["relax_len"]
while curr_idx < relax_len:
coords = coords_log[curr_idx]
relaxation["frames"].append({
"atomic_nums": atomic_nums,
"coords": coords.tolist(),
})
curr_idx += 1
json.dump(relaxation, open("relaxation.json", "w"), separators=(',', ':'))
def get_initial_coords(num_polymers:int):
cube_side_len = math.sqrt(num_polymers)
assert cube_side_len %1 == 0, "Number of num_polymers must be a perfect square"
initial_positions = []
distance_between_chains = 7
cube_side_len = int(cube_side_len)
for i in range(cube_side_len):
for j in range(cube_side_len):
pos = [i*distance_between_chains, j*distance_between_chains, 0]
initial_positions.append(pos)
return np.array(initial_positions)
def create_bulk_polymer():
sevennet_0_cal = SevenNetCalculator("7net-0", device="auto") # 7net-0, SevenNet-0, 7net-0_22May2024, 7net-0_11July2024 ...
print(f"running on device {sevennet_0_cal.device}")
smiles="NCC(=O)NCCCCCC(=O)"
coords_log = None
atomic_nums = None
relax_batches = []
num_chains = 4
for initial_coord in get_initial_coords(num_chains):
atomic_nums, coords_log, last_non_hydrogen_idx_on_main_chain = get_molecules.grow_two_molecules(sevennet_0_cal, smiles, initial_coord, atomic_nums, coords_log)
relax_batches.append({
"atomic_nums": atomic_nums.tolist(),
"relax_len": len(coords_log),
"last_non_hydrogen_idx_on_main_chain": last_non_hydrogen_idx_on_main_chain,
})
num_monomers = 3
for _ in range(num_monomers - 2): # -2 since we already have the first 2 monomers
batches_to_grow = relax_batches[-num_chains:len(relax_batches)]
for chain in batches_to_grow:
last_non_hydrogen_idx_on_main_chain = chain["last_non_hydrogen_idx_on_main_chain"]
atomic_nums, coords_log, last_non_hydrogen_idx_on_main_chain = get_molecules.grow_on_chain(sevennet_0_cal, atomic_nums, coords_log, last_non_hydrogen_idx_on_main_chain, smiles)
relax_batches.append({
"atomic_nums": atomic_nums.tolist(),
"relax_len": len(coords_log),
"last_non_hydrogen_idx_on_main_chain": last_non_hydrogen_idx_on_main_chain,
})
# final relaxation
print("running final relaxation")
coords_log = coords_log + get_molecules.relax(sevennet_0_cal, relax_batches[-1]["atomic_nums"], coords_log[-1], max_steps=50)
relax_batches.append({
"atomic_nums": relax_batches[-1]["atomic_nums"],
"relax_len": len(coords_log),
})
relaxation = {
"frames": []
}
curr_idx = 0
for relax_batch in relax_batches:
atomic_nums = relax_batch["atomic_nums"]
relax_len = relax_batch["relax_len"]
while curr_idx < relax_len:
coords = coords_log[curr_idx]
relaxation["frames"].append({
"atomic_nums": atomic_nums,
"coords": coords.tolist(),
})
curr_idx += 1
json.dump(relaxation, open("relaxation.json", "w"), separators=(',', ':'))
if __name__ == "__main__":
create_bulk_polymer()