Skip to content

Commit

Permalink
Merge pull request #256 from alexhernandezgarcia/db-loss
Browse files Browse the repository at this point in the history
Detailed Balance loss
  • Loading branch information
alexhernandezgarcia authored Nov 28, 2023
2 parents 21e38ac + 798b5ef commit 7a6bcdb
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 1 deletion.
9 changes: 9 additions & 0 deletions config/gflownet/detailedbalance.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
defaults:
- gflownet
- state_flow: mlp

optimizer:
loss: detailedbalance
lr: 0.0001
lr_decay_period: 1000000
lr_decay_gamma: 0.5
7 changes: 7 additions & 0 deletions config/policy/mlp_detailedbalance.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
defaults:
- mlp

backward:
shared_weights: True
checkpoint: null
reload_ckpt: False
67 changes: 66 additions & 1 deletion gflownet/gflownet.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def __init__(
elif optimizer.loss in ["trajectorybalance", "tb"]:
self.loss = "trajectorybalance"
self.logZ = nn.Parameter(torch.ones(optimizer.z_dim) * 150.0 / 64)
elif optimizer.loss in ["detailedbalance", "db"]:
self.loss = "detailedbalance"
self.logZ = None
elif optimizer.loss in ["forwardlooking", "fl"]:
self.loss = "forwardlooking"
self.logZ = None
Expand Down Expand Up @@ -198,7 +201,7 @@ def parameters(self):
raise ValueError("Backward Policy cannot be a model in flowmatch.")
parameters += list(self.backward_policy.model.parameters())
if self.state_flow is not None:
if self.loss != "forwardlooking":
if self.loss not in ["detailedbalance", "forwardlooking"]:
raise ValueError(f"State flow cannot be trained with {self.loss} loss.")
parameters += list(self.state_flow.model.parameters())
return parameters
Expand Down Expand Up @@ -679,6 +682,66 @@ def trajectorybalance_loss(self, it, batch):
)
return loss, loss, loss

def detailedbalance_loss(self, it, batch):
"""
Computes the Detailed Balance GFlowNet loss of a batch
Reference : https://arxiv.org/pdf/2201.13259.pdf (eq 11)
Args
----
it : int
Iteration
batch : Batch
A batch of data, containing all the states in the trajectories.
Returns
-------
loss : float
term_loss : float
Loss of the terminal nodes only
nonterm_loss : float
Loss of the intermediate nodes only
"""

assert batch.is_valid()
# Get necessary tensors from batch
states = batch.get_states(policy=False)
states_policy = batch.get_states(policy=True)
actions = batch.get_actions()
parents = batch.get_parents(policy=False)
parents_policy = batch.get_parents(policy=True)
done = batch.get_done()
rewards = batch.get_terminating_rewards(sort_by="insertion")

# Get logprobs
masks_f = batch.get_masks_forward(of_parents=True)
policy_output_f = self.forward_policy(parents_policy)
logprobs_f = self.env.get_logprobs(
policy_output_f, actions, masks_f, parents, is_backward=False
)
masks_b = batch.get_masks_backward()
policy_output_b = self.backward_policy(states_policy)
logprobs_b = self.env.get_logprobs(
policy_output_b, actions, masks_b, states, is_backward=True
)

# Get logflows
logflow_states = self.state_flow(states_policy)
logflow_states[done.eq(1)] = torch.log(rewards)
# TODO: Optimise by reusing logflow_states and batch.get_parent_indices
logflow_parents = self.state_flow(parents_policy)

# Detailed balance loss
loss_all = (logflow_parents + logprobs_f - logflow_states - logprobs_b).pow(2)
loss = loss_all.mean()
loss_terminating = loss_all[done].mean()
loss_intermediate = loss_all[~done].mean()
return loss, loss_terminating, loss_intermediate

def forwardlooking_loss(self, it, batch):
"""
Computes the Forward-Looking GFlowNet loss of a batch
Expand Down Expand Up @@ -957,6 +1020,8 @@ def train(self):
losses = self.trajectorybalance_loss(
it * self.ttsr + j, batch
) # returns (opt loss, *metrics)
elif self.loss == "detailedbalance":
losses = self.detailedbalance_loss(it * self.ttsr + j, batch)
elif self.loss == "forwardlooking":
losses = self.forwardlooking_loss(it * self.ttsr + j, batch)
else:
Expand Down

0 comments on commit 7a6bcdb

Please sign in to comment.