-
Notifications
You must be signed in to change notification settings - Fork 130
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
polish(pu): polish comments and render in tictactoe, gomoku, connect4…
…, 2048
- Loading branch information
1 parent
9e97184
commit d3aaccd
Showing
19 changed files
with
532 additions
and
200 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from zoo.board_games.connect4.config.connect4_muzero_bot_mode_config import main_config, create_config | ||
from lzero.entry import eval_muzero | ||
import numpy as np | ||
|
||
if __name__ == '__main__': | ||
""" | ||
Entry point for the evaluation of the MuZero model on the Connect4 environment. | ||
Variables: | ||
- model_path (:obj:`Optional[str]`): The pretrained model path, which should point to the ckpt file of the | ||
pretrained model. An absolute path is recommended. In LightZero, the path is usually something like | ||
``exp_name/ckpt/ckpt_best.pth.tar``. | ||
- returns_mean_seeds (:obj:`List[float]`): List to store the mean returns for each seed. | ||
- returns_seeds (:obj:`List[float]`): List to store the returns for each seed. | ||
- seeds (:obj:`List[int]`): List of seeds for the environment. | ||
- num_episodes_each_seed (:obj:`int`): Number of episodes to run for each seed. | ||
- total_test_episodes (:obj:`int`): Total number of test episodes, computed as the product of the number of | ||
seeds and the number of episodes per seed. | ||
""" | ||
# model_path = './ckpt/ckpt_best.pth.tar' | ||
model_path = None | ||
seeds = [0] | ||
num_episodes_each_seed = 1 | ||
# If True, you can play with the agent. | ||
# main_config.env.agent_vs_human = True | ||
main_config.env.agent_vs_human = False | ||
# main_config.env.render_mode = 'image_realtime_mode' | ||
main_config.env.render_mode = 'image_savefile_mode' | ||
main_config.env.replay_path = './video' | ||
|
||
main_config.env.prob_random_action_in_bot = 0. | ||
main_config.env.bot_action_type = 'rule' | ||
create_config.env_manager.type = 'base' | ||
main_config.env.evaluator_env_num = 1 | ||
main_config.env.n_evaluator_episode = 1 | ||
total_test_episodes = num_episodes_each_seed * len(seeds) | ||
returns_mean_seeds = [] | ||
returns_seeds = [] | ||
for seed in seeds: | ||
returns_mean, returns = eval_muzero( | ||
[main_config, create_config], | ||
seed=seed, | ||
num_episodes_each_seed=num_episodes_each_seed, | ||
print_seed_details=True, | ||
model_path=model_path | ||
) | ||
returns_mean_seeds.append(returns_mean) | ||
returns_seeds.append(returns) | ||
|
||
returns_mean_seeds = np.array(returns_mean_seeds) | ||
returns_seeds = np.array(returns_seeds) | ||
|
||
print("=" * 20) | ||
print(f"We evaluated a total of {len(seeds)} seeds. For each seed, we evaluated {num_episodes_each_seed} episode(s).") | ||
print(f"For seeds {seeds}, the mean returns are {returns_mean_seeds}, and the returns are {returns_seeds}.") | ||
print("Across all seeds, the mean reward is:", returns_mean_seeds.mean()) | ||
print( | ||
f'win rate: {len(np.where(returns_seeds == 1.)[0]) / total_test_episodes}, draw rate: {len(np.where(returns_seeds == 0.)[0]) / total_test_episodes}, lose rate: {len(np.where(returns_seeds == -1.)[0]) / total_test_episodes}' | ||
) | ||
print("=" * 20) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.