From d3044624b5153220fb0cceebd07bdc7d592ce864 Mon Sep 17 00:00:00 2001 From: Zecheng Zhang Date: Mon, 30 Sep 2024 00:12:43 +0000 Subject: [PATCH] Update --- examples/ngcf.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/ngcf.py b/examples/ngcf.py index 64ffc19..7f70f64 100644 --- a/examples/ngcf.py +++ b/examples/ngcf.py @@ -96,6 +96,7 @@ def sparse_dropout(self, row: Tensor, col: Tensor, value: Tensor, rate: float, nnz: int) -> SparseTensor: rand = 1 - rate rand += torch.rand(nnz) + assert isinstance(rand, Tensor) dropout_mask = torch.floor(rand).type(torch.bool) adj = SparseTensor( row=row[dropout_mask], @@ -105,9 +106,10 @@ def sparse_dropout(self, row: Tensor, col: Tensor, value: Tensor, ) return adj - def get_embedding(self, norm_adj: Tensor, device=torch.device) -> Tensor: + def get_embedding(self, norm_adj: SparseTensor, + device=torch.device) -> Tensor: ego_emb = self.emb.weight - all_embs = [ego_emb] + all_embs: List[Tensor] = [ego_emb] if self.node_dropout > 0 and self.training: row, col, value = norm_adj.coo() adj = self.sparse_dropout( @@ -129,7 +131,7 @@ def get_embedding(self, norm_adj: Tensor, device=torch.device) -> Tensor: ego_emb = F.dropout(ego_emb) norm_emb = F.normalize(ego_emb, p=2, dim=1) all_embs += [norm_emb] - all_embs = torch.cat(all_embs, 1) + all_embs: Tensor = torch.cat(all_embs, 1) return all_embs def recommendation_loss(