Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
zechengz committed Sep 30, 2024
1 parent beed5d8 commit d304462
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions examples/ngcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def sparse_dropout(self, row: Tensor, col: Tensor, value: Tensor,
rate: float, nnz: int) -> SparseTensor:
rand = 1 - rate
rand += torch.rand(nnz)
assert isinstance(rand, Tensor)
dropout_mask = torch.floor(rand).type(torch.bool)
adj = SparseTensor(
row=row[dropout_mask],
Expand All @@ -105,9 +106,10 @@ def sparse_dropout(self, row: Tensor, col: Tensor, value: Tensor,
)
return adj

def get_embedding(self, norm_adj: Tensor, device=torch.device) -> Tensor:
def get_embedding(self, norm_adj: SparseTensor,
device=torch.device) -> Tensor:
ego_emb = self.emb.weight
all_embs = [ego_emb]
all_embs: List[Tensor] = [ego_emb]
if self.node_dropout > 0 and self.training:
row, col, value = norm_adj.coo()
adj = self.sparse_dropout(
Expand All @@ -129,7 +131,7 @@ def get_embedding(self, norm_adj: Tensor, device=torch.device) -> Tensor:
ego_emb = F.dropout(ego_emb)
norm_emb = F.normalize(ego_emb, p=2, dim=1)
all_embs += [norm_emb]
all_embs = torch.cat(all_embs, 1)
all_embs: Tensor = torch.cat(all_embs, 1)
return all_embs

def recommendation_loss(
Expand Down

0 comments on commit d304462

Please sign in to comment.