diff --git a/config/gym.py b/config/gym.py index 0cce226..9abec05 100644 --- a/config/gym.py +++ b/config/gym.py @@ -46,14 +46,18 @@ def __init__(self, config: dict, debug=False): WindowEvent.RELEASE_BUTTON_START, ] - self.observation_space = Dict({ - 'screen': Box(low=0, high=255, shape=(144, 160, 3), dtype=np.uint8), - 'current_room_layout': Box(low=0, high=255, shape=(156,), dtype=np.uint8), - 'items_in_hand': Box(low=0, high=255, shape=(2,), dtype=np.uint8), - 'items_in_inventory': Box(low=0, high=255, shape=(9,)), - 'health': Box(low=0, high=16, shape=(1,), dtype=np.uint8), - 'rupees': Box(low=0, high=999, shape=(1,), dtype=np.uint8), - }) + self.observation_space = Dict( + { + "screen": Box(low=0, high=255, shape=(144, 160, 3), dtype=np.uint8), + "current_room_layout": Box( + low=0, high=255, shape=(156,), dtype=np.uint8 + ), + "items_in_hand": Box(low=0, high=255, shape=(2,), dtype=np.uint8), + "items_in_inventory": Box(low=0, high=255, shape=(9,)), + "health": Box(low=0, high=16, shape=(1,), dtype=np.uint8), + "rupees": Box(low=0, high=999, shape=(1,), dtype=np.uint8), + } + ) self.action_space = Discrete(len(self.valid_actions)) @@ -113,8 +117,9 @@ def _calculate_fitness(self): self._fitness = 0 self._fitness += self._check_new_items() * self.reward_scale - self._fitness += self._check_new_locations() * self.reward_scale * \ - self.exploration_reward + self._fitness += ( + self._check_new_locations() * self.reward_scale * self.exploration_reward + ) if self.moving_things_in_inventory: self._fitness -= 0.1 * self.reward_scale @@ -176,7 +181,11 @@ def _check_new_items(self): self.items[item_in_hand] = True items_in_hand_count += 1 - if items_in_hand_count < 2 and items_in_inventory_count >= items_in_hand_count and items_in_inventory_count != 0: + if ( + items_in_hand_count < 2 + and items_in_inventory_count >= items_in_hand_count + and items_in_inventory_count != 0 + ): self.moving_things_in_inventory = True else: self.moving_things_in_inventory = False @@ -206,19 +215,17 @@ def _get_observation(self): rupees = [self._check_rupees()] - items_in_inventory = [self.pyboy.memory[addr] - for addr in ADDR_INVENTORY] + items_in_inventory = [self.pyboy.memory[addr] for addr in ADDR_INVENTORY] - items_in_hand = [self.pyboy.memory[addr] - for addr in ADDR_HELD_ITEMS] + items_in_hand = [self.pyboy.memory[addr] for addr in ADDR_HELD_ITEMS] obs = { - 'screen': screen, - 'current_room_layout': current_room_layout, - 'items_in_hand': items_in_hand, - 'items_in_inventory': items_in_inventory, - 'health': health, - 'rupees': rupees + "screen": screen, + "current_room_layout": current_room_layout, + "items_in_hand": items_in_hand, + "items_in_inventory": items_in_inventory, + "health": health, + "rupees": rupees, } return obs diff --git a/main.py b/main.py index 862e3ec..b236504 100644 --- a/main.py +++ b/main.py @@ -24,8 +24,17 @@ mode = sys.argv[1] if mode == "train": - model = PPO("MultiInputPolicy", env, verbose=1, n_steps=2048, - batch_size=512, n_epochs=1, gamma=0.997, ent_coef=0.01, tensorboard_log=config["log_dir"]) + model = PPO( + "MultiInputPolicy", + env, + verbose=1, + n_steps=2048, + batch_size=512, + n_epochs=1, + gamma=0.997, + ent_coef=0.01, + tensorboard_log=config["log_dir"], + ) model.learn(total_timesteps=1000000, callback=callback) elif mode == "test": diff --git a/utils.py b/utils.py index ff42c57..02c34d5 100644 --- a/utils.py +++ b/utils.py @@ -3,8 +3,7 @@ import gymnasium as gym import numpy as np from gymnasium.spaces import Box -from stable_baselines3.common.callbacks import \ - BaseCallback +from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack from config.gym import ZeldaGymEnv @@ -34,8 +33,9 @@ def __init__(self, env: gym.Env, keep_dim: bool = False): super().__init__(env) self.keep_dim = keep_dim - assert isinstance(env.observation_space, - gym.spaces.Dict), "Observation space must be a Dict" + assert isinstance( + env.observation_space, gym.spaces.Dict + ), "Observation space must be a Dict" obs_shape = env.observation_space["screen"].shape[:2] @@ -51,8 +51,7 @@ def __init__(self, env: gym.Env, keep_dim: bool = False): def observation(self, observation): import cv2 - observation["screen"] = cv2.cvtColor( - observation["screen"], cv2.COLOR_RGB2GRAY) + observation["screen"] = cv2.cvtColor(observation["screen"], cv2.COLOR_RGB2GRAY) if self.keep_dim: observation["screen"] = np.expand_dims(observation["screen"], -1) @@ -74,12 +73,13 @@ def _init_callback(self) -> None: def _on_step(self): # Log episode info - if self.locals.get('done'): - self.logger.record('game/episode_reward', - self.locals.get('rewards')[0]) - self.logger.record('game/episode_length', self.n_calls) - self.logger.record('game/current_health', - self.training_env.get_attr('pyboy')[0].memory[0xDB5A]) + if self.locals.get("done"): + self.logger.record("game/episode_reward", self.locals.get("rewards")[0]) + self.logger.record("game/episode_length", self.n_calls) + self.logger.record( + "game/current_health", + self.training_env.get_attr("pyboy")[0].memory[0xDB5A], + ) # Save model checkpoint if self.n_calls % self.check_freq == 0: