Skip to content

Commit

Permalink
added cnn policy, and flatten flag should be set to false when using …
Browse files Browse the repository at this point in the history
…cnn policy
  • Loading branch information
engmubarak48 committed Jun 28, 2024
1 parent 9474150 commit 83a4904
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 0 deletions.
2 changes: 2 additions & 0 deletions config/env/tetris.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ height: 20
pieces: ["I", "J", "L", "O", "S", "T", "Z"]
# Allowed roations
rotations: [0, 90, 180, 270]
# Don't flatten if using CNN
flatten: True
# Other config
allow_redundant_rotations: False
allow_eos_before_full: False
Expand Down
16 changes: 16 additions & 0 deletions config/policy/cnn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
_target_: gflownet.policy.cnn.CNNPolicy

shared: null

forward:
n_layers: 1
channels: [16]
kernel_sizes: [[3, 3], [2, 2], [1, 1]] # Each tuple represents (height, width)
strides: [[2, 2], [2, 1], [1, 1]] # Each tuple represents (vertical_stride, horizontal_stride)
checkpoint: null
reload_ckpt: False

backward:
shared_weights: True
checkpoint: null
reload_ckpt: False
91 changes: 91 additions & 0 deletions gflownet/policy/cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import torch
from omegaconf import OmegaConf
from torch import nn

from gflownet.policy.base import Policy


class CNNPolicy(Policy):
def __init__(self, config, env, device, float_precision, base=None):
self.env = env
super().__init__(
config=config,
env=env,
device=device,
float_precision=float_precision,
base=base,
)
self.is_model = True

def make_cnn(self):
"""
Defines an CNN with no top layer activation
"""
if self.shared_weights and self.base is not None:
layers = list(self.base.model.children())[:-1]
last_layer = nn.Linear(
self.base.model[-1].in_features, self.base.model[-1].out_features
)

model = nn.Sequential(*layers, last_layer).to(self.device)
return model

current_channels = 1
conv_module = nn.Sequential()

if len(self.kernel_sizes) != self.n_layers:
raise ValueError(
f"Inconsistent dimensions kernel_sizes != n_layers, {len(self.kernel_sizes)} != {self.n_layers}"
)

for i in range(self.n_layers):
conv_module.add_module(
f"conv_{i}",
nn.Conv2d(
in_channels=current_channels,
out_channels=self.channels[i],
kernel_size=tuple(self.kernel_sizes[i]),
stride=tuple(self.strides[i]),
padding=0,
padding_mode="zeros", # Constant zero padding
),
)
conv_module.add_module(f"relu_{i}", nn.ReLU())
current_channels = self.channels[i]

dummy_input = torch.ones(
(1, 1, self.env.height, self.env.width)
) # (batch_size, channels, height, width)
try:
in_channels = conv_module(dummy_input).numel()
if in_channels >= 500_000: # TODO: this could better be handled
raise RuntimeWarning(
"Input channels for the dense layer are too big, this will increase number of parameters"
)
except RuntimeError as e:
raise RuntimeError(
"Failed during convolution operation. Ensure that the kernel sizes and strides are appropriate for the input dimensions."
) from e

model = nn.Sequential(
conv_module, nn.Flatten(), nn.Linear(in_channels, self.output_dim)
)
return model.to(self.device)

def parse_config(self, config):
if config is None:
config = OmegaConf.create()
self.checkpoint = config.get("checkpoint", None)
self.shared_weights = config.get("shared_weights", False)
self.reload_ckpt = config.get("reload_ckpt", False)
self.n_layers = config.get("n_layers", 3)
self.channels = config.get("channels", [16] * self.n_layers)
self.kernel_sizes = config.get("kernel_sizes", [(3, 3)] * self.n_layers)
self.strides = config.get("strides", [(1, 1)] * self.n_layers)

def instantiate(self):
self.model = self.make_cnn()

def __call__(self, states):
states = states.unsqueeze(1) # (batch_size, channels, height, width)
return self.model(states)

0 comments on commit 83a4904

Please sign in to comment.