Skip to content

Commit

Permalink
updated equations to match the mattergen mean and variance
Browse files Browse the repository at this point in the history
  • Loading branch information
curtischong committed Apr 17, 2024
1 parent f11e482 commit ff1dbd8
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 15 deletions.
33 changes: 19 additions & 14 deletions diffusion/diffusion_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,13 @@ def __init__(self, num_steps=1000, s=0.0001, power=2, clipmax=0.999):
self.register_buffer("betas", betas)
self.register_buffer("sigmas", sigmas)

def forward(self, h0, t):
def forward(self, h0, t, num_atoms):
alpha_bar = self.alpha_bars[t]
eps = torch.randn_like(h0)
ht = (
torch.sqrt(alpha_bar).view(-1, 1) * h0
+ torch.sqrt(1 - alpha_bar).view(-1, 1) * eps
)
mean = torch.sqrt(alpha_bar).view(-1, 1) * h0
# variance = torch.sqrt(1 - alpha_bar).view(-1, 1)* eps
variance = (1 - alpha_bar).view(-1, 1) * eps
ht = mean + variance
return ht, eps

def reverse(self, lt, predicted_symmetric_vector_noise, t):
Expand All @@ -172,21 +172,26 @@ def reverse(self, lt, predicted_symmetric_vector_noise, t):
torch.zeros_like(lt),
)

# return (1.0 / torch.sqrt(alpha + EPSILON)).view(-1, 1) * (
# lt
# - ((1 - alpha) / torch.sqrt(1 - alpha_bar + EPSILON)).view(-1, 1)
# * predicted_symmetric_vector_noise
# ) + sigma * z
return (1.0 / torch.sqrt(alpha + EPSILON)).view(-1, 1) * (
lt
- ((1 - alpha) / torch.sqrt(1 - alpha_bar + EPSILON)).view(-1, 1)
- ((1 - alpha) / (1 - alpha_bar)).view(-1, 1)
* predicted_symmetric_vector_noise
) + sigma * z

# def normalizing_mean_constant(self, n: torch.Tensor):
# avg_density_of_dataset = 0.05539856385043283
# c = 1 / avg_density_of_dataset
# return torch.pow(n * c, 1 / 3)
def normalizing_mean_constant(self, n: torch.Tensor):
avg_density_of_dataset = 0.05539856385043283
c = 1 / avg_density_of_dataset
return torch.pow(n * c, 1 / 3)

# def normalizing_variance_constant(self, n: torch.Tensor):
# v = 152.51649752530176 # assuming that v is the average volume of the dataset
# v = v / 6 # This is an adjustment I think will lead to more stable volumes
# return torch.pow(n * v, 1 / 3)
def normalizing_variance_constant(self, n: torch.Tensor):
v = 152.51649752530176 # assuming that v is the average volume of the dataset
# v = v / 6 # This is an adjustment I think will lead to more stable volumes
return torch.pow(n * v, 1 / 3)


def frac_to_cart_coords(
Expand Down
2 changes: 1 addition & 1 deletion exploration/verify_vp_limited_mean_and_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def main():
for i in range(30):
rotation_matrix, symmetric_matrix = polar_decomposition(square_lattice)
symmetric_matrix_vector = symmetric_matrix_to_vector(symmetric_matrix)
num_atoms = torch.tensor([15])
num_atoms = torch.tensor([8])
noisy_symmetric_vector, _symmetric_vector_noise = vp(
symmetric_matrix_vector, t, num_atoms
)
Expand Down

0 comments on commit ff1dbd8

Please sign in to comment.