Skip to content

Commit

Permalink
Implement gray scale observation
Browse files Browse the repository at this point in the history
  • Loading branch information
msosav committed Aug 25, 2024
1 parent 6e70c74 commit e673b09
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 57 deletions.
20 changes: 12 additions & 8 deletions config/gym.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import gymnasium as gym
from gymnasium import spaces
import gym
from gym.spaces import Box, Discrete
import numpy as np
from pyboy import PyBoy
from pyboy.utils import WindowEvent
from .memory_addresses import *


Expand All @@ -27,9 +26,10 @@ def __init__(self, config: dict, debug=False):
self.valid_actions = ['', 'a', 'b', 'left', 'right',
'up', 'down', 'start', 'select']

self.observation_space = spaces.Box(
low=0, high=255, shape=(16, 20), dtype=np.uint8)
self.action_space = spaces.Discrete(len(self.valid_actions))
self.observation_space = Box(
low=0, high=255, shape=(144, 160, 3), dtype=np.uint8)

self.action_space = Discrete(len(self.valid_actions))

self.items = {
'01': False, # Sword
Expand Down Expand Up @@ -63,7 +63,8 @@ def step(self, action):
self._calculate_fitness()
reward = self._fitness-self._previous_fitness

observation = self.pyboy.game_area()
observation = self.pyboy.screen.ndarray

info = {}
truncated = False

Expand All @@ -81,6 +82,8 @@ def _calculate_fitness(self):

self._fitness += self._check_new_items()

# TODO: Sword and shield level

def reset(self, **kwargs):
try:
with open(self.state_path, 'rb') as state_file:
Expand All @@ -91,7 +94,8 @@ def reset(self, **kwargs):
self._fitness = 0
self._previous_fitness = 0

observation = self.pyboy.game_area()
observation = self.pyboy.screen.ndarray

info = {}
return observation, info

Expand Down
14 changes: 14 additions & 0 deletions config/preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from gym.wrappers import gray_scale_observation
from stable_baselines3.common.vec_env import VecFrameStack, DummyVecEnv
from matplotlib import pyplot as plt
from config.gym import ZeldaGymEnv
from gym.spaces import Box


def Preprocess(config: dict) -> DummyVecEnv:
env = ZeldaGymEnv(config, debug=True)
env = gray_scale_observation.GrayScaleObservation(env, keep_dim=True)
# env = DummyVecEnv([lambda: env])
# env = VecFrameStack(env, n_stack=4, channel_order='last')

return env
6 changes: 3 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from config.gym import ZeldaGymEnv

from config.preprocess import Preprocess

if __name__ == "__main__":
config = {
'rom_path': 'roms/ZeldaLinksAwakening.gb',
'state_path': 'roms/ZeldaLinksAwakening.gb.state'
}

env = ZeldaGymEnv(config, debug=True)
env = Preprocess(config)

done = True
for step in range(100000):
if done:
env.reset()
observation, reward, done, truncated, info = env.step(
env.action_space.sample())
print(observation.shape)
env.render()
env.close()
80 changes: 34 additions & 46 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,72 +1,60 @@
asttokens==2.4.1
black==24.4.2
cfgv==3.4.0
absl-py==2.1.0
ale-py==0.8.1
AutoROM==0.6.1
AutoROM.accept-rom-license==0.6.1
certifi==2024.7.4
charset-normalizer==3.3.2
click==8.1.7
cloudpickle==3.0.0
comm==0.2.2
contourpy==1.2.1
cycler==0.12.1
debugpy==1.8.2
decorator==5.1.1
distlib==0.3.8
executing==2.0.1
Farama-Notifications==0.0.4
filelock==3.15.4
filelock==3.13.1
fonttools==4.53.1
fsspec==2024.2.0
grpcio==1.66.0
gym==0.26.2
gym-notices==0.0.8
gymnasium==0.29.1
identify==2.5.36
ipykernel==6.29.5
ipython==8.26.0
isort==5.13.2
jedi==0.19.1
idna==3.8
importlib_resources==6.4.4
Jinja2==3.1.3
jupyter_client==8.6.2
jupyter_core==5.7.2
kiwisolver==1.4.5
Markdown==3.7
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib-inline==0.1.7
matplotlib==3.9.2
mdurl==0.1.2
mpmath==1.3.0
mypy-extensions==1.0.0
nes_py==8.2.1
nest-asyncio==1.6.0
networkx==3.2.1
nodeenv==1.9.1
numpy==1.26.4
numpy==1.26.3
opencv-python==4.10.0.84
packaging==24.1
parso==0.8.4
pathspec==0.12.1
pexpect==4.9.0
pillow==10.4.0
platformdirs==4.2.2
pre-commit==3.7.1
prompt_toolkit==3.0.47
pandas==2.2.2
pillow==10.2.0
protobuf==5.27.3
psutil==6.0.0
ptyprocess==0.7.0
pure_eval==0.2.3
pyboy==2.2.0
pyglet==1.5.21
pyboy==2.2.2
pygame==2.6.0
Pygments==2.18.0
pyparsing==3.1.2
pyparsing==3.1.4
PySDL2==0.9.16
pysdl2-dll==2.30.2
python-dateutil==2.9.0.post0
PyYAML==6.0.1
pyzmq==26.0.3
ruff==0.5.6
setuptools==72.1.0
pytz==2024.1
requests==2.32.3
rich==13.7.1
Shimmy==1.3.0
six==1.16.0
stack-data==0.6.3
stable_baselines3==2.3.2
sympy==1.12
tensorboard==2.17.1
tensorboard-data-server==0.7.2
torch==2.4.0+cpu
torchaudio==2.4.0+cpu
torchvision==0.19.0+cpu
tornado==6.4.1
tqdm==4.66.4
traitlets==5.14.3
typing_extensions==4.12.2
virtualenv==20.26.3
wcwidth==0.2.13
wheel==0.43.0
tqdm==4.66.5
typing_extensions==4.9.0
tzdata==2024.1
urllib3==2.2.2
Werkzeug==3.0.4

0 comments on commit e673b09

Please sign in to comment.