Skip to content

Commit

Permalink
Implement preprocessing of the environment
Browse files Browse the repository at this point in the history
  • Loading branch information
msosav committed Aug 25, 2024
1 parent e673b09 commit 2b4a63d
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 12 deletions.
4 changes: 2 additions & 2 deletions config/gym.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import gym
from gym.spaces import Box, Discrete
import gymnasium as gym
from gymnasium.spaces import Box, Discrete
import numpy as np
from pyboy import PyBoy
from .memory_addresses import *
Expand Down
18 changes: 14 additions & 4 deletions config/preprocess.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
from gym.wrappers import gray_scale_observation
from gymnasium.wrappers import gray_scale_observation
from stable_baselines3.common.vec_env import VecFrameStack, DummyVecEnv
from matplotlib import pyplot as plt
from config.gym import ZeldaGymEnv
from gym.spaces import Box


def Preprocess(config: dict) -> DummyVecEnv:
def PreprocessEnv(config: dict) -> VecFrameStack:
"""
Preprocesses the environment for reinforcement learning.
Args:
config (dict): Configuration parameters for the environment.
Returns:
VecFrameStack: Preprocessed environment with stacked frames.
"""

env = ZeldaGymEnv(config, debug=True)
env = gray_scale_observation.GrayScaleObservation(env, keep_dim=True)
# env = DummyVecEnv([lambda: env])
# env = VecFrameStack(env, n_stack=4, channel_order='last')
env = DummyVecEnv([lambda: env])
env = VecFrameStack(env, n_stack=4)

return env
10 changes: 4 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
from config.preprocess import Preprocess
from config.preprocess import PreprocessEnv

if __name__ == "__main__":
config = {
'rom_path': 'roms/ZeldaLinksAwakening.gb',
'state_path': 'roms/ZeldaLinksAwakening.gb.state'
}

env = Preprocess(config)

env = PreprocessEnv(config)
done = True
for step in range(100000):
if done:
env.reset()
observation, reward, done, truncated, info = env.step(
env.action_space.sample())
print(observation.shape)
observation, reward, done, info = env.step(
[env.action_space.sample()])
env.render()
env.close()

0 comments on commit 2b4a63d

Please sign in to comment.