-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_diffusion_generate.py
115 lines (97 loc) · 4.01 KB
/
main_diffusion_generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import argparse
from typing import Optional
import torch
from diffusion.diffusion_loss import SampleResult
from diffusion.inference.create_gif import generate_gif
from diffusion.inference.process_generated_crystals import save_sample_results_to_hdf5
from diffusion.inference.visualize_crystal import VisualizationSetting
from lightning_wrappers.diffusion import PONITA_DIFFUSION
import numpy as np
OUT_DIR = "out"
DIFFUSION_DIR = f"{OUT_DIR}/diffusion"
SHOW_BONDS = False
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_path", type=str, required=True, help="Path to the model file"
)
return parser.parse_args()
def get_model() -> PONITA_DIFFUSION:
args = parse_args()
# load model
torch.set_default_dtype(torch.float64)
model_path = args.model_path
return PONITA_DIFFUSION.load_from_checkpoint(model_path, strict=False)
def generate_single_crystal(
num_atoms: int,
visualization_setting: VisualizationSetting,
use_constant_atomic_symbols: Optional[list[str]],
):
model = get_model()
result = model.sample(
num_atoms_per_sample=num_atoms,
num_samples_in_batch=1,
visualization_setting=visualization_setting,
show_bonds=SHOW_BONDS,
use_constant_atomic_symbols=use_constant_atomic_symbols,
)
if visualization_setting != VisualizationSetting.NONE:
generate_gif(src_img_dir=DIFFUSION_DIR, output_file=f"{OUT_DIR}/crystal.gif")
result.idx_start = np.array([0])
save_sample_results_to_hdf5(result, f"{OUT_DIR}/crystals.h5")
def generate_n_crystals(
num_crystals: int,
num_atoms_per_sample: int,
use_constant_atomic_symbols: Optional[list[str]],
):
num_crystals_per_batch = 10
assert num_crystals_per_batch > 0
assert num_crystals_per_batch <= num_crystals
assert (
num_crystals % num_crystals_per_batch == 0
), f"num_crystals ({num_crystals}) must be divisible by num_crystals_per_batch ({num_crystals_per_batch})"
total_num_atoms = num_crystals * num_atoms_per_sample
model = get_model()
crystals = SampleResult()
crystals.frac_x = np.empty((total_num_atoms, 3))
crystals.atomic_numbers = np.empty((total_num_atoms))
crystals.lattice = np.empty((num_crystals, 3, 3))
crystals.idx_start = np.arange(0, total_num_atoms, num_atoms_per_sample)
crystals.num_atoms = np.full(num_crystals, num_atoms_per_sample)
for i in range(0, num_crystals, num_crystals_per_batch):
generated_crystals = model.sample(
num_atoms_per_sample=num_atoms_per_sample,
num_samples_in_batch=num_crystals_per_batch,
visualization_setting=VisualizationSetting.NONE,
show_bonds=SHOW_BONDS,
use_constant_atomic_symbols=use_constant_atomic_symbols,
)
batch_start = i * num_atoms_per_sample
batch_end = (i + num_crystals_per_batch) * num_atoms_per_sample
crystals.frac_x[batch_start:batch_end] = generated_crystals.frac_x
crystals.atomic_numbers[batch_start:batch_end] = (
generated_crystals.atomic_numbers
)
crystals_batch_start = i
crystals_batch_end = i + num_crystals_per_batch
crystals.lattice[crystals_batch_start:crystals_batch_end] = (
generated_crystals.lattice
)
save_sample_results_to_hdf5(crystals, f"{OUT_DIR}/crystals.h5")
if __name__ == "__main__":
num_atoms = 4
# use_constant_atomic_symbols = ["Ac", "Ac", "Ir", "Ag"]
# use_constant_atomic_symbols = ["C", "C", "C", "C", "C", "C", "C", "C"]
use_constant_atomic_symbols = None
if use_constant_atomic_symbols is not None:
num_atoms = len(use_constant_atomic_symbols)
# generate_single_crystal(
# num_atoms=num_atoms,
# visualization_setting=VisualizationSetting.ALL,
# use_constant_atomic_symbols=use_constant_atomic_symbols,
# )
generate_n_crystals(
num_crystals=10,
num_atoms_per_sample=num_atoms,
use_constant_atomic_symbols=use_constant_atomic_symbols,
)