Skip to content

Commit

Permalink
Fix requires_grad for identity in TNet
Browse files Browse the repository at this point in the history
  • Loading branch information
adosar committed Dec 2, 2024
1 parent fed7f75 commit 5e8bda5
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/aidsorb/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def forward(self, x):
x = self.dense_blocks(x)

# Initialize the identity matrix.
identity = torch.eye(self.embed_dim, device=x.device, requires_grad=True).repeat(bs, 1, 1)
identity = torch.eye(self.embed_dim, device=x.device, requires_grad=x.requires_grad).repeat(bs, 1, 1)

# Output has shape (B, self.embed_dim, self.embed_dim).
x = x.view(-1, self.embed_dim, self.embed_dim) + identity
Expand Down

0 comments on commit 5e8bda5

Please sign in to comment.