diff --git a/benchmark/relbench_multi_vae.py b/benchmark/relbench_multi_vae.py index ee6cb48..0bba9f9 100644 --- a/benchmark/relbench_multi_vae.py +++ b/benchmark/relbench_multi_vae.py @@ -14,12 +14,17 @@ from torch_geometric.utils import coalesce from tqdm import tqdm -total_optimization_steps = 0 + +# The metric for link prediction LINK_PREDICTION_METRIC = 'link_prediction_map' +# The absolute tolerance for validation map for early stopping VAL_MAP_ATOL = 0.001 -# At least, run MIN_EPOCHS epochs before early stopping +# The minimum number of epochs to run before early stopping MIN_EPOCHS = 10 +# The total number of optimization step count for annealing +total_optimization_steps = 0 + class MultiVAE(torch.nn.Module): def __init__( @@ -46,7 +51,7 @@ def __init__( torch.nn.Linear(d_in, d_out) for d_in, d_out in zip( self.q_dims[:-1], # NOTE: Double the last dim of the encoder for mean and logvar, - # i.e., [q0, ..., qn] -> [q0, ..., qn*2]. + # i.e., [q1, ..., qn] -> [q1, ..., qn*2]. self.q_dims[1:-1] + [self.q_dims[-1] * 2], ) ]) @@ -75,7 +80,7 @@ def encode(self, x: Tensor) -> tuple[Tensor, Tensor]: return mu, logvar def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: - """Returns a sample and mean of z from q(z|x) during training and + """Returns a sample from q(z|x) and the mean during training and inference, respectively. """ if self.training: @@ -83,8 +88,7 @@ def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: eps = torch.randn_like(std) # sample more if necessary return mu + eps * std else: - # Use mean for inference - return mu + return mu # Use mean for inference def decode(self, z: Tensor) -> Tensor: """Decodes z to a probability distribution over items."""