diff --git a/zoo/atari/entry/atari_eval.py b/zoo/atari/entry/atari_eval.py index 824ffd1ab..49a42b99c 100644 --- a/zoo/atari/entry/atari_eval.py +++ b/zoo/atari/entry/atari_eval.py @@ -1,30 +1,41 @@ -# According to the model you want to evaluate, import the corresponding config. from lzero.entry import eval_muzero import numpy as np if __name__ == "__main__": - """ - model_path (:obj:`Optional[str]`): The pretrained model path, which should - point to the ckpt file of the pretrained model, and an absolute path is recommended. - In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. """ - # Take the config of sampled efficientzero as an example - from zoo.atari.config.atari_sampled_efficientzero_config import main_config, create_config + Overview: + Main script to evaluate the MuZero model on Atari games. The script will loop over multiple seeds, + evaluating a certain number of episodes per seed. Results are aggregated and printed. - model_path = "/path/ckpt/ckpt_best.pth.tar" + Variables: + - model_path (:obj:`Optional[str]`): The pretrained model path, pointing to the ckpt file of the pretrained model. + The path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. + - seeds (:obj:`List[int]`): List of seeds to use for the evaluations. + - num_episodes_each_seed (:obj:`int`): Number of episodes to evaluate for each seed. + - total_test_episodes (:obj:`int`): Total number of test episodes, calculated as num_episodes_each_seed * len(seeds). + - returns_mean_seeds (:obj:`np.array`): Array of mean return values for each seed. + - returns_seeds (:obj:`np.array`): Array of all return values for each seed. + """ + # Take the config of MuZero as an example + from zoo.atari.config.atari_muzero_config import main_config, create_config + + # model_path = "/path/ckpt/ckpt_best.pth.tar" + model_path = None - returns_mean_seeds = [] - returns_seeds = [] seeds = [0] num_episodes_each_seed = 1 total_test_episodes = num_episodes_each_seed * len(seeds) create_config.env_manager.type = 'base' # Visualization requires the 'type' to be set as base main_config.env.evaluator_env_num = 1 # Visualization requires the 'env_num' to be set as 1 main_config.env.n_evaluator_episode = total_test_episodes - main_config.env.render_mode_human = True # Whether to enable real-time rendering - main_config.env.save_video = True # Whether to save the video, if save the video render_mode_human must to be True - main_config.env.save_path = '../config/' - main_config.env.eval_max_episode_steps = int(1e3) # Adjust according to different environments + main_config.env.render_mode_human = False # Whether to enable real-time rendering + + main_config.env.save_replay = True # Whether to save the video + main_config.env.save_path = './video' + main_config.env.eval_max_episode_steps = int(20) # Adjust according to different environments + + returns_mean_seeds = [] + returns_seeds = [] for seed in seeds: returns_mean, returns = eval_muzero( @@ -45,4 +56,4 @@ print(f'We eval total {len(seeds)} seeds. In each seed, we eval {num_episodes_each_seed} episodes.') print(f'In seeds {seeds}, returns_mean_seeds is {returns_mean_seeds}, returns is {returns_seeds}') print('In all seeds, reward_mean:', returns_mean_seeds.mean()) - print("=" * 20) + print("=" * 20) \ No newline at end of file diff --git a/zoo/atari/envs/atari_lightzero_env.py b/zoo/atari/envs/atari_lightzero_env.py index 893c37e71..177c07625 100644 --- a/zoo/atari/envs/atari_lightzero_env.py +++ b/zoo/atari/envs/atari_lightzero_env.py @@ -1,6 +1,6 @@ import copy import sys -from typing import List +from typing import List, Any import gym import numpy as np @@ -14,56 +14,108 @@ @ENV_REGISTRY.register('atari_lightzero') class AtariLightZeroEnv(BaseEnv): + """ + Overview: + AtariLightZeroEnv is a derived class from BaseEnv and represents the environment for the Atari LightZero game. + This class provides the necessary interfaces to interact with the environment, including reset, step, seed, + close, etc. and manages the environment's properties such as observation_space, action_space, and reward_space. + Properties: + cfg, _init_flag, channel_last, clip_rewards, episode_life, _env, _observation_space, _action_space, + _reward_space, obs, _eval_episode_return, has_reset, _seed, _dynamic_seed + """ config = dict( + # (int) The number of environment instances used for data collection. collector_env_num=8, + # (int) The number of environment instances used for evaluator. evaluator_env_num=3, + # (int) The number of episodes to evaluate during each evaluation period. n_evaluator_episode=3, + # (str) The name of the Atari game environment. env_name='PongNoFrameskip-v4', + # (str) The type of the environment, here it's Atari. env_type='Atari', + # (tuple) The shape of the observation space, which is a stacked frame of 4 images each of 96x96 pixels. obs_shape=(4, 96, 96), + # (int) The maximum number of steps in each episode during data collection. collect_max_episode_steps=int(1.08e5), + # (int) The maximum number of steps in each episode during evaluation. eval_max_episode_steps=int(1.08e5), + # (bool) If True, the game is rendered in real-time. + render_mode_human=False, + # (bool) If True, a video of the game play is saved. + save_replay=False, + # (str) The path to save the video. + replay_path='./video', + # (bool) If set to True, the game screen is converted to grayscale, reducing the complexity of the observation space. gray_scale=True, + # (int) The number of frames to skip between each action. Higher values result in faster simulation. frame_skip=4, + # (bool) If True, the game ends when the agent loses a life, otherwise, the game only ends when all lives are lost. episode_life=True, + # (bool) If True, the rewards are clipped to a certain range, usually between -1 and 1, to reduce variance. clip_rewards=True, + # (bool) If True, the channels of the observation images are placed last (e.g., height, width, channels). channel_last=True, - render_mode_human=False, + # (bool) If True, the pixel values of the game frames are scaled down to the range [0, 1]. scale=True, + # (bool) If True, the game frames are preprocessed by cropping irrelevant parts and resizing to a smaller resolution. warp_frame=True, - save_video=False, + # (bool) If True, the game state is transformed into a string before being returned by the environment. transform2string=False, + # (bool) If True, additional wrappers for the game environment are used. game_wrapper=True, + # (dict) The configuration for the environment manager. If shared_memory is set to False, each environment instance + # runs in the same process as the trainer, otherwise, they run in separate processes. manager=dict(shared_memory=False, ), + # (int) The value of the cumulative reward at which the training stops. stop_value=int(1e6), ) @classmethod def default_config(cls: type) -> EasyDict: + """ + Overview: + Return the default configuration for the Atari LightZero environment. + Arguments: + - cls (:obj:`type`): The class AtariLightZeroEnv. + Returns: + - cfg (:obj:`EasyDict`): The default configuration dictionary. + """ cfg = EasyDict(copy.deepcopy(cls.config)) cfg.cfg_type = cls.__name__ + 'Dict' return cfg - def __init__(self, cfg=None): + def __init__(self, cfg: EasyDict) -> None: + """ + Overview: + Initialize the Atari LightZero environment with the given configuration. + Arguments: + - cfg (:obj:`EasyDict`): The configuration dictionary. + """ self.cfg = cfg self._init_flag = False self.channel_last = cfg.channel_last self.clip_rewards = cfg.clip_rewards self.episode_life = cfg.episode_life - def _make_env(self): - return wrap_lightzero(self.cfg, episode_life=self.cfg.episode_life, clip_rewards=self.cfg.clip_rewards) - - def reset(self): + def reset(self) -> dict: + """ + Overview: + Reset the environment and return the initial observation. + Returns: + - obs (:obj:`dict`): The initial observation after reset. + """ if not self._init_flag: - self._env = self._make_env() + # Create and return the wrapped environment for Atari LightZero. + self._env = wrap_lightzero(self.cfg, episode_life=self.cfg.episode_life, clip_rewards=self.cfg.clip_rewards) self._observation_space = self._env.env.observation_space self._action_space = self._env.env.action_space self._reward_space = gym.spaces.Box( - low=self._env.env.reward_range[0], high=self._env.env.reward_range[1], shape=(1, ), dtype=np.float32 + low=self._env.env.reward_range[0], high=self._env.env.reward_range[1], shape=(1,), dtype=np.float32 ) self._init_flag = True + if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed: np_seed = 100 * np.random.randint(1, 1000) self._env.env.seed(self._seed + np_seed) @@ -73,29 +125,19 @@ def reset(self): obs = self._env.reset() self.obs = to_ndarray(obs) self._eval_episode_return = 0. - self.has_reset = True obs = self.observe() - # obs.shape: 96,96,1 return obs - def observe(self): + def step(self, action: int) -> BaseEnvTimestep: """ Overview: - add action_mask to obs to adapt with MCTS alg.. + Execute the given action and return the resulting environment timestep. + Arguments: + - action (:obj:`int`): The action to be executed. + Returns: + - timestep (:obj:`BaseEnvTimestep`): The environment timestep after executing the action. """ - observation = self.obs - - if not self.channel_last: - # move the channel dim to the fist axis - # (96, 96, 3) -> (3, 96, 96) - observation = np.transpose(observation, (2, 0, 1)) - - action_mask = np.ones(self._action_space.n, 'int8') - return {'observation': observation, 'action_mask': action_mask, 'to_play': -1} - - def step(self, action): obs, reward, done, info = self._env.step(action) - # self._env.render() self.obs = to_ndarray(obs) self.reward = np.array(reward).astype(np.float32) self._eval_episode_return += self.reward @@ -105,6 +147,23 @@ def step(self, action): return BaseEnvTimestep(observation, self.reward, done, info) + def observe(self) -> dict: + """ + Overview: + Return the current observation along with the action mask and to_play flag. + Returns: + - observation (:obj:`dict`): The dictionary containing current observation, action mask, and to_play flag. + """ + observation = self.obs + + if not self.channel_last: + # move the channel dim to the fist axis + # (96, 96, 3) -> (3, 96, 96) + observation = np.transpose(observation, (2, 0, 1)) + + action_mask = np.ones(self._action_space.n, 'int8') + return {'observation': observation, 'action_mask': action_mask, 'to_play': -1} + @property def legal_actions(self): return np.arange(self._action_space.n) @@ -113,52 +172,41 @@ def random_action(self): action_list = self.legal_actions return np.random.choice(action_list) - def render(self, mode='human'): - self._env.render() - - def human_to_action(self): - """ - Overview: - For multiplayer games, ask the user for a legal action - and return the corresponding action number. - Returns: - An integer from the action space. - """ - while True: - try: - print(f"Current available actions for the player are:{self.legal_actions}") - choice = int(input(f"Enter the index of next action: ")) - if choice in self.legal_actions: - break - else: - print("Wrong input, try again") - except KeyboardInterrupt: - print("exit") - sys.exit(0) - except Exception as e: - print("Wrong input, try again") - return choice - def close(self) -> None: + """ + Close the environment, and set the initialization flag to False. + """ if self._init_flag: self._env.close() self._init_flag = False def seed(self, seed: int, dynamic_seed: bool = True) -> None: + """ + Set the seed for the environment's random number generator. Can handle both static and dynamic seeding. + """ self._seed = seed self._dynamic_seed = dynamic_seed np.random.seed(self._seed) @property def observation_space(self) -> gym.spaces.Space: + """ + Property to access the observation space of the environment. + """ return self._observation_space @property def action_space(self) -> gym.spaces.Space: + """ + Property to access the action space of the environment. + """ return self._action_space @property def reward_space(self) -> gym.spaces.Space: + """ + Property to access the reward space of the environment. + """ return self._reward_space def __repr__(self) -> str: diff --git a/zoo/atari/envs/atari_wrappers.py b/zoo/atari/envs/atari_wrappers.py index d16ff28cb..4254f16b3 100644 --- a/zoo/atari/envs/atari_wrappers.py +++ b/zoo/atari/envs/atari_wrappers.py @@ -1,5 +1,6 @@ -# Borrow a lot from openai baselines: -# https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py +# Adapted from openai baselines: https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py +from datetime import datetime +from typing import Optional import cv2 import gym @@ -8,9 +9,11 @@ ScaledFloatFrameWrapper, \ ClipRewardWrapper, FrameStackWrapper from ding.utils.compression_helper import jpeg_data_compressor +from easydict import EasyDict from gym.wrappers import RecordVideo +# only for reference now def wrap_deepmind(env_id, episode_life=True, clip_rewards=True, frame_stack=4, scale=True, warp_frame=True): """Configure environment for DeepMind-style Atari. The observation is channel-first: (c, h, w) instead of (h, w, c). @@ -42,6 +45,7 @@ def wrap_deepmind(env_id, episode_life=True, clip_rewards=True, frame_stack=4, s return env +# only for reference now def wrap_deepmind_mr(env_id, episode_life=True, clip_rewards=True, frame_stack=4, scale=True, warp_frame=True): """Configure environment for DeepMind-style Atari. The observation is channel-first: (c, h, w) instead of (h, w, c). @@ -73,18 +77,17 @@ def wrap_deepmind_mr(env_id, episode_life=True, clip_rewards=True, frame_stack=4 return env -def wrap_lightzero(config, episode_life, clip_rewards): +def wrap_lightzero(config: EasyDict, episode_life: bool, clip_rewards: bool) -> gym.Env: """ Overview: Configure environment for MuZero-style Atari. The observation is channel-first: (c, h, w) instead of (h, w, c). Arguments: - - config (:obj:`Dict`): Dict containing configuration. - - wrap_frame (:obj:`bool`): - - save_video (:obj:`bool`): - - save_path (:obj:`bool`): + - config (:obj:`Dict`): Dict containing configuration parameters for the environment. + - episode_life (:obj:`bool`): If True, the agent starts with a set number of lives and loses them during the game. + - clip_rewards (:obj:`bool`): If True, the rewards are clipped to a certain range. Return: - - the wrapped atari environment. + - env (:obj:`gym.Env`): The wrapped Atari environment with the given configurations. """ if config.render_mode_human: env = gym.make(config.env_name, render_mode='human') @@ -103,13 +106,14 @@ def wrap_lightzero(config, episode_life, clip_rewards): env = ScaledFloatFrameWrapper(env) if clip_rewards: env = ClipRewardWrapper(env) - if config.save_video: - import random, string + if config.save_replay: + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + video_name = f'{env.spec.id}-video-{timestamp}' env = RecordVideo( env, - video_folder=config.save_path, + video_folder=config.replay_path, episode_trigger=lambda episode_id: True, - name_prefix='rl-video-{}'.format(''.join(random.choice(string.ascii_lowercase) for i in range(5))), + name_prefix=video_name ) env = JpegWrapper(env, transform2string=config.transform2string) @@ -120,8 +124,17 @@ def wrap_lightzero(config, episode_life, clip_rewards): class TimeLimit(gym.Wrapper): + """ + Overview: + A wrapper that limits the maximum number of steps in an episode. + """ - def __init__(self, env, max_episode_steps=None): + def __init__(self, env: gym.Env, max_episode_steps: Optional[int] = None): + """ + Arguments: + - env (:obj:`gym.Env`): The environment to wrap. + - max_episode_steps (:obj:`Optional[int]`): Maximum number of steps per episode. If None, no limit is applied. + """ super(TimeLimit, self).__init__(env) self._max_episode_steps = max_episode_steps self._elapsed_steps = 0 @@ -140,12 +153,20 @@ def reset(self, **kwargs): class WarpFrame(gym.ObservationWrapper): + """ + Overview: + A wrapper that warps frames to 84x84 as done in the Nature paper and later work. + """ - def __init__(self, env, width=84, height=84, grayscale=True, dict_space_key=None): + def __init__(self, env: gym.Env, width: int = 84, height: int = 84, grayscale: bool = True, + dict_space_key: Optional[str] = None): """ - Warp frames to 84x84 as done in the Nature paper and later work. - If the environment uses dictionary observations, `dict_space_key` can be specified which indicates which - observation should be warped. + Arguments: + - env (:obj:`gym.Env`): The environment to wrap. + - width (:obj:`int`): The width to which the frames are resized. + - height (:obj:`int`): The height to which the frames are resized. + - grayscale (:obj:`bool`): If True, convert frames to grayscale. + - dict_space_key (:obj:`Optional[str]`): If specified, indicates which observation should be warped. """ super().__init__(env) self._width = width @@ -192,10 +213,16 @@ def observation(self, obs): class JpegWrapper(gym.Wrapper): + """ + Overview: + A wrapper that converts the observation into a string to save memory. + """ - def __init__(self, env, transform2string=True): + def __init__(self, env: gym.Env, transform2string: bool = True): """ - Overview: convert the observation into string to save memory + Arguments: + - env (:obj:`gym.Env`): The environment to wrap. + - transform2string (:obj:`bool`): If True, transform the observations to string. """ super().__init__(env) self.transform2string = transform2string @@ -218,10 +245,15 @@ def reset(self, **kwargs): class GameWrapper(gym.Wrapper): + """ + Overview: + A wrapper to adapt the environment to the game interface. + """ - def __init__(self, env): + def __init__(self, env: gym.Env): """ - Overview: warp env to adapt the game interface + Arguments: + - env (:obj:`gym.Env`): The environment to wrap. """ super().__init__(env) diff --git a/zoo/board_games/tictactoe/entry/tictactoe_muzero_eval.py b/zoo/board_games/tictactoe/entry/tictactoe_muzero_eval.py index 35e6a1d38..2ae5919e9 100644 --- a/zoo/board_games/tictactoe/entry/tictactoe_muzero_eval.py +++ b/zoo/board_games/tictactoe/entry/tictactoe_muzero_eval.py @@ -9,7 +9,9 @@ In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. """ - model_path = "./ckpt/ckpt_best.pth.tar" + # model_path = "./ckpt/ckpt_best.pth.tar" + model_path = None + seeds = [0] num_episodes_each_seed = 5 # If True, you can play with the agent. diff --git a/zoo/box2d/lunarlander/entry/lunarlander_eval.py b/zoo/box2d/lunarlander/entry/lunarlander_eval.py index b74516ed0..06ca5e26f 100644 --- a/zoo/box2d/lunarlander/entry/lunarlander_eval.py +++ b/zoo/box2d/lunarlander/entry/lunarlander_eval.py @@ -1,20 +1,38 @@ -# According to the model you want to evaluate, import the corresponding config. +# Import the necessary libraries and configs based on the model you want to evaluate from zoo.box2d.lunarlander.config.lunarlander_disc_muzero_config import main_config, create_config from lzero.entry import eval_muzero import numpy as np if __name__ == "__main__": - """ - model_path (:obj:`Optional[str]`): The pretrained model path, which should - point to the ckpt file of the pretrained model, and an absolute path is recommended. - In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. """ - model_path = './ckpt/ckpt_best.pth.tar' + Overview: + Evaluate the model performance by running multiple episodes with different seeds using the MuZero algorithm. + The evaluation results (returns and mean returns) are printed out for each seed and summarized for all seeds. + Variables: + - model_path (:obj:`str`): Path to the pretrained model's checkpoint file. Usually something like + "exp_name/ckpt/ckpt_best.pth.tar". Absolute path is recommended. + - seeds (:obj:`List[int]`): List of seeds to use for evaluation. Each seed will run for a specified number + of episodes. + - num_episodes_each_seed (:obj:`int`): Number of episodes to be run for each seed. + - main_config (:obj:`EasyDict`): Main configuration for the evaluation, imported from the model's config file. + - returns_mean_seeds (:obj:`List[float]`): List to store the mean returns for each seed. + - returns_seeds (:obj:`List[List[float]]`): List to store the returns for each episode from each seed. + Outputs: + Prints out the mean returns and returns for each seed, along with the overall mean return across all seeds. + + .. note:: + The eval_muzero function is used here for evaluation. For more details about this function and its parameters, + please refer to its own documentation. + """ + # model_path = './ckpt/ckpt_best.pth.tar' + model_path = None + seeds = [0] - num_episodes_each_seed = 5 + num_episodes_each_seed = 1 main_config.env.evaluator_env_num = 1 main_config.env.n_evaluator_episode = 1 total_test_episodes = num_episodes_each_seed * len(seeds) + main_config.env.replay_path = './video' returns_mean_seeds = [] returns_seeds = [] for seed in seeds: @@ -35,4 +53,4 @@ print(f'We eval total {len(seeds)} seeds. In each seed, we eval {num_episodes_each_seed} episodes.') print(f'In seeds {seeds}, returns_mean_seeds is {returns_mean_seeds}, returns is {returns_seeds}') print('In all seeds, reward_mean:', returns_mean_seeds.mean()) - print("=" * 20) + print("=" * 20) \ No newline at end of file diff --git a/zoo/box2d/lunarlander/envs/lunarlander_env.py b/zoo/box2d/lunarlander/envs/lunarlander_env.py index a6827d2b7..581bbedbd 100755 --- a/zoo/box2d/lunarlander/envs/lunarlander_env.py +++ b/zoo/box2d/lunarlander/envs/lunarlander_env.py @@ -1,19 +1,27 @@ import copy import os -from typing import List, Optional +from datetime import datetime +from typing import List, Optional, Dict import gym import numpy as np -from ding.envs import BaseEnv, BaseEnvTimestep +from ding.envs import BaseEnvTimestep from ding.envs import ObsPlusPrevActRewWrapper from ding.envs.common import affine_transform from ding.torch_utils import to_ndarray from ding.utils import ENV_REGISTRY from easydict import EasyDict +from zoo.classic_control.cartpole.envs.cartpole_lightzero_env import CartPoleEnv + @ENV_REGISTRY.register('lunarlander') -class LunarLanderEnv(BaseEnv): +class LunarLanderEnv(CartPoleEnv): + """ + Overview: + The LunarLander Environment class for LightZero algo.. This class is a wrapper of the gym LunarLander environment, with additional + functionalities like replay saving and seed setting. The class is registered in ENV_REGISTRY with the key 'lunarlander'. + """ config = dict( env_name="LunarLander-v2", @@ -29,11 +37,23 @@ class LunarLanderEnv(BaseEnv): @classmethod def default_config(cls: type) -> EasyDict: + """ + Overview: + Return the default configuration of the class. + Returns: + - cfg (:obj:`EasyDict`): Default configuration dict. + """ cfg = EasyDict(copy.deepcopy(cls.config)) cfg.cfg_type = cls.__name__ + 'Dict' return cfg def __init__(self, cfg: dict) -> None: + """ + Overview: + Initialize the LunarLander environment. + Arguments: + - cfg (:obj:`dict`): Configuration dict. The dict should include keys like 'env_name', 'replay_path', etc. + """ self._cfg = cfg self._init_flag = False # env_name options = {'LunarLander-v2', 'LunarLanderContinuous-v2'} @@ -47,22 +67,30 @@ def __init__(self, cfg: dict) -> None: else: self._act_scale = False - def reset(self) -> np.ndarray: + def reset(self) -> Dict[str, np.ndarray]: + """ + Overview: + Reset the environment and return the initial observation. + Returns: + - obs (:obj:`np.ndarray`): The initial observation after resetting. + """ if not self._init_flag: self._env = gym.make(self._cfg.env_name) if self._replay_path is not None: + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + video_name = f'{self._env.spec.id}-video-{timestamp}' self._env = gym.wrappers.RecordVideo( self._env, video_folder=self._replay_path, episode_trigger=lambda episode_id: True, - name_prefix='rl-video-{}'.format(id(self)) + name_prefix=video_name ) if hasattr(self._cfg, 'obs_plus_prev_action_reward') and self._cfg.obs_plus_prev_action_reward: self._env = ObsPlusPrevActRewWrapper(self._env) self._observation_space = self._env.observation_space self._action_space = self._env.action_space self._reward_space = gym.spaces.Box( - low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32 + low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1,), dtype=np.float32 ) self._init_flag = True if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed: @@ -83,21 +111,16 @@ def reset(self) -> np.ndarray: obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1} return obs - def close(self) -> None: - if self._init_flag: - self._env.close() - self._init_flag = False - - def render(self) -> None: - self._env.render() - - def seed(self, seed: int, dynamic_seed: bool = True) -> None: - self._seed = seed - self._dynamic_seed = dynamic_seed - np.random.seed(self._seed) - def step(self, action: np.ndarray) -> BaseEnvTimestep: - if action.shape == (1, ): + """ + Overview: + Take a step in the environment with the given action. + Arguments: + - action (:obj:`np.ndarray`): The action to be taken. + Returns: + - timestep (:obj:`BaseEnvTimestep`): The timestep information including observation, reward, done flag, and info. + """ + if action.shape == (1,): action = action.item() # 0-dim array if self._act_scale: action = affine_transform(action, min_val=-1, max_val=1) @@ -129,7 +152,13 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep: return BaseEnvTimestep(obs, rew, done, info) @property - def legal_actions(self): + def legal_actions(self) -> np.ndarray: + """ + Overview: + Get the legal actions in the environment. + Returns: + - legal_actions (:obj:`np.ndarray`): An array of legal actions. + """ return np.arange(self._action_space.n) def enable_save_replay(self, replay_path: Optional[str] = None) -> None: @@ -152,23 +181,19 @@ def random_action(self) -> np.ndarray: random_action = to_ndarray([random_action], dtype=np.int64) return random_action - @property - def observation_space(self) -> gym.spaces.Space: - return self._observation_space - - @property - def action_space(self) -> gym.spaces.Space: - return self._action_space - - @property - def reward_space(self) -> gym.spaces.Space: - return self._reward_space - def __repr__(self) -> str: return "LightZero LunarLander Env." @staticmethod def create_collector_env_cfg(cfg: dict) -> List[dict]: + """ + Overview: + Create a list of environment configurations for the collector. + Arguments: + - cfg (:obj:`dict`): The base configuration dict. + Returns: + - cfgs (:obj:`List[dict]`): The list of environment configurations. + """ collector_env_num = cfg.pop('collector_env_num') cfg = copy.deepcopy(cfg) cfg.max_episode_steps = cfg.collect_max_episode_steps @@ -176,6 +201,14 @@ def create_collector_env_cfg(cfg: dict) -> List[dict]: @staticmethod def create_evaluator_env_cfg(cfg: dict) -> List[dict]: + """ + Overview: + Create a list of environment configurations for the evaluator. + Arguments: + - cfg (:obj:`dict`): The base configuration dict. + Returns: + - cfgs (:obj:`List[dict]`): The list of environment configurations. + """ evaluator_env_num = cfg.pop('evaluator_env_num') cfg = copy.deepcopy(cfg) cfg.max_episode_steps = cfg.eval_max_episode_steps diff --git a/zoo/classic_control/cartpole/entry/cartpole_eval.py b/zoo/classic_control/cartpole/entry/cartpole_eval.py index 32cb54907..e7f73ba4a 100644 --- a/zoo/classic_control/cartpole/entry/cartpole_eval.py +++ b/zoo/classic_control/cartpole/entry/cartpole_eval.py @@ -1,22 +1,39 @@ -from cartpole_muzero_config import main_config, create_config +from zoo.classic_control.cartpole.config.cartpole_muzero_config import main_config, create_config from lzero.entry import eval_muzero import numpy as np if __name__ == "__main__": - """ - model_path (:obj:`Optional[str]`): The pretrained model path, which should - point to the ckpt file of the pretrained model, and an absolute path is recommended. - In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. """ - model_path = "./ckpt/ckpt_best.pth.tar" - seeds = [0] - num_episodes_each_seed = 5 - main_config.env.evaluator_env_num = 1 - main_config.env.n_evaluator_episode = 1 - total_test_episodes = num_episodes_each_seed * len(seeds) + Entry point for the evaluation of the MuZero model on the CartPole environment. + + Variables: + - model_path (:obj:`Optional[str]`): The pretrained model path, which should point to the ckpt file of the + pretrained model. An absolute path is recommended. In LightZero, the path is usually something like + ``exp_name/ckpt/ckpt_best.pth.tar``. + - returns_mean_seeds (:obj:`List[float]`): List to store the mean returns for each seed. + - returns_seeds (:obj:`List[float]`): List to store the returns for each seed. + - seeds (:obj:`List[int]`): List of seeds for the environment. + - num_episodes_each_seed (:obj:`int`): Number of episodes to run for each seed. + - total_test_episodes (:obj:`int`): Total number of test episodes, computed as the product of the number of + seeds and the number of episodes per seed. + """ + # model_path = "./ckpt/ckpt_best.pth.tar" + model_path = None returns_mean_seeds = [] returns_seeds = [] + seeds = [0] + num_episodes_each_seed = 2 + total_test_episodes = num_episodes_each_seed * len(seeds) + create_config.env_manager.type = 'base' # Visualization requires the 'type' to be set as base + main_config.env.evaluator_env_num = 1 # Visualization requires the 'env_num' to be set as 1 + main_config.env.n_evaluator_episode = total_test_episodes + main_config.env.replay_path = './video' + for seed in seeds: + """ + - returns_mean (:obj:`float`): The mean return of the evaluation. + - returns (:obj:`List[float]`): The returns of the evaluation. + """ returns_mean, returns = eval_muzero( [main_config, create_config], seed=seed, @@ -34,4 +51,4 @@ print(f'We eval total {len(seeds)} seeds. In each seed, we eval {num_episodes_each_seed} episodes.') print(f'In seeds {seeds}, returns_mean_seeds is {returns_mean_seeds}, returns is {returns_seeds}') print('In all seeds, reward_mean:', returns_mean_seeds.mean()) - print("=" * 20) + print("=" * 20) \ No newline at end of file diff --git a/zoo/classic_control/cartpole/envs/cartpole_lightzero_env.py b/zoo/classic_control/cartpole/envs/cartpole_lightzero_env.py index bf958ce1c..52d337189 100644 --- a/zoo/classic_control/cartpole/envs/cartpole_lightzero_env.py +++ b/zoo/classic_control/cartpole/envs/cartpole_lightzero_env.py @@ -1,8 +1,8 @@ -from typing import Union, Optional +from datetime import datetime +from typing import Union, Optional, Dict import gym import numpy as np - from ding.envs import BaseEnv, BaseEnvTimestep from ding.envs import ObsPlusPrevActRewWrapper from ding.torch_utils import to_ndarray @@ -11,31 +11,46 @@ @ENV_REGISTRY.register('cartpole_lightzero') class CartPoleEnv(BaseEnv): + """ + LightZero version of the classic CartPole environment. This class includes methods for resetting, closing, and + stepping through the environment, as well as seeding for reproducibility, saving replay videos, and generating random + actions. It also includes properties for accessing the observation space, action space, and reward space of the + environment. + """ def __init__(self, cfg: dict = {}) -> None: + """ + Initialize the environment with a configuration dictionary. Sets up spaces for observations, actions, and rewards. + """ self._cfg = cfg self._init_flag = False - self._replay_path = None + self._continuous = False + self._replay_path = cfg.replay_path self._observation_space = gym.spaces.Box( low=np.array([-4.8, float("-inf"), -0.42, float("-inf")]), high=np.array([4.8, float("inf"), 0.42, float("inf")]), - shape=(4, ), + shape=(4,), dtype=np.float32 ) self._action_space = gym.spaces.Discrete(2) self._action_space.seed(0) # default seed - self._reward_space = gym.spaces.Box(low=0.0, high=1.0, shape=(1, ), dtype=np.float32) - self._continuous = False + self._reward_space = gym.spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32) - def reset(self) -> np.ndarray: + def reset(self) -> Dict[str, np.ndarray]: + """ + Reset the environment. If it hasn't been initialized yet, this method also handles that. It also handles seeding + if necessary. Returns the first observation. + """ if not self._init_flag: self._env = gym.make('CartPole-v0') if self._replay_path is not None: + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + video_name = f'{self._env.spec.id}-video-{timestamp}' self._env = gym.wrappers.RecordVideo( self._env, video_folder=self._replay_path, episode_trigger=lambda episode_id: True, - name_prefix='rl-video-{}'.format(id(self)) + name_prefix=video_name ) if hasattr(self._cfg, 'obs_plus_prev_action_reward') and self._cfg.obs_plus_prev_action_reward: self._env = ObsPlusPrevActRewWrapper(self._env) @@ -57,18 +72,26 @@ def reset(self) -> np.ndarray: return obs - def close(self) -> None: - if self._init_flag: - self._env.close() - self._init_flag = False - - def seed(self, seed: int, dynamic_seed: bool = True) -> None: - self._seed = seed - self._dynamic_seed = dynamic_seed - np.random.seed(self._seed) - def step(self, action: Union[int, np.ndarray]) -> BaseEnvTimestep: - if isinstance(action, np.ndarray) and action.shape == (1, ): + """ + Overview: + Perform a step in the environment using the provided action, and return the next state of the environment. + The next state is encapsulated in a BaseEnvTimestep object, which includes the new observation, reward, + done flag, and info dictionary. + Arguments: + - action (:obj:`Union[int, np.ndarray]`): The action to be performed in the environment. If the action is + a 1-dimensional numpy array, it is squeezed to a 0-dimension array. + Returns: + - timestep (:obj:`BaseEnvTimestep`): An object containing the new observation, reward, done flag, + and info dictionary. + .. note:: + - The cumulative reward (`_eval_episode_return`) is updated with the reward obtained in this step. + - If the episode ends (done is True), the total reward for the episode is stored in the info dictionary + under the key 'eval_episode_return'. + - An action mask is created with ones, which represents the availability of each action in the action space. + - Observations are returned in a dictionary format containing 'observation', 'action_mask', and 'to_play'. + """ + if isinstance(action, np.ndarray) and action.shape == (1,): action = action.squeeze() # 0-dim array obs, rew, done, info = self._env.step(action) @@ -82,27 +105,61 @@ def step(self, action: Union[int, np.ndarray]) -> BaseEnvTimestep: return BaseEnvTimestep(obs, rew, done, info) + def close(self) -> None: + """ + Close the environment, and set the initialization flag to False. + """ + if self._init_flag: + self._env.close() + self._init_flag = False + + def seed(self, seed: int, dynamic_seed: bool = True) -> None: + """ + Set the seed for the environment's random number generator. Can handle both static and dynamic seeding. + """ + self._seed = seed + self._dynamic_seed = dynamic_seed + np.random.seed(self._seed) + def enable_save_replay(self, replay_path: Optional[str] = None) -> None: + """ + Enable the saving of replay videos. If no replay path is given, a default is used. + """ if replay_path is None: replay_path = './video' self._replay_path = replay_path def random_action(self) -> np.ndarray: + """ + Generate a random action using the action space's sample method. Returns a numpy array containing the action. + """ random_action = self.action_space.sample() random_action = to_ndarray([random_action], dtype=np.int64) return random_action @property def observation_space(self) -> gym.spaces.Space: + """ + Property to access the observation space of the environment. + """ return self._observation_space @property def action_space(self) -> gym.spaces.Space: + """ + Property to access the action space of the environment. + """ return self._action_space @property def reward_space(self) -> gym.spaces.Space: + """ + Property to access the reward space of the environment. + """ return self._reward_space def __repr__(self) -> str: + """ + String representation of the environment. + """ return "LightZero CartPole Env" diff --git a/zoo/classic_control/pendulum/envs/pendulum_lightzero_env.py b/zoo/classic_control/pendulum/envs/pendulum_lightzero_env.py index 204ddbc3b..1ca23fb04 100644 --- a/zoo/classic_control/pendulum/envs/pendulum_lightzero_env.py +++ b/zoo/classic_control/pendulum/envs/pendulum_lightzero_env.py @@ -1,17 +1,26 @@ import copy -from typing import Optional +from datetime import datetime +from typing import Union, Dict import gym import numpy as np -from ding.envs import BaseEnv, BaseEnvTimestep +from ding.envs import BaseEnvTimestep from ding.envs.common.common_function import affine_transform from ding.torch_utils import to_ndarray from ding.utils import ENV_REGISTRY from easydict import EasyDict +from zoo.classic_control.cartpole.envs.cartpole_lightzero_env import CartPoleEnv + @ENV_REGISTRY.register('pendulum_lightzero') -class PendulumEnv(BaseEnv): +class PendulumEnv(CartPoleEnv): + """ + LightZero version of the classic Pendulum environment. This class includes methods for resetting, closing, and + stepping through the environment, as well as seeding for reproducibility, saving replay videos, and generating random + actions. It also includes properties for accessing the observation space, action space, and reward space of the + environment. + """ @classmethod def default_config(cls: type) -> EasyDict: @@ -20,18 +29,18 @@ def default_config(cls: type) -> EasyDict: return cfg config = dict( + # (bool) Whether to use continuous action space continuous=True, - save_replay_gif=False, - replay_path_gif=None, + # (str) The path to save replay videos replay_path=None, + # (bool) Whether to scale action into [-2, 2] act_scale=True, - delay_reward_step=0, - prob_random_agent=0., - collect_max_episode_steps=int(1.08e5), - eval_max_episode_steps=int(1.08e5), ) def __init__(self, cfg: dict) -> None: + """ + Initialize the environment with a configuration dictionary. Sets up spaces for observations, actions, and rewards. + """ self._cfg = cfg self._act_scale = cfg.act_scale try: @@ -39,33 +48,39 @@ def __init__(self, cfg: dict) -> None: except: self._env = gym.make('Pendulum-v0') self._init_flag = False - self._replay_path = None + self._replay_path = cfg.replay_path self._continuous = cfg.get("continuous", True) self._observation_space = gym.spaces.Box( - low=np.array([-1.0, -1.0, -8.0]), high=np.array([1.0, 1.0, 8.0]), shape=(3, ), dtype=np.float32 + low=np.array([-1.0, -1.0, -8.0]), high=np.array([1.0, 1.0, 8.0]), shape=(3,), dtype=np.float32 ) if self._continuous: - self._action_space = gym.spaces.Box(low=-2.0, high=2.0, shape=(1, ), dtype=np.float32) + self._action_space = gym.spaces.Box(low=-2.0, high=2.0, shape=(1,), dtype=np.float32) else: self.discrete_action_num = 11 self._action_space = gym.spaces.Discrete(self.discrete_action_num) self._action_space.seed(0) # default seed self._reward_space = gym.spaces.Box( - low=-1 * (3.14 * 3.14 + 0.1 * 8 * 8 + 0.001 * 2 * 2), high=0.0, shape=(1, ), dtype=np.float32 + low=-1 * (3.14 * 3.14 + 0.1 * 8 * 8 + 0.001 * 2 * 2), high=0.0, shape=(1,), dtype=np.float32 ) - def reset(self) -> np.ndarray: + def reset(self) -> Dict[str, np.ndarray]: + """ + Reset the environment. If it hasn't been initialized yet, this method also handles that. It also handles seeding + if necessary. Returns the first observation. + """ if not self._init_flag: try: self._env = gym.make('Pendulum-v1') except: self._env = gym.make('Pendulum-v0') if self._replay_path is not None: + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + video_name = f'{self._env.spec.id}-video-{timestamp}' self._env = gym.wrappers.RecordVideo( self._env, video_folder=self._replay_path, episode_trigger=lambda episode_id: True, - name_prefix='rl-video-{}'.format(id(self)) + name_prefix=video_name ) self._init_flag = True if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed: @@ -87,29 +102,38 @@ def reset(self) -> np.ndarray: return obs - def close(self) -> None: - if self._init_flag: - self._env.close() - self._init_flag = False - - def seed(self, seed: int, dynamic_seed: bool = True) -> None: - self._seed = seed - self._dynamic_seed = dynamic_seed - np.random.seed(self._seed) - - def step(self, action: np.ndarray) -> BaseEnvTimestep: + def step(self, action: Union[int, np.ndarray]) -> BaseEnvTimestep: + """ + Overview: + Step the environment forward with the provided action. This method returns the next state of the environment + (observation, reward, done flag, and info dictionary) encapsulated in a BaseEnvTimestep object. + Arguments: + - action (:obj:`Union[int, np.ndarray]`): The action to be performed in the environment. + Returns: + - timestep (:obj:`BaseEnvTimestep`): An object containing the new observation, reward, done flag, + and info dictionary. + + .. note:: + - If the environment requires discrete actions, they are converted to float actions in the range [-1, 1]. + - If action scaling is enabled, continuous actions are scaled into the range [-2, 2]. + - For each step, the cumulative reward (`_eval_episode_return`) is updated. + - If the episode ends (done is True), the total reward for the episode is stored in the info dictionary + under the key 'eval_episode_return'. + - If the environment requires discrete actions, an action mask is created, otherwise, it's None. + - Observations are returned in a dictionary format containing 'observation', 'action_mask', and 'to_play'. + """ if isinstance(action, int): action = np.array(action) # if require discrete env, convert actions to [-1 ~ 1] float actions if not self._continuous: action = (action / (self.discrete_action_num - 1)) * 2 - 1 - # scale into [-2, 2] + # scale the continous action into [-2, 2] if self._act_scale: action = affine_transform(action, min_val=self._env.action_space.low, max_val=self._env.action_space.high) obs, rew, done, info = self._env.step(action) self._eval_episode_return += rew obs = to_ndarray(obs).astype(np.float32) - # wrapped to be transferred to a array with shape (1,) + # wrapped to be transferred to an array with shape (1,) rew = to_ndarray([rew]).astype(np.float32) if done: @@ -123,12 +147,10 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep: return BaseEnvTimestep(obs, rew, done, info) - def enable_save_replay(self, replay_path: Optional[str] = None) -> None: - if replay_path is None: - replay_path = './video' - self._replay_path = replay_path - def random_action(self) -> np.ndarray: + """ + Generate a random action using the action space's sample method. Returns a numpy array containing the action. + """ if self._continuous: random_action = self.action_space.sample().astype(np.float32) else: @@ -136,17 +158,8 @@ def random_action(self) -> np.ndarray: random_action = to_ndarray([random_action], dtype=np.int64) return random_action - @property - def observation_space(self) -> gym.spaces.Space: - return self._observation_space - - @property - def action_space(self) -> gym.spaces.Space: - return self._action_space - - @property - def reward_space(self) -> gym.spaces.Space: - return self._reward_space - def __repr__(self) -> str: + """ + String representation of the environment. + """ return "LightZero Pendulum Env({})".format(self._cfg.env_id)