Skip to content

Commit

Permalink
Merge pull request #10 from msosav/dev
Browse files Browse the repository at this point in the history
Model and training upgrades
  • Loading branch information
msosav authored Jan 11, 2025
2 parents c783dc0 + ee4f723 commit 98cd7dc
Show file tree
Hide file tree
Showing 6 changed files with 330 additions and 72 deletions.
39 changes: 39 additions & 0 deletions .github/workflows/formatting.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
name: Formatter

on:
pull_request:
branches:
- main

jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v2
with:
ref: ${{ github.head_ref }}

- name: Set up Python 3.11
uses: actions/setup-python@v2
with:
python-version: 3.11

- name: Install dependencies
run: |
pip install -r requirements.txt
- name: Black formatter
run: |
black .
- name: Isort formatter
run: |
isort .
- name: Commit changes
uses: stefanzweifel/git-auto-commit-action@v4
with:
commit_message: Fix formatting
skip_fetch: true
branch: ${{ github.head_ref }}
151 changes: 123 additions & 28 deletions config/gym.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import gymnasium as gym
import numpy as np
from gymnasium.spaces import Box, Discrete
from gymnasium.spaces import Box, Dict, Discrete
from pyboy import PyBoy
from pyboy.utils import WindowEvent

Expand All @@ -15,33 +15,55 @@ def __init__(self, config: dict, debug=False):

assert self.rom_path is not None, "ROM path is required"

self.pyboy = PyBoy(self.rom_path)
self.pyboy = PyBoy(self.rom_path, sound=config["game_with_sound"])

self._fitness = 0
self._previous_fitness = 0
self.debug = debug

self.action_freq = config["action_freq"]

if not self.debug:
self.pyboy.set_emulation_speed(0)

self.valid_actions = [
"",
"a",
"b",
"left",
"right",
"up",
"down",
"start",
"select",
WindowEvent.PRESS_ARROW_DOWN,
WindowEvent.PRESS_ARROW_LEFT,
WindowEvent.PRESS_ARROW_RIGHT,
WindowEvent.PRESS_ARROW_UP,
WindowEvent.PRESS_BUTTON_A,
WindowEvent.PRESS_BUTTON_B,
WindowEvent.PRESS_BUTTON_START,
]

self.release_actions = [
WindowEvent.RELEASE_ARROW_DOWN,
WindowEvent.RELEASE_ARROW_LEFT,
WindowEvent.RELEASE_ARROW_RIGHT,
WindowEvent.RELEASE_ARROW_UP,
WindowEvent.RELEASE_BUTTON_A,
WindowEvent.RELEASE_BUTTON_B,
WindowEvent.RELEASE_BUTTON_START,
]

