Skip to content

Commit

Permalink
add mean cell back to the return value
Browse files Browse the repository at this point in the history
  • Loading branch information
curtischong committed Apr 30, 2024
1 parent 08a8e83 commit 54bcd51
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion diffusion/diffusion_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def forward(self, l0_vector, t, num_atoms):
# * eps
# )
ht = mean + variance
return ht, eps
return ht, eps, mean_cell

def reverse(self, lt, predicted_symmetric_vector_noise, t):
alpha = 1 - self.betas[t]
Expand Down
2 changes: 1 addition & 1 deletion diffusion/diffusion_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def diffuse_lattice_params(
rotation_matrix, symmetric_matrix = polar_decomposition(lattice)
symmetric_matrix_vector = symmetric_matrix_to_vector(symmetric_matrix)

noisy_symmetric_vector, noise_vector = self.lattice_diffusion(
noisy_symmetric_vector, noise_vector, _mean_cell = self.lattice_diffusion(
symmetric_matrix_vector, t_int, num_atoms
)
noisy_symmetric_matrix = vector_to_symmetric_matrix(noisy_symmetric_vector)
Expand Down
2 changes: 1 addition & 1 deletion exploration/verify_vp_limited_mean_and_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def main():
symmetric_matrix_vector = symmetric_matrix_to_vector(symmetric_matrix)
num_atoms = torch.tensor([8])
noisy_symmetric_vector, _symmetric_vector_noise, mean_cell = vp(
symmetric_matrix_vector, t, num_atoms, real_lattice
symmetric_matrix_vector, t, num_atoms
)

noisy_symmetric_matrix = vector_to_symmetric_matrix(noisy_symmetric_vector)
Expand Down

0 comments on commit 54bcd51

Please sign in to comment.