Skip to content

Commit

Permalink
polish(pu): rename mcts_mode to battle_mode_in_simulation_env, add sa…
Browse files Browse the repository at this point in the history
…mpled alphazero config for tictactoe (#179)

* feature(pu): add sampled alphazero config for tictactoe

* polish(pu): rename mcts_mode to battle_mode_in_simulation_env
  • Loading branch information
puyuan1996 authored Jan 4, 2024
1 parent 7953c54 commit 3d338ae
Show file tree
Hide file tree
Showing 16 changed files with 299 additions and 74 deletions.
21 changes: 12 additions & 9 deletions lzero/mcts/buffer/game_buffer_gumbel_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,28 +45,30 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
pos_in_game_segment = pos_in_game_segment_list[i]

actions_tmp = game.action_segment[pos_in_game_segment:pos_in_game_segment +
self._cfg.num_unroll_steps].tolist()

_improved_policy = game.improved_policy_probs[pos_in_game_segment:pos_in_game_segment + self._cfg.num_unroll_steps]
self._cfg.num_unroll_steps].tolist()

_improved_policy = game.improved_policy_probs[
pos_in_game_segment:pos_in_game_segment + self._cfg.num_unroll_steps]
if not isinstance(_improved_policy, list):
_improved_policy = _improved_policy.tolist()

# add mask for invalid actions (out of trajectory)
mask_tmp = [1. for i in range(len(actions_tmp))]
mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps+1 - len(mask_tmp))]
mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))]

# pad random action
actions_tmp += [
np.random.randint(0, game.action_space_size)
for _ in range(self._cfg.num_unroll_steps - len(actions_tmp))
]

# pad improved policy with with a value such that the sum of the values is equal to 1
_improved_policy.extend(np.random.dirichlet(np.ones(game.action_space_size),size=self._cfg.num_unroll_steps + 1 - len(_improved_policy)))
# pad improved policy with a value such that the sum of the values is equal to 1
_improved_policy.extend(np.random.dirichlet(np.ones(game.action_space_size),
size=self._cfg.num_unroll_steps + 1 - len(_improved_policy)))

