Skip to content

Commit

Permalink
remove env from the cnn policy
Browse files Browse the repository at this point in the history
  • Loading branch information
engmubarak48 committed Sep 23, 2024
1 parent c9ec03f commit 9fa9381
Showing 1 changed file with 1 addition and 4 deletions.
5 changes: 1 addition & 4 deletions gflownet/policy/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@ def __init__(self, **kwargs):
self.channels = config.get("channels", [16] * self.n_layers)
self.kernel_sizes = config.get("kernel_sizes", [(3, 3)] * self.n_layers)
self.strides = config.get("strides", [(1, 1)] * self.n_layers)
# Environment
# TODO: rethink whether storing the whole environment is needed
self.env = env
# Base init
super().__init__(**kwargs)

Expand Down Expand Up @@ -68,7 +65,7 @@ def make_model(self):
current_channels = self.channels[i]

dummy_input = torch.ones(
(1, 1, self.env.height, self.env.width)
(1, 1, self.height, self.width)
) # (batch_size, channels, height, width)
try:
in_channels = conv_module(dummy_input).numel()
Expand Down

0 comments on commit 9fa9381

Please sign in to comment.