Skip to content

Commit

Permalink
init the cnn env's height and width in the policy
Browse files Browse the repository at this point in the history
  • Loading branch information
engmubarak48 committed Sep 23, 2024
1 parent 9fa9381 commit 2bf438a
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions gflownet/policy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ def __init__(
self.type = config.get("type", "uniform")
# Checkpoint, defaults to None
self.checkpoint = config.get("checkpoint", None)
# TODO: This could be done better? We could store this only when using CNN policy. e.g. self.type could be "cnn"
if hasattr(env, 'height'):
self.height = env.height
if hasattr(env, 'width'):
self.width = env.width
# Instantiate the model
self.model, self.is_model = self.make_model()

Expand Down

0 comments on commit 2bf438a

Please sign in to comment.