Skip to content

Commit

Permalink
Merge pull request #259 from alexhernandezgarcia/fl-loss-ahg
Browse files Browse the repository at this point in the history
FL loss cosmetics
  • Loading branch information
alexhernandezgarcia authored Nov 28, 2023
2 parents 1b67f32 + 660d5ef commit 21e38ac
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 15 deletions.
48 changes: 36 additions & 12 deletions gflownet/utils/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,27 +525,37 @@ def get_parents(

def get_parents_indices(self):
"""
Returns indices of the parents of the states in the batch.
Each index corresponds to the position of the patent in the self.states tensor, if it is peresent there.
If a parent is not present in self.states (i.e. it is source), the corresponding index is -1
Returns the indices of the parents of the states in the batch.
Each i-th item in the returned list contains the index in self.states that
contains the parent of self.states[i], if it is present there. If a parent
is not present in self.states (because it is the source), the index is -1.
Returns
-------
self.parents_indices
The indices in self.states of the parents of self.states.
"""
if self.parents_available is False:
self._compute_parents()
return self.parents_indices

def _compute_parents(self):
"""
Obtains the parent (single parent for each state) of all states in the batch and its index.
Obtains the parent (single parent for each state) of all states in the batch
and its index.
The parents are computed, obtaining all necessary components, if they are not
readily available. Missing components and newly computed components are added
to the batch (self.component is set). The following variable is stored:
to the batch (self.component is set). The following variables are stored:
- self.parents: the parent of each state in the batch. It will be the same type
as self.states (list of lists or tensor)
Length: n_states
Shape: [n_states, state_dims]
- self.parents_indices: the position of each parent in self.states tensor.
If a parent is not present in self.states (i.e. it is source), the corresponding index is -1
- self.parents_indices: the position of each parent in self.states tensor. If a
parent is not present in self.states (i.e. it is source), the corresponding
index is -1.
self.parents_available is set to True.
"""
Expand Down Expand Up @@ -887,16 +897,23 @@ def _compute_rewards(self, do_non_terminating: Optional[bool] = False):

def get_rewards_parents(self) -> TensorType["n_states"]:
"""
Returns the rewards of all parents in the batch
Returns the rewards of all parents in the batch.
Returns
-------
self.rewards_parents
A tensor containing the rewards of the parents of self.states.
"""
if not self.rewards_parents_available:
self._compute_rewards_parents()
return self.rewards_parents

def _compute_rewards_parents(self):
"""
Computes rewards of the self.parents by reusing rewards of the states (i.e. self.rewards).
Stores the result in self.rewards_parents
Computes the rewards of self.parents by reusing the rewards of the states
(self.rewards).
Stores the result in self.rewards_parents.
"""
# TODO: this may return zero rewards for all parents if before
# rewards for states were computed with do_non_terminating=False
Expand All @@ -914,15 +931,22 @@ def _compute_rewards_parents(self):
def get_rewards_source(self) -> TensorType["n_states"]:
"""
Returns rewards of the corresponding source states for each state in the batch.
Returns
-------
self.rewards_source
A tensor containing the rewards the source states.
"""
if not self.rewards_source_available:
self._compute_rewards_source()
return self.rewards_source

def _compute_rewards_source(self):
"""
Computes a tensor of length len(self.states) with rewards of the corresponding source states.
Stores the result in self.rewards_source
Computes a tensor of length len(self.states) with the rewards of the
corresponding source states.
Stores the result in self.rewards_source.
"""
# This will not work if source is randomised
if not self.conditional:
Expand Down
6 changes: 3 additions & 3 deletions tests/gflownet/utils/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1338,7 +1338,7 @@ def test__make_indices_consecutive__multiplied_indices_become_consecutive(
[("grid2d", "corners"), ("tetris6x4", "tetris_score"), ("ctorus2d5l", "corners")],
)
# @pytest.mark.skip(reason="skip while developping other tests")
def test__get_rewards__single_env_returns_expected_non_terminal(
def test__get_rewards__single_env_returns_expected_non_terminating(
env, proxy, batch, request
):
env = request.getfixturevalue(env)
Expand Down Expand Up @@ -1371,7 +1371,7 @@ def test__get_rewards__single_env_returns_expected_non_terminal(
"env, proxy",
[("grid2d", "corners"), ("tetris6x4", "tetris_score_norm")],
)
def test__get_rewards_multiple_env_returns_expected_non_zero_non_terminal(
def test__get_rewards_multiple_env_returns_expected_non_zero_non_terminating(
env, proxy, batch, request
):
batch_size = BATCH_SIZE
Expand Down Expand Up @@ -1434,7 +1434,7 @@ def test__get_rewards_multiple_env_returns_expected_non_zero_non_terminal(
("ctorus2d5l", "corners"),
],
)
def test__get_rewards_parents_multiple_env_returns_expected_non_terminal(
def test__get_rewards_parents_multiple_env_returns_expected_non_terminating(
env, proxy, batch, request
):
batch_size = BATCH_SIZE
Expand Down

0 comments on commit 21e38ac

Please sign in to comment.