Skip to content

Commit

Permalink
add mlp to device
Browse files Browse the repository at this point in the history
  • Loading branch information
engmubarak48 committed Sep 23, 2024
1 parent 2bf438a commit 9395e61
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions gflownet/policy/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def make_model(self, activation: nn.Module = nn.LeakyReLU()):
is_model : bool
True because an MLP is a model.
"""
activation.to(self.device)

if self.shared_weights == True and self.base is not None:
mlp = nn.Sequential(
Expand Down Expand Up @@ -66,7 +65,7 @@ def make_model(self, activation: nn.Module = nn.LeakyReLU()):
+ self.tail
)
)
return mlp, True
return mlp.to(self.device), True
else:
raise ValueError(
"Base Model must be provided when shared_weights is set to True"
Expand Down

0 comments on commit 9395e61

Please sign in to comment.