Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prevent lattice squishing by sampling from a cuboid #73

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 41 additions & 16 deletions diffusion/diffusion_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,33 @@ def __init__(self, num_steps=1000, s=0.0001, power=2, clipmax=0.999):
self.register_buffer("betas", betas)
self.register_buffer("sigmas", sigmas)

def forward(self, h0, t):
alpha_bar = self.alpha_bars[t]
eps = torch.randn_like(h0)
ht = (
torch.sqrt(alpha_bar).view(-1, 1) * h0
+ torch.sqrt(1 - alpha_bar).view(-1, 1) * eps
def forward(self, l0_vector, t, num_atoms):
alpha_bar = self.alpha_bars[t].view(-1, 1)
# when t is high, alpha_bar is close to 0
eps = torch.randn_like(l0_vector)

identity_matrix = (
torch.eye(3, device=l0_vector.device)
.reshape((1, 3, 3))
.repeat(l0_vector.shape[0], 1, 1)
)
mean_cell = (
self.normalizing_mean_constant(num_atoms).view(-1, 1, 1) * identity_matrix
)
mean_cell_vector = symmetric_matrix_to_vector(mean_cell)
mean = (
torch.sqrt(alpha_bar) * l0_vector
+ (1 - torch.sqrt(alpha_bar)) * mean_cell_vector
)
return ht, eps

variance = torch.sqrt(1 - alpha_bar).view(-1, 1) * eps
# variance = (
# torch.sqrt(1 - alpha_bar)
# * self.normalizing_variance_constant(num_atoms).view(-1, 1)
# * eps
# )
ht = mean + variance
return ht, eps, mean_cell

def reverse(self, lt, predicted_symmetric_vector_noise, t):
alpha = 1 - self.betas[t]
Expand All @@ -180,15 +199,21 @@ def reverse(self, lt, predicted_symmetric_vector_noise, t):
* predicted_symmetric_vector_noise
) + sigma * z

# def normalizing_mean_constant(self, n: torch.Tensor):
# avg_density_of_dataset = 0.05539856385043283
# c = 1 / avg_density_of_dataset
# return torch.pow(n * c, 1 / 3)

# def normalizing_variance_constant(self, n: torch.Tensor):
# v = 152.51649752530176 # assuming that v is the average volume of the dataset
# v = v / 6 # This is an adjustment I think will lead to more stable volumes
# return torch.pow(n * v, 1 / 3)
def normalizing_mean_constant(self, n: torch.Tensor) -> torch.Tensor:
# volume = mass (aka num_atoms) / density
# side length of mean cell (i.e. this function's return value) = cube_root(volume)
# so we need to return cube_root(num_atoms / density)

avg_density_of_dataset = 0.05539856385043283
c = 1 / avg_density_of_dataset
return torch.pow(n * c, 1 / 3)

def normalizing_variance_constant(self, n: torch.Tensor):
v = 152.51649752530176 # assuming that v is the average volume of the dataset
v = v / 6 # This is an adjustment I think will lead to more stable volumes
# I tested that with this v, the first and third quartiles of the generated angles at t=999 are around 60 and 120 degrees
# I had to test because the paper wasn't specific about how they got v
return torch.pow(n * v, 1 / 3)


def frac_to_cart_coords(
Expand Down
10 changes: 6 additions & 4 deletions diffusion/diffusion_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,15 +190,17 @@ def phi(
pred_symmetric_vector_noise,
)

def diffuse_lattice_params(self, lattice: torch.Tensor, t_int: torch.Tensor):
def diffuse_lattice_params(
self, lattice: torch.Tensor, t_int: torch.Tensor, num_atoms: torch.Tensor
):
# the diffusion happens on the symmetric positive-definite matrix part, but we will pass in vectors and receive vectors out from the model.
# This is so the model can use vector features for the equivariance

rotation_matrix, symmetric_matrix = polar_decomposition(lattice)
symmetric_matrix_vector = symmetric_matrix_to_vector(symmetric_matrix)

noisy_symmetric_vector, noise_vector = self.lattice_diffusion(
symmetric_matrix_vector, t_int
noisy_symmetric_vector, noise_vector, _mean_cell = self.lattice_diffusion(
symmetric_matrix_vector, t_int, num_atoms
)
noisy_symmetric_matrix = vector_to_symmetric_matrix(noisy_symmetric_vector)
noisy_lattice = rotation_matrix @ noisy_symmetric_matrix
Expand Down Expand Up @@ -242,7 +244,7 @@ def __call__(self, model, batch, t_emb_weights, t_int=None):
noisy_lattice,
noisy_symmetric_vector,
symmetric_vector_noise,
) = self.diffuse_lattice_params(lattice, t_int)
) = self.diffuse_lattice_params(lattice, t_int, num_atoms)

# Compute the prediction.
(pred_frac_eps_x, predicted_h0_logits, pred_symmetric_vector) = self.phi(
Expand Down
46 changes: 44 additions & 2 deletions diffusion/inference/visualize_lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def plot_edges(fig, edges, color):
)


def plot_with_parallelopied(fig, L):
def plot_with_parallelopied(fig, L, color="#0d5d85"):
v1 = L[0]
v2 = L[1]
v3 = L[2]
Expand All @@ -40,7 +40,7 @@ def plot_with_parallelopied(fig, L):
(tuple(points[3]), tuple(points[7])),
]
# Plot the edges using the helper function
plot_edges(fig, edges, "#0d5d85")
plot_edges(fig, edges, color)

return points

Expand Down Expand Up @@ -78,3 +78,45 @@ def visualize_lattice(lattice: torch.Tensor, out_path: str):
# Save the plot as a PNG file
fig.write_image(out_path)
print(f"Saved {out_path}")


