Skip to content

Commit

Permalink
Merge pull request #251 from alexhernandezgarcia/fix-get_logprobs-base
Browse files Browse the repository at this point in the history
Fix: get_logprobs of base env (reason why TB was not training well)
  • Loading branch information
alexhernandezgarcia authored Nov 17, 2023
2 parents 619281d + 7ea8f9d commit cb4bf18
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion gflownet/envs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ def get_logprobs(
"""
device = policy_outputs.device
ns_range = torch.arange(policy_outputs.shape[0]).to(device)
logits = policy_outputs.clone().detach()
logits = policy_outputs.clone()
if mask is not None:
logits[mask] = -torch.inf
action_indices = (
Expand Down

0 comments on commit cb4bf18

Please sign in to comment.