Skip to content

Commit

Permalink
Refactor get_sample_space_and_reward
Browse files Browse the repository at this point in the history
  • Loading branch information
alexhernandezgarcia committed Jul 22, 2024
1 parent aebc5f9 commit 2a07ca2
Showing 1 changed file with 34 additions and 10 deletions.
44 changes: 34 additions & 10 deletions gflownet/gflownet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1266,16 +1266,17 @@ def train(self):
if self.use_context is False:
self.logger.end()

def get_sample_space_and_reward(self):
def get_sample_space(self):
"""
Returns samples representative of the env state space with their rewards
Obtains and returns samples representative of the env state space, in
environment format.
This method sets self.sample_space_batch.
Returns
-------
sample_space_batch : tensor
Repressentative terminating states for the environment
rewards_sample_space : tensor
Rewards associated with the tates in sample_space_batch
sample_space_batch : list, tensor, array
Representative terminating states (in environment format) for the environment.
"""
if not hasattr(self, "sample_space_batch"):
if hasattr(self.env, "get_all_terminating_states"):
Expand All @@ -1290,11 +1291,34 @@ def get_sample_space_and_reward(self):
"environment must implement either get_all_terminating_states() "
"or get_grid_terminating_states()"
)
self.sample_space_batch = self.env.states2proxy(self.sample_space_batch)
if not hasattr(self, "rewards_sample_space"):
self.rewards_sample_space = self.proxy.rewards(self.sample_space_batch)
return self.sample_space_batch

def get_sample_space_and_reward(self, return_states_proxy: bool = False):
"""
Returns samples representative of the env state space with their rewards.
Parameters
----------
return_states_proxy : bool
If True, returns the states in proxy format.
return self.sample_space_batch, self.rewards_sample_space
Returns
-------
sample_space_batch : list, tensor, array
Representative terminating states for the environment. If
return_states_proxy, the format returned will be the proxy format.
Otherwise, states will be returned in environment fomat.
rewards_sample_space : tensor
Rewards associated with the tates in sample_space_batch
"""
if return_states_proxy or not hasattr(self, "rewards_sample_space"):
sample_space_proxy = self.env.states2proxy(self.get_sample_space())
if not hasattr(self, "rewards_sample_space"):
self.rewards_sample_space = self.proxy.rewards(sample_space_proxy)
if return_states_proxy:
return sample_space_proxy, self.rewards_sample_space
else:
return self.sample_space_batch, self.rewards_sample_space

# TODO: implement other proposal distributions
# TODO: rethink whether it is needed to convert to reward
Expand Down

0 comments on commit 2a07ca2

Please sign in to comment.