self.observation_space = Box(
low=0, high=255, shape=(144, 160, 3), dtype=np.uint8
self.observation_space = Dict(
{
"screen": Box(low=0, high=255, shape=(144, 160, 3), dtype=np.uint8),
"current_room_layout": Box(
low=0, high=255, shape=(156,), dtype=np.uint8
),
"items_in_hand": Box(low=0, high=255, shape=(2,), dtype=np.uint8),
"items_in_inventory": Box(low=0, high=255, shape=(9,)),
"health": Box(low=0, high=16, shape=(1,), dtype=np.uint8),
"rupees": Box(low=0, high=999, shape=(1,), dtype=np.uint8),
}
)

self.action_space = Discrete(len(self.valid_actions))

self.exploration_reward = config["exploration_reward"]
self.reward_scale = config["reward_scale"]

self.items = {
"01": False, # Sword
"02": False, # Bombs
Expand All @@ -59,30 +81,31 @@ def __init__(self, config: dict, debug=False):
}

def step(self, action):
assert self.action_space.contains(action), "%r (%s) invalid" % (
action,
type(action),
)

if action == 0:
pass
else:
self.pyboy.button(self.valid_actions[action])

self.pyboy.tick()
self.run_action(action)

done = self.__game_over()

self._calculate_fitness()
reward = self._fitness - self._previous_fitness

observation = self.pyboy.screen.ndarray
observation = self._get_observation()

info = {}
truncated = False

return observation, reward, done, truncated, info

def run_action(self, action):
self.pyboy.send_input(self.valid_actions[action])
press_step = 8

self.action_freq = 24

self.pyboy.tick(press_step)
self.pyboy.send_input(self.release_actions[action])
self.pyboy.tick(self.action_freq - press_step - 1)
self.pyboy.tick(1, True)

def __game_over(self):
if self.pyboy.memory[ADDR_CURRENT_HEALTH] == 0:
return True
Expand All @@ -93,10 +116,24 @@ def _calculate_fitness(self):

self._fitness = 0

self._fitness += self._check_new_items()
self._fitness += self._check_new_items() * self.reward_scale
self._fitness += (
self._check_new_locations() * self.reward_scale * self.exploration_reward
)

if self.moving_things_in_inventory:
self._fitness -= 0.1 * self.reward_scale

# TODO: Sword and shield level

def _check_new_locations(self):
explored_locations = 0
for addr in ADDR_WORLD_MAP_STATUS:
if self.pyboy.memory[addr] == 0x80:
explored_locations += 1

return explored_locations

def start_sequence(self):
self.pyboy.button("start")
self.pyboy.tick()
Expand All @@ -117,7 +154,7 @@ def reset(self, **kwargs):
self._fitness = 0
self._previous_fitness = 0

observation = self.pyboy.screen.ndarray
observation = self._get_observation()

info = {}
return observation, info
Expand All @@ -136,4 +173,62 @@ def _check_new_items(self):
self.items[item_in_inventory] = True
items_in_inventory_count += 1

return items_in_inventory_count
items_in_hand_count = 0
for held_address in ADDR_HELD_ITEMS:
item_in_hand = self.pyboy.memory[held_address]

if item_in_hand in self.items:
self.items[item_in_hand] = True
items_in_hand_count += 1

if (
items_in_hand_count < 2
and items_in_inventory_count >= items_in_hand_count
and items_in_inventory_count != 0
):
self.moving_things_in_inventory = True
else:
self.moving_things_in_inventory = False

return items_in_inventory_count + items_in_hand_count

def _check_rupees(self):
rupees = 0
for addr in ADDR_RUPEES:
rupees += self.pyboy.memory[addr]

return rupees

def _get_observation(self):
# Image observation
screen = self._get_screen()

room_type = self.pyboy.memory[ADDR_DESTINATION_BYTE_1]

room_number = self.pyboy.memory[ADDR_DESTINATION_BYTE_3]

current_room_layout = [
self.pyboy.memory[addr] for addr in ADDR_CURRENTLY_LOADED_MAP
]

health = [self.pyboy.memory[ADDR_CURRENT_HEALTH] / 8]

rupees = [self._check_rupees()]

items_in_inventory = [self.pyboy.memory[addr] for addr in ADDR_INVENTORY]

items_in_hand = [self.pyboy.memory[addr] for addr in ADDR_HELD_ITEMS]

obs = {
"screen": screen,
"current_room_layout": current_room_layout,
"items_in_hand": items_in_hand,
"items_in_inventory": items_in_inventory,
"health": health,
"rupees": rupees,
}

return obs

def _get_screen(self):
return self.pyboy.screen.ndarray
89 changes: 70 additions & 19 deletions config/memory_addresses.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,76 @@
# Addresses from https://datacrystal.romhacking.net/wiki/The_Legend_of_Zelda:_Link%27s_Awakening_(Game_Boy)/RAM_map

# Memory Addresses and Their Purpose

# Destination data
ADDR_DESTINATION_BYTE_1 = 0xD401 # 00: Overworld, 01: Dungeon, 02: Side view area
# Values from 00 to 1F accepted. FF is Color Dungeon
ADDR_DESTINATION_BYTE_2 = 0xD402
# Room number. Must appear on map or it will lead to an empty room
ADDR_DESTINATION_BYTE_3 = 0xD403
ADDR_DESTINATION_COORD_X = 0xD404 # Destination X coordinate
ADDR_DESTINATION_COORD_Y = 0xD405 # Destination Y coordinate

# Map Data
ADDR_CURRENTLY_LOADED_MAP = [i for i in range(0xD700, 0xD79C)]
ADDR_WORLD_MAP_STATUS = [i for i in range(0xD800, 0xD900)]
# 00: Unexplored, 10: Changed from initial status (e.g., sword taken on the beach or dungeon opened with key)
# 20: Owl talked, 80: Visited
# Example: Visiting the first dungeon's screen (80) and opening it with the key (10) would put that byte at 90

# Inventory and Items
ADDR_HELD_ITEMS = [0xDB00, 0xDB01] # Your currently held items
# 01: Sword, 02: Bombs, 03: Power Bracelet
# 04: Shield, 05: Bow, 06: Hookshot, 07: Fire Rod
# 08: Pegasus Boots, 09: Ocarina, 0A: Feather
# 0B: Shovel, 0C: Magic Powder, 0D: Boomerang
ADDR_INVENTORY = [i for i in range(0xDB02, 0xDB0B)]
ADDR_FLIPPERS = 0xDB0C # 01 = Have
ADDR_POTION = 0xDB0D # 01 = Have
# Current item in trading game (01 = Yoshi, 0E = Magnifier)
ADDR_TRADING_GAME_ITEM = 0xDB0E
ADDR_SECRET_SHELLS = 0xDB0F # Number of secret shells
ADDR_DUNGEON_ENTRANCE_KEYS = [i for i in range(0xDB10, 0xDB15)] # 01 = Have
ADDR_GOLDEN_LEAVES = 0xDB15 # Number of golden leaves

# Dungeon Item Flags
ADDR_DUNGEON_ITEM_FLAGS = [
[i for i in range(0xDB16 + 5 * d, 0xDB1B + 5 * d)] for d in range(10)
] # 5 bytes per dungeon, 5th byte = quantity of keys

# Equipment
ADDR_POWER_BRACELET_LEVEL = 0xDB43
ADDR_SHIELD_LEVEL = 0xDB44
ADDR_SWORD_LEVEL = 0xDB4E

# Quantities
ADDR_ARROWS = 0xDB45 # Number of arrows
ADDR_BOMBS = 0xDB4D # Number of bombs
ADDR_MAGIC_POWDER = 0xDB4C # Magic powder quantity
ADDR_MAX_MAGIC_POWDER = 0xDB76
ADDR_MAX_BOMBS = 0xDB77
ADDR_MAX_ARROWS = 0xDB78

# Ocarina
ADDR_OCARINA_SONGS = 0xDB49 # 3-bit mask: 0 = No songs, 7 = All songs
ADDR_OCARINA_SELECTED_SONG = 0xDB4A

# Health
# Each increment of 08h = one full heart, 04h = one-half heart
ADDR_CURRENT_HEALTH = 0xDB5A
# Number of hearts in hex (max recommended: 0Eh = 14 hearts)
ADDR_MAX_HEALTH = 0xDB5B
ADDR_POSITION_8X8 = 0xDBAE

# 00: Unexplored, 10: Changed from initial status (for example sword taken on the beach or dungeon opened with key)
# 20: Owl talked, 80: Visited
# For example, visiting the first dungeon's screen (80) and opening it with the key (10) would put that byte at 90
ADDR_WORLD_MAP_STATUS = [i for i in range(0xD800, 0xD900)]
# Rupees
ADDR_RUPEES = [0xDB5D, 0xDB5E] # Number of rupees (e.g., 0999 for 999 rupees)

# Instruments
# 00 = No instrument, 03 = Have instrument
ADDR_DUNGEON_INSTRUMENTS = [i for i in range(0xDB65, 0xDB6D)]

# Dungeon Position
ADDR_DUNGEON_POSITION = 0xDBAE # Position on the 8x8 dungeon grid
ADDR_DUNGEON_KEYS = 0xDBD0 # Quantity of keys in possession

"""01 Sword
02 Bombs
03 Power bracelet
04 Shield
05 Bow
06 Hookshot
07 Fire rod
08 Pegasus boots
09 Ocarina
0A Feather
0B Shovel
0C Magic powder
0D Boomrang """
ADDR_INVENTORY = [i for i in range(0xDB00, 0xDB0B)] # Also contains held items
# Save Slot Death Count
ADDR_DEATH_COUNT = [0xDB56 + i for i in range(3)] # One byte per save slot
15 changes: 11 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@
"rom_path": "roms/ZeldaLinksAwakening.gb",
"checkpoint_dir": "checkpoints/",
"log_dir": "logs/",
"action_freq": 24,
"exploration_reward": 0.25,
"reward_scale": 1,
"game_with_sound": True,
}

callback = CheckpointAndLoggingCallback(
check_freq=1000, save_path=config["checkpoint_dir"]
check_freq=5000, save_path=config["checkpoint_dir"]
)

env = PreprocessEnv(config)
Expand All @@ -21,12 +25,15 @@

if mode == "train":
model = PPO(
"CnnPolicy",
"MultiInputPolicy",
env,
verbose=1,
n_steps=2048,
batch_size=512,
n_epochs=1,
gamma=0.997,
ent_coef=0.01,
tensorboard_log=config["log_dir"],
learning_rate=0.000001,
n_steps=512,
)

model.learn(total_timesteps=1000000, callback=callback)
Expand Down
Loading

0 comments on commit 98cd7dc

Please sign in to comment.