Skip to content

Commit

Permalink
Merge pull request #260 from alexhernandezgarcia/fl-loss-ahg
Browse files Browse the repository at this point in the history
FL loss final cosmetics
  • Loading branch information
alexhernandezgarcia authored Nov 28, 2023
2 parents 7a6bcdb + a777e99 commit f7e5ec3
Showing 1 changed file with 37 additions and 38 deletions.
75 changes: 37 additions & 38 deletions gflownet/gflownet.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,13 +730,13 @@ def detailedbalance_loss(self, it, batch):
)

# Get logflows
logflow_states = self.state_flow(states_policy)
logflow_states[done.eq(1)] = torch.log(rewards)
# TODO: Optimise by reusing logflow_states and batch.get_parent_indices
logflow_parents = self.state_flow(parents_policy)
logflows_states = self.state_flow(states_policy)
logflows_states[done.eq(1)] = torch.log(rewards)
# TODO: Optimise by reusing logflows_states and batch.get_parent_indices
logflows_parents = self.state_flow(parents_policy)

# Detailed balance loss
loss_all = (logflow_parents + logprobs_f - logflow_states - logprobs_b).pow(2)
loss_all = (logflows_parents + logprobs_f - logflows_states - logprobs_b).pow(2)
loss = loss_all.mean()
loss_terminating = loss_all[done].mean()
loss_intermediate = loss_all[~done].mean()
Expand Down Expand Up @@ -769,50 +769,49 @@ def forwardlooking_loss(self, it, batch):

assert batch.is_valid()
# Get necessary tensors from batch
states_policy = batch.get_states(policy=True)
states = batch.get_states(policy=False)
states_policy = batch.get_states(policy=True)
actions = batch.get_actions()
parents_policy = batch.get_parents(policy=True)
parents = batch.get_parents(policy=False)
traj_indices = batch.get_trajectory_indices(consecutive=True)
parents_policy = batch.get_parents(policy=True)
rewards_states = batch.get_rewards(do_non_terminating=True)
rewards_parents = batch.get_rewards_parents()
done = batch.get_done()

masks_b = batch.get_masks_backward()
policy_output_b = self.backward_policy(states_policy)
logprobs_bkw = self.env.get_logprobs(
policy_output_b, actions, masks_b, states, is_backward=True
)
# Get logprobs
masks_f = batch.get_masks_forward(of_parents=True)
policy_output_f = self.forward_policy(parents_policy)
logprobs_fwd = self.env.get_logprobs(
logprobs_f = self.env.get_logprobs(
policy_output_f, actions, masks_f, parents, is_backward=False
)
masks_b = batch.get_masks_backward()
policy_output_b = self.backward_policy(states_policy)
logprobs_b = self.env.get_logprobs(
policy_output_b, actions, masks_b, states, is_backward=True
)

states_log_flflow = self.state_flow(states_policy)
# forward-looking flow is 1 in the terminal states
states_log_flflow[done.eq(1)] = 0.0
# Can be optimised by reusing states_log_flflow and batch.get_parent_indices
parents_log_flflow = self.state_flow(parents_policy)

rewards_states = batch.get_rewards(do_non_terminating=True)
rewards_parents = batch.get_rewards_parents()
energies_states = -torch.log(rewards_states)
energies_parents = -torch.log(rewards_parents)

per_node_loss = (
parents_log_flflow
- states_log_flflow
+ logprobs_fwd
- logprobs_bkw
+ energies_states
- energies_parents
# Get FL logflows
logflflows_states = self.state_flow(states_policy)
# Log FL flow of terminal states is 0 (eq. 9 of paper)
logflflows_states[done.eq(1)] = 0.0
# TODO: Optimise by reusing logflows_states and batch.get_parent_indices
logflflows_parents = self.state_flow(parents_policy)

# Get energies transitions
energies_transitions = torch.log(rewards_parents) - torch.log(rewards_states)

# Forward-looking loss
loss_all = (
logflflows_parents
- logflflows_states
+ logprobs_f
- logprobs_b
+ energies_transitions
).pow(2)

term_loss = per_node_loss[done].mean()
nonterm_loss = per_node_loss[~done].mean()
loss = per_node_loss.mean()

return loss, term_loss, nonterm_loss
loss = loss_all.mean()
loss_terminating = loss_all[done].mean()
loss_intermediate = loss_all[~done].mean()
return loss, loss_terminating, loss_intermediate

@torch.no_grad()
def estimate_logprobs_data(
Expand Down

0 comments on commit f7e5ec3

Please sign in to comment.