diff --git a/diffusion/diffusion_helpers.py b/diffusion/diffusion_helpers.py index a3e1c84..69bcbec 100644 --- a/diffusion/diffusion_helpers.py +++ b/diffusion/diffusion_helpers.py @@ -150,13 +150,13 @@ def __init__(self, num_steps=1000, s=0.0001, power=2, clipmax=0.999): self.register_buffer("betas", betas) self.register_buffer("sigmas", sigmas) - def forward(self, h0, t): + def forward(self, h0, t, num_atoms): alpha_bar = self.alpha_bars[t] eps = torch.randn_like(h0) - ht = ( - torch.sqrt(alpha_bar).view(-1, 1) * h0 - + torch.sqrt(1 - alpha_bar).view(-1, 1) * eps - ) + mean = torch.sqrt(alpha_bar).view(-1, 1) * h0 + # variance = torch.sqrt(1 - alpha_bar).view(-1, 1)* eps + variance = (1 - alpha_bar).view(-1, 1) * eps + ht = mean + variance return ht, eps def reverse(self, lt, predicted_symmetric_vector_noise, t): @@ -172,21 +172,26 @@ def reverse(self, lt, predicted_symmetric_vector_noise, t): torch.zeros_like(lt), ) + # return (1.0 / torch.sqrt(alpha + EPSILON)).view(-1, 1) * ( + # lt + # - ((1 - alpha) / torch.sqrt(1 - alpha_bar + EPSILON)).view(-1, 1) + # * predicted_symmetric_vector_noise + # ) + sigma * z return (1.0 / torch.sqrt(alpha + EPSILON)).view(-1, 1) * ( lt - - ((1 - alpha) / torch.sqrt(1 - alpha_bar + EPSILON)).view(-1, 1) + - ((1 - alpha) / (1 - alpha_bar)).view(-1, 1) * predicted_symmetric_vector_noise ) + sigma * z - # def normalizing_mean_constant(self, n: torch.Tensor): - # avg_density_of_dataset = 0.05539856385043283 - # c = 1 / avg_density_of_dataset - # return torch.pow(n * c, 1 / 3) + def normalizing_mean_constant(self, n: torch.Tensor): + avg_density_of_dataset = 0.05539856385043283 + c = 1 / avg_density_of_dataset + return torch.pow(n * c, 1 / 3) - # def normalizing_variance_constant(self, n: torch.Tensor): - # v = 152.51649752530176 # assuming that v is the average volume of the dataset - # v = v / 6 # This is an adjustment I think will lead to more stable volumes - # return torch.pow(n * v, 1 / 3) + def normalizing_variance_constant(self, n: torch.Tensor): + v = 152.51649752530176 # assuming that v is the average volume of the dataset + # v = v / 6 # This is an adjustment I think will lead to more stable volumes + return torch.pow(n * v, 1 / 3) def frac_to_cart_coords( diff --git a/exploration/verify_vp_limited_mean_and_var.py b/exploration/verify_vp_limited_mean_and_var.py index f2af00d..07248bd 100644 --- a/exploration/verify_vp_limited_mean_and_var.py +++ b/exploration/verify_vp_limited_mean_and_var.py @@ -26,7 +26,7 @@ def main(): for i in range(30): rotation_matrix, symmetric_matrix = polar_decomposition(square_lattice) symmetric_matrix_vector = symmetric_matrix_to_vector(symmetric_matrix) - num_atoms = torch.tensor([15]) + num_atoms = torch.tensor([8]) noisy_symmetric_vector, _symmetric_vector_noise = vp( symmetric_matrix_vector, t, num_atoms )