Skip to content

Commit

Permalink
fix(pu): fix alphazero ctree import
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Nov 13, 2023
1 parent 30bc99b commit c45af81
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 43 deletions.
11 changes: 3 additions & 8 deletions lzero/mcts/ctree/ctree_alphazero/mcts_alphazero.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ class MCTS {
}
}

// 转换action_visits为两个分离的数组
// Convert 'action_visits' into two separate arrays.
std::vector<int> actions;
std::vector<int> visits;
for (const auto& av : action_visits) {
Expand All @@ -208,7 +208,6 @@ class MCTS {

// This function performs a simulation from a given node until a leaf node is reached or a terminal state is reached.
void _simulate(Node* node, py::object simulate_env, py::object policy_value_func) {
// std::cout << "position21 " << std::endl;
while (!node->is_leaf()) {
int action;
std::tie(action, node) = _select_child(node, simulate_env);
Expand All @@ -227,20 +226,16 @@ class MCTS {
double leaf_value;
if (!done) {
leaf_value = _expand_leaf_node(node, simulate_env, policy_value_func);
// std::cout << "position23 " << std::endl;
}
else {
// if (simulate_env.attr("mcts_mode") == "self_play_mode") {
if (simulate_env.attr("mcts_mode").cast<std::string>() == "self_play_mode") { // 使用get_mcts_mode()方法替代mcts_mode成员
// std::cout << "position24 " << std::endl;

if (simulate_env.attr("mcts_mode").cast<std::string>() == "self_play_mode") {
if (winner == -1) {
leaf_value = 0;
} else {
leaf_value = (simulate_env.attr("current_player").cast<int>() == winner) ? 1 : -1;
}
}
else if (simulate_env.attr("mcts_mode").cast<std::string>() == "play_with_bot_mode") { // 使用get_mcts_mode()方法替代mcts_mode成员
else if (simulate_env.attr("mcts_mode").cast<std::string>() == "play_with_bot_mode") {
if (winner == -1) {
leaf_value = 0;
} else if (winner == 1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch

sys.path.append('/Users/puyuan/code/LightZero/lzero/mcts/ctree/ctree_alphazero/build')
sys.path.append('./LightZero/lzero/mcts/ctree/ctree_alphazero/build')

import mcts_alphazero
mcts_alphazero = mcts_alphazero.MCTS()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sys
sys.path.append('/Users/puyuan/code/LightZero/lzero/mcts/ctree/ctree_alphazero/build')
sys.path.append('./LightZero/lzero/mcts/ctree/ctree_alphazero/build')

import mcts_alphazero
n = mcts_alphazero.Node()
Expand Down
52 changes: 41 additions & 11 deletions lzero/policy/alphazero.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from collections import namedtuple
from typing import List, Dict, Tuple

Expand All @@ -11,7 +12,6 @@
from ding.utils.data import default_collate
from easydict import EasyDict

from lzero.mcts.ptree.ptree_az import MCTS
from lzero.policy import configure_optimizers


Expand Down Expand Up @@ -213,8 +213,22 @@ def _init_collect(self) -> None:
"""
self._get_simulation_env()
self._collect_model = self._model
self._collect_mcts_temperature = 1
self._collect_mcts = MCTS(self._cfg.mcts, self.simulate_env)
if self._cfg.mcts_ctree:
import sys
sys.path.append('./LightZero/lzero/mcts/ctree/ctree_alphazero/build')
import mcts_alphazero
self._collect_mcts = mcts_alphazero.MCTS(self._cfg.mcts.max_moves, self._cfg.mcts.num_simulations,
self._cfg.mcts.pb_c_base,
self._cfg.mcts.pb_c_init, self._cfg.mcts.root_dirichlet_alpha,
self._cfg.mcts.root_noise_weight, self.simulate_env)
else:
if self._cfg.sampled_algo:
from lzero.mcts.ptree.ptree_az_sampled import MCTS
else:
from lzero.mcts.ptree.ptree_az import MCTS
self._collect_mcts = MCTS(self._cfg.mcts, self.simulate_env)

self.collect_mcts_temperature = 1

@torch.no_grad()
def _forward_collect(self, obs: Dict, temperature: float = 1) -> Dict[str, torch.Tensor]:
Expand All @@ -237,7 +251,7 @@ def _forward_collect(self, obs: Dict, temperature: float = 1) -> Dict[str, torch
self._policy_model = self._collect_model
for env_id in ready_env_id:
state_config_for_simulation_env_reset = EasyDict(dict(start_player_index=start_player_index[env_id],
init_state=init_state[env_id], ))
init_state=init_state[env_id], ))
action, mcts_probs = self._collect_mcts.get_next_action(
state_config_for_simulation_env_reset,
policy_forward_fn=self._policy_value_fn,
Expand All @@ -256,11 +270,26 @@ def _init_eval(self) -> None:
Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils.
"""
self._get_simulation_env()
import copy
mcts_eval_config = copy.deepcopy(self._cfg.mcts)
# The number of simulations for evaluation should be larger than that for collecting data.
mcts_eval_config.num_simulations = min(mcts_eval_config.num_simulations * 4, 800)
self._eval_mcts = MCTS(mcts_eval_config, self.simulate_env)
if self._cfg.mcts_ctree:
import sys
sys.path.append('./LightZero/lzero/mcts/ctree/ctree_alphazero/build')
import mcts_alphazero
# TODO(pu): how to set proper num_simulations for evaluation
self._eval_mcts = mcts_alphazero.MCTS(self._cfg.mcts.max_moves,
min(800, self._cfg.mcts.num_simulations * 4),
self._cfg.mcts.pb_c_base,
self._cfg.mcts.pb_c_init, self._cfg.mcts.root_dirichlet_alpha,
self._cfg.mcts.root_noise_weight, self.simulate_env)
else:
if self._cfg.sampled_algo:
from lzero.mcts.ptree.ptree_az_sampled import MCTS
else:
from lzero.mcts.ptree.ptree_az import MCTS
mcts_eval_config = copy.deepcopy(self._cfg.mcts)
# TODO(pu): how to set proper num_simulations for evaluation
mcts_eval_config.num_simulations = min(800, mcts_eval_config.num_simulations * 4)
self._eval_mcts = MCTS(mcts_eval_config, self.simulate_env)

self._eval_model = self._model

def _forward_eval(self, obs: Dict) -> Dict[str, torch.Tensor]:
Expand All @@ -281,9 +310,10 @@ def _forward_eval(self, obs: Dict) -> Dict[str, torch.Tensor]:
self._policy_model = self._eval_model
for env_id in ready_env_id:
state_config_for_simulation_env_reset = EasyDict(dict(start_player_index=start_player_index[env_id],
init_state=init_state[env_id],))
init_state=init_state[env_id], ))
action, mcts_probs = self._eval_mcts.get_next_action(
state_config_for_simulation_env_reset, policy_forward_fn=self._policy_value_fn, temperature=1.0, sample=False
state_config_for_simulation_env_reset, policy_forward_fn=self._policy_value_fn, temperature=1.0,
sample=False
)
output[env_id] = {
'action': action,
Expand Down
41 changes: 19 additions & 22 deletions lzero/policy/sampled_alphazero.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,6 @@ def _forward_learn(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, float]:
del input_dict['next_obs']['katago_game_state']

# list of dict -> dict of list
# inputs_deepcopy = copy.deepcopy(inputs)
# only for env with variable legal actions
inputs = pad_and_get_lengths(inputs, self._cfg.mcts.num_of_sampled_actions)
inputs = default_collate(inputs)
Expand Down Expand Up @@ -208,14 +207,8 @@ def _forward_learn(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, float]:
# ==============================================================
# policy loss
# ==============================================================
# mcts_visit_count_probs = mcts_visit_count_probs / (mcts_visit_count_probs.sum(dim=1, keepdim=True) + 1e-6)
# policy_loss = torch.nn.functional.kl_div(
# policy_log_probs, mcts_visit_count_probs, reduction='batchmean'
# )
# orig implementation
# policy_loss = -torch.mean(torch.sum(mcts_visit_count_probs * policy_log_probs, 1))

policy_loss = self._calculate_policy_loss_disc(policy_probs, mcts_visit_count_probs, root_sampled_actions, valid_action_length)
policy_loss = self._calculate_policy_loss_disc(policy_probs, mcts_visit_count_probs, root_sampled_actions,
valid_action_length)

# ==============================================================
# value loss
Expand Down Expand Up @@ -286,8 +279,6 @@ def _calculate_policy_loss_disc(
sampled_policy_probs.log(), sampled_target_policy, reduction='none'
)

# TODO(pu)
# 使用 nan_to_num 将 loss 中的 nan 值设置为0
loss = torch.nan_to_num(loss)

# Apply the mask to the loss
Expand Down Expand Up @@ -315,7 +306,6 @@ def _calculate_policy_loss_disc(
else:
raise ValueError(f"Invalid policy_loss_type: {self._cfg.policy_loss_type}")


return loss

def _init_collect(self) -> None:
Expand All @@ -327,6 +317,9 @@ def _init_collect(self) -> None:

self._collect_model = self._model
if self._cfg.mcts_ctree:
import sys
sys.path.append('./LightZero/lzero/mcts/ctree/ctree_alphazero/build')
import mcts_alphazero
self._collect_mcts = mcts_alphazero.MCTS(self._cfg.mcts.max_moves, self._cfg.mcts.num_simulations,
self._cfg.mcts.pb_c_base,
self._cfg.mcts.pb_c_init, self._cfg.mcts.root_dirichlet_alpha,
Expand Down Expand Up @@ -368,7 +361,6 @@ def _forward_collect(self, obs: Dict, temperature: float = 1) -> Dict[str, torch
for env_id in ready_env_id:
# print('[collect] start_player_index={}'.format(start_player_index[env_id]))
# print('[collect] init_state=\n{}'.format(init_state[env_id]))

state_config_for_env_reset = EasyDict(dict(start_player_index=start_player_index[env_id],
init_state=init_state[env_id],
katago_policy_init=True,
Expand Down Expand Up @@ -397,9 +389,13 @@ def _init_eval(self) -> None:
Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils.
"""
self._get_simulation_env()
# TODO(pu): use double num_simulations for evaluation
if self._cfg.mcts_ctree:
self._eval_mcts = mcts_alphazero.MCTS(self._cfg.mcts.max_moves, min(800, mcts_eval_config.num_simulations * 4),
import sys
sys.path.append('./LightZero/lzero/mcts/ctree/ctree_alphazero/build')
import mcts_alphazero
# TODO(pu): how to set proper num_simulations for evaluation
self._eval_mcts = mcts_alphazero.MCTS(self._cfg.mcts.max_moves,
min(800, self._cfg.mcts.num_simulations * 4),
self._cfg.mcts.pb_c_base,
self._cfg.mcts.pb_c_init, self._cfg.mcts.root_dirichlet_alpha,
self._cfg.mcts.root_noise_weight, self.simulate_env)
Expand All @@ -410,9 +406,7 @@ def _init_eval(self) -> None:
from lzero.mcts.ptree.ptree_az import MCTS
mcts_eval_config = copy.deepcopy(self._cfg.mcts)
# TODO(pu): how to set proper num_simulations for evaluation
# mcts_eval_config.num_simulations = mcts_eval_config.num_simulations
mcts_eval_config.num_simulations = min(800, mcts_eval_config.num_simulations * 4)

self._eval_mcts = MCTS(mcts_eval_config, self.simulate_env)

self._eval_model = self._model
Expand Down Expand Up @@ -449,8 +443,9 @@ def _forward_eval(self, obs: Dict) -> Dict[str, torch.Tensor]:
katago_game_state=katago_game_state[env_id]))

# try:
action, mcts_visit_count_probs = self._eval_mcts.get_next_action(state_config_for_env_reset, self._policy_value_func,
1.0, False)
action, mcts_visit_count_probs = self._eval_mcts.get_next_action(state_config_for_env_reset,
self._policy_value_func,
1.0, False)
# except Exception as e:
# print(f"Exception occurred: {e}")
# print(f"Is self._policy_value_func callable? {callable(self._policy_value_func)}")
Expand All @@ -466,7 +461,8 @@ def _forward_eval(self, obs: Dict) -> Dict[str, torch.Tensor]:

def _get_simulation_env(self):
assert self._cfg.simulation_env_name in ['tictactoe', 'gomoku', 'go'], self._cfg.simulation_env_name
assert self._cfg.simulation_env_config_type in ['play_with_bot', 'self_play', 'league', 'sampled_play_with_bot'], self._cfg.simulation_env_config_type
assert self._cfg.simulation_env_config_type in ['play_with_bot', 'self_play', 'league',
'sampled_play_with_bot'], self._cfg.simulation_env_config_type
if self._cfg.simulation_env_name == 'tictactoe':
from zoo.board_games.tictactoe.envs.tictactoe_env import TicTacToeEnv
if self._cfg.simulation_env_config_type == 'play_with_bot':
Expand All @@ -493,7 +489,8 @@ def _get_simulation_env(self):
elif self._cfg.simulation_env_config_type == 'league':
from zoo.board_games.gomoku.config.gomoku_alphazero_league_config import gomoku_alphazero_config
elif self._cfg.simulation_env_config_type == 'sampled_play_with_bot':
from zoo.board_games.gomoku.config.gomoku_sampled_alphazero_bot_mode_config import gomoku_sampled_alphazero_config as gomoku_alphazero_config
from zoo.board_games.gomoku.config.gomoku_sampled_alphazero_bot_mode_config import \
gomoku_sampled_alphazero_config as gomoku_alphazero_config

self.simulate_env = GomokuEnv(gomoku_alphazero_config.env)
elif self._cfg.simulation_env_name == 'go':
Expand Down Expand Up @@ -568,4 +565,4 @@ def _process_transition(self, obs: Dict, model_output: Dict[str, torch.Tensor],

def _get_train_sample(self, data):
# be compatible with DI-engine Policy class
pass
pass

0 comments on commit c45af81

Please sign in to comment.