diff --git a/src/aidsorb/modules.py b/src/aidsorb/modules.py index aa85531..46f9d34 100644 --- a/src/aidsorb/modules.py +++ b/src/aidsorb/modules.py @@ -207,7 +207,7 @@ def forward(self, x): x = self.dense_blocks(x) # Initialize the identity matrix. - identity = torch.eye(self.embed_dim, device=x.device, requires_grad=True).repeat(bs, 1, 1) + identity = torch.eye(self.embed_dim, device=x.device, requires_grad=x.requires_grad).repeat(bs, 1, 1) # Output has shape (B, self.embed_dim, self.embed_dim). x = x.view(-1, self.embed_dim, self.embed_dim) + identity