Skip to content

Commit

Permalink
Merge pull request #305 from alexhernandezgarcia/speedup_batch
Browse files Browse the repository at this point in the history
Speedup batch _compute_parents()
  • Loading branch information
alexhernandezgarcia authored May 6, 2024
2 parents ba8129a + 788abce commit 2167136
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
17 changes: 13 additions & 4 deletions gflownet/utils/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,10 @@ def _compute_parents(self):
"""
self.parents = []
self.parents_indices = []
indices = []

indices_dict = {}
indices_next = 0

# Iterate over the trajectories to obtain the parents from the states
for traj_idx, batch_indices in self.trajectories.items():
# parent is source
Expand All @@ -567,12 +570,18 @@ def _compute_parents(self):
# TODO: check if tensor and sort without iter
self.parents.extend([self.states[idx] for idx in batch_indices[:-1]])
self.parents_indices.extend(batch_indices[:-1])
indices.extend(batch_indices)

# Store the indices required to reorder the parents lists in the same
# order as the states
for b_idx in batch_indices:
indices_dict[b_idx] = indices_next
indices_next += 1

# Sort parents list in the same order as states
# TODO: check if tensor and sort without iter
self.parents = [self.parents[indices.index(idx)] for idx in range(len(self))]
self.parents = [self.parents[indices_dict[idx]] for idx in range(len(self))]
self.parents_indices = tlong(
[self.parents_indices[indices.index(idx)] for idx in range(len(self))],
[self.parents_indices[indices_dict[idx]] for idx in range(len(self))],
device=self.device,
)
self.parents_available = True
Expand Down
4 changes: 0 additions & 4 deletions tests/gflownet/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def test__step__returns_same_state_action_and_invalid_if_done(self, n_repeat=1):
def test__sample_actions__backward__returns_eos_if_done(
self, n_repeat=1, n_states=5
):

if _get_current_method_name() in self.n_states:
n_states = self.n_states[_get_current_method_name()]

Expand Down Expand Up @@ -96,7 +95,6 @@ def test__sample_actions__backward__returns_eos_if_done(
def test__get_logprobs__backward__returns_zero_if_done(
self, n_repeat=1, n_states=5
):

if _get_current_method_name() in self.n_states:
n_states = self.n_states[_get_current_method_name()]

Expand Down Expand Up @@ -161,7 +159,6 @@ def test__forward_actions_have_nonzero_backward_prob(self, n_repeat=1):
def test__backward_actions_have_nonzero_forward_prob(
self, n_repeat=1, n_states=100
):

if _get_current_method_name() in self.n_states:
n_states = self.n_states[_get_current_method_name()]

Expand Down Expand Up @@ -398,7 +395,6 @@ def test__state2readable__is_reversible(self, n_repeat=1):
def test__get_parents__returns_same_state_and_eos_if_done(
self, n_repeat=1, n_states=10
):

if _get_current_method_name() in self.n_states:
n_states = self.n_states[_get_current_method_name()]

Expand Down

0 comments on commit 2167136

Please sign in to comment.