From 2bf438a2e76e3b5ba6f57047e05908fe380f83a8 Mon Sep 17 00:00:00 2001 From: Jama Hussein Mohamud Date: Mon, 23 Sep 2024 14:56:20 -0400 Subject: [PATCH] init the cnn env's height and width in the policy --- gflownet/policy/base.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/gflownet/policy/base.py b/gflownet/policy/base.py index 5d26e0c1..053eee51 100644 --- a/gflownet/policy/base.py +++ b/gflownet/policy/base.py @@ -53,6 +53,11 @@ def __init__( self.type = config.get("type", "uniform") # Checkpoint, defaults to None self.checkpoint = config.get("checkpoint", None) + # TODO: This could be done better? We could store this only when using CNN policy. e.g. self.type could be "cnn" + if hasattr(env, 'height'): + self.height = env.height + if hasattr(env, 'width'): + self.width = env.width # Instantiate the model self.model, self.is_model = self.make_model()