From 1a4f0f6cca6e32b13cee98fc591fd474b45e985f Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Wed, 17 Apr 2024 14:31:28 -0400 Subject: [PATCH] function to visualize the new lattice relative to the original lattice --- diffusion/inference/visualize_lattice.py | 43 ++++++++++++++++++- exploration/verify_vp_limited_mean_and_var.py | 8 +++- 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/diffusion/inference/visualize_lattice.py b/diffusion/inference/visualize_lattice.py index dff802d..08e5de9 100644 --- a/diffusion/inference/visualize_lattice.py +++ b/diffusion/inference/visualize_lattice.py @@ -17,7 +17,7 @@ def plot_edges(fig, edges, color): ) -def plot_with_parallelopied(fig, L): +def plot_with_parallelopied(fig, L, color="#0d5d85"): v1 = L[0] v2 = L[1] v3 = L[2] @@ -40,7 +40,7 @@ def plot_with_parallelopied(fig, L): (tuple(points[3]), tuple(points[7])), ] # Plot the edges using the helper function - plot_edges(fig, edges, "#0d5d85") + plot_edges(fig, edges, color) return points @@ -78,3 +78,42 @@ def visualize_lattice(lattice: torch.Tensor, out_path: str): # Save the plot as a PNG file fig.write_image(out_path) print(f"Saved {out_path}") + + +def visualize_multiple_lattices(lattices: list[torch.Tensor], out_path: str): + # Create a Plotly figure + fig = go.Figure() + points = [] + for i, lattice in enumerate(lattices): + if i == 0: + color = "#0d5d85" + else: + color = "#ffffff" + points.extend( + plot_with_parallelopied(fig, lattice.squeeze(0), color=color).tolist() + ) + points = np.array(points) + smallest = np.min(points, axis=0) + largest = np.max(points, axis=0) + + # Set the layout for the 3D plot + fig.update_layout( + title="Crystal Structure", + scene=dict( + xaxis_title="X", + yaxis_title="Y", + zaxis_title="Z", + ), + margin=dict(l=0, r=0, b=0, t=0), + ) + fig.update_layout( + scene=dict( + xaxis=dict(range=[smallest[0], largest[0]]), + yaxis=dict(range=[smallest[1], largest[1]]), + zaxis=dict(range=[smallest[2], largest[2]]), + ) + ) + + # Save the plot as a PNG file + fig.write_image(out_path) + print(f"Saved {out_path}") diff --git a/exploration/verify_vp_limited_mean_and_var.py b/exploration/verify_vp_limited_mean_and_var.py index 07248bd..f09f5df 100644 --- a/exploration/verify_vp_limited_mean_and_var.py +++ b/exploration/verify_vp_limited_mean_and_var.py @@ -8,7 +8,9 @@ import torch import os -from diffusion.inference.visualize_lattice import visualize_lattice +from diffusion.inference.visualize_lattice import ( + visualize_multiple_lattices, +) OUT_DIR = f"{pathlib.Path(__file__).parent.resolve()}/../out/vp_limited_mean_and_var" @@ -34,7 +36,9 @@ def main(): noisy_symmetric_matrix = vector_to_symmetric_matrix(noisy_symmetric_vector) noisy_lattice = rotation_matrix @ noisy_symmetric_matrix - visualize_lattice(noisy_lattice, f"{OUT_DIR}/{i}.png") + visualize_multiple_lattices( + [square_lattice, noisy_lattice], f"{OUT_DIR}/{i}.png" + ) print( f"noisy_symmetric_vector: {noisy_symmetric_vector} noisy_symmetric_matrix: {noisy_symmetric_matrix}" )