# obtain the input observations
# pad if length of obs in game_segment is less than stack+num_unroll_steps
# e.g. stack+num_unroll_steps 4+5
# e.g. stack+num_unroll_steps = 4+5
obs_list.append(
game_segment_list[i].get_unroll_obs(
pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True
Expand All @@ -80,7 +82,8 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
obs_list = prepare_observation(obs_list, self._cfg.model.model_type)

# formalize the inputs of a batch
current_batch = [obs_list, action_list, improved_policy_list, mask_list, batch_index_list, weights_list, make_time_list]
current_batch = [obs_list, action_list, improved_policy_list, mask_list, batch_index_list, weights_list,
make_time_list]
for i in range(len(current_batch)):
current_batch[i] = np.asarray(current_batch[i])

Expand Down Expand Up @@ -117,4 +120,4 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
policy_non_re_context = None

context = reward_value_context, policy_re_context, policy_non_re_context, current_batch
return context
return context
14 changes: 7 additions & 7 deletions lzero/mcts/ctree/ctree_alphazero/mcts_alphazero.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ class MCTS {
state_config_for_env_reset["katago_policy_init"].cast<bool>(),
katago_game_state
);
simulate_env.attr("battle_mode") = simulate_env.attr("mcts_mode");
simulate_env.attr("battle_mode") = simulate_env.attr("battle_mode_in_simulation_env");
_simulate(root, simulate_env, policy_value_func);
}

Expand Down Expand Up @@ -228,14 +228,14 @@ class MCTS {
leaf_value = _expand_leaf_node(node, simulate_env, policy_value_func);
}
else {
if (simulate_env.attr("mcts_mode").cast<std::string>() == "self_play_mode") {
if (simulate_env.attr("battle_mode_in_simulation_env").cast<std::string>() == "self_play_mode") {
if (winner == -1) {
leaf_value = 0;
} else {
leaf_value = (simulate_env.attr("current_player").cast<int>() == winner) ? 1 : -1;
}
}
else if (simulate_env.attr("mcts_mode").cast<std::string>() == "play_with_bot_mode") {
else if (simulate_env.attr("battle_mode_in_simulation_env").cast<std::string>() == "play_with_bot_mode") {
if (winner == -1) {
leaf_value = 0;
} else if (winner == 1) {
Expand All @@ -245,11 +245,11 @@ class MCTS {
}
}
}
if (simulate_env.attr("mcts_mode").cast<std::string>() == "play_with_bot_mode") {
node->update_recursive(leaf_value, simulate_env.attr("mcts_mode").cast<std::string>());
if (simulate_env.attr("battle_mode_in_simulation_env").cast<std::string>() == "play_with_bot_mode") {
node->update_recursive(leaf_value, simulate_env.attr("battle_mode_in_simulation_env").cast<std::string>());
}
else if (simulate_env.attr("mcts_mode").cast<std::string>() == "self_play_mode") {
node->update_recursive(-leaf_value, simulate_env.attr("mcts_mode").cast<std::string>());
else if (simulate_env.attr("battle_mode_in_simulation_env").cast<std::string>() == "self_play_mode") {
node->update_recursive(-leaf_value, simulate_env.attr("battle_mode_in_simulation_env").cast<std::string>());
}
}

Expand Down
10 changes: 5 additions & 5 deletions lzero/mcts/ctree/ctree_alphazero/node_alphazero.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,19 @@ class Node {
}

// Recursively updates the value and visit count of the node and its parent nodes
void update_recursive(float leaf_value, std::string mcts_mode) {
void update_recursive(float leaf_value, std::string battle_mode_in_simulation_env) {
// If the mode is "self_play_mode", the leaf_value is subtracted from the parent's value
if (mcts_mode == "self_play_mode") {
if (battle_mode_in_simulation_env == "self_play_mode") {
update(leaf_value);
if (!is_root()) {
parent->update_recursive(-leaf_value, mcts_mode);
parent->update_recursive(-leaf_value, battle_mode_in_simulation_env);
}
}
// If the mode is "play_with_bot_mode", the leaf_value is added to the parent's value
else if (mcts_mode == "play_with_bot_mode") {
else if (battle_mode_in_simulation_env == "play_with_bot_mode") {
update(leaf_value);
if (!is_root()) {
parent->update_recursive(leaf_value, mcts_mode);
parent->update_recursive(leaf_value, battle_mode_in_simulation_env);
}
}
}
Expand Down
28 changes: 14 additions & 14 deletions lzero/mcts/ptree/ptree_az.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def update(self, value: float) -> None:
# Updates the sum of the values of all child nodes of this node.
self._value_sum += value

def update_recursive(self, leaf_value: float, mcts_mode: str) -> None:
def update_recursive(self, leaf_value: float, battle_mode_in_simulation_env: str) -> None:
"""
Overview:
Update node information recursively.
Expand All @@ -86,19 +86,19 @@ def update_recursive(self, leaf_value: float, mcts_mode: str) -> None:
Arguments:
- leaf_value (:obj:`Float`): The value of the node.
- mcts_mode (:obj:`str`): The mode of MCTS, can be 'self_play_mode' or 'play_with_bot_mode'.
- battle_mode_in_simulation_env (:obj:`str`): The mode of MCTS, can be 'self_play_mode' or 'play_with_bot_mode'.
"""
# Update the node information recursively based on the MCTS mode.
if mcts_mode == 'self_play_mode':
if battle_mode_in_simulation_env == 'self_play_mode':
# Update the current node's information.
self.update(leaf_value)
# If the current node is the root node, return.
if self.is_root():
return
# Update the parent node's information recursively. When propagating the value back to the parent node,
# the value needs to be negated once because the perspective of evaluation has changed.
self._parent.update_recursive(-leaf_value, mcts_mode)
if mcts_mode == 'play_with_bot_mode':
self._parent.update_recursive(-leaf_value, battle_mode_in_simulation_env)
if battle_mode_in_simulation_env == 'play_with_bot_mode':
# Update the current node's information.
self.update(leaf_value)
# If the current node is the root node, return.
Expand All @@ -107,7 +107,7 @@ def update_recursive(self, leaf_value: float, mcts_mode: str) -> None:
# Update the parent node's information recursively. In ``play_with_bot_mode``, since the nodes' values
# are always evaluated from the perspective of the agent player, there is no need to negate the value
# during value propagation.
self._parent.update_recursive(leaf_value, mcts_mode)
self._parent.update_recursive(leaf_value, battle_mode_in_simulation_env)

def is_leaf(self) -> bool:
"""
Expand Down Expand Up @@ -242,7 +242,7 @@ def get_next_action(
# In ``play_with_bot_mode``, when the step function is called, it will play one move based on the incoming action,
# and then it will play another move based on the action generated by the built-in bot in the environment, which means two moves in total.
# Therefore, in the MCTS process, except for the terminal nodes, the player corresponding to each node is the same player as the root node.
self.simulate_env.battle_mode = self.simulate_env.mcts_mode
self.simulate_env.battle_mode = self.simulate_env.battle_mode_in_simulation_env
self.simulate_env.render_mode = None
# Run the simulation from the root to a leaf node and update the node values along the way.
self._simulate(root, self.simulate_env, policy_forward_fn)
Expand All @@ -262,7 +262,7 @@ def get_next_action(
# When the visit count of a node is 0, then the corresponding action probability will be 0 in order to prevent the selection of illegal actions.
visits_t = torch.as_tensor(visits, dtype=torch.float32)
visits_t = torch.pow(visits_t, 1/temperature)
action_probs= (visits_t / visits_t.sum()).numpy()
action_probs = (visits_t / visits_t.sum()).numpy()

# action_probs = nn.functional.softmax(1.0 / temperature * np.log(torch.as_tensor(visits) + 1e-10), dim=0).numpy()

Expand Down Expand Up @@ -306,7 +306,7 @@ def _simulate(self, node: Node, simulate_env: Type[BaseEnv], policy_forward_fn:
# game state from the perspective of player 1.
leaf_value = self._expand_leaf_node(node, simulate_env, policy_forward_fn)
else:
if simulate_env.mcts_mode == 'self_play_mode':
if simulate_env.battle_mode_in_simulation_env == 'self_play_mode':
# In a tie game, the value corresponding to a terminal node is 0.
if winner == -1:
leaf_value = 0
Expand All @@ -316,7 +316,7 @@ def _simulate(self, node: Node, simulate_env: Type[BaseEnv], policy_forward_fn:
# which is convenient for subsequent updates.
leaf_value = 1 if simulate_env.current_player == winner else -1

if simulate_env.mcts_mode == 'play_with_bot_mode':
if simulate_env.battle_mode_in_simulation_env == 'play_with_bot_mode':
# in ``play_with_bot_mode``, the leaf_value should be transformed to the perspective of player 1.
if winner == -1:
leaf_value = 0
Expand All @@ -326,9 +326,9 @@ def _simulate(self, node: Node, simulate_env: Type[BaseEnv], policy_forward_fn:
leaf_value = -1

# Update value and visit count of nodes in this traversal.
if simulate_env.mcts_mode == 'play_with_bot_mode':
node.update_recursive(leaf_value, simulate_env.mcts_mode)
elif simulate_env.mcts_mode == 'self_play_mode':
if simulate_env.battle_mode_in_simulation_env == 'play_with_bot_mode':
node.update_recursive(leaf_value, simulate_env.battle_mode_in_simulation_env)
elif simulate_env.battle_mode_in_simulation_env == 'self_play_mode':
# NOTE: e.g.
# to_play: 1 ----------> 2 ----------> 1 ----------> 2
# state: s1 ----------> s2 ----------> s3 ----------> s4
Expand All @@ -337,7 +337,7 @@ def _simulate(self, node: Node, simulate_env: Type[BaseEnv], policy_forward_fn:
# leaf_value is calculated from the perspective of player 1, leaf_value = value_func(s3),
# but node.value should be the value of E[q(s2, action)], i.e. calculated from the perspective of player 2.
# thus we add the negative when call update_recursive().
node.update_recursive(-leaf_value, simulate_env.mcts_mode)
node.update_recursive(-leaf_value, simulate_env.battle_mode_in_simulation_env)

def _select_child(self, node: Node, simulate_env: Type[BaseEnv]) -> Tuple[Union[int, float], Node]:
"""
Expand Down
26 changes: 13 additions & 13 deletions lzero/mcts/ptree/ptree_az_sampled.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def update(self, value: float) -> None:
# Updates the sum of the values of all child nodes of this node.
self._value_sum += value

def update_recursive(self, leaf_value: float, mcts_mode: str) -> None:
def update_recursive(self, leaf_value: float, battle_mode_in_simulation_env: str) -> None:
"""
Overview:
Update node information recursively.
Expand All @@ -87,19 +87,19 @@ def update_recursive(self, leaf_value: float, mcts_mode: str) -> None:
Arguments:
- leaf_value (:obj:`Float`): The value of the node.
- mcts_mode (:obj:`str`): The mode of MCTS, can be 'self_play_mode' or 'play_with_bot_mode'.
- battle_mode_in_simulation_env (:obj:`str`): The mode of MCTS, can be 'self_play_mode' or 'play_with_bot_mode'.
"""
# Update the node information recursively based on the MCTS mode.
if mcts_mode == 'self_play_mode':
if battle_mode_in_simulation_env == 'self_play_mode':
# Update the current node's information.
self.update(leaf_value)
# If the current node is the root node, return.
if self.is_root():
return
# Update the parent node's information recursively. When propagating the value back to the parent node,
# the value needs to be negated once because the perspective of evaluation has changed.
self._parent.update_recursive(-leaf_value, mcts_mode)
if mcts_mode == 'play_with_bot_mode':
self._parent.update_recursive(-leaf_value, battle_mode_in_simulation_env)
if battle_mode_in_simulation_env == 'play_with_bot_mode':
# Update the current node's information.
self.update(leaf_value)
# If the current node is the root node, return.
Expand All @@ -108,7 +108,7 @@ def update_recursive(self, leaf_value: float, mcts_mode: str) -> None:
# Update the parent node's information recursively. In ``play_with_bot_mode``, since the nodes' values
# are always evaluated from the perspective of the agent player, there is no need to negate the value
# during value propagation.
self._parent.update_recursive(leaf_value, mcts_mode)
self._parent.update_recursive(leaf_value, battle_mode_in_simulation_env)

def is_leaf(self) -> bool:
"""
Expand Down Expand Up @@ -242,7 +242,7 @@ def get_next_action(
start_player_index=state_config_for_env_reset.start_player_index,
init_state=state_config_for_env_reset.init_state,
)
self.simulate_env.battle_mode = self.simulate_env.mcts_mode
self.simulate_env.battle_mode = self.simulate_env.battle_mode_in_simulation_env
self._simulate(self.root, self.simulate_env, policy_value_func)

# sampled related code
Expand Down Expand Up @@ -321,7 +321,7 @@ def _simulate(self, node: Node, simulate_env: Type[BaseEnv], policy_value_func:
leaf_value = self._expand_leaf_node(node, simulate_env, policy_value_func)

else:
if simulate_env.mcts_mode == 'self_play_mode':
if simulate_env.battle_mode_in_simulation_env == 'self_play_mode':
# In a tie game, the value corresponding to a terminal node is 0.
if winner == -1:
leaf_value = 0
Expand All @@ -331,7 +331,7 @@ def _simulate(self, node: Node, simulate_env: Type[BaseEnv], policy_value_func:
# which is convenient for subsequent updates.
leaf_value = 1 if simulate_env.current_player == winner else -1

if simulate_env.mcts_mode == 'play_with_bot_mode':
if simulate_env.battle_mode_in_simulation_env == 'play_with_bot_mode':
# in ``play_with_bot_mode``, the leaf_value should be transformed to the perspective of player 1.
if winner == -1:
leaf_value = 0
Expand All @@ -341,9 +341,9 @@ def _simulate(self, node: Node, simulate_env: Type[BaseEnv], policy_value_func:
leaf_value = -1

# Update value and visit count of nodes in this traversal.
if simulate_env.mcts_mode == 'play_with_bot_mode':
node.update_recursive(leaf_value, simulate_env.mcts_mode)
elif simulate_env.mcts_mode == 'self_play_mode':
if simulate_env.battle_mode_in_simulation_env == 'play_with_bot_mode':
node.update_recursive(leaf_value, simulate_env.battle_mode_in_simulation_env)
elif simulate_env.battle_mode_in_simulation_env == 'self_play_mode':
# NOTE: e.g.
# to_play: 1 ----------> 2 ----------> 1 ----------> 2
# state: s1 ----------> s2 ----------> s3 ----------> s4
Expand All @@ -352,7 +352,7 @@ def _simulate(self, node: Node, simulate_env: Type[BaseEnv], policy_value_func:
# leaf_value is calculated from the perspective of player 1, leaf_value = value_func(s3),
# but node.value should be the value of E[q(s2, action)], i.e. calculated from the perspective of player 2.
# thus we add the negative when call update_recursive().
node.update_recursive(-leaf_value, simulate_env.mcts_mode)
node.update_recursive(-leaf_value, simulate_env.battle_mode_in_simulation_env)

def _select_child(self, node: Node, simulate_env: Type[BaseEnv]) -> Tuple[Union[int, float], Node]:
"""
Expand Down
4 changes: 2 additions & 2 deletions lzero/policy/alphazero.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def _init_collect(self) -> None:
self._collect_model = self._model
if self._cfg.mcts_ctree:
import sys
sys.path.append('./LightZero/lzero/mcts/ctree/ctree_alphazero/build')
sys.path.append('/Users/your_user_name/code/LightZero/lzero/mcts/ctree/ctree_alphazero/build')
import mcts_alphazero
self._collect_mcts = mcts_alphazero.MCTS(self._cfg.mcts.max_moves, self._cfg.mcts.num_simulations,
self._cfg.mcts.pb_c_base,
Expand Down Expand Up @@ -279,7 +279,7 @@ def _init_eval(self) -> None:
self._get_simulation_env()
if self._cfg.mcts_ctree:
import sys
sys.path.append('./LightZero/lzero/mcts/ctree/ctree_alphazero/build')
sys.path.append('/Users/your_user_name/code/LightZero/lzero/mcts/ctree/ctree_alphazero/build')
import mcts_alphazero
# TODO(pu): how to set proper num_simulations for evaluation
self._eval_mcts = mcts_alphazero.MCTS(self._cfg.mcts.max_moves,
Expand Down
Loading

0 comments on commit 3d338ae

Please sign in to comment.