Skip to content

Commit

Permalink
polish(pu): use larger num_sim in eval phase
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Nov 8, 2023
1 parent d84fdda commit 4e08396
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
3 changes: 2 additions & 1 deletion lzero/policy/alphazero.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,8 @@ def _init_eval(self) -> None:
self._get_simulation_env()
import copy
mcts_eval_config = copy.deepcopy(self._cfg.mcts)
mcts_eval_config.num_simulations = mcts_eval_config.num_simulations * 2
# The number of simulations for evaluation should be larger than that for collecting data.
mcts_eval_config.num_simulations = min(mcts_eval_config.num_simulations * 4, 800)
self._eval_mcts = MCTS(mcts_eval_config, self.simulate_env)
self._eval_model = self._model

Expand Down
7 changes: 4 additions & 3 deletions lzero/policy/sampled_alphazero.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def _init_eval(self) -> None:
self._get_simulation_env()
# TODO(pu): use double num_simulations for evaluation
if self._cfg.mcts_ctree:
self._eval_mcts = mcts_alphazero.MCTS(self._cfg.mcts.max_moves, 2 * self._cfg.mcts.num_simulations,
self._eval_mcts = mcts_alphazero.MCTS(self._cfg.mcts.max_moves, min(4 * self._cfg.mcts.num_simulations, 800),
self._cfg.mcts.pb_c_base,
self._cfg.mcts.pb_c_init, self._cfg.mcts.root_dirichlet_alpha,
self._cfg.mcts.root_noise_weight, self.simulate_env)
Expand All @@ -409,7 +409,8 @@ def _init_eval(self) -> None:
else:
from lzero.mcts.ptree.ptree_az import MCTS
mcts_eval_config = copy.deepcopy(self._cfg.mcts)
mcts_eval_config.num_simulations = mcts_eval_config.num_simulations * 2
# The number of simulations for evaluation should be larger than that for collecting data.
mcts_eval_config.num_simulations = min(mcts_eval_config.num_simulations * 4, 800)

self._eval_mcts = MCTS(mcts_eval_config, self.simulate_env)

Expand Down Expand Up @@ -483,7 +484,7 @@ def _get_simulation_env(self):
self.simulate_env = TicTacToeEnv(tictactoe_alphazero_config.env)

elif self._cfg.simulate_env_name == 'gomoku':
from zoo.board_games.gomoku.envs.gomoku_env_ui import GomokuEnv
from zoo.board_games.gomoku.envs.gomoku_env import GomokuEnv
if self._cfg.simulate_env_config_type == 'play_with_bot':
from zoo.board_games.gomoku.config.gomoku_alphazero_bot_mode_config import gomoku_alphazero_config
elif self._cfg.simulate_env_config_type == 'self_play':
Expand Down

0 comments on commit 4e08396

Please sign in to comment.