def visualize_multiple_lattices(lattices: list[torch.Tensor], out_path: str):
# Create a Plotly figure
fig = go.Figure()
points = []
for i, lattice in enumerate(lattices):
if i == 0:
color = "#0d5d85"
elif i == 1:
color = "#ff0000"
else:
color = "#00ff00"
points.extend(
plot_with_parallelopied(fig, lattice.squeeze(0), color=color).tolist()
)
points = np.array(points)
smallest = np.min(points, axis=0)
largest = np.max(points, axis=0)

# Set the layout for the 3D plot
fig.update_layout(
title="Crystal Structure",
scene=dict(
xaxis_title="X",
yaxis_title="Y",
zaxis_title="Z",
),
margin=dict(l=0, r=0, b=0, t=0),
)
fig.update_layout(
scene=dict(
xaxis=dict(range=[smallest[0], largest[0]]),
yaxis=dict(range=[smallest[1], largest[1]]),
zaxis=dict(range=[smallest[2], largest[2]]),
)
)

# Save the plot as a PNG file
fig.write_image(out_path)
print(f"Saved {out_path}")
return fig
2 changes: 1 addition & 1 deletion diffusion/lattice_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def matrix_to_params(matrix: torch.Tensor) -> torch.Tensor:
1.0,
)
)
# angles = angles * 180.0 / torch.pi # convert radians to degrees
# angles = angles * 180.0 / torch.pi # convert radians to degrees
return torch.cat([lengths, angles], dim=1)


Expand Down
42 changes: 28 additions & 14 deletions exploration/verify_vp_limited_mean_and_var.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,57 @@
import pathlib
from diffusion.diffusion_helpers import (
VP_limited_mean_and_var,
VP_lattice,
polar_decomposition,
symmetric_matrix_to_vector,
vector_to_symmetric_matrix,
)
import torch
import os
from diffusion.inference.visualize_lattice import visualize_multiple_lattices

from diffusion.inference.visualize_lattice import visualize_lattice
from diffusion.lattice_helpers import matrix_to_params

OUT_DIR = f"{pathlib.Path(__file__).parent.resolve()}/../out/vp_limited_mean_and_var"


# we want to sample many lattices at a high time step to see if the lattices look realistic
def main():
os.makedirs(OUT_DIR, exist_ok=True)
vp = VP_limited_mean_and_var(num_steps=1000, s=0.0001, power=2, clipmax=0.999)
t = 999 # sample from a very high time step for maximal variance

square_lattice = torch.tensor(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]
vp = VP_lattice(num_steps=1000, s=0.0001, power=2, clipmax=0.999) # noqa: F821
t = 800 # sample from a very high time step for maximal variance

real_lattice = torch.tensor(
[
[7.864230632781982, -0.028291359543800354, 0.010975549928843975],
[-2.929015636444092, 7.298478603363037, -0.010975549928843975],
[-2.7713515758514404, -3.428903341293335, 6.512240409851074],
]
).unsqueeze(0)

for i in range(30):
rotation_matrix, symmetric_matrix = polar_decomposition(square_lattice)
num_samples = 10000
all_angles = []
for i in range(num_samples):
rotation_matrix, symmetric_matrix = polar_decomposition(real_lattice)
symmetric_matrix_vector = symmetric_matrix_to_vector(symmetric_matrix)
num_atoms = torch.tensor([15])
noisy_symmetric_vector, _symmetric_vector_noise = vp(
num_atoms = torch.tensor([8])
noisy_symmetric_vector, _symmetric_vector_noise, mean_cell = vp(
symmetric_matrix_vector, t, num_atoms
)

noisy_symmetric_matrix = vector_to_symmetric_matrix(noisy_symmetric_vector)
noisy_lattice = rotation_matrix @ noisy_symmetric_matrix
params = matrix_to_params(noisy_lattice)
angles = params[:, 3:]
all_angles.append(angles.squeeze())

visualize_lattice(noisy_lattice, f"{OUT_DIR}/{i}.png")
print(
f"noisy_symmetric_vector: {noisy_symmetric_vector} noisy_symmetric_matrix: {noisy_symmetric_matrix}"
fig = visualize_multiple_lattices(
[real_lattice, noisy_lattice, mean_cell], f"{OUT_DIR}/{i}.png"
)
fig.show()
quantiles = torch.quantile(
torch.stack(all_angles), torch.tensor([0.1, 0.25, 0.5, 0.75, 0.9])
)
print("quantiles", quantiles)


if __name__ == "__main__":
Expand Down
23 changes: 10 additions & 13 deletions exploration/view_alexandria_dataset.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
import pathlib
from diffusion.inference.visualize_crystal import (
visualize_and_save_crystal,
)
import torch
from diffusion.lattice_dataset import CrystalDataset
import os

Expand All @@ -22,20 +18,21 @@ def main():
)
os.makedirs(dataset_vis_dir, exist_ok=True)

for i in range(50):
for i in range(20000, 25000):
print(f"sample {i}")
ith_sample = dataset[i]

atomic_num = torch.argmax(ith_sample.A0, dim=1)
atomic_num = ith_sample.A0
lattice = ith_sample.L0.numpy()
frac_x = ith_sample.X0.numpy()
visualize_and_save_crystal(
atomic_num,
lattice,
frac_x,
name=f"{dataset_vis_dir}/{i}",
show_bonds=False,
)
# visualize_and_save_crystal(
# atomic_num,
# lattice,
# frac_x,
# name=f"{dataset_vis_dir}/{i}",
# show_bonds=False,
# )
print(lattice.tolist())


if __name__ == "__main__":
Expand Down