Skip to content

Commit

Permalink
function to visualize the new lattice relative to the original lattice
Browse files Browse the repository at this point in the history
  • Loading branch information
curtischong committed Apr 17, 2024
1 parent ff1dbd8 commit 1a4f0f6
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 4 deletions.
43 changes: 41 additions & 2 deletions diffusion/inference/visualize_lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def plot_edges(fig, edges, color):
)


def plot_with_parallelopied(fig, L):
def plot_with_parallelopied(fig, L, color="#0d5d85"):
v1 = L[0]
v2 = L[1]
v3 = L[2]
Expand All @@ -40,7 +40,7 @@ def plot_with_parallelopied(fig, L):
(tuple(points[3]), tuple(points[7])),
]
# Plot the edges using the helper function
plot_edges(fig, edges, "#0d5d85")
plot_edges(fig, edges, color)

return points

Expand Down Expand Up @@ -78,3 +78,42 @@ def visualize_lattice(lattice: torch.Tensor, out_path: str):
# Save the plot as a PNG file
fig.write_image(out_path)
print(f"Saved {out_path}")


def visualize_multiple_lattices(lattices: list[torch.Tensor], out_path: str):
# Create a Plotly figure
fig = go.Figure()
points = []
for i, lattice in enumerate(lattices):
if i == 0:
color = "#0d5d85"
else:
color = "#ffffff"
points.extend(
plot_with_parallelopied(fig, lattice.squeeze(0), color=color).tolist()
)
points = np.array(points)
smallest = np.min(points, axis=0)
largest = np.max(points, axis=0)

# Set the layout for the 3D plot
fig.update_layout(
title="Crystal Structure",
scene=dict(
xaxis_title="X",
yaxis_title="Y",
zaxis_title="Z",
),
margin=dict(l=0, r=0, b=0, t=0),
)
fig.update_layout(
scene=dict(
xaxis=dict(range=[smallest[0], largest[0]]),
yaxis=dict(range=[smallest[1], largest[1]]),
zaxis=dict(range=[smallest[2], largest[2]]),
)
)

# Save the plot as a PNG file
fig.write_image(out_path)
print(f"Saved {out_path}")
8 changes: 6 additions & 2 deletions exploration/verify_vp_limited_mean_and_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import torch
import os

from diffusion.inference.visualize_lattice import visualize_lattice
from diffusion.inference.visualize_lattice import (
visualize_multiple_lattices,
)

OUT_DIR = f"{pathlib.Path(__file__).parent.resolve()}/../out/vp_limited_mean_and_var"

Expand All @@ -34,7 +36,9 @@ def main():
noisy_symmetric_matrix = vector_to_symmetric_matrix(noisy_symmetric_vector)
noisy_lattice = rotation_matrix @ noisy_symmetric_matrix

visualize_lattice(noisy_lattice, f"{OUT_DIR}/{i}.png")
visualize_multiple_lattices(
[square_lattice, noisy_lattice], f"{OUT_DIR}/{i}.png"
)
print(
f"noisy_symmetric_vector: {noisy_symmetric_vector} noisy_symmetric_matrix: {noisy_symmetric_matrix}"
)
Expand Down

0 comments on commit 1a4f0f6

Please sign in to comment.