Skip to content

Commit

Permalink
clean
Browse files Browse the repository at this point in the history
  • Loading branch information
akihironitta committed Nov 14, 2024
1 parent d79107f commit 040780f
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions benchmark/relbench_multi_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,17 @@
from torch_geometric.utils import coalesce
from tqdm import tqdm

total_optimization_steps = 0

# The metric for link prediction
LINK_PREDICTION_METRIC = 'link_prediction_map'
# The absolute tolerance for validation map for early stopping
VAL_MAP_ATOL = 0.001
# At least, run MIN_EPOCHS epochs before early stopping
# The minimum number of epochs to run before early stopping
MIN_EPOCHS = 10

# The total number of optimization step count for annealing
total_optimization_steps = 0


class MultiVAE(torch.nn.Module):
def __init__(
Expand All @@ -46,7 +51,7 @@ def __init__(
torch.nn.Linear(d_in, d_out) for d_in, d_out in zip(
self.q_dims[:-1],
# NOTE: Double the last dim of the encoder for mean and logvar,
# i.e., [q0, ..., qn] -> [q0, ..., qn*2].
# i.e., [q1, ..., qn] -> [q1, ..., qn*2].
self.q_dims[1:-1] + [self.q_dims[-1] * 2],
)
])
Expand Down Expand Up @@ -75,16 +80,15 @@ def encode(self, x: Tensor) -> tuple[Tensor, Tensor]:
return mu, logvar

def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
"""Returns a sample and mean of z from q(z|x) during training and
"""Returns a sample from q(z|x) and the mean during training and
inference, respectively.
"""
if self.training:
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std) # sample more if necessary
return mu + eps * std
else:
# Use mean for inference
return mu
return mu # Use mean for inference

def decode(self, z: Tensor) -> Tensor:
"""Decodes z to a probability distribution over items."""
Expand Down

0 comments on commit 040780f

Please sign in to comment.