Skip to content

Commit

Permalink
Implement preprocessing of the environment
Browse files Browse the repository at this point in the history
  • Loading branch information
msosav authored Aug 25, 2024
2 parents db12646 + 2b4a63d commit 927a368
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 30 deletions.
47 changes: 39 additions & 8 deletions config/gym.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import gymnasium as gym
from gymnasium import spaces
from gymnasium.spaces import Box, Discrete
import numpy as np
from pyboy import PyBoy
from pyboy.utils import WindowEvent
from .memory_addresses import *


Expand All @@ -27,9 +26,26 @@ def __init__(self, config: dict, debug=False):
self.valid_actions = ['', 'a', 'b', 'left', 'right',
'up', 'down', 'start', 'select']

self.observation_space = spaces.Box(
low=0, high=255, shape=(16, 20), dtype=np.uint8)
self.action_space = spaces.Discrete(len(self.valid_actions))
self.observation_space = Box(
low=0, high=255, shape=(144, 160, 3), dtype=np.uint8)

self.action_space = Discrete(len(self.valid_actions))

self.items = {
'01': False, # Sword
'02': False, # Bombs
'03': False, # Power bracelet
'04': False, # Shield
'05': False, # Bow
'06': False, # Hookshot
'07': False, # Fire rod
'08': False, # Pegasus boots
'09': False, # Ocarina
'0A': False, # Feather
'0B': False, # Shovel
'0C': False, # Magic powder
'0D': False # Boomrang
}

def step(self, action):
assert self.action_space.contains(
Expand All @@ -47,7 +63,8 @@ def step(self, action):
self._calculate_fitness()
reward = self._fitness-self._previous_fitness

observation = self.pyboy.game_area()
observation = self.pyboy.screen.ndarray

info = {}
truncated = False

Expand All @@ -61,9 +78,12 @@ def __game_over(self):
def _calculate_fitness(self):
self._previous_fitness = self._fitness

# TODO: Implement reward logic
self._fitness = 0

self._fitness += self._check_new_items()

# TODO: Sword and shield level

def reset(self, **kwargs):
try:
with open(self.state_path, 'rb') as state_file:
Expand All @@ -74,7 +94,8 @@ def reset(self, **kwargs):
self._fitness = 0
self._previous_fitness = 0

observation = self.pyboy.game_area()
observation = self.pyboy.screen.ndarray

info = {}
return observation, info

Expand All @@ -83,3 +104,13 @@ def render(self, mode='human'):

def close(self):
self.pyboy.stop()

def _check_new_items(self):
items_in_inventory_count = 0
for inventory_address in ADDR_INVENTORY:
item_in_inventory = self.pyboy.memory[inventory_address]
if item_in_inventory in self.items:
self.items[item_in_inventory] = True
items_in_inventory_count += 1

return items_in_inventory_count
15 changes: 15 additions & 0 deletions config/memory_addresses.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,18 @@
# 20: Owl talked, 80: Visited
# For example, visiting the first dungeon's screen (80) and opening it with the key (10) would put that byte at 90
ADDR_WORLD_MAP_STATUS = [i for i in range(0xD800, 0xD900)]

"""01 Sword
02 Bombs
03 Power bracelet
04 Shield
05 Bow
06 Hookshot
07 Fire rod
08 Pegasus boots
09 Ocarina
0A Feather
0B Shovel
0C Magic powder
0D Boomrang """
ADDR_INVENTORY = [i for i in range(0xDB00, 0xDB0B)] # Also contains held items
24 changes: 24 additions & 0 deletions config/preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from gymnasium.wrappers import gray_scale_observation
from stable_baselines3.common.vec_env import VecFrameStack, DummyVecEnv
from matplotlib import pyplot as plt
from config.gym import ZeldaGymEnv
from gym.spaces import Box


def PreprocessEnv(config: dict) -> VecFrameStack:
"""
Preprocesses the environment for reinforcement learning.
Args:
config (dict): Configuration parameters for the environment.
Returns:
VecFrameStack: Preprocessed environment with stacked frames.
"""

env = ZeldaGymEnv(config, debug=True)
env = gray_scale_observation.GrayScaleObservation(env, keep_dim=True)
env = DummyVecEnv([lambda: env])
env = VecFrameStack(env, n_stack=4)

return env
10 changes: 4 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
from config.gym import ZeldaGymEnv

from config.preprocess import PreprocessEnv

if __name__ == "__main__":
config = {
'rom_path': 'roms/ZeldaLinksAwakening.gb',
'state_path': 'roms/ZeldaLinksAwakening.gb.state'
}

env = ZeldaGymEnv(config, debug=True)

env = PreprocessEnv(config)
done = True
for step in range(100000):
if done:
env.reset()
observation, reward, done, truncated, info = env.step(
env.action_space.sample())
observation, reward, done, info = env.step(
[env.action_space.sample()])
env.render()
env.close()
72 changes: 56 additions & 16 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,20 +1,60 @@
black==24.4.2
cfgv==3.4.0
absl-py==2.1.0
ale-py==0.8.1
AutoROM==0.6.1
AutoROM.accept-rom-license==0.6.1
certifi==2024.7.4
charset-normalizer==3.3.2
click==8.1.7
distlib==0.3.8
filelock==3.15.4
identify==2.5.36
isort==5.13.2
mypy-extensions==1.0.0
nodeenv==1.9.1
numpy==1.26.4
cloudpickle==3.0.0
contourpy==1.2.1
cycler==0.12.1
Farama-Notifications==0.0.4
filelock==3.13.1
fonttools==4.53.1
fsspec==2024.2.0
grpcio==1.66.0
gym==0.26.2
gym-notices==0.0.8
gymnasium==0.29.1
idna==3.8
importlib_resources==6.4.4
Jinja2==3.1.3
kiwisolver==1.4.5
Markdown==3.7
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.9.2
mdurl==0.1.2
mpmath==1.3.0
networkx==3.2.1
numpy==1.26.3
opencv-python==4.10.0.84
packaging==24.1
pathspec==0.12.1
pillow==10.4.0
platformdirs==4.2.2
pre-commit==3.7.1
pyboy==2.2.0
pandas==2.2.2
pillow==10.2.0
protobuf==5.27.3
psutil==6.0.0
pyboy==2.2.2
pygame==2.6.0
Pygments==2.18.0
pyparsing==3.1.4
PySDL2==0.9.16
pysdl2-dll==2.30.2
PyYAML==6.0.1
virtualenv==20.26.3
python-dateutil==2.9.0.post0
pytz==2024.1
requests==2.32.3
rich==13.7.1
Shimmy==1.3.0
six==1.16.0
stable_baselines3==2.3.2
sympy==1.12
tensorboard==2.17.1
tensorboard-data-server==0.7.2
torch==2.4.0+cpu
torchaudio==2.4.0+cpu
torchvision==0.19.0+cpu
tqdm==4.66.5
typing_extensions==4.9.0
tzdata==2024.1
urllib3==2.2.2
Werkzeug==3.0.4

0 comments on commit 927a368

Please sign in to comment.