From 2b4a63d231d46bf9886f05e2df691f1feba26ba1 Mon Sep 17 00:00:00 2001 From: Miguel Sosa <85181687+msosav@users.noreply.github.com.> Date: Sun, 25 Aug 2024 13:35:25 -0500 Subject: [PATCH] Implement preprocessing of the environment --- config/gym.py | 4 ++-- config/preprocess.py | 18 ++++++++++++++---- main.py | 10 ++++------ 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/config/gym.py b/config/gym.py index a463990..e8ded5a 100644 --- a/config/gym.py +++ b/config/gym.py @@ -1,5 +1,5 @@ -import gym -from gym.spaces import Box, Discrete +import gymnasium as gym +from gymnasium.spaces import Box, Discrete import numpy as np from pyboy import PyBoy from .memory_addresses import * diff --git a/config/preprocess.py b/config/preprocess.py index 910c67c..5a7164e 100644 --- a/config/preprocess.py +++ b/config/preprocess.py @@ -1,14 +1,24 @@ -from gym.wrappers import gray_scale_observation +from gymnasium.wrappers import gray_scale_observation from stable_baselines3.common.vec_env import VecFrameStack, DummyVecEnv from matplotlib import pyplot as plt from config.gym import ZeldaGymEnv from gym.spaces import Box -def Preprocess(config: dict) -> DummyVecEnv: +def PreprocessEnv(config: dict) -> VecFrameStack: + """ + Preprocesses the environment for reinforcement learning. + + Args: + config (dict): Configuration parameters for the environment. + + Returns: + VecFrameStack: Preprocessed environment with stacked frames. + """ + env = ZeldaGymEnv(config, debug=True) env = gray_scale_observation.GrayScaleObservation(env, keep_dim=True) - # env = DummyVecEnv([lambda: env]) - # env = VecFrameStack(env, n_stack=4, channel_order='last') + env = DummyVecEnv([lambda: env]) + env = VecFrameStack(env, n_stack=4) return env diff --git a/main.py b/main.py index 27dc4eb..dd0d3b8 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,4 @@ -from config.preprocess import Preprocess +from config.preprocess import PreprocessEnv if __name__ == "__main__": config = { @@ -6,14 +6,12 @@ 'state_path': 'roms/ZeldaLinksAwakening.gb.state' } - env = Preprocess(config) - + env = PreprocessEnv(config) done = True for step in range(100000): if done: env.reset() - observation, reward, done, truncated, info = env.step( - env.action_space.sample()) - print(observation.shape) + observation, reward, done, info = env.step( + [env.action_space.sample()]) env.render() env.close()