Skip to content

Commit

Permalink
Fix formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
msosav authored and github-actions[bot] committed Jan 11, 2025
1 parent ae1308b commit ee4f723
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 35 deletions.
49 changes: 28 additions & 21 deletions config/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,18 @@ def __init__(self, config: dict, debug=False):
WindowEvent.RELEASE_BUTTON_START,
]

self.observation_space = Dict({
'screen': Box(low=0, high=255, shape=(144, 160, 3), dtype=np.uint8),
'current_room_layout': Box(low=0, high=255, shape=(156,), dtype=np.uint8),
'items_in_hand': Box(low=0, high=255, shape=(2,), dtype=np.uint8),
'items_in_inventory': Box(low=0, high=255, shape=(9,)),
'health': Box(low=0, high=16, shape=(1,), dtype=np.uint8),
'rupees': Box(low=0, high=999, shape=(1,), dtype=np.uint8),
})
self.observation_space = Dict(
{
"screen": Box(low=0, high=255, shape=(144, 160, 3), dtype=np.uint8),
"current_room_layout": Box(
low=0, high=255, shape=(156,), dtype=np.uint8
),
"items_in_hand": Box(low=0, high=255, shape=(2,), dtype=np.uint8),
"items_in_inventory": Box(low=0, high=255, shape=(9,)),
"health": Box(low=0, high=16, shape=(1,), dtype=np.uint8),
"rupees": Box(low=0, high=999, shape=(1,), dtype=np.uint8),
}
)

self.action_space = Discrete(len(self.valid_actions))

Expand Down Expand Up @@ -113,8 +117,9 @@ def _calculate_fitness(self):
self._fitness = 0

self._fitness += self._check_new_items() * self.reward_scale
self._fitness += self._check_new_locations() * self.reward_scale * \
self.exploration_reward
self._fitness += (
self._check_new_locations() * self.reward_scale * self.exploration_reward
)

if self.moving_things_in_inventory:
self._fitness -= 0.1 * self.reward_scale
Expand Down Expand Up @@ -176,7 +181,11 @@ def _check_new_items(self):
self.items[item_in_hand] = True
items_in_hand_count += 1

if items_in_hand_count < 2 and items_in_inventory_count >= items_in_hand_count and items_in_inventory_count != 0:
if (
items_in_hand_count < 2
and items_in_inventory_count >= items_in_hand_count
and items_in_inventory_count != 0
):
self.moving_things_in_inventory = True
else:
self.moving_things_in_inventory = False
Expand Down Expand Up @@ -206,19 +215,17 @@ def _get_observation(self):

rupees = [self._check_rupees()]

items_in_inventory = [self.pyboy.memory[addr]
for addr in ADDR_INVENTORY]
items_in_inventory = [self.pyboy.memory[addr] for addr in ADDR_INVENTORY]

items_in_hand = [self.pyboy.memory[addr]
for addr in ADDR_HELD_ITEMS]
items_in_hand = [self.pyboy.memory[addr] for addr in ADDR_HELD_ITEMS]

obs = {
'screen': screen,
'current_room_layout': current_room_layout,
'items_in_hand': items_in_hand,
'items_in_inventory': items_in_inventory,
'health': health,
'rupees': rupees
"screen": screen,
"current_room_layout": current_room_layout,
"items_in_hand": items_in_hand,
"items_in_inventory": items_in_inventory,
"health": health,
"rupees": rupees,
}

return obs
Expand Down
13 changes: 11 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,17 @@
mode = sys.argv[1]

if mode == "train":
model = PPO("MultiInputPolicy", env, verbose=1, n_steps=2048,
batch_size=512, n_epochs=1, gamma=0.997, ent_coef=0.01, tensorboard_log=config["log_dir"])
model = PPO(
"MultiInputPolicy",
env,
verbose=1,
n_steps=2048,
batch_size=512,
n_epochs=1,
gamma=0.997,
ent_coef=0.01,
tensorboard_log=config["log_dir"],
)

model.learn(total_timesteps=1000000, callback=callback)
elif mode == "test":
Expand Down
24 changes: 12 additions & 12 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import gymnasium as gym
import numpy as np
from gymnasium.spaces import Box
from stable_baselines3.common.callbacks import \
BaseCallback
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack

from config.gym import ZeldaGymEnv
Expand Down Expand Up @@ -34,8 +33,9 @@ def __init__(self, env: gym.Env, keep_dim: bool = False):
super().__init__(env)
self.keep_dim = keep_dim

assert isinstance(env.observation_space,
gym.spaces.Dict), "Observation space must be a Dict"
assert isinstance(
env.observation_space, gym.spaces.Dict
), "Observation space must be a Dict"

obs_shape = env.observation_space["screen"].shape[:2]

Expand All @@ -51,8 +51,7 @@ def __init__(self, env: gym.Env, keep_dim: bool = False):
def observation(self, observation):
import cv2

observation["screen"] = cv2.cvtColor(
observation["screen"], cv2.COLOR_RGB2GRAY)
observation["screen"] = cv2.cvtColor(observation["screen"], cv2.COLOR_RGB2GRAY)

if self.keep_dim:
observation["screen"] = np.expand_dims(observation["screen"], -1)
Expand All @@ -74,12 +73,13 @@ def _init_callback(self) -> None:

def _on_step(self):
# Log episode info
if self.locals.get('done'):
self.logger.record('game/episode_reward',
self.locals.get('rewards')[0])
self.logger.record('game/episode_length', self.n_calls)
self.logger.record('game/current_health',
self.training_env.get_attr('pyboy')[0].memory[0xDB5A])
if self.locals.get("done"):
self.logger.record("game/episode_reward", self.locals.get("rewards")[0])
self.logger.record("game/episode_length", self.n_calls)
self.logger.record(
"game/current_health",
self.training_env.get_attr("pyboy")[0].memory[0xDB5A],
)

# Save model checkpoint
if self.n_calls % self.check_freq == 0:
Expand Down

0 comments on commit ee4f723

Please sign in to comment.