From c95e749e692134200d7742e5b2cc860f8035d031 Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Tue, 21 Sep 2021 21:30:20 +0800 Subject: [PATCH] fix linkpred models --- graphgallery/gallery/linkpred/pyg/gae.py | 2 +- graphgallery/nn/models/pyg/autoencoder/autoencoder.py | 6 +++--- graphgallery/nn/models/pytorch/autoencoder/autoencoder.py | 2 +- graphgallery/nn/models/pytorch/autoencoder/vgae.py | 5 +++-- graphgallery/nn/models/torch_keras.py | 2 +- 5 files changed, 9 insertions(+), 8 deletions(-) diff --git a/graphgallery/gallery/linkpred/pyg/gae.py b/graphgallery/gallery/linkpred/pyg/gae.py index 8074c4cd..f08dafa0 100644 --- a/graphgallery/gallery/linkpred/pyg/gae.py +++ b/graphgallery/gallery/linkpred/pyg/gae.py @@ -35,7 +35,7 @@ def model_step(self, lr=0.01, bias=False): - model = get_model("autoencoder.VGAE", self.backend) + model = get_model("autoencoder.GAE", self.backend) model = model(self.graph.num_node_attrs, out_features=out_features, hids=hids, diff --git a/graphgallery/nn/models/pyg/autoencoder/autoencoder.py b/graphgallery/nn/models/pyg/autoencoder/autoencoder.py index ed3e94e5..5578fa7d 100644 --- a/graphgallery/nn/models/pyg/autoencoder/autoencoder.py +++ b/graphgallery/nn/models/pyg/autoencoder/autoencoder.py @@ -20,7 +20,7 @@ def train_step_on_batch(self, self.train() optimizer = self.optimizer optimizer.zero_grad() - x = to_device(x, device=device) + x, _ = to_device(x, device=device) z = self.encode(*x) # here `out_index` maybe pos_edge_index # or (pos_edge_index, neg_edge_index) @@ -65,7 +65,7 @@ def test_step_on_batch(self, device="cpu"): self.eval() metrics = self.metrics - x = to_device(x, device=device) + x, _ = to_device(x, device=device) z = self.encode(*x) pred = self.decode(z, out_index) @@ -78,7 +78,7 @@ def test_step_on_batch(self, @torch.no_grad() def predict_step_on_batch(self, x, out_index=None, device="cpu"): self.eval() - x = to_device(x, device=device) + x, _ = to_device(x, device=device) z = self.encode(*x) pred = self.decode(z, out_index) return pred.cpu().detach() diff --git a/graphgallery/nn/models/pytorch/autoencoder/autoencoder.py b/graphgallery/nn/models/pytorch/autoencoder/autoencoder.py index 06853539..d22b5d04 100644 --- a/graphgallery/nn/models/pytorch/autoencoder/autoencoder.py +++ b/graphgallery/nn/models/pytorch/autoencoder/autoencoder.py @@ -27,7 +27,7 @@ def test_step_on_batch(self, x, y = to_device(x, y, device=device) z = self.encode(*x) out = self.decode(z, out_index) - loss = self.compute_loss(out, y) + loss, out = self.compute_loss(out, y) self.update_metrics(out, y) if loss is not None: diff --git a/graphgallery/nn/models/pytorch/autoencoder/vgae.py b/graphgallery/nn/models/pytorch/autoencoder/vgae.py index 57070f80..614827b3 100644 --- a/graphgallery/nn/models/pytorch/autoencoder/vgae.py +++ b/graphgallery/nn/models/pytorch/autoencoder/vgae.py @@ -63,11 +63,12 @@ def forward(self, x, adj): out = self.decode(z) return out - def compute_loss(self, out, y): + def compute_loss(self, out, y, out_index=None): + out = self.index_select(out, out_index=out_index) if self.training: mu = self.cache.pop('mu') logstd = self.cache.pop('logstd') kl_loss = -0.5 / mu.size(0) * torch.mean(torch.sum(1 + 2 * logstd - mu.pow(2) - logstd.exp().pow(2), dim=1)) else: kl_loss = 0. - return self.loss(out, y) + kl_loss + return self.loss(out, y) + kl_loss, out diff --git a/graphgallery/nn/models/torch_keras.py b/graphgallery/nn/models/torch_keras.py index 02ad98f6..c9ed1f65 100644 --- a/graphgallery/nn/models/torch_keras.py +++ b/graphgallery/nn/models/torch_keras.py @@ -115,7 +115,7 @@ def test_step_on_batch(self, @torch.no_grad() def predict_step_on_batch(self, x, out_index=None, device="cpu"): self.eval() - x = to_device(x, device=device) + x, _ = to_device(x, device=device) out = self.index_select(self(*x), out_index=out_index) return out.cpu().detach()