Skip to content

Commit

Permalink
Fixes (cherry-pick)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexhernandezgarcia committed Nov 28, 2023
1 parent 20285ff commit a777e99
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions gflownet/gflownet.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,10 +808,10 @@ def forwardlooking_loss(self, it, batch):
- logprobs_b
+ energies_transitions
).pow(2)
loss = per_node_loss.mean()
loss_terminating = per_node_loss[done].mean()
loss_intermediate = per_node_loss[~done].mean()
return loss, term_loss, nonterm_loss
loss = loss_all.mean()
loss_terminating = loss_all[done].mean()
loss_intermediate = loss_all[~done].mean()
return loss, loss_terminating, loss_intermediate

@torch.no_grad()
def estimate_logprobs_data(
Expand Down

0 comments on commit a777e99

Please sign in to comment.