From b86aca39b3a2a0edb78833829c33c9632d7a2937 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <48008469+puyuan1996@users.noreply.github.com> Date: Mon, 6 Nov 2023 20:30:21 +0800 Subject: [PATCH 1/3] fix(pu): fix assert bug in game_segment.py (#138) --- lzero/mcts/buffer/game_segment.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lzero/mcts/buffer/game_segment.py b/lzero/mcts/buffer/game_segment.py index ecbd8f1e1..ae9260921 100644 --- a/lzero/mcts/buffer/game_segment.py +++ b/lzero/mcts/buffer/game_segment.py @@ -170,13 +170,13 @@ def pad_over( """ assert len(next_segment_observations) <= self.num_unroll_steps assert len(next_segment_child_visits) <= self.num_unroll_steps - assert len(next_segment_root_values) <= self.num_unroll_steps + self.num_unroll_steps - assert len(next_segment_rewards) <= self.num_unroll_steps + self.num_unroll_steps - 1 + assert len(next_segment_root_values) <= self.num_unroll_steps + self.td_steps + assert len(next_segment_rewards) <= self.num_unroll_steps + self.td_steps - 1 # ============================================================== # The core difference between GumbelMuZero and MuZero # ============================================================== if self.gumbel_algo: - assert len(next_segment_improved_policy) <= self.num_unroll_steps + self.num_unroll_steps + assert len(next_segment_improved_policy) <= self.num_unroll_steps + self.td_steps # NOTE: next block observation should start from (stacked_observation - 1) in next trajectory for observation in next_segment_observations: From 8285db3274cc174558f38ea4131d9ffb502d2980 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <48008469+puyuan1996@users.noreply.github.com> Date: Thu, 9 Nov 2023 17:29:57 +0800 Subject: [PATCH 2/3] fix(pu): fix gumbel muzero collector bug, fix gumbel typo (#144) --- lzero/mcts/buffer/game_buffer.py | 8 ++++---- lzero/policy/gumbel_muzero.py | 6 +++++- lzero/policy/muzero.py | 2 ++ .../gomoku/config/gomoku_gumbel_muzero_bot_mode_config.py | 4 ---- .../config/tictactoe_gumbel_muzero_bot_mode_config.py | 4 ---- .../config/lunarlander_disc_gumbel_muzero_config.py | 5 ----- .../cartpole/config/cartpole_gumbel_muzero_config.py | 4 ---- .../config/pendulum_cont_disc_gumbel_muzero_config.py | 4 ---- 8 files changed, 11 insertions(+), 26 deletions(-) diff --git a/lzero/mcts/buffer/game_buffer.py b/lzero/mcts/buffer/game_buffer.py index f632d1e7d..e5066bca5 100644 --- a/lzero/mcts/buffer/game_buffer.py +++ b/lzero/mcts/buffer/game_buffer.py @@ -9,14 +9,14 @@ from easydict import EasyDict if TYPE_CHECKING: - from lzero.policy import MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy, GumeblMuZeroPolicy + from lzero.policy import MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy, GumbelMuZeroPolicy @BUFFER_REGISTRY.register('game_buffer') class GameBuffer(ABC, object): """ Overview: - The base game buffer class for MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy, GumeblMuZeroPolicy. + The base game buffer class for MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy, GumbelMuZeroPolicy. """ @classmethod @@ -69,14 +69,14 @@ def __init__(self, cfg: dict): @abstractmethod def sample( - self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy", "GumeblMuZeroPolicy"] + self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy", "GumbelMuZeroPolicy"] ) -> List[Any]: """ Overview: sample data from ``GameBuffer`` and prepare the current and target batch for training. Arguments: - batch_size (:obj:`int`): batch size. - - policy (:obj:`Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy", "GumeblMuZeroPolicy"]`): policy. + - policy (:obj:`Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy", "GumbelMuZeroPolicy"]`): policy. Returns: - train_data (:obj:`List`): List of train data, including current_batch and target_batch. """ diff --git a/lzero/policy/gumbel_muzero.py b/lzero/policy/gumbel_muzero.py index 218ab99b6..0f91a408a 100644 --- a/lzero/policy/gumbel_muzero.py +++ b/lzero/policy/gumbel_muzero.py @@ -20,7 +20,7 @@ @POLICY_REGISTRY.register('gumbel_muzero') -class GumeblMuZeroPolicy(MuZeroPolicy): +class GumbelMuZeroPolicy(MuZeroPolicy): """ Overview: The policy class for Gumbel MuZero proposed in the paper https://openreview.net/forum?id=bERaNdoegnO. @@ -486,6 +486,7 @@ def _forward_collect( action_mask: list = None, temperature: float = 1, to_play: List = [-1], + epsilon: float = 0.25, ready_env_id: np.array = None, ) -> Dict: """ @@ -514,6 +515,7 @@ def _forward_collect( """ self._collect_model.eval() self._collect_mcts_temperature = temperature + self.collect_epsilon = epsilon active_collect_env_num = data.shape[0] with torch.no_grad(): # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} @@ -560,6 +562,7 @@ def _forward_collect( for i, env_id in enumerate(ready_env_id): distributions, value, improved_policy_probs = roots_visit_count_distributions[i], roots_values[i], roots_improved_policy_probs[i] + roots_completed_value = roots_completed_values[i] # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents # the index within the legal action set, rather than the index in the entire action set. @@ -570,6 +573,7 @@ def _forward_collect( # entire action set. valid_value = np.where(action_mask[i] == 1.0, improved_policy_probs, 0.0) action = np.argmax([v for v in valid_value]) + output[env_id] = { 'action': action, 'visit_count_distributions': distributions, diff --git a/lzero/policy/muzero.py b/lzero/policy/muzero.py index b4094e0a0..b1c0cbaf6 100644 --- a/lzero/policy/muzero.py +++ b/lzero/policy/muzero.py @@ -526,6 +526,7 @@ def _forward_collect( - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. - temperature (:obj:`float`): The temperature of the policy. - to_play (:obj:`int`): The player to play. + - epsilon (:obj:`float`): The epsilon of the eps greedy exploration. - ready_env_id (:obj:`list`): The id of the env that is ready to collect. Shape: - data (:obj:`torch.Tensor`): @@ -535,6 +536,7 @@ def _forward_collect( - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. - temperature: :math:`(1, )`. - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - epsilon: :math:`(1, )`. - ready_env_id: None Returns: - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ diff --git a/zoo/board_games/gomoku/config/gomoku_gumbel_muzero_bot_mode_config.py b/zoo/board_games/gomoku/config/gomoku_gumbel_muzero_bot_mode_config.py index cdecb0f50..6f87424cd 100644 --- a/zoo/board_games/gomoku/config/gomoku_gumbel_muzero_bot_mode_config.py +++ b/zoo/board_games/gomoku/config/gomoku_gumbel_muzero_bot_mode_config.py @@ -80,10 +80,6 @@ type='gumbel_muzero', import_names=['lzero.policy.gumbel_muzero'], ), - collector=dict( - type='gumbel_muzero', - import_names=['lzero.worker.gumbel_muzero_collector'], - ) ) gomoku_gumbel_muzero_create_config = EasyDict(gomoku_gumbel_muzero_create_config) create_config = gomoku_gumbel_muzero_create_config diff --git a/zoo/board_games/tictactoe/config/tictactoe_gumbel_muzero_bot_mode_config.py b/zoo/board_games/tictactoe/config/tictactoe_gumbel_muzero_bot_mode_config.py index d2bc056c6..c11c79309 100644 --- a/zoo/board_games/tictactoe/config/tictactoe_gumbel_muzero_bot_mode_config.py +++ b/zoo/board_games/tictactoe/config/tictactoe_gumbel_muzero_bot_mode_config.py @@ -77,10 +77,6 @@ type='gumbel_muzero', import_names=['lzero.policy.gumbel_muzero'], ), - collector=dict( - type='gumbel_muzero', - import_names=['lzero.worker.gumbel_muzero_collector'], - ) ) tictactoe_gumbel_muzero_create_config = EasyDict(tictactoe_gumbel_muzero_create_config) create_config = tictactoe_gumbel_muzero_create_config diff --git a/zoo/box2d/lunarlander/config/lunarlander_disc_gumbel_muzero_config.py b/zoo/box2d/lunarlander/config/lunarlander_disc_gumbel_muzero_config.py index 54767af7c..daeb0738e 100644 --- a/zoo/box2d/lunarlander/config/lunarlander_disc_gumbel_muzero_config.py +++ b/zoo/box2d/lunarlander/config/lunarlander_disc_gumbel_muzero_config.py @@ -71,11 +71,6 @@ type='gumbel_muzero', import_names=['lzero.policy.gumbel_muzero'], ), - collector=dict( - type='gumbel_muzero', - get_train_sample=True, - import_names=['lzero.worker.gumbel_muzero_collector'], - ) ) lunarlander_gumbel_muzero_create_config = EasyDict(lunarlander_gumbel_muzero_create_config) create_config = lunarlander_gumbel_muzero_create_config diff --git a/zoo/classic_control/cartpole/config/cartpole_gumbel_muzero_config.py b/zoo/classic_control/cartpole/config/cartpole_gumbel_muzero_config.py index 5e86a5831..4a84a861b 100644 --- a/zoo/classic_control/cartpole/config/cartpole_gumbel_muzero_config.py +++ b/zoo/classic_control/cartpole/config/cartpole_gumbel_muzero_config.py @@ -70,10 +70,6 @@ type='gumbel_muzero', import_names=['lzero.policy.gumbel_muzero'], ), - collector=dict( - type='gumbel_muzero', - import_names=['lzero.worker.gumbel_muzero_collector'], - ) ) cartpole_gumbel_muzero_create_config = EasyDict(cartpole_gumbel_muzero_create_config) create_config = cartpole_gumbel_muzero_create_config diff --git a/zoo/classic_control/pendulum/config/pendulum_cont_disc_gumbel_muzero_config.py b/zoo/classic_control/pendulum/config/pendulum_cont_disc_gumbel_muzero_config.py index e28cf9fd6..1cc2db463 100644 --- a/zoo/classic_control/pendulum/config/pendulum_cont_disc_gumbel_muzero_config.py +++ b/zoo/classic_control/pendulum/config/pendulum_cont_disc_gumbel_muzero_config.py @@ -73,10 +73,6 @@ type='gumbel_muzero', import_names=['lzero.policy.gumbel_muzero'], ), - collector=dict( - type='gumbel_muzero', - import_names=['lzero.worker.gumbel_muzero_collector'], - ) ) pendulum_disc_gumbel_muzero_create_config = EasyDict(pendulum_disc_gumbel_muzero_create_config) create_config = pendulum_disc_gumbel_muzero_create_config From cee2849b0f600510b32cee8b29f5fdf6836e6de5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <48008469+puyuan1996@users.noreply.github.com> Date: Mon, 13 Nov 2023 20:03:12 +0800 Subject: [PATCH 3/3] feature(pu): add sampled alphazero and polish gomoku env (#141) * feature(pu): add sampled alphazero * feature(pu): polish gomoku render method and add gomoku_human_vs_bot_ui * polish(pu): polish gomoku render method * polish(pu): polish gomoku_rule_bot_v0 * polish(pu): polish gomoku_human_vs_bot_UI * polish(pu): polish gomoku_rule_bot_v0 * polish(pu): use dp to check_five_in_a_row in gomoku_rule_bot_v0 * fix(pu): fix emplace_back * polish(pu): polish comments and eval num_sim --- lzero/entry/eval_alphazero.py | 2 +- .../ctree_sampled_efficientzero/lib/cnode.cpp | 1 + lzero/mcts/ptree/ptree_az_sampled.py | 521 ++++++++++++++++ lzero/mcts/ptree/ptree_sez.py | 73 ++- lzero/model/alphazero_model.py | 107 +++- lzero/model/tests/test_alphazero_model.py | 6 +- lzero/policy/alphazero.py | 8 +- lzero/policy/sampled_alphazero.py | 571 ++++++++++++++++++ lzero/policy/utils.py | 34 ++ zoo/board_games/connect4/envs/connect4_env.py | 25 +- .../gomoku_alphazero_bot_mode_config.py | 2 + ...omoku_sampled_alphazero_bot_mode_config.py | 122 ++++ ...gomoku_sampled_alphazero_sp_mode_config.py | 120 ++++ .../gomoku/entry/gomoku_alphazero_eval.py | 3 +- .../entry/gomoku_sampled_alphazero_eval.py | 47 ++ zoo/board_games/gomoku/envs/gomoku_env.py | 556 ++++++++--------- .../gomoku/envs/gomoku_human_vs_bot_UI.py | 235 +++++++ .../gomoku/envs/gomoku_rule_bot_v0.py | 376 ++++++++++++ .../gomoku/envs/gomoku_rule_bot_v1.py | 1 + .../gomoku/envs/test_gomoku_env.py | 18 +- .../gomoku/envs/test_gomoku_rule_bot_v0.py | 50 +- 21 files changed, 2495 insertions(+), 383 deletions(-) create mode 100644 lzero/mcts/ptree/ptree_az_sampled.py create mode 100644 lzero/policy/sampled_alphazero.py create mode 100644 zoo/board_games/gomoku/config/gomoku_sampled_alphazero_bot_mode_config.py create mode 100644 zoo/board_games/gomoku/config/gomoku_sampled_alphazero_sp_mode_config.py create mode 100644 zoo/board_games/gomoku/entry/gomoku_sampled_alphazero_eval.py create mode 100644 zoo/board_games/gomoku/envs/gomoku_human_vs_bot_UI.py create mode 100644 zoo/board_games/gomoku/envs/gomoku_rule_bot_v0.py diff --git a/lzero/entry/eval_alphazero.py b/lzero/entry/eval_alphazero.py index eb093af0b..486e2e6e5 100644 --- a/lzero/entry/eval_alphazero.py +++ b/lzero/entry/eval_alphazero.py @@ -87,7 +87,7 @@ def eval_alphazero( if print_seed_details: print("=" * 20) print(f'In seed {seed}, returns: {returns}') - if cfg.policy.env_type == 'board_games': + if cfg.policy.simulation_env_name in ['tictactoe', 'connect4', 'gomoku', 'chess']: print( f'win rate: {len(np.where(returns == 1.)[0]) / num_episodes_each_seed}, draw rate: {len(np.where(returns == 0.)[0]) / num_episodes_each_seed}, lose rate: {len(np.where(returns == -1.)[0]) / num_episodes_each_seed}' ) diff --git a/lzero/mcts/ctree/ctree_sampled_efficientzero/lib/cnode.cpp b/lzero/mcts/ctree/ctree_sampled_efficientzero/lib/cnode.cpp index 0f0a6ae5f..e9a6cc628 100644 --- a/lzero/mcts/ctree/ctree_sampled_efficientzero/lib/cnode.cpp +++ b/lzero/mcts/ctree/ctree_sampled_efficientzero/lib/cnode.cpp @@ -385,6 +385,7 @@ namespace tree #else disc_action_with_probs.emplace_back(std::make_pair(iter, disturbed_probs[iter])); #endif + // disc_action_with_probs.emplace_back(std::make_pair(iter, disturbed_probs[iter])); } std::sort(disc_action_with_probs.begin(), disc_action_with_probs.end(), cmp); diff --git a/lzero/mcts/ptree/ptree_az_sampled.py b/lzero/mcts/ptree/ptree_az_sampled.py new file mode 100644 index 000000000..37b96927e --- /dev/null +++ b/lzero/mcts/ptree/ptree_az_sampled.py @@ -0,0 +1,521 @@ +""" +Overview: + This code implements the Monte Carlo Tree Search (MCTS) algorithm with the integration of neural networks. + The Node class represents a node in the Monte Carlo tree and implements the basic functionalities expected in a node. + The MCTS class implements the specific search functionality and provides the optimal action through the ``get_next_action`` method. + Compared to traditional MCTS, the introduction of value networks and policy networks brings several advantages. + During the expansion of nodes, it is no longer necessary to explore every single child node, but instead, + the child nodes are directly selected based on the prior probabilities provided by the neural network. + This reduces the breadth of the search. When estimating the value of leaf nodes, there is no need for a rollout; + instead, the value output by the neural network is used, which saves the depth of the search. +""" + +import math +from typing import List, Tuple, Union, Callable, Type, Dict, Any + +import numpy as np +import torch +from ding.envs import BaseEnv +from easydict import EasyDict +from numpy import ndarray + +from lzero.mcts.ptree.ptree_sez import Action + + +class Node(object): + """ + Overview: + A class for a node in a Monte Carlo Tree. The properties of this class store basic information about the node, + such as its parent node, child nodes, and the number of times the node has been visited. + The methods of this class implement basic functionalities that a node should have, such as propagating the value back, + checking if the node is the root node, and determining if it is a leaf node. + """ + + def __init__(self, parent: "Node" = None, prior_p: float = 1.0) -> None: + """ + Overview: + Initialize a Node object. + Arguments: + - parent (:obj:`Node`): The parent node of the current node. + - prior_p (:obj:`Float`): The prior probability of selecting this node. + """ + # The parent node. + self._parent = parent + # A dictionary representing the children of the current node. The keys are the actions, and the values are + # the child nodes. + self._children = {} + # The number of times this node has been visited. + self._visit_count = 0 + # The sum of the values of all child nodes of this node. + self._value_sum = 0 + # The prior probability of selecting this node. + self.prior_p = prior_p + + @property + def value(self) -> float: + """ + Overview: + The value of the current node. + Returns: + - output (:obj:`Int`): Current value, used to compute ucb score. + """ + # Computes the average value of the current node. + if self._visit_count == 0: + return 0 + return self._value_sum / self._visit_count + + def update(self, value: float) -> None: + """ + Overview: + Update the current node information, such as ``_visit_count`` and ``_value_sum``. + Arguments: + - value (:obj:`Float`): The value of the node. + """ + # Updates the number of times this node has been visited. + self._visit_count += 1 + # Updates the sum of the values of all child nodes of this node. + self._value_sum += value + + def update_recursive(self, leaf_value: float, mcts_mode: str) -> None: + """ + Overview: + Update node information recursively. + The same game state has opposite values in the eyes of two players playing against each other. + The value of a node is evaluated from the perspective of the player corresponding to its parent node. + In ``self_play_mode``, because the player corresponding to a node changes every step during the backpropagation process, the value needs to be negated once. + In ``play_with_bot_mode``, since all nodes correspond to the same player, the value does not need to be negated. + + Arguments: + - leaf_value (:obj:`Float`): The value of the node. + - mcts_mode (:obj:`str`): The mode of MCTS, can be 'self_play_mode' or 'play_with_bot_mode'. + """ + # Update the node information recursively based on the MCTS mode. + if mcts_mode == 'self_play_mode': + # Update the current node's information. + self.update(leaf_value) + # If the current node is the root node, return. + if self.is_root(): + return + # Update the parent node's information recursively. When propagating the value back to the parent node, + # the value needs to be negated once because the perspective of evaluation has changed. + self._parent.update_recursive(-leaf_value, mcts_mode) + if mcts_mode == 'play_with_bot_mode': + # Update the current node's information. + self.update(leaf_value) + # If the current node is the root node, return. + if self.is_root(): + return + # Update the parent node's information recursively. In ``play_with_bot_mode``, since the nodes' values + # are always evaluated from the perspective of the agent player, there is no need to negate the value + # during value propagation. + self._parent.update_recursive(leaf_value, mcts_mode) + + def is_leaf(self) -> bool: + """ + Overview: + Check if the current node is a leaf node or not. + Returns: + - output (:obj:`Bool`): If self._children is empty, it means that the node has not + been expanded yet, which indicates that the node is a leaf node. + """ + # Returns True if the node is a leaf node (i.e., has no children), and False otherwise. + return self._children == {} + + def is_root(self) -> bool: + """ + Overview: + Check if the current node is a root node or not. + Returns: + - output (:obj:`Bool`): If the node does not have a parent node, + then it is a root node. + """ + return self._parent is None + + @property + def parent(self) -> None: + """ + Overview: + Get the parent node of the current node. + Returns: + - output (:obj:`Node`): The parent node of the current node. + """ + return self._parent + + @property + def children(self) -> None: + """ + Overview: + Get the dictionary of children nodes of the current node. + Returns: + - output (:obj:`dict`): A dictionary representing the children of the current node. + """ + return self._children + + @property + def visit_count(self) -> None: + """ + Overview: + Get the number of times the current node has been visited. + Returns: + - output (:obj:`Int`): The number of times the current node has been visited. + """ + return self._visit_count + + +class MCTS(object): + """ + Overview: + A class for Monte Carlo Tree Search (MCTS). The methods in this class implement the steps involved in MCTS, such as selection and expansion. + Based on this, the ``_simulate`` method is used to traverse from the root node to a leaf node. + Finally, by repeatedly calling ``_simulate`` through ``get_next_action``, the optimal action is obtained. + """ + + def __init__(self, cfg: EasyDict, simulate_env: Type[BaseEnv]) -> None: + """ + Overview: + Initializes the MCTS process. + Arguments: + - cfg (:obj:`EasyDict`): A dictionary containing the configuration parameters for the MCTS process. + """ + # Stores the configuration parameters for the MCTS search process. + self._cfg = cfg + + # ============================================================== + # sampled related core code + # ============================================================== + self.legal_actions = self._cfg.legal_actions + self.action_space_size = self._cfg.action_space_size + self.num_of_sampled_actions = self._cfg.num_of_sampled_actions + print(f'num_of_sampled_actions: {self.num_of_sampled_actions}') + self.continuous_action_space = self._cfg.continuous_action_space + + # The maximum number of moves allowed in a game. + self._max_moves = self._cfg.get('max_moves', 512) # for chess and shogi, 722 for Go. + # The number of simulations to run for each move. + self._num_simulations = self._cfg.get('num_simulations', 800) + + # UCB formula + self._pb_c_base = self._cfg.get('pb_c_base', 19652) # 19652 + self._pb_c_init = self._cfg.get('pb_c_init', 1.25) # 1.25 + + # Root prior exploration noise. + self._root_dirichlet_alpha = self._cfg.get( + 'root_dirichlet_alpha', 0.3 + ) # 0.3 # for chess, 0.03 for Go and 0.15 for shogi. + self._root_noise_weight = self._cfg.get('root_noise_weight', 0.25) # 0.25 + self.mcts_search_cnt = 0 + self.simulate_env = simulate_env + + def get_next_action( + self, + state_config_for_env_reset: Dict[str, Any], + policy_value_func: Callable, + temperature: float = 1.0, + sample: bool = True + ) -> Tuple[int, List[float]]: + """ + Overview: + Get the next action to take based on the current state of the game. + Arguments: + - state_config_for_env_reset (:obj:`Dict`): The config of state when reset the env. + - policy_value_func (:obj:`Function`): The Callable to compute the action probs and state value. + - temperature (:obj:`Float`): The exploration temperature. + - sample (:obj:`Bool`): Whether to sample an action from the probabilities or choose the most probable action. + Returns: + - action (:obj:`Int`): The selected action to take. + - action_probs (:obj:`List`): The output probability of each action. + """ + # Create a new root node for the MCTS search. + self.root = Node() + self.simulate_env.reset( + start_player_index=state_config_for_env_reset.start_player_index, + init_state=state_config_for_env_reset.init_state, + ) + # self.simulate_env_root = copy.deepcopy(self.simulate_env) + self._expand_leaf_node(self.root, self.simulate_env, policy_value_func) + + if sample: + self._add_exploration_noise(self.root) + + for n in range(self._num_simulations): + self.simulate_env.reset( + start_player_index=state_config_for_env_reset.start_player_index, + init_state=state_config_for_env_reset.init_state, + ) + self.simulate_env.battle_mode = self.simulate_env.mcts_mode + self._simulate(self.root, self.simulate_env, policy_value_func) + + # sampled related code + # Get the visit count for each possible action at the root node. + action_visits = [] + for action in range(self.simulate_env.action_space.n): + # Create an Action object for the current action + current_action_object = Action(action) + + # Use the Action object to look up the child node in the dictionary + if current_action_object in self.root.children: + action_visits.append((action, self.root.children[current_action_object].visit_count)) + else: + action_visits.append((action, 0)) + + # Unpack the tuples in action_visits list into two separate tuples: actions and visits. + actions, visits = zip(*action_visits) + # print('action_visits= {}'.format(visits)) + + visits_t = torch.as_tensor(visits, dtype=torch.float32) + visits_t /= temperature + action_probs = (visits_t / visits_t.sum()).numpy() + + if sample: + action = np.random.choice(actions, p=action_probs) + else: + action = actions[np.argmax(action_probs)] + self.mcts_search_cnt += 1 + + # get the root sampled actions according to the action_probs + self.root_sampled_actions = np.nonzero(action_probs)[0] + + # print(f'self.simulate_env_root: {self.simulate_env_root.legal_actions}') + # print(f'mcts_search_cnt: {self.mcts_search_cnt}') + # print('action= {}'.format(action)) + # print('action_probs= {}'.format(action_probs)) + # Return the selected action and the output probability of each action. + return action, action_probs + + def _simulate(self, node: Node, simulate_env: Type[BaseEnv], policy_value_func: Callable) -> None: + """ + Overview: + Run a single playout from the root to the leaf, getting a value at the leaf and propagating it back through its parents. + State is modified in-place, so a deepcopy must be provided. + Arguments: + - node (:obj:`Class Node`): Current node when performing mcts search. + - simulate_env (:obj:`Class BaseGameEnv`): The class of simulate env. + - policy_value_func (:obj:`Function`): The Callable to compute the action probs and state value. + """ + while not node.is_leaf(): + # only for debug + # print('=='*20) + # print('node.legal_actions: ', node.legal_actions) + # print(node.children.keys()) + # print('simulate_env.board: ', simulate_env.board) + # print('simulate_env.legal_actions:', simulate_env.legal_actions) + + # Traverse the tree until the leaf node. + action, node = self._select_child(node, simulate_env) + if action is None: + break + # sampled related code + simulate_env.step(action.value) + + done, winner = simulate_env.get_done_winner() + """ + in ``self_play_mode``, the leaf_value is calculated from the perspective of player ``simulate_env.current_player``. + in ``play_with_bot_mode``, the leaf_value is calculated from the perspective of player 1. + """ + + if not done: + # The leaf_value here is obtained from the neural network. The perspective of this value is from the + # player corresponding to the game state input to the neural network. For example, if the current_player + # of the current node is player 1, the value output by the network represents the goodness of the current + # game state from the perspective of player 1. + leaf_value = self._expand_leaf_node(node, simulate_env, policy_value_func) + + else: + if simulate_env.mcts_mode == 'self_play_mode': + # In a tie game, the value corresponding to a terminal node is 0. + if winner == -1: + leaf_value = 0 + else: + # To maintain consistency with the perspective of the neural network, the value of a terminal + # node is also calculated from the perspective of the current_player of the terminal node, + # which is convenient for subsequent updates. + leaf_value = 1 if simulate_env.current_player == winner else -1 + + if simulate_env.mcts_mode == 'play_with_bot_mode': + # in ``play_with_bot_mode``, the leaf_value should be transformed to the perspective of player 1. + if winner == -1: + leaf_value = 0 + elif winner == 1: + leaf_value = 1 + elif winner == 2: + leaf_value = -1 + + # Update value and visit count of nodes in this traversal. + if simulate_env.mcts_mode == 'play_with_bot_mode': + node.update_recursive(leaf_value, simulate_env.mcts_mode) + elif simulate_env.mcts_mode == 'self_play_mode': + # NOTE: e.g. + # to_play: 1 ----------> 2 ----------> 1 ----------> 2 + # state: s1 ----------> s2 ----------> s3 ----------> s4 + # action node + # leaf_value + # leaf_value is calculated from the perspective of player 1, leaf_value = value_func(s3), + # but node.value should be the value of E[q(s2, action)], i.e. calculated from the perspective of player 2. + # thus we add the negative when call update_recursive(). + node.update_recursive(-leaf_value, simulate_env.mcts_mode) + + def _select_child(self, node: Node, simulate_env: Type[BaseEnv]) -> Tuple[Union[int, float], Node]: + """ + Overview: + Select the child with the highest UCB score. + Arguments: + - node (:obj:`Class Node`): Current node. + Returns: + - action (:obj:`Int`): choose the action with the highest ucb score. + - child (:obj:`Node`): the child node reached by executing the action with the highest ucb score. + """ + action = None + child_node = None + best_score = -9999999 + # print(simulate_env._raw_env._go.board, simulate_env.legal_actions) + # Iterate over each child of the current node. + for action_tmp, child_node_tmp in node.children.items(): + # print(a, simulate_env.legal_actions) + # print('node.legal_actions: ', node.legal_actions) + if action_tmp.value in simulate_env.legal_actions: + score = self._ucb_score(node, child_node_tmp) + # Check if the score of the current child is higher than the best score so far. + if score > best_score: + best_score = score + action = action_tmp + child_node = child_node_tmp + else: + print(f'error: {action_tmp} not in {simulate_env.legal_actions}') + if child_node is None: + child_node = node # child==None, node is leaf node in play_with_bot_mode. + if action is None: + print('error: action is None') + + return action, child_node + + def _expand_leaf_node(self, node: Node, simulate_env: Type[BaseEnv], policy_value_func: Callable) -> float: + """ + Overview: + expand the node with the policy_value_func. + Arguments: + - node (:obj:`Class Node`): current node when performing mcts search. + - simulate_env (:obj:`Class BaseGameEnv`): the class of simulate env. + - policy_value_func (:obj:`Function`): the Callable to compute the action probs and state value. + Returns: + - leaf_value (:obj:`Bool`): the leaf node's value. + """ + # ============================================================== + # sampled related core code + # ============================================================== + if self.continuous_action_space: + pass + else: + # discrete action space + + # Call the policy_value_func function to compute the action probabilities and state value, and return a + # dictionary and the value of the leaf node. + legal_action_probs_dict, leaf_value = policy_value_func(simulate_env) + + node.legal_actions = [] + + # Extract actions and their corresponding probabilities from the dictionary + actions = list(legal_action_probs_dict.keys()) + probabilities = list(legal_action_probs_dict.values()) + + # Normalize the probabilities so they sum to 1 + probabilities = np.array(probabilities) + probabilities /= probabilities.sum() + + # self.num_of_sampled_actions = len(actions) + + # If there are fewer legal actions than the desired number of samples, + # adjust the number of samples to the number of legal actions + num_samples = min(len(actions), self.num_of_sampled_actions) + # Use numpy to randomly sample actions according to the given probabilities, without replacement + sampled_actions = np.random.choice(actions, size=num_samples, p=probabilities, replace=False) + sampled_actions = sampled_actions.tolist() # Convert numpy array to list + + for action_index in range(num_samples): + node.children[Action(sampled_actions[action_index])] = \ + Node( + parent=node, + prior_p=legal_action_probs_dict[sampled_actions[action_index]], + ) + node.legal_actions.append(Action(sampled_actions[action_index])) + + # Return the value of the leaf node. + return leaf_value + + def _ucb_score(self, parent: Node, child: Node) -> float: + """ + Overview: + Compute UCB score. The score for a node is based on its value, plus an exploration bonus based on the prior. + For more details, please refer to this paper: http://gauss.ececs.uc.edu/Workshops/isaim2010/papers/rosin.pdf + UCB = Q(s,a) + P(s,a) \cdot \frac{N(\text{parent})}{1+N(\text{child})} \cdot \left(c_1 + \log\left(\frac{N(\text{parent})+c_2+1}{c_2}\right)\right) + - Q(s,a): value of a child node. + - P(s,a): The prior of a child node. + - N(parent): The number of the visiting of the parent node. + - N(child): The number of the visiting of the child node. + - c_1: a parameter given by self._pb_c_init to control the influence of the prior P(s,a) relative to the value Q(s,a). + - c_2: a parameter given by self._pb_c_base to control the influence of the prior P(s,a) relative to the value Q(s,a). + Arguments: + - parent (:obj:`Class Node`): Current node. + - child (:obj:`Class Node`): Current node's child. + Returns: + - score (:obj:`Bool`): The UCB score. + """ + # Compute the value of parameter pb_c using the formula of the UCB algorithm. + pb_c = math.log((parent.visit_count + self._pb_c_base + 1) / self._pb_c_base) + self._pb_c_init + pb_c *= math.sqrt(parent.visit_count) / (child.visit_count + 1) + + # ============================================================== + # sampled related core code + # ============================================================== + # TODO(pu) + node_prior = "density" + # node_prior = "uniform" + if node_prior == "uniform": + # Uniform prior for continuous action space + prior_score = pb_c * (1 / len(parent.children)) + elif node_prior == "density": + # TODO(pu): empirical distribution + if self.continuous_action_space: + # prior is log_prob + prior_score = pb_c * ( + torch.exp(child.prior_p) / ( + sum([torch.exp(node.prior_p) for node in parent.children.values()]) + 1e-6) + ) + else: + # prior is prob + prior_score = pb_c * (child.prior_p / (sum([node.prior_p for node in parent.children.values()]) + 1e-6)) + else: + raise ValueError("{} is unknown prior option, choose uniform or density") + + # Compute the UCB score by combining the prior score and value score. + value_score = child.value + prior_score = pb_c * child.prior_p + + return prior_score + value_score + + def _add_exploration_noise(self, node: Node) -> None: + """ + Overview: + Add exploration noise. + Arguments: + - node (:obj:`Class Node`): Current node. + """ + # Get a list of actions corresponding to the child nodes. + actions = list(node.children.keys()) + # Create a list of alpha values for Dirichlet noise. + alpha = [self._root_dirichlet_alpha] * len(actions) + # Generate Dirichlet noise using the alpha values. + noise = np.random.dirichlet(alpha) + # Compute the weight of the exploration noise. + frac = self._root_noise_weight + # Update the prior probability of each child node with the exploration noise. + for a, n in zip(actions, noise): + node.children[a].prior_p = node.children[a].prior_p * (1 - frac) + n * frac + + def get_sampled_actions(self) -> List[int]: + """ + Overview: + Get the sampled actions of the root node. + Returns: + - output (:obj:`List`): The sampled actions of the root node. + """ + return self.root_sampled_actions diff --git a/lzero/mcts/ptree/ptree_sez.py b/lzero/mcts/ptree/ptree_sez.py index c6449a850..4e9f0890b 100644 --- a/lzero/mcts/ptree/ptree_sez.py +++ b/lzero/mcts/ptree/ptree_sez.py @@ -123,7 +123,6 @@ def expand( for action_index in range(self.num_of_sampled_actions): self.children[Action(sampled_actions[action_index].detach().cpu().numpy())] = Node( - # prob[action_index], # NOTE: this is a bug prob[sampled_actions[action_index]], # action_space_size=self.action_space_size, num_of_sampled_actions=self.num_of_sampled_actions, @@ -403,6 +402,7 @@ def get_sampled_actions(self) -> List[List[Union[int, float]]]: - python_sampled_actions: a vector of sampled_actions for each root, e.g. the size of original action space is 6, the K=3, python_sampled_actions = [[1,3,0], [2,4,0], [5,4,1]]. """ + # TODO(pu): root_sampled_actions bug in discere action space? sampled_actions = [] for i in range(self.root_num): sampled_actions.append(self.roots[i].legal_actions) @@ -774,20 +774,79 @@ def batch_backpropagate( ) +from typing import Union +import numpy as np + class Action: - """Class that represent an action of a game.""" + """ + Class that represents an action of a game. + + Attributes: + value (Union[int, np.ndarray]): The value of the action. Can be either an integer or a numpy array. + """ + + def __init__(self, value: Union[int, np.ndarray]) -> None: + """ + Initializes the Action with the given value. - def __init__(self, value: float) -> None: + Args: + value (Union[int, np.ndarray]): The value of the action. + """ self.value = value - def __hash__(self) -> hash: - return hash(self.value.tostring()) + def __hash__(self) -> int: + """ + Returns a hash of the Action's value. + + If the value is a numpy array, it is flattened to a tuple and then hashed. + If the value is a single integer, it is hashed directly. + + Returns: + int: The hash of the Action's value. + """ + if isinstance(self.value, np.ndarray): + if self.value.ndim == 0: + return hash(self.value.item()) + else: + return hash(tuple(self.value.flatten())) + else: + return hash(self.value) def __eq__(self, other: "Action") -> bool: - return (self.value == other.value).all() + """ + Determines if this Action is equal to another Action. + + If both values are numpy arrays, they are compared element-wise. + Otherwise, they are compared directly. + + Args: + other (Action): The Action to compare with. + + Returns: + bool: True if the two Actions are equal, False otherwise. + """ + if isinstance(self.value, np.ndarray) and isinstance(other.value, np.ndarray): + return np.array_equal(self.value, other.value) + else: + return self.value == other.value def __gt__(self, other: "Action") -> bool: - return self.value[0] > other.value[0] + """ + Determines if this Action's value is greater than another Action's value. + + Args: + other (Action): The Action to compare with. + + Returns: + bool: True if this Action's value is greater, False otherwise. + """ + return self.value > other.value def __repr__(self) -> str: + """ + Returns a string representation of this Action. + + Returns: + str: A string representation of the Action's value. + """ return str(self.value) diff --git a/lzero/model/alphazero_model.py b/lzero/model/alphazero_model.py index b916be093..286e46ff5 100644 --- a/lzero/model/alphazero_model.py +++ b/lzero/model/alphazero_model.py @@ -8,6 +8,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from ding.model import ReparameterizationHead from ding.torch_utils import MLP, ResBlock from ding.utils import MODEL_REGISTRY, SequenceType @@ -34,6 +35,16 @@ def __init__( fc_value_layers: SequenceType = [32], fc_policy_layers: SequenceType = [32], value_support_size: int = 601, + # ============================================================== + # specific sampled related config + # ============================================================== + continuous_action_space: bool = False, + num_of_sampled_actions: int = 6, + sigma_type='conditioned', + fixed_sigma_value: float = 0.3, + bound_type: str = None, + norm_type: str = 'BN', + discrete_action_encoding_type: str = 'one_hot', ): """ Overview: @@ -70,7 +81,26 @@ def __init__( self.last_linear_layer_init_zero = last_linear_layer_init_zero self.representation_network = representation_network + self.continuous_action_space = continuous_action_space self.action_space_size = action_space_size + # The dim of action space. For discrete action space, it's 1. + # For continuous action space, it is the dim of action. + self.action_space_dim = action_space_size if self.continuous_action_space else 1 + assert discrete_action_encoding_type in ['one_hot', 'not_one_hot'], discrete_action_encoding_type + self.discrete_action_encoding_type = discrete_action_encoding_type + if self.continuous_action_space: + self.action_encoding_dim = action_space_size + else: + if self.discrete_action_encoding_type == 'one_hot': + self.action_encoding_dim = action_space_size + elif self.discrete_action_encoding_type == 'not_one_hot': + self.action_encoding_dim = 1 + self.sigma_type = sigma_type + self.fixed_sigma_value = fixed_sigma_value + self.bound_type = bound_type + self.norm_type = norm_type + self.num_of_sampled_actions = num_of_sampled_actions + # TODO use more adaptive way to get the flatten output size flatten_output_size_for_value_head = ( ( @@ -88,6 +118,7 @@ def __init__( self.prediction_network = PredictionNetwork( action_space_size, + self.continuous_action_space, num_res_blocks, num_channels, value_head_channels, @@ -99,6 +130,10 @@ def __init__( flatten_output_size_for_policy_head, last_linear_layer_init_zero=self.last_linear_layer_init_zero, activation=activation, + sigma_type=self.sigma_type, + fixed_sigma_value=self.fixed_sigma_value, + bound_type=self.bound_type, + norm_type=self.norm_type, ) if self.representation_network is None: @@ -131,7 +166,7 @@ def forward(self, state_batch: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor logit, value = self.prediction_network(encoded_state) return logit, value - def compute_prob_value(self, state_batch: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def compute_policy_value(self, state_batch: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Overview: The computation graph of AlphaZero model to calculate action selection probability and value. @@ -147,9 +182,7 @@ def compute_prob_value(self, state_batch: torch.Tensor) -> Tuple[torch.Tensor, t - value (:obj:`torch.Tensor`): :math:`(B, 1)`, where B is batch size. """ logit, value = self.forward(state_batch) - # construct categorical distribution to calculate probability - dist = torch.distributions.Categorical(logits=logit) - prob = dist.probs + prob = torch.nn.functional.softmax(logit, dim=-1) return prob, value def compute_logp_value(self, state_batch: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: @@ -178,6 +211,7 @@ class PredictionNetwork(nn.Module): def __init__( self, action_space_size: int, + continuous_action_space: bool, num_res_blocks: int, num_channels: int, value_head_channels: int, @@ -189,6 +223,13 @@ def __init__( flatten_output_size_for_policy_head: int, last_linear_layer_init_zero: bool = True, activation: Optional[nn.Module] = nn.ReLU(inplace=True), + # ============================================================== + # specific sampled related config + # ============================================================== + sigma_type='conditioned', + fixed_sigma_value: float = 0.3, + bound_type: str = None, + norm_type: str = 'BN', ) -> None: """ Overview: @@ -213,6 +254,15 @@ def __init__( operation to speedup, e.g. ReLU(inplace=True). """ super().__init__() + self.continuous_action_space = continuous_action_space + self.flatten_output_size_for_value_head = flatten_output_size_for_value_head + self.flatten_output_size_for_policy_head = flatten_output_size_for_policy_head + self.norm_type = norm_type + self.sigma_type = sigma_type + self.fixed_sigma_value = fixed_sigma_value + self.bound_type = bound_type + self.activation = activation + self.resblocks = nn.ModuleList( [ ResBlock(in_channels=num_channels, activation=activation, norm_type='BN', res_type='basic', bias=False) @@ -226,7 +276,7 @@ def __init__( self.norm_policy = nn.BatchNorm2d(policy_head_channels) self.flatten_output_size_for_value_head = flatten_output_size_for_value_head self.flatten_output_size_for_policy_head = flatten_output_size_for_policy_head - self.fc_value = MLP( + self.fc_value_head = MLP( in_channels=self.flatten_output_size_for_value_head, hidden_channels=fc_value_layers[0], out_channels=output_support_size, @@ -237,17 +287,31 @@ def __init__( output_norm=False, last_linear_layer_init_zero=last_linear_layer_init_zero ) - self.fc_policy = MLP( - in_channels=self.flatten_output_size_for_policy_head, - hidden_channels=fc_policy_layers[0], - out_channels=action_space_size, - layer_num=len(fc_policy_layers) + 1, - activation=activation, - norm_type='LN', - output_activation=False, - output_norm=False, - last_linear_layer_init_zero=last_linear_layer_init_zero - ) + + # sampled related core code + if self.continuous_action_space: + self.fc_policy_head = ReparameterizationHead( + input_size=self.flatten_output_size_for_policy_head, + output_size=action_space_size, + layer_num=len(fc_policy_layers) + 1, + sigma_type=self.sigma_type, + fixed_sigma_value=self.fixed_sigma_value, + activation=nn.ReLU(), + norm_type=None, + bound_type=self.bound_type + ) + else: + self.fc_policy_head = MLP( + in_channels=self.flatten_output_size_for_policy_head, + hidden_channels=fc_policy_layers[0], + out_channels=action_space_size, + layer_num=len(fc_policy_layers) + 1, + activation=activation, + norm_type='LN', + output_activation=False, + output_norm=False, + last_linear_layer_init_zero=last_linear_layer_init_zero + ) self.activation = activation @@ -279,6 +343,11 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: value = value.reshape(-1, self.flatten_output_size_for_value_head) policy = policy.reshape(-1, self.flatten_output_size_for_policy_head) - value = self.fc_value(value) - logit = self.fc_policy(policy) - return logit, value + value = self.fc_value_head(value) + + # sampled related core code + policy = self.fc_policy_head(policy) + if self.continuous_action_space: + policy = torch.cat([policy['mu'], policy['sigma']], dim=-1) + + return policy, value diff --git a/lzero/model/tests/test_alphazero_model.py b/lzero/model/tests/test_alphazero_model.py index 892865d60..5a5a6f888 100644 --- a/lzero/model/tests/test_alphazero_model.py +++ b/lzero/model/tests/test_alphazero_model.py @@ -53,8 +53,9 @@ def output_check(self, model, outputs): prediction_network_args ) def test_prediction_network( - self, action_space_size, batch_size, num_res_blocks, num_channels, value_head_channels, policy_head_channels, - fc_value_layers, fc_policy_layers, output_support_size + self, action_space_size, batch_size, num_res_blocks, num_channels, value_head_channels, + policy_head_channels, + fc_value_layers, fc_policy_layers, output_support_size ): obs = torch.rand(batch_size, num_channels, 3, 3) flatten_output_size_for_value_head = value_head_channels * observation_shape[1] * observation_shape[2] @@ -64,6 +65,7 @@ def test_prediction_network( # print('='*20) prediction_network = PredictionNetwork( action_space_size=action_space_size, + continuous_action_space=False, num_res_blocks=num_res_blocks, num_channels=num_channels, value_head_channels=value_head_channels, diff --git a/lzero/policy/alphazero.py b/lzero/policy/alphazero.py index bb9ee2637..f924e23c1 100644 --- a/lzero/policy/alphazero.py +++ b/lzero/policy/alphazero.py @@ -162,7 +162,7 @@ def _forward_learn(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, float]: mcts_probs = mcts_probs.to(device=self._device, dtype=torch.float) reward = reward.to(device=self._device, dtype=torch.float) - action_probs, values = self._learn_model.compute_prob_value(state_batch) + action_probs, values = self._learn_model.compute_policy_value(state_batch) log_probs = torch.log(action_probs) # calculate policy entropy, for monitoring only @@ -258,7 +258,9 @@ def _init_eval(self) -> None: self._get_simulation_env() import copy mcts_eval_config = copy.deepcopy(self._cfg.mcts) - mcts_eval_config.num_simulations = mcts_eval_config.num_simulations * 2 + # TODO(pu): how to set proper num_simulations for evaluation + # mcts_eval_config.num_simulations = mcts_eval_config.num_simulations + mcts_eval_config.num_simulations = min(800, mcts_eval_config.num_simulations * 4) self._eval_mcts = MCTS(mcts_eval_config, self.simulate_env) self._eval_model = self._model @@ -323,7 +325,7 @@ def _policy_value_fn(self, env: 'Env') -> Tuple[Dict[int, np.ndarray], float]: device=self._device, dtype=torch.float ).unsqueeze(0) with torch.no_grad(): - action_probs, value = self._policy_model.compute_prob_value(current_state_scale) + action_probs, value = self._policy_model.compute_policy_value(current_state_scale) action_probs_dict = dict(zip(legal_actions, action_probs.squeeze(0)[legal_actions].detach().cpu().numpy())) return action_probs_dict, value.item() diff --git a/lzero/policy/sampled_alphazero.py b/lzero/policy/sampled_alphazero.py new file mode 100644 index 000000000..4ef2058ef --- /dev/null +++ b/lzero/policy/sampled_alphazero.py @@ -0,0 +1,571 @@ +import copy +from collections import namedtuple +from typing import List, Dict, Tuple + +import numpy as np +import torch.distributions +import torch.nn.functional as F +import torch.optim as optim +from ding.policy.base_policy import Policy +from ding.torch_utils import to_device +from ding.utils import POLICY_REGISTRY +from ding.utils.data import default_collate +from easydict import EasyDict + +from lzero.policy import configure_optimizers +from lzero.policy.utils import pad_and_get_lengths, compute_entropy + + +@POLICY_REGISTRY.register('sampled_alphazero') +class SampledAlphaZeroPolicy(Policy): + """ + Overview: + The policy class for Sampled AlphaZero. + """ + + # The default_config for AlphaZero policy. + config = dict( + # (str) The type of policy, as the key of the policy registry. + type='alphazero', + # (bool) Whether to enable the sampled-based algorithm (e.g. Sampled AlphaZero) + # this variable is used in ``collector``. + sampled_algo=False, + normalize_prob_of_sampled_actions=False, + policy_loss_type='cross_entropy', # options={'cross_entropy', 'KL'} + # (bool) Whether to use torch.compile method to speed up our model, which required torch>=2.0. + torch_compile=False, + # (bool) Whether to use TF32 for our model. + tensor_float_32=False, + model=dict( + # (tuple) The stacked obs shape. + observation_shape=(3, 6, 6), + # (int) The number of res blocks in AlphaZero model. + num_res_blocks=1, + # (int) The number of channels of hidden states in AlphaZero model. + num_channels=32, + ), + # (bool) Whether to use C++ MCTS in policy. If False, use Python implementation. + mcts_ctree=True, + # (bool) Whether to use cuda for network. + cuda=False, + # (int) How many updates(iterations) to train after collector's one collection. + # Bigger "update_per_collect" means bigger off-policy. + # collect data -> update policy-> collect data -> ... + # For different env, we have different episode_length, + # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor. + # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.model_update_ratio automatically. + update_per_collect=None, + # (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None. + model_update_ratio=0.1, + # (int) Minibatch size for one gradient descent. + batch_size=256, + # (str) Optimizer for training policy network. ['SGD', 'Adam', 'AdamW'] + optim_type='SGD', + # (float) Learning rate for training policy network. Initial lr for manually decay schedule. + learning_rate=0.2, + # (float) Weight decay for training policy network. + weight_decay=1e-4, + # (float) One-order Momentum in optimizer, which stabilizes the training process (gradient direction). + momentum=0.9, + # (float) The maximum constraint value of gradient norm clipping. + grad_clip_value=10, + # (float) The weight of value loss. + value_weight=1.0, + # (int) The number of environments used in collecting data. + collector_env_num=8, + # (int) The number of environments used in evaluating policy. + evaluator_env_num=3, + # (bool) Whether to use piecewise constant learning rate decay. + # i.e. lr: 0.2 -> 0.02 -> 0.002 + lr_piecewise_constant_decay=True, + # (int) The number of final training iterations to control lr decay, which is only used for manually decay. + threshold_training_steps_for_final_lr=int(5e5), + # (bool) Whether to use manually temperature decay. + # i.e. temperature: 1 -> 0.5 -> 0.25 + manual_temperature_decay=False, + # (int) The number of final training iterations to control temperature, which is only used for manually decay. + threshold_training_steps_for_final_temperature=int(1e5), + # (float) The fixed temperature value for MCTS action selection, which is used to control the exploration. + # The larger the value, the more exploration. This value is only used when manual_temperature_decay=False. + fixed_temperature_value=0.25, + mcts=dict( + # (int) The number of simulations to perform at each move. + num_simulations=50, + # (int) The maximum number of moves to make in a game. + max_moves=512, # for chess and shogi, 722 for Go. + # (float) The alpha value used in the Dirichlet distribution for exploration at the root node of the search tree. + root_dirichlet_alpha=0.3, + # (float) The noise weight at the root node of the search tree. + root_noise_weight=0.25, + # (int) The base constant used in the PUCT formula for balancing exploration and exploitation during tree search. + pb_c_base=19652, + # (float) The initialization constant used in the PUCT formula for balancing exploration and exploitation during tree search. + pb_c_init=1.25, + # + legal_actions=None, + # (int) The action space size. + action_space_size=9, + # (int) The number of sampled actions for each state. + num_of_sampled_actions=2, + # + continuous_action_space=False, + ), + other=dict(replay_buffer=dict( + replay_buffer_size=int(1e6), + save_episode=False, + )), + ) + + def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm default model setting for demonstration. + Returns: + - model_type (:obj:`str`): The model type used in this algorithm, which is registered in ModelRegistry. + - import_names (:obj:`List[str]`): The model class path list used in this algorithm. + """ + return 'AlphaZeroModel', ['lzero.model.alphazero_model'] + + def _init_learn(self) -> None: + assert self._cfg.optim_type in ['SGD', 'Adam', 'AdamW'], self._cfg.optim_type + if self._cfg.optim_type == 'SGD': + self._optimizer = optim.SGD( + self._model.parameters(), + lr=self._cfg.learning_rate, + momentum=self._cfg.momentum, + weight_decay=self._cfg.weight_decay, + ) + elif self._cfg.optim_type == 'Adam': + self._optimizer = optim.Adam( + self._model.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay + ) + elif self._cfg.optim_type == 'AdamW': + self._optimizer = configure_optimizers( + model=self._model, + weight_decay=self._cfg.weight_decay, + learning_rate=self._cfg.learning_rate, + device_type=self._cfg.device + ) + + if self._cfg.lr_piecewise_constant_decay: + from torch.optim.lr_scheduler import LambdaLR + max_step = self._cfg.threshold_training_steps_for_final_lr + # NOTE: the 1, 0.1, 0.01 is the decay rate, not the lr. + # lr_lambda = lambda step: 1 if step < max_step * 0.5 else (0.1 if step < max_step else 0.01) # noqa + lr_lambda = lambda step: 1 if step < max_step * 0.33 else (0.1 if step < max_step * 0.66 else 0.01) # noqa + + self.lr_scheduler = LambdaLR(self._optimizer, lr_lambda=lr_lambda) + + # Algorithm config + self._value_weight = self._cfg.value_weight + self._entropy_weight = self._cfg.entropy_weight + # Main and target models + self._learn_model = self._model + + # TODO(pu): test the effect of torch 2.0 + if self._cfg.torch_compile: + self._learn_model = torch.compile(self._learn_model) + + def _forward_learn(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, float]: + for input_dict in inputs: + # Check and remove 'katago_game_state' from 'obs' if it exists + if 'katago_game_state' in input_dict['obs']: + del input_dict['obs']['katago_game_state'] + + # Check and remove 'katago_game_state' from 'next_obs' if it exists + if 'katago_game_state' in input_dict['next_obs']: + del input_dict['next_obs']['katago_game_state'] + + # list of dict -> dict of list + # inputs_deepcopy = copy.deepcopy(inputs) + # only for env with variable legal actions + inputs = pad_and_get_lengths(inputs, self._cfg.mcts.num_of_sampled_actions) + inputs = default_collate(inputs) + valid_action_length = inputs['action_length'] + + if self._cuda: + inputs = to_device(inputs, self._device) + self._learn_model.train() + + state_batch = inputs['obs']['observation'] + mcts_visit_count_probs = inputs['probs'] + reward = inputs['reward'] + root_sampled_actions = inputs['root_sampled_actions'] + + if len(root_sampled_actions.shape) == 1: + print(f"root_sampled_actions.shape: {root_sampled_actions.shape}") + state_batch = state_batch.to(device=self._device, dtype=torch.float) + mcts_visit_count_probs = mcts_visit_count_probs.to(device=self._device, dtype=torch.float) + reward = reward.to(device=self._device, dtype=torch.float) + + policy_probs, values = self._learn_model.compute_policy_value(state_batch) + policy_log_probs = torch.log(policy_probs) + + # calculate policy entropy, for monitoring only + entropy = compute_entropy(policy_probs) + entropy_loss = -entropy + + # ============================================================== + # policy loss + # ============================================================== + # mcts_visit_count_probs = mcts_visit_count_probs / (mcts_visit_count_probs.sum(dim=1, keepdim=True) + 1e-6) + # policy_loss = torch.nn.functional.kl_div( + # policy_log_probs, mcts_visit_count_probs, reduction='batchmean' + # ) + # orig implementation + # policy_loss = -torch.mean(torch.sum(mcts_visit_count_probs * policy_log_probs, 1)) + + policy_loss = self._calculate_policy_loss_disc(policy_probs, mcts_visit_count_probs, root_sampled_actions, valid_action_length) + + # ============================================================== + # value loss + # ============================================================== + value_loss = F.mse_loss(values.view(-1), reward) + + total_loss = self._value_weight * value_loss + policy_loss + self._entropy_weight * entropy_loss + self._optimizer.zero_grad() + total_loss.backward() + + total_grad_norm_before_clip = torch.nn.utils.clip_grad_norm_( + list(self._model.parameters()), + max_norm=self._cfg.grad_clip_value, + ) + self._optimizer.step() + if self._cfg.lr_piecewise_constant_decay is True: + self.lr_scheduler.step() + + # ============= + # after update + # ============= + return { + 'cur_lr': self._optimizer.param_groups[0]['lr'], + 'total_loss': total_loss.item(), + 'policy_loss': policy_loss.item(), + 'value_loss': value_loss.item(), + 'entropy_loss': entropy_loss.item(), + 'total_grad_norm_before_clip': total_grad_norm_before_clip.item(), + 'collect_mcts_temperature': self.collect_mcts_temperature, + } + + def _calculate_policy_loss_disc( + self, policy_probs: torch.Tensor, target_policy: torch.Tensor, + target_sampled_actions: torch.Tensor, valid_action_lengths: torch.Tensor + ) -> torch.Tensor: + + # For each batch and each sampled action, get the corresponding probability + # from policy_probs and target_policy, and put it into sampled_policy_probs and + # sampled_target_policy at the same position. + sampled_policy_probs = policy_probs.gather(1, target_sampled_actions) + sampled_target_policy = target_policy.gather(1, target_sampled_actions) + + # Create a mask for valid actions + max_length = target_sampled_actions.size(1) + mask = torch.arange(max_length).expand(len(valid_action_lengths), max_length) < valid_action_lengths.unsqueeze( + 1) + mask = mask.to(device=self._device) + + # Apply the mask to sampled_policy_probs and sampled_target_policy + sampled_policy_probs = sampled_policy_probs * mask.float() + sampled_target_policy = sampled_target_policy * mask.float() + + # Normalize sampled_policy_probs and sampled_target_policy + sampled_policy_probs = sampled_policy_probs / (sampled_policy_probs.sum(dim=1, keepdim=True) + 1e-6) + sampled_target_policy = sampled_target_policy / (sampled_target_policy.sum(dim=1, keepdim=True) + 1e-6) + + # after normalization, the sum of each row should be 1, but the prob corresponding to valid action becomes a small non-zero value + # Use torch.where to prevent gradients for invalid actions + sampled_policy_probs = torch.where(mask, sampled_policy_probs, torch.zeros_like(sampled_policy_probs)) + sampled_target_policy = torch.where(mask, sampled_target_policy, torch.zeros_like(sampled_target_policy)) + + if self._cfg.policy_loss_type == 'KL': + # Calculate the KL divergence between sampled_policy_probs and sampled_target_policy + # The KL divergence between 2 probability distributions P and Q is defined as: + # KL(P || Q) = sum(P(i) * log(P(i) / Q(i))) + # We use the PyTorch function kl_div to calculate it. + loss = torch.nn.functional.kl_div( + sampled_policy_probs.log(), sampled_target_policy, reduction='none' + ) + + # TODO(pu) + # 使用 nan_to_num 将 loss 中的 nan 值设置为0 + loss = torch.nan_to_num(loss) + + # Apply the mask to the loss + loss = loss * mask.float() + # Calculate the mean loss over the batch + loss = loss.sum() / mask.sum() + + elif self._cfg.policy_loss_type == 'cross_entropy': + # Calculate the cross entropy loss between sampled_policy_probs and sampled_target_policy + # The cross entropy between 2 probability distributions P and Q is defined as: + # H(P, Q) = -sum(P(i) * log(Q(i))) + # We use the PyTorch function cross_entropy to calculate it. + loss = torch.nn.functional.cross_entropy( + sampled_policy_probs, torch.argmax(sampled_target_policy, dim=1), reduction='none' + ) + + # 使用 nan_to_num 将 loss 中的 nan 值设置为0 + loss = torch.nan_to_num(loss) + + # Apply the mask to the loss + loss = loss * mask.float() + # Calculate the mean loss over the batch + loss = loss.sum() / mask.sum() + + else: + raise ValueError(f"Invalid policy_loss_type: {self._cfg.policy_loss_type}") + + + return loss + + def _init_collect(self) -> None: + """ + Overview: + Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. + """ + self._get_simulation_env() + + self._collect_model = self._model + if self._cfg.mcts_ctree: + self._collect_mcts = mcts_alphazero.MCTS(self._cfg.mcts.max_moves, self._cfg.mcts.num_simulations, + self._cfg.mcts.pb_c_base, + self._cfg.mcts.pb_c_init, self._cfg.mcts.root_dirichlet_alpha, + self._cfg.mcts.root_noise_weight, self.simulate_env) + else: + if self._cfg.sampled_algo: + from lzero.mcts.ptree.ptree_az_sampled import MCTS + else: + from lzero.mcts.ptree.ptree_az import MCTS + self._collect_mcts = MCTS(self._cfg.mcts, self.simulate_env) + + self.collect_mcts_temperature = 1 + + @torch.no_grad() + def _forward_collect(self, obs: Dict, temperature: float = 1) -> Dict[str, torch.Tensor]: + + """ + Overview: + The forward function for collecting data in collect mode. Use real env to execute MCTS search. + Arguments: + - obs (:obj:`Dict`): The dict of obs, the key is env_id and the value is the \ + corresponding obs in this timestep. + - temperature (:obj:`float`): The temperature for MCTS search. + Returns: + - output (:obj:`Dict[str, torch.Tensor]`): The dict of output, the key is env_id and the value is the \ + the corresponding policy output in this timestep, including action, probs and so on. + """ + self.collect_mcts_temperature = temperature + ready_env_id = list(obs.keys()) + init_state = {env_id: obs[env_id]['board'] for env_id in ready_env_id} + try: + katago_game_state = {env_id: obs[env_id]['katago_game_state'] for env_id in ready_env_id} + except Exception as e: + katago_game_state = {env_id: None for env_id in ready_env_id} + + start_player_index = {env_id: obs[env_id]['current_player_index'] for env_id in ready_env_id} + output = {} + self._policy_model = self._collect_model + for env_id in ready_env_id: + # print('[collect] start_player_index={}'.format(start_player_index[env_id])) + # print('[collect] init_state=\n{}'.format(init_state[env_id])) + + state_config_for_env_reset = EasyDict(dict(start_player_index=start_player_index[env_id], + init_state=init_state[env_id], + katago_policy_init=True, + katago_game_state=katago_game_state[env_id])) + + action, mcts_visit_count_probs = self._collect_mcts.get_next_action( + state_config_for_env_reset, + self._policy_value_func, + self.collect_mcts_temperature, + True, + ) + + # if np.array_equal(self._collect_mcts.get_sampled_actions(), np.array([2, 2, 3])): + # print('debug') + output[env_id] = { + 'action': action, + 'probs': mcts_visit_count_probs, + 'root_sampled_actions': self._collect_mcts.get_sampled_actions(), + } + + return output + + def _init_eval(self) -> None: + """ + Overview: + Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. + """ + self._get_simulation_env() + # TODO(pu): use double num_simulations for evaluation + if self._cfg.mcts_ctree: + self._eval_mcts = mcts_alphazero.MCTS(self._cfg.mcts.max_moves, min(800, mcts_eval_config.num_simulations * 4), + self._cfg.mcts.pb_c_base, + self._cfg.mcts.pb_c_init, self._cfg.mcts.root_dirichlet_alpha, + self._cfg.mcts.root_noise_weight, self.simulate_env) + else: + if self._cfg.sampled_algo: + from lzero.mcts.ptree.ptree_az_sampled import MCTS + else: + from lzero.mcts.ptree.ptree_az import MCTS + mcts_eval_config = copy.deepcopy(self._cfg.mcts) + # TODO(pu): how to set proper num_simulations for evaluation + # mcts_eval_config.num_simulations = mcts_eval_config.num_simulations + mcts_eval_config.num_simulations = min(800, mcts_eval_config.num_simulations * 4) + + self._eval_mcts = MCTS(mcts_eval_config, self.simulate_env) + + self._eval_model = self._model + + def _forward_eval(self, obs: Dict) -> Dict[str, torch.Tensor]: + + """ + Overview: + The forward function for evaluating the current policy in eval mode, similar to ``self._forward_collect``. + Arguments: + - obs (:obj:`Dict`): The dict of obs, the key is env_id and the value is the \ + corresponding obs in this timestep. + Returns: + - output (:obj:`Dict[str, torch.Tensor]`): The dict of output, the key is env_id and the value is the \ + the corresponding policy output in this timestep, including action, probs and so on. + """ + ready_env_id = list(obs.keys()) + init_state = {env_id: obs[env_id]['board'] for env_id in ready_env_id} + try: + katago_game_state = {env_id: obs[env_id]['katago_game_state'] for env_id in ready_env_id} + except Exception as e: + katago_game_state = {env_id: None for env_id in ready_env_id} + + start_player_index = {env_id: obs[env_id]['current_player_index'] for env_id in ready_env_id} + output = {} + self._policy_model = self._eval_model + for env_id in ready_env_id: + # print('[eval] start_player_index={}'.format(start_player_index[env_id])) + # print('[eval] init_state=\n {}'.format(init_state[env_id])) + + state_config_for_env_reset = EasyDict(dict(start_player_index=start_player_index[env_id], + init_state=init_state[env_id], + katago_policy_init=False, + katago_game_state=katago_game_state[env_id])) + + # try: + action, mcts_visit_count_probs = self._eval_mcts.get_next_action(state_config_for_env_reset, self._policy_value_func, + 1.0, False) + # except Exception as e: + # print(f"Exception occurred: {e}") + # print(f"Is self._policy_value_func callable? {callable(self._policy_value_func)}") + # raise # re-raise the exception + # print("="*20) + # print(action, mcts_visit_count_probs) + # print("="*20) + output[env_id] = { + 'action': action, + 'probs': mcts_visit_count_probs, + } + return output + + def _get_simulation_env(self): + assert self._cfg.simulation_env_name in ['tictactoe', 'gomoku', 'go'], self._cfg.simulation_env_name + assert self._cfg.simulation_env_config_type in ['play_with_bot', 'self_play', 'league', 'sampled_play_with_bot'], self._cfg.simulation_env_config_type + if self._cfg.simulation_env_name == 'tictactoe': + from zoo.board_games.tictactoe.envs.tictactoe_env import TicTacToeEnv + if self._cfg.simulation_env_config_type == 'play_with_bot': + from zoo.board_games.tictactoe.config.tictactoe_alphazero_bot_mode_config import \ + tictactoe_alphazero_config + elif self._cfg.simulation_env_config_type == 'self_play': + from zoo.board_games.tictactoe.config.tictactoe_alphazero_sp_mode_config import \ + tictactoe_alphazero_config + elif self._cfg.simulation_env_config_type == 'league': + from zoo.board_games.tictactoe.config.tictactoe_alphazero_league_config import \ + tictactoe_alphazero_config + elif self._cfg.simulation_env_config_type == 'sampled_play_with_bot': + from zoo.board_games.tictactoe.config.tictactoe_sampled_alphazero_bot_mode_config import \ + tictactoe_sampled_alphazero_config as tictactoe_alphazero_config + + self.simulate_env = TicTacToeEnv(tictactoe_alphazero_config.env) + + elif self._cfg.simulation_env_name == 'gomoku': + from zoo.board_games.gomoku.envs.gomoku_env import GomokuEnv + if self._cfg.simulation_env_config_type == 'play_with_bot': + from zoo.board_games.gomoku.config.gomoku_alphazero_bot_mode_config import gomoku_alphazero_config + elif self._cfg.simulation_env_config_type == 'self_play': + from zoo.board_games.gomoku.config.gomoku_alphazero_sp_mode_config import gomoku_alphazero_config + elif self._cfg.simulation_env_config_type == 'league': + from zoo.board_games.gomoku.config.gomoku_alphazero_league_config import gomoku_alphazero_config + elif self._cfg.simulation_env_config_type == 'sampled_play_with_bot': + from zoo.board_games.gomoku.config.gomoku_sampled_alphazero_bot_mode_config import gomoku_sampled_alphazero_config as gomoku_alphazero_config + + self.simulate_env = GomokuEnv(gomoku_alphazero_config.env) + elif self._cfg.simulation_env_name == 'go': + from zoo.board_games.go.envs.go_env import GoEnv + if self._cfg.simulation_env_config_type == 'play_with_bot': + from zoo.board_games.go.config.go_alphazero_bot_mode_config import go_alphazero_config + elif self._cfg.simulation_env_config_type == 'self_play': + from zoo.board_games.go.config.go_alphazero_sp_mode_config import go_alphazero_config + elif self._cfg.simulation_env_config_type == 'league': + from zoo.board_games.go.config.go_alphazero_league_config import go_alphazero_config + elif self._cfg.simulation_env_config_type == 'sampled_play_with_bot': + from zoo.board_games.go.config.go_sampled_alphazero_bot_mode_config import \ + go_sampled_alphazero_config as go_alphazero_config + + self.simulate_env = GoEnv(go_alphazero_config.env) + + @torch.no_grad() + def _policy_value_func(self, environment: 'Environment') -> Tuple[Dict[int, np.ndarray], float]: + # Retrieve the legal actions in the current environment + legal_actions = environment.legal_actions + + # Retrieve the current state and its scale from the environment + current_state, state_scale = environment.current_state() + + # Convert the state scale to a PyTorch FloatTensor, adding a dimension to match the model's input requirements + state_scale_tensor = torch.from_numpy(state_scale).to( + device=self._device, dtype=torch.float + ).unsqueeze(0) + + # Compute action probabilities and state value for the current state using the policy model, without gradient computation + with torch.no_grad(): + action_probabilities, state_value = self._policy_model.compute_policy_value(state_scale_tensor) + + # Extract the probabilities of the legal actions from the action probabilities, and convert the result to a numpy array + legal_action_probabilities = dict( + zip(legal_actions, action_probabilities.squeeze(0)[legal_actions].detach().cpu().numpy())) + + # Return probabilities of the legal actions and the state value + return legal_action_probabilities, state_value.item() + + def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Register the variables to be monitored in learn mode. The registered variables will be logged in + tensorboard according to the return value ``_forward_learn``. + """ + return super()._monitor_vars_learn() + [ + 'cur_lr', 'total_loss', 'policy_loss', 'value_loss', 'entropy_loss', 'total_grad_norm_before_clip', + 'collect_mcts_temperature' + ] + + def _process_transition(self, obs: Dict, model_output: Dict[str, torch.Tensor], timestep: namedtuple) -> Dict: + """ + Overview: + Generate the dict type transition (one timestep) data from policy learning. + """ + if 'katago_game_state' in obs.keys(): + del obs['katago_game_state'] + # if 'katago_game_state' in timestep.obs.keys(): + # del timestep.obs['katago_game_state'] + # Note: used in _foward_collect in alphazero_collector now + + return { + 'obs': obs, + 'next_obs': timestep.obs, + 'action': model_output['action'], + 'root_sampled_actions': model_output['root_sampled_actions'], + 'probs': model_output['probs'], + 'reward': timestep.reward, + 'done': timestep.done, + } + + def _get_train_sample(self, data): + # be compatible with DI-engine Policy class + pass \ No newline at end of file diff --git a/lzero/policy/utils.py b/lzero/policy/utils.py index 0c49a4ac7..1323fbf89 100644 --- a/lzero/policy/utils.py +++ b/lzero/policy/utils.py @@ -11,6 +11,34 @@ from torch.nn import functional as F +def pad_and_get_lengths(inputs, num_of_sampled_actions): + """ + Overview: + Pad root_sampled_actions to make sure that the length of root_sampled_actions is equal to num_of_sampled_actions. + Also record the true length of each sequence before padding. + Arguments: + - inputs (:obj:`List[dict]`): The input data. + - num_of_sampled_actions (:obj:`int`): The number of sampled actions. + Returns: + - inputs (:obj:`List[dict]`): The input data after padding. Each dict also contains 'action_length' which indicates + the true length of 'root_sampled_actions' before padding. + Example: + >>> inputs = [{'root_sampled_actions': torch.tensor([1, 2])}, {'root_sampled_actions': torch.tensor([3, 4, 5])}] + >>> num_of_sampled_actions = 5 + >>> result = pad_and_get_lengths(inputs, num_of_sampled_actions) + >>> print(result) # Prints [{'root_sampled_actions': tensor([1, 2, 2, 2, 2]), 'action_length': 2}, + {'root_sampled_actions': tensor([3, 4, 5, 5, 5]), 'action_length': 3}] + """ + for input_dict in inputs: + root_sampled_actions = input_dict['root_sampled_actions'] + input_dict['action_length'] = len(root_sampled_actions) + if len(root_sampled_actions) < num_of_sampled_actions: + # Use the last element to pad root_sampled_actions + padding = root_sampled_actions[-1].repeat(num_of_sampled_actions - len(root_sampled_actions)) + input_dict['root_sampled_actions'] = torch.cat((root_sampled_actions, padding)) + return inputs + + def visualize_avg_softmax(logits): """ Overview: @@ -346,6 +374,12 @@ def negative_cosine_similarity(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tens return -(x1 * x2).sum(dim=1) +def compute_entropy(policy_probs: torch.Tensor) -> torch.Tensor: + dist = torch.distributions.Categorical(probs=policy_probs) + entropy = dist.entropy().mean() + return entropy + + def get_max_entropy(action_shape: int) -> np.float32: """ Overview: diff --git a/zoo/board_games/connect4/envs/connect4_env.py b/zoo/board_games/connect4/envs/connect4_env.py index 38fcc469f..9b73f9291 100644 --- a/zoo/board_games/connect4/envs/connect4_env.py +++ b/zoo/board_games/connect4/envs/connect4_env.py @@ -47,16 +47,7 @@ from zoo.board_games.mcts_bot import MCTSBot -def get_image(path: str) -> Any: - from os import path as os_path - import pygame - - cwd = os_path.dirname(__file__) - image = pygame.image.load(cwd + "/" + path) - sfc = pygame.Surface(image.get_size(), flags=pygame.SRCALPHA) - sfc.blit(image, (0, 0)) - return sfc @ENV_REGISTRY.register('connect4') @@ -421,17 +412,17 @@ def render(self, mode: str = None) -> None: # Load and scale all of the necessary images. tile_size = (screen_width * (91 / 99)) / 7 - red_chip = get_image(os.path.join("img", "C4RedPiece.png")) + red_chip = self.get_image(os.path.join("img", "C4RedPiece.png")) red_chip = pygame.transform.scale( red_chip, (int(tile_size * (9 / 13)), int(tile_size * (9 / 13))) ) - black_chip = get_image(os.path.join("img", "C4BlackPiece.png")) + black_chip = self.get_image(os.path.join("img", "C4BlackPiece.png")) black_chip = pygame.transform.scale( black_chip, (int(tile_size * (9 / 13)), int(tile_size * (9 / 13))) ) - board_img = get_image(os.path.join("img", "Connect4Board.png")) + board_img = self.get_image(os.path.join("img", "Connect4Board.png")) board_img = pygame.transform.scale( board_img, ((int(screen_width)), int(screen_height)) ) @@ -718,3 +709,13 @@ def create_evaluator_env_cfg(cfg: dict) -> List[dict]: def close(self) -> None: pass + + def get_image(self, path: str) -> Any: + from os import path as os_path + import pygame + + cwd = os_path.dirname(__file__) + image = pygame.image.load(cwd + "/" + path) + sfc = pygame.Surface(image.get_size(), flags=pygame.SRCALPHA) + sfc.blit(image, (0, 0)) + return sfc diff --git a/zoo/board_games/gomoku/config/gomoku_alphazero_bot_mode_config.py b/zoo/board_games/gomoku/config/gomoku_alphazero_bot_mode_config.py index 9321037f4..2382d5c39 100644 --- a/zoo/board_games/gomoku/config/gomoku_alphazero_bot_mode_config.py +++ b/zoo/board_games/gomoku/config/gomoku_alphazero_bot_mode_config.py @@ -36,6 +36,8 @@ scale=True, check_action_to_connect4_in_bot_v0=False, # ============================================================== + screen_scaling=9, + render_mode=None, ), policy=dict( # ============================================================== diff --git a/zoo/board_games/gomoku/config/gomoku_sampled_alphazero_bot_mode_config.py b/zoo/board_games/gomoku/config/gomoku_sampled_alphazero_bot_mode_config.py new file mode 100644 index 000000000..fbb52d70e --- /dev/null +++ b/zoo/board_games/gomoku/config/gomoku_sampled_alphazero_bot_mode_config.py @@ -0,0 +1,122 @@ +from easydict import EasyDict + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +board_size = 6 +num_simulations = 100 +update_per_collect = 50 +# board_size = 9 +# num_simulations = 200 +# update_per_collect = 100 +num_of_sampled_actions = 20 +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +batch_size = 256 +max_env_step = int(10e6) +prob_random_action_in_bot = 0.5 +mcts_ctree = False +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== +gomoku_sampled_alphazero_config = dict( + exp_name= + f'data_saz_ptree/gomoku_bs{board_size}_sampled_alphazero_bot-mode_rand{prob_random_action_in_bot}_na{num_of_sampled_actions}_ns{num_simulations}_upc{update_per_collect}_seed0', + env=dict( + stop_value=2, + board_size=board_size, + battle_mode='play_with_bot_mode', + bot_action_type='v0', + prob_random_action_in_bot=prob_random_action_in_bot, + channel_last=False, + use_katago_bot=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + # ============================================================== + # for the creation of simulation env + agent_vs_human=False, + prob_random_agent=0, + prob_expert_agent=0, + scale=True, + check_action_to_connect4_in_bot_v0=False, + simulation_env_name="gomoku", + # ============================================================== + mcts_ctree=mcts_ctree, + screen_scaling=9, + render_mode=None, + ), + policy=dict( + # ============================================================== + # for the creation of simulation env + simulation_env_name='gomoku', + simulation_env_config_type='sampled_play_with_bot', + # ============================================================== + torch_compile=False, + tensor_float_32=False, + model=dict( + observation_shape=(3, board_size, board_size), + action_space_size=int(1 * board_size * board_size), + num_res_blocks=1, + num_channels=64, + ), + sampled_algo=True, + mcts_ctree=mcts_ctree, + policy_loss_type='KL', + cuda=True, + board_size=board_size, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + value_weight=1.0, + entropy_weight=0.0, + n_episode=n_episode, + eval_freq=int(2e3), + mcts=dict(num_simulations=num_simulations, num_of_sampled_actions=num_of_sampled_actions), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), +) + +gomoku_sampled_alphazero_config = EasyDict(gomoku_sampled_alphazero_config) +main_config = gomoku_sampled_alphazero_config + +gomoku_sampled_alphazero_create_config = dict( + env=dict( + type='gomoku', + import_names=['zoo.board_games.gomoku.envs.gomoku_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='sampled_alphazero', + import_names=['lzero.policy.sampled_alphazero'], + ), + collector=dict( + type='episode_alphazero', + get_train_sample=False, + import_names=['lzero.worker.alphazero_collector'], + ), + evaluator=dict( + type='alphazero', + import_names=['lzero.worker.alphazero_evaluator'], + ) +) +gomoku_sampled_alphazero_create_config = EasyDict(gomoku_sampled_alphazero_create_config) +create_config = gomoku_sampled_alphazero_create_config + +if __name__ == '__main__': + if main_config.policy.tensor_float_32: + import torch + + # The flag below controls whether to allow TF32 on matmul. This flag defaults to False + # in PyTorch 1.12 and later. + torch.backends.cuda.matmul.allow_tf32 = True + # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. + torch.backends.cudnn.allow_tf32 = True + + from lzero.entry import train_alphazero + train_alphazero([main_config, create_config], seed=0, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/board_games/gomoku/config/gomoku_sampled_alphazero_sp_mode_config.py b/zoo/board_games/gomoku/config/gomoku_sampled_alphazero_sp_mode_config.py new file mode 100644 index 000000000..a018b8b91 --- /dev/null +++ b/zoo/board_games/gomoku/config/gomoku_sampled_alphazero_sp_mode_config.py @@ -0,0 +1,120 @@ +from easydict import EasyDict + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +board_size = 6 # default_size is 15 +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 50 +update_per_collect = 100 # 8*36=288 +batch_size = 256 +max_env_step = int(10e6) +prob_random_action_in_bot = 0.5 +mcts_ctree = False +num_of_sampled_actions = 20 +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== +gomoku_sampled_alphazero_config = dict( + exp_name= + f'data_saz_ptree/gomoku_sampled_alphazero_sp-mode_rand{prob_random_action_in_bot}_na{num_of_sampled_actions}_ns{num_simulations}_upc{update_per_collect}_seed0', + env=dict( + stop_value=2, + board_size=board_size, + battle_mode='self_play_mode', + bot_action_type='v0', + prob_random_action_in_bot=prob_random_action_in_bot, + channel_last=False, # NOTE + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + # ============================================================== + # for the creation of simulation env + agent_vs_human=False, + prob_random_agent=0, + prob_expert_agent=0, + scale=True, + check_action_to_connect4_in_bot_v0=False, + simulation_env_name="gomoku", + # ============================================================== + mcts_ctree=mcts_ctree, + screen_scaling=9, + render_mode=None, + ), + policy=dict( + # ============================================================== + # for the creation of simulation env + simulation_env_name='gomoku', + simulation_env_config_type='sampled_self_play', + # ============================================================== + torch_compile=False, + tensor_float_32=False, + model=dict( + observation_shape=(3, board_size, board_size), + action_space_size=int(1 * board_size * board_size), + num_res_blocks=1, + num_channels=32, + ), + sampled_algo=True, + mcts_ctree=mcts_ctree, + policy_loss_type='KL', + cuda=True, + board_size=board_size, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + manual_temperature_decay=True, + grad_clip_value=0.5, + value_weight=1.0, + entropy_weight=0.0, + n_episode=n_episode, + eval_freq=int(2e3), + mcts=dict(num_simulations=num_simulations, num_of_sampled_actions=num_of_sampled_actions), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), +) + +gomoku_sampled_alphazero_config = EasyDict(gomoku_sampled_alphazero_config) +main_config = gomoku_sampled_alphazero_config + +gomoku_sampled_alphazero_create_config = dict( + env=dict( + type='gomoku', + import_names=['zoo.board_games.gomoku.envs.gomoku_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='sampled_alphazero', + import_names=['lzero.policy.sampled_alphazero'], + ), + collector=dict( + type='episode_alphazero', + get_train_sample=False, + import_names=['lzero.worker.alphazero_collector'], + ), + evaluator=dict( + type='alphazero', + import_names=['lzero.worker.alphazero_evaluator'], + ) +) +gomoku_sampled_alphazero_create_config = EasyDict(gomoku_sampled_alphazero_create_config) +create_config = gomoku_sampled_alphazero_create_config + +if __name__ == '__main__': + if main_config.policy.tensor_float_32: + import torch + + # The flag below controls whether to allow TF32 on matmul. This flag defaults to False + # in PyTorch 1.12 and later. + torch.backends.cuda.matmul.allow_tf32 = True + # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. + torch.backends.cudnn.allow_tf32 = True + + from lzero.entry import train_alphazero + train_alphazero([main_config, create_config], seed=0, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/board_games/gomoku/entry/gomoku_alphazero_eval.py b/zoo/board_games/gomoku/entry/gomoku_alphazero_eval.py index 9cbd67392..e5ae336b0 100644 --- a/zoo/board_games/gomoku/entry/gomoku_alphazero_eval.py +++ b/zoo/board_games/gomoku/entry/gomoku_alphazero_eval.py @@ -12,7 +12,8 @@ seeds = [0] num_episodes_each_seed = 5 # If True, you can play with the agent. - main_config.env.agent_vs_human = False + main_config.env.agent_vs_human = True + main_config.env.render_mode = 'image_realtime_mode' create_config.env_manager.type = 'base' main_config.env.evaluator_env_num = 1 main_config.env.n_evaluator_episode = 1 diff --git a/zoo/board_games/gomoku/entry/gomoku_sampled_alphazero_eval.py b/zoo/board_games/gomoku/entry/gomoku_sampled_alphazero_eval.py new file mode 100644 index 000000000..3cc2e0d90 --- /dev/null +++ b/zoo/board_games/gomoku/entry/gomoku_sampled_alphazero_eval.py @@ -0,0 +1,47 @@ +from zoo.board_games.gomoku.config.gomoku_sampled_alphazero_bot_mode_config import main_config, create_config +from lzero.entry import eval_alphazero +import numpy as np + +if __name__ == '__main__': + """ + model_path (:obj:`Optional[str]`): The pretrained model path, which should + point to the ckpt file of the pretrained model, and an absolute path is recommended. + In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. + """ + # model_path = './ckpt/ckpt_best.pth.tar' + model_path = None + seeds = [0] + num_episodes_each_seed = 1 + # If True, you can play with the agent. + main_config.env.agent_vs_human = True + main_config.env.battle_mode = 'eval_mode' + main_config.env.render_mode = 'image_realtime_mode' + create_config.env_manager.type = 'base' + main_config.env.collector_env_num = 1 + main_config.env.evaluator_env_num = 1 + main_config.env.n_evaluator_episode = 1 + total_test_episodes = num_episodes_each_seed * len(seeds) + returns_mean_seeds = [] + returns_seeds = [] + for seed in seeds: + returns_mean, returns = eval_alphazero( + [main_config, create_config], + seed=seed, + num_episodes_each_seed=num_episodes_each_seed, + print_seed_details=True, + model_path=model_path + ) + returns_mean_seeds.append(returns_mean) + returns_seeds.append(returns) + + returns_mean_seeds = np.array(returns_mean_seeds) + returns_seeds = np.array(returns_seeds) + + print("=" * 20) + print(f'We eval total {len(seeds)} seeds. In each seed, we eval {num_episodes_each_seed} episodes.') + print(f'In seeds {seeds}, returns_mean_seeds is {returns_mean_seeds}, returns is {returns_seeds}') + print('In all seeds, reward_mean:', returns_mean_seeds.mean(), end='. ') + print( + f'win rate: {len(np.where(returns_seeds == 1.)[0]) / total_test_episodes}, draw rate: {len(np.where(returns_seeds == 0.)[0]) / total_test_episodes}, lose rate: {len(np.where(returns_seeds == -1.)[0]) / total_test_episodes}' + ) + print("=" * 20) diff --git a/zoo/board_games/gomoku/envs/gomoku_env.py b/zoo/board_games/gomoku/envs/gomoku_env.py index 684f861fa..aa4d4bd29 100644 --- a/zoo/board_games/gomoku/envs/gomoku_env.py +++ b/zoo/board_games/gomoku/envs/gomoku_env.py @@ -1,23 +1,25 @@ import copy -import random +import os import sys from functools import lru_cache -from typing import List +from typing import List, Any import gym +import imageio +import matplotlib.patches as patches +import matplotlib.pyplot as plt import numpy as np +import pygame from ding.envs import BaseEnv, BaseEnvTimestep from ding.utils import ENV_REGISTRY from ditk import logging from easydict import EasyDict -from zoo.board_games.gomoku.envs.legal_actions_cython import legal_actions_cython from zoo.board_games.gomoku.envs.get_done_winner_cython import get_done_winner_cython +from zoo.board_games.gomoku.envs.legal_actions_cython import legal_actions_cython from zoo.board_games.alphabeta_pruning_bot import AlphaBetaPruningBot +from zoo.board_games.gomoku.envs.gomoku_rule_bot_v0 import GomokuRuleBotV0 from zoo.board_games.gomoku.envs.gomoku_rule_bot_v1 import GomokuRuleBotV1 -from zoo.board_games.gomoku.envs.utils import check_action_to_special_connect4_case1, \ - check_action_to_special_connect4_case2, \ - check_action_to_connect4 @lru_cache(maxsize=512) @@ -41,18 +43,35 @@ def _get_done_winner_func_lru(board_size, board_tuple): @ENV_REGISTRY.register('gomoku') class GomokuEnv(BaseEnv): config = dict( + # (str) The name of the environment registered in the environment registry. env_name="Gomoku", - prob_random_agent=0, + # (int) The size of the board. board_size=6, + # (str) The mode of the environment when take a step. battle_mode='self_play_mode', + # (str) The mode of the environment when doing the MCTS. mcts_mode='self_play_mode', # only used in AlphaZero + # (str) The render mode. Options are 'None', 'state_realtime_mode', 'image_realtime_mode' or 'image_savefile_mode'. + # If None, then the game will not be rendered. + render_mode=None, + # (float) The scale of the render screen. + screen_scaling=9, + # (bool) Whether to use the 'channel last' format for the observation space. If False, 'channel first' format is used. channel_last=False, + # (bool) Whether to scale the observation. scale=True, + # (bool) Whether to let human to play with the agent when evaluating. If False, then use the bot to evaluate the agent. agent_vs_human=False, - bot_action_type='v0', # {'v0', 'alpha_beta_pruning'} + # (str) The type of the bot of the environment. + bot_action_type='v1', # {'v0', 'v1', 'alpha_beta_pruning'}, 'v1' is faster and stronger than 'v0' now. + # (float) The probability that a random agent is used instead of the learning agent. + prob_random_agent=0, + # (float) The probability that a random action will be taken when calling the bot. prob_random_action_in_bot=0., + # (bool) Whether to check the action to connect 4 in the bot v0. check_action_to_connect4_in_bot_v0=False, - stop_value=1, + # (float) The stop value when training the agent. If the evalue return reach the stop value, then the training will stop. + stop_value=2, ) @classmethod @@ -102,7 +121,18 @@ def __init__(self, cfg: dict = None): self.agent_vs_human = cfg.agent_vs_human self.bot_action_type = cfg.bot_action_type + # Set the parameters about replay render. + self.screen_scaling = cfg.screen_scaling + # options = {None, 'state_realtime_mode', 'image_realtime_mode', 'image_savefile_mode'} + self.render_mode = cfg.render_mode + self.replay_name_suffix = "test" + self.replay_path = None + self.replay_format = 'gif' # 'mp4' # + self.screen = None + self.frames = [] + self.players = [1, 2] + self._current_player = 1 self.board_markers = [str(i + 1) for i in range(self.board_size)] self.total_num_actions = self.board_size * self.board_size self.gomoku_rule_bot_v1 = GomokuRuleBotV1() @@ -110,6 +140,11 @@ def __init__(self, cfg: dict = None): if self.bot_action_type == 'alpha_beta_pruning': self.alpha_beta_pruning_player = AlphaBetaPruningBot(self, cfg, 'alpha_beta_pruning_player') + elif self.bot_action_type == 'v0': + self.rule_bot = GomokuRuleBotV0(self, self._current_player) + + self.fig, self.ax = plt.subplots(figsize=(self.board_size, self.board_size)) + plt.ion() def reset(self, start_player_index=0, init_state=None): self._observation_space = gym.spaces.Box( @@ -146,6 +181,11 @@ def reset(self, start_player_index=0, init_state=None): 'current_player_index': self.start_player_index, 'to_play': self.current_player } + + # Render the beginning state of the game. + if self.render_mode is not None: + self.render(self.render_mode) + return obs def reset_v2(self, start_player_index=0, init_state=None): @@ -167,7 +207,8 @@ def step(self, action): timestep = self._player_step(action) if timestep.done: # The eval_episode_return is calculated from Player 1's perspective. - timestep.info['eval_episode_return'] = -timestep.reward if timestep.obs['to_play'] == 1 else timestep.reward + timestep.info['eval_episode_return'] = -timestep.reward if timestep.obs[ + 'to_play'] == 1 else timestep.reward return timestep elif self.battle_mode == 'play_with_bot_mode': # player 1 battle with expert player 2 @@ -202,7 +243,7 @@ def step(self, action): timestep_player1 = self._player_step(action) if self.agent_vs_human: print('player 1 (agent): ' + self.action_to_string(action)) # Note: visualize - self.render() + self.render(mode="image_realtime_mode") if timestep_player1.done: # in eval_mode, we set to_play as None/-1, because we don't consider the alternation between players @@ -219,7 +260,7 @@ def step(self, action): timestep_player2 = self._player_step(bot_action) if self.agent_vs_human: print('player 2 (human): ' + self.action_to_string(bot_action)) # Note: visualize - self.render() + self.render(mode="image_realtime_mode") # the eval_episode_return is calculated from Player 1's perspective timestep_player2.info['eval_episode_return'] = -timestep_player2.reward @@ -254,9 +295,17 @@ def _player_step(self, action): """ self.current_player = self.to_play + # Render the new step. + # The following code is used to save the rendered images in both + # collect/eval step and the simulated mcts step. + # if self.render_mode is not None: + # self.render(self.render_mode) + if done: info['eval_episode_return'] = reward - # print('gomoku one episode done: ', info) + if self.render_mode == 'image_savefile_mode': + self.save_render_output(replay_name_suffix=self.replay_name_suffix, replay_path=self.replay_path, + format=self.replay_format) action_mask = np.zeros(self.total_num_actions, 'int8') action_mask[self.legal_actions] = 1 @@ -330,8 +379,10 @@ def bot_action(self): if np.random.rand() < self.prob_random_action_in_bot: return self.random_action() else: + # if self.bot_action_type == 'v0': + # return self.rule_bot_v0() if self.bot_action_type == 'v0': - return self.rule_bot_v0() + return self.rule_bot.get_rule_bot_action(self.board, self._current_player) elif self.bot_action_type == 'v1': return self.rule_bot_v1() elif self.bot_action_type == 'alpha_beta_pruning': @@ -348,299 +399,6 @@ def rule_bot_v1(self): obs = {'observation': self.current_state()[0], 'action_mask': action_mask} return self.gomoku_rule_bot_v1.get_action(obs) - def rule_bot_v0(self): - """ - Overview: - Hard coded agent v0 for gomoku env. - Considering the situation of to-connect-4 and to-connect-5 in a sliding window of 5X5, and lacks the consideration of the entire chessboard. - In each sliding window of 5X5, first random sample a action from legal_actions, - then take the action that will lead a connect4 or connect-5 of current/oppenent player's pieces. - Returns: - - action (:obj:`int`): the expert action to take in the current game state. - """ - assert self.board_size >= 5, "current rule_bot_v0 is only support board_size>=5!" - # To easily calculate expert action, we convert the chessboard notation: - # from player 1: 1, player 2: 2 - # to player 1: -1, player 2: 1 - # TODO: more elegant implementation - board_deepcopy = copy.deepcopy(self.board) - for i in range(board_deepcopy.shape[0]): - for j in range(board_deepcopy.shape[1]): - if board_deepcopy[i][j] == 1: - board_deepcopy[i][j] = -1 - elif board_deepcopy[i][j] == 2: - board_deepcopy[i][j] = 1 - - # first random sample a action from legal_actions - action = np.random.choice(self.legal_actions) - - size_of_board_template = 5 - shift_distance = [ - [i, j] for i in range(self.board_size - size_of_board_template + 1) - for j in range(self.board_size - size_of_board_template + 1) - ] - action_block_opponent_to_connect5 = None - action_to_connect4 = None - action_to_special_connect4_case1 = None - action_to_special_connect4_case2 = None - - min_to_connect = 3 - - for board_block_index in range((self.board_size - size_of_board_template + 1) ** 2): - """ - e.g., self.board_size=6 - board_block_index =[0,1,2,3] - shift_distance = (0,0), (0,1), (1,0), (1,1) - """ - shfit_tmp_board = copy.deepcopy( - board_deepcopy[shift_distance[board_block_index][0]:size_of_board_template + - shift_distance[board_block_index][0], - shift_distance[board_block_index][1]:size_of_board_template + - shift_distance[board_block_index][1]] - ) - - # Horizontal and vertical checks - for i in range(size_of_board_template): - if abs(sum(shfit_tmp_board[i, :])) >= min_to_connect: - # if i-th horizontal line has three same pieces and two empty position, or four same pieces and one opponent piece. - # e.g., case1: .xxx. , case2: oxxxx - - # find the index in the i-th horizontal line - zero_position_index = np.where(shfit_tmp_board[i, :] == 0)[0] - if zero_position_index.shape[0] == 0: - logging.debug( - 'there is no empty position in this searched five positions, continue to search...' - ) - else: - if zero_position_index.shape[0] == 2: - ind = random.choice(zero_position_index) - elif zero_position_index.shape[0] == 1: - ind = zero_position_index[0] - # convert ind to action - # the action that will lead a connect5 of current or opponent player's pieces - action = np.ravel_multi_index( - ( - np.array([i + shift_distance[board_block_index][0]] - ), np.array([ind + shift_distance[board_block_index][1]]) - ), (self.board_size, self.board_size) - )[0] - if self.check_action_to_connect4_in_bot_v0: - if check_action_to_special_connect4_case1(shfit_tmp_board[i, :]): - action_to_special_connect4_case1 = action - if check_action_to_special_connect4_case2(shfit_tmp_board[i, :]): - action_to_special_connect4_case2 = action - if check_action_to_connect4(shfit_tmp_board[i, :]): - action_to_connect4 = action - if (self.current_player_to_compute_bot_action * sum(shfit_tmp_board[i, :]) > 0) and abs(sum( - shfit_tmp_board[i, :])) == size_of_board_template - 1: - # immediately take the action that will lead a connect5 of current player's pieces - return action - if (self.current_player_to_compute_bot_action * sum(shfit_tmp_board[i, :]) < 0) and abs(sum( - shfit_tmp_board[i, :])) == size_of_board_template - 1: - # memory the action that will lead a connect5 of opponent player's pieces, to avoid the forget - action_block_opponent_to_connect5 = action - - if abs(sum(shfit_tmp_board[:, i])) >= min_to_connect: - # if i-th vertical has three same pieces and two empty position, or four same pieces and one opponent piece. - # e.g., case1: .xxx. , case2: oxxxx - - # find the index in the i-th vertical line - zero_position_index = np.where(shfit_tmp_board[:, i] == 0)[0] - if zero_position_index.shape[0] == 0: - logging.debug( - 'there is no empty position in this searched five positions, continue to search...' - ) - else: - if zero_position_index.shape[0] == 2: - ind = random.choice(zero_position_index) - elif zero_position_index.shape[0] == 1: - ind = zero_position_index[0] - - # convert ind to action - # the action that will lead a connect5 of current or opponent player's pieces - action = np.ravel_multi_index( - ( - np.array([ind + shift_distance[board_block_index][0]] - ), np.array([i + shift_distance[board_block_index][1]]) - ), (self.board_size, self.board_size) - )[0] - if self.check_action_to_connect4_in_bot_v0: - if check_action_to_special_connect4_case1(shfit_tmp_board[:, i]): - action_to_special_connect4_case1 = action - if check_action_to_special_connect4_case2(shfit_tmp_board[:, i]): - action_to_special_connect4_case2 = action - if check_action_to_connect4(shfit_tmp_board[:, i]): - action_to_connect4 = action - if (self.current_player_to_compute_bot_action * sum(shfit_tmp_board[:, i]) > 0) and abs(sum( - shfit_tmp_board[:, i])) == size_of_board_template - 1: - # immediately take the action that will lead a connect5 of current player's pieces - return action - if (self.current_player_to_compute_bot_action * sum(shfit_tmp_board[:, i]) < 0) and abs(sum( - shfit_tmp_board[:, i])) == size_of_board_template - 1: - # memory the action that will lead a connect5 of opponent player's pieces, to avoid the forget - action_block_opponent_to_connect5 = action - - # Diagonal checks - diag = shfit_tmp_board.diagonal() - anti_diag = np.fliplr(shfit_tmp_board).diagonal() - if abs(sum(diag)) >= min_to_connect: - # if diagonal has three same pieces and two empty position, or four same pieces and one opponent piece. - # e.g., case1: .xxx. , case2: oxxxx - # find the index in the diag vector - - zero_position_index = np.where(diag == 0)[0] - if zero_position_index.shape[0] == 0: - logging.debug('there is no empty position in this searched five positions, continue to search...') - else: - if zero_position_index.shape[0] == 2: - ind = random.choice(zero_position_index) - elif zero_position_index.shape[0] == 1: - ind = zero_position_index[0] - - # convert ind to action - # the action that will lead a connect5 of current or opponent player's pieces - action = np.ravel_multi_index( - ( - np.array([ind + shift_distance[board_block_index][0]] - ), np.array([ind + shift_distance[board_block_index][1]]) - ), (self.board_size, self.board_size) - )[0] - if self.check_action_to_connect4_in_bot_v0: - if check_action_to_special_connect4_case1(diag): - action_to_special_connect4_case1 = action - if check_action_to_special_connect4_case2(diag): - action_to_special_connect4_case2 = action - if check_action_to_connect4(diag): - action_to_connect4 = action - if self.current_player_to_compute_bot_action * sum(diag) > 0 and abs( - sum(diag)) == size_of_board_template - 1: - # immediately take the action that will lead a connect5 of current player's pieces - return action - if self.current_player_to_compute_bot_action * sum(diag) < 0 and abs( - sum(diag)) == size_of_board_template - 1: - # memory the action that will lead a connect5 of opponent player's pieces, to avoid the forget - action_block_opponent_to_connect5 = action - - if abs(sum(anti_diag)) >= min_to_connect: - # if anti-diagonal has three same pieces and two empty position, or four same pieces and one opponent piece. - # e.g., case1: .xxx. , case2: oxxxx - - # find the index in the anti_diag vector - zero_position_index = np.where(anti_diag == 0)[0] - if zero_position_index.shape[0] == 0: - logging.debug('there is no empty position in this searched five positions, continue to search...') - else: - if zero_position_index.shape[0] == 2: - ind = random.choice(zero_position_index) - elif zero_position_index.shape[0] == 1: - ind = zero_position_index[0] - # convert ind to action - # the action that will lead a connect5 of current or opponent player's pieces - action = np.ravel_multi_index( - ( - np.array([ind + shift_distance[board_block_index][0]]), - np.array([size_of_board_template - 1 - ind + shift_distance[board_block_index][1]]) - ), (self.board_size, self.board_size) - )[0] - if self.check_action_to_connect4_in_bot_v0: - if check_action_to_special_connect4_case1(anti_diag): - action_to_special_connect4_case1 = action - if check_action_to_special_connect4_case2(anti_diag): - action_to_special_connect4_case2 = action - if check_action_to_connect4(anti_diag): - action_to_connect4 = action - if self.current_player_to_compute_bot_action * sum(anti_diag) > 0 and abs( - sum(anti_diag)) == size_of_board_template - 1: - # immediately take the action that will lead a connect5 of current player's pieces - return action - if self.current_player_to_compute_bot_action * sum(anti_diag) < 0 and abs( - sum(anti_diag)) == size_of_board_template - 1: - # memory the action that will lead a connect5 of opponent player's pieces, to avoid the forget - action_block_opponent_to_connect5 = action - - if action_block_opponent_to_connect5 is not None: - return action_block_opponent_to_connect5 - elif action_to_special_connect4_case1 is not None: - return action_to_special_connect4_case1 - elif action_to_special_connect4_case2 is not None: - return action_to_special_connect4_case2 - elif action_to_connect4 is not None: - return action_to_connect4 - else: - return action - - def naive_rule_bot_v0_for_board_size_5(self): - """ - Overview: - Hard coded expert agent for gomoku env. - First random sample a action from legal_actions, then take the action that will lead a connect4 of current player's pieces. - Returns: - - action (:obj:`int`): the expert action to take in the current game state. - """ - assert self.board_size == 5, "current naive_rule_bot_v0 is only support board_size=5!" - # To easily calculate expert action, we convert the chessboard notation: - # from player 1: 1, player 2: 2 - # to player 1: -1, player 2: 1 - # TODO: more elegant implementation - board = copy.deepcopy(self.board) - for i in range(board.shape[0]): - for j in range(board.shape[1]): - if board[i][j] == 1: - board[i][j] = -1 - elif board[i][j] == 2: - board[i][j] = 1 - - # first random sample a action from legal_actions - action = np.random.choice(self.legal_actions) - # Horizontal and vertical checks - for i in range(self.board_size): - if abs(sum(board[i, :])) == 4: - # if i-th horizontal line has four same pieces and one empty position - # find the index in the i-th horizontal line - ind = np.where(board[i, :] == 0)[0][0] - # convert ind to action - action = np.ravel_multi_index((np.array([i]), np.array([ind])), (self.board_size, self.board_size))[0] - if self.current_player_to_compute_bot_action * sum(board[i, :]) > 0: - # immediately take the action that will lead a connect5 of current player's pieces - return action - - if abs(sum(board[:, i])) == 4: - # if i-th vertical line has two same pieces and one empty position - # find the index in the i-th vertical line - ind = np.where(board[:, i] == 0)[0][0] - # convert ind to action - action = np.ravel_multi_index((np.array([ind]), np.array([i])), (self.board_size, self.board_size))[0] - if self.current_player_to_compute_bot_action * sum(board[:, i]) > 0: - # immediately take the action that will lead a connect5 of current player's pieces - return action - - # Diagonal checks - diag = board.diagonal() - anti_diag = np.fliplr(board).diagonal() - if abs(sum(diag)) == 4: - # if diagonal has two same pieces and one empty position - # find the index in the diag vector - ind = np.where(diag == 0)[0][0] - # convert ind to action - action = np.ravel_multi_index((np.array([ind]), np.array([ind])), (self.board_size, self.board_size))[0] - if self.current_player_to_compute_bot_action * sum(diag) > 0: - # immediately take the action that will lead a connect5 of current player's pieces - return action - - if abs(sum(anti_diag)) == 4: - # if anti-diagonal has two same pieces and one empty position - # find the index in the anti_diag vector - ind = np.where(anti_diag == 0)[0][0] - # convert ind to action - action = np.ravel_multi_index( - (np.array([ind]), np.array([self.board_size - 1 - ind])), (self.board_size, self.board_size) - )[0] - if self.current_player_to_compute_bot_action * sum(anti_diag) > 0: - # immediately take the action that will lead a connect5 of current player's pieces - return action - - return action - @property def current_player(self): return self._current_player @@ -770,7 +528,179 @@ def seed(self, seed: int, dynamic_seed: bool = True) -> None: self._dynamic_seed = dynamic_seed np.random.seed(self._seed) - def render(self, mode="human"): + def draw_board(self): + """ + Overview: + This method draws the Gomoku board using matplotlib. + """ + + # Clear the previous board + self.ax.clear() + + # Set the limits of the x and y axes + self.ax.set_xlim(0, self.board_size + 1) + self.ax.set_ylim(self.board_size + 1, 0) + + # Set the board background color + self.ax.set_facecolor('peachpuff') + + # Draw the grid lines + for i in range(self.board_size + 1): + self.ax.plot([i + 1, i + 1], [1, self.board_size], color='black') + self.ax.plot([1, self.board_size], [i + 1, i + 1], color='black') + + def render(self, mode="state_realtime_mode"): + """ + Overview: + The render method is used to draw the current state of the game. The rendering mode can be + set according to the needs of the user. + Arguments: + - mode (str): Rendering mode, options are "state_realtime_mode", "image_realtime_mode", + and "image_savefile_mode". + """ + # Print the state of the board directly + if mode == "state_realtime_mode": + print(np.array(self.board).reshape(self.board_size, self.board_size)) + return + # Render the game as an image + elif mode == "image_realtime_mode" or mode == "image_savefile_mode": + self.draw_board() + # Draw the pieces on the board + for x in range(self.board_size): + for y in range(self.board_size): + if self.board[x][y] == 1: # Black piece + circle = patches.Circle((y + 1, x + 1), 0.4, edgecolor='black', + facecolor='black', zorder=3) + self.ax.add_patch(circle) + elif self.board[x][y] == 2: # White piece + circle = patches.Circle((y + 1, x + 1), 0.4, edgecolor='black', + facecolor='white', zorder=3) + self.ax.add_patch(circle) + # Set the title of the game + plt.title('Agent vs. Human: ' + ('Black Turn' if self.current_player == 1 else 'White Turn')) + # If in realtime mode, draw and pause briefly + if mode == "image_realtime_mode": + plt.draw() + plt.pause(0.1) + # In savefile mode, save the current frame to the frames list + elif mode == "image_savefile_mode": + # Save the current frame to the frames list. + self.fig.canvas.draw() + image = np.frombuffer(self.fig.canvas.tostring_rgb(), dtype='uint8') + image = image.reshape(self.fig.canvas.get_width_height()[::-1] + (3,)) + self.frames.append(image) + + def close(self): + """ + Overview: + This method is used to display the final game board to the user and turn off interactive + mode in matplotlib. + """ + plt.ioff() + plt.show() + + def render_for_b15(self, mode: str = None) -> None: + """ + Overview: + Renders the Gomoku (Five in a Row) game environment. Now only support board_size=15. + Arguments: + - mode (:obj:`str`): The mode to render with. Options are: None, 'human', 'state_realtime_mode', + 'image_realtime_mode', 'image_savefile_mode'. + """ + # 'state_realtime_mode' mode, print the current game board for rendering. + if mode == "state_realtime_mode": + print(np.array(self.board).reshape(self.board_size, self.board_size)) + return + else: + # Other modes, use a screen for rendering. + screen_width = self.board_size * self.screen_scaling + screen_height = self.board_size * self.screen_scaling + pygame.init() + self.screen = pygame.Surface((screen_width, screen_height)) + + # Load and scale all of the necessary images. + tile_size = screen_width / self.board_size + + black_chip = self.get_image(os.path.join("img", "Gomoku_BlackPiece.png")) + black_chip = pygame.transform.scale( + black_chip, (int(tile_size), int(tile_size)) + ) + + white_chip = self.get_image(os.path.join("img", "Gomoku_WhitePiece.png")) + white_chip = pygame.transform.scale( + white_chip, (int(tile_size), int(tile_size)) + ) + + board_img = self.get_image(os.path.join("img", "GomokuBoard.png")) + board_img = pygame.transform.scale( + board_img, (int(screen_width), int(screen_height)) + ) + + self.screen.blit(board_img, (0, 0)) + + # Blit the necessary chips and their positions. + for row in range(self.board_size): + for col in range(self.board_size): + if self.board[row][col] == 1: # Black piece + self.screen.blit( + black_chip, + ( + col * tile_size, + row * tile_size, + ), + ) + elif self.board[row][col] == 2: # White piece + self.screen.blit( + white_chip, + ( + col * tile_size, + row * tile_size, + ), + ) + if mode == "image_realtime_mode": + surface_array = pygame.surfarray.pixels3d(self.screen) + surface_array = np.transpose(surface_array, (1, 0, 2)) + plt.imshow(surface_array) + plt.draw() + plt.pause(0.001) + elif mode == "image_savefile_mode": + # Draw the observation and save to frames. + observation = np.array(pygame.surfarray.pixels3d(self.screen)) + self.frames.append(np.transpose(observation, axes=(1, 0, 2))) + + self.screen = None + + return None + + def save_render_output(self, replay_name_suffix: str = '', replay_path: str = None, format: str = 'gif') -> None: + """ + Overview: + Save the rendered frames as an output file. + Arguments: + - replay_name_suffix (:obj:`str`): The suffix to be added to the replay filename. + - replay_path (:obj:`str`): The path to save the replay file. If None, the default filename will be used. + - format (:obj:`str`): The format of the output file. Options are 'gif' or 'mp4'. + """ + # At the end of the episode, save the frames. + if replay_path is None: + filename = f'game_gomoku_{self.board_size}_{replay_name_suffix}.{format}' + else: + filename = f'{replay_path}.{format}' + + if format == 'gif': + # Save frames as a GIF with a duration of 0.1 seconds per frame. + # imageio.mimsave(filename, self.frames, 'GIF', duration=0.1) + imageio.mimsave(filename, self.frames, 'GIF', fps=30, subrectangles=True) + elif format == 'mp4': + # Save frames as an MP4 video with a frame rate of 30 frames per second. + imageio.mimsave(filename, self.frames, fps=30, codec='mpeg4') + + else: + raise ValueError("Unsupported format: {}".format(format)) + logging.info("Saved output to {}".format(filename)) + self.frames = [] + + def render_naive(self, mode="human"): marker = " " for i in range(self.board_size): if i <= 8: @@ -829,3 +759,13 @@ def __repr__(self) -> str: def close(self) -> None: pass + + def get_image(self, path: str) -> Any: + from os import path as os_path + import pygame + + cwd = os_path.dirname(__file__) + image = pygame.image.load(cwd + "/" + path) + sfc = pygame.Surface(image.get_size(), flags=pygame.SRCALPHA) + sfc.blit(image, (0, 0)) + return sfc diff --git a/zoo/board_games/gomoku/envs/gomoku_human_vs_bot_UI.py b/zoo/board_games/gomoku/envs/gomoku_human_vs_bot_UI.py new file mode 100644 index 000000000..3cd0de766 --- /dev/null +++ b/zoo/board_games/gomoku/envs/gomoku_human_vs_bot_UI.py @@ -0,0 +1,235 @@ +import os +import subprocess +from typing import Optional + +import matplotlib + +# Use the TkAgg backend for matplotlib +matplotlib.use("TkAgg") +import tkinter as tk + +import imageio +from PIL import ImageGrab + + +class GomokuUI(tk.Tk): + def __init__( + self, + gomoku_env: "GomokuEnv", + save_frames: bool = True + ) -> None: + """ + Overview: + Initialize the GomokuUI class. This class provides the user interface for the Gomoku game. + Arguments: + - gomoku_env (:obj:`GomokuEnv`): An instance of GomokuEnv which provides the game environment. + - save_frames (:obj:`bool`): A boolean to decide whether to save frames for creating a gif, default is True. + """ + tk.Tk.__init__(self) + self.env = gomoku_env + self.board_size = gomoku_env.board_size + self.cell_size = 50 # the size of each cell in the UI + self.canvas_size = self.cell_size * (self.board_size + 1) # the size of the canvas + + # Create a canvas for drawing + self.canvas = tk.Canvas(self, width=self.canvas_size, height=self.canvas_size, bg='peach puff') + self.canvas.pack() + self.frames = [] # used to store frames when save_frames is True + self.canvas.bind("", self.click) # bind the click event to the canvas + self.save_frames = save_frames + + def click(self, event: tk.Event) -> None: + """ + Overview: + This method is called every time the canvas is clicked. + Arguments: + - event (:obj:`tk.Event`): The event object containing information about the click. + """ + # Adjust the x and y coordinates to account for the boundary + adjusted_x = event.y - self.cell_size + adjusted_y = event.x - self.cell_size + + # Map the click to the nearest intersection point + x = (adjusted_x + self.cell_size // 2) // self.cell_size + y = (adjusted_y + self.cell_size // 2) // self.cell_size + + action = self.coord_to_action(x, y) + self.update_board(action, from_ui=True) + + def update_board(self, action: Optional[int] = None, from_ui: bool = False) -> None: + """ + Overview: + Update the board state based on the action taken. + Arguments: + - action (:obj:`int`, optional): The action to be taken, default is None. + - from_ui (:obj:`bool`, optional): Flag to indicate if action is from user interface, default is False. + """ + if from_ui: + print('player 1: ' + self.env.action_to_string(action)) + timestep = self.env.step(action) + self.timestep = timestep + obs = self.timestep.obs + else: + obs = {'board': self.env.board} + + # Update the board UI + for i in range(0, self.board_size): + for j in range(0, self.board_size): + if obs['board'][i][j] == 1: # black + color = 'black' + self.draw_piece(i, j, color) + elif obs['board'][i][j] == 2: # white + color = 'white' + self.draw_piece(i, j, color) + # else: + # # only for debug + # self.draw_piece(i, j, color) + if self.save_frames: + self.save_frame() + self.update_turn_label() + # time.sleep(0.1) + + # Check if the game has ended + if self.timestep.done: + self.quit() + + def draw_piece(self, x: int, y: int, color: str) -> None: + """ + Overview: + Draw a game piece on the board. + Arguments: + - x (:obj:`int`): The x-coordinate of the piece. + - y (:obj:`int`): The y-coordinate of the piece. + - color (:obj:`str`): The color of the piece. + """ + padding = self.cell_size // 2 + self.canvas.create_oval(y * self.cell_size + padding, x * self.cell_size + padding, + (y + 1) * self.cell_size + padding, (x + 1) * self.cell_size + padding, fill=color) + + def save_frame_bkp(self): + # Get the bounds of the window + x = self.canvas.winfo_rootx() + y = self.canvas.winfo_rooty() + x1 = x + self.canvas.winfo_width() + y1 = y + self.canvas.winfo_height() + + # Grab the image and save it + img = ImageGrab.grab(bbox=(x, y, x1, y1)) + img.save("frame.png") + + # Append the image to the frames + self.frames.append(imageio.imread("frame.png")) + + def save_frame(self) -> None: + """ + Overview: + Save the current frame of the game board. + """ + # Generate Postscript from the canvas + ps = self.canvas.postscript(colormode='color') + + # Use ImageMagick to convert the Postscript to PNG + with open('temp.ps', 'w') as f: + f.write(ps) + # subprocess.run(['convert', 'temp.ps', 'frame.png']) + subprocess.run(['convert', '-colorspace', 'sRGB', 'temp.ps', 'frame.png']) + os.remove('temp.ps') + + # Append the PNG to the frames + self.frames.append(imageio.imread('frame.png')) + + def save_gif(self, file_name: str) -> None: + """ + Overview: + Save all stored frames as a gif file. + Arguments: + - file_name (:obj:`str`): The name of the gif file to be saved. + """ + imageio.mimsave(file_name, self.frames, 'GIF', duration=0.1) + + def draw_board(self) -> None: + """ + Overview: + Draw the game board on the canvas. + """ + self.canvas.create_text(self.canvas_size // 2, self.cell_size // 2, text="Gomoku (Human vs AI)", + font=("Arial", 10)) + # Reduce the loop count to avoid drawing extra lines + for i in range(1, self.board_size + 1): + self.canvas.create_line(i * self.cell_size, self.cell_size, i * self.cell_size, + self.canvas_size - self.cell_size) + self.canvas.create_line(self.cell_size, i * self.cell_size, self.canvas_size - self.cell_size, + i * self.cell_size) + self.update_turn_label() + + def update_turn_label(self) -> None: + """ + Overview: + Update the turn label on the canvas. + """ + # Change the label text + turn_text = "Human's Turn (Black)" if self.env.current_player == 1 else "AI's Turn (White)" + self.canvas.create_text(self.canvas_size // 2 + 5, self.cell_size // 2 + 15, text=turn_text, font=("Arial", 10)) + + def coord_to_action(self, x: int, y: int) -> int: + """ + Overview: + Convert coordinates to an action. + Arguments: + - x (:obj:`int`): The x-coordinate. + - y (:obj:`int`): The y-coordinate. + Returns: + - action (:obj:`int`): The action corresponding to the coordinates. + """ + # Adjusted the coordinate system + return x * self.board_size + y + + +from easydict import EasyDict +from zoo.board_games.gomoku.envs.gomoku_env import GomokuEnv + + +def test_human_vs_bot_ui() -> None: + """ + Overview: + Test function for running a Gomoku game between a human player and a bot. + """ + cfg = EasyDict( + # board_size=15, + # board_size=9, + board_size=6, + battle_mode='play_with_bot_mode', + prob_random_agent=0, + channel_last=False, + scale=True, + agent_vs_human=False, + bot_action_type='v0', + prob_random_action_in_bot=0., + render_mode='state_realtime_mode', + screen_scaling=9, + check_action_to_connect4_in_bot_v0=False, + save_frames=True, + # save_frames=False, + ) + env = GomokuEnv(cfg) + env.reset() + game_ui = GomokuUI(env, save_frames=cfg.save_frames) + game_ui.draw_board() + + while True: + game_ui.mainloop() + if game_ui.timestep.done: + game_ui.save_gif('gomoku_human_vs_bot.gif') + + if game_ui.timestep.reward != 0 and game_ui.timestep.info['next player to play'] == 2: + print('player 1 (human player) win') + + elif game_ui.timestep.reward != 0 and game_ui.timestep.info['next player to play'] == 1: + print('player 2 (AI player) win') + else: + print('draw') + break + + +if __name__ == "__main__": + test_human_vs_bot_ui() diff --git a/zoo/board_games/gomoku/envs/gomoku_rule_bot_v0.py b/zoo/board_games/gomoku/envs/gomoku_rule_bot_v0.py new file mode 100644 index 000000000..9462a6685 --- /dev/null +++ b/zoo/board_games/gomoku/envs/gomoku_rule_bot_v0.py @@ -0,0 +1,376 @@ +import copy +from typing import Any, Tuple, List + +import numpy as np + + +class GomokuRuleBotV0: + """ + Overview: + The rule-based bot for the Gomoku game. The bot follows a set of rules in a certain order until a valid move is found.\ + The rules are: winning move, blocking move, do not take a move which may lead to opponent win in 3 steps, \ + forming a sequence of 4, forming a sequence of 3, forming a sequence of 2, and a random move. + """ + + def __init__(self, env: Any, player: int, search_only_in_neighbor: bool = False) -> None: + """ + Overview: + Initializes the bot with the game environment and the player it represents. + Arguments: + - env (:obj:`Any`): The game environment, which contains the game state and allows interactions with it. + - player (:obj:`int`): The player that the bot represents in the game. + """ + self.env = env + self.current_player = player + self.players = self.env.players + self.board_size = self.env.board_size + self.dp = None + self.search_only_in_neighbor = search_only_in_neighbor + + def get_neighbor_actions(self, board: np.ndarray) -> List[int]: + """ + Overview: + Get the legal actions in the neighborhood of existing pieces on the board. + Arguments: + - board (:obj:`np.ndarray`): The current game board. + Returns: + - neighbor_actions (:obj:`list` of :obj:`int`): The legal actions in the neighborhood of existing pieces. + """ + neighbor_actions = set() + for i in range(self.board_size): + for j in range(self.board_size): + # If there is a piece at (i, j) + if board[i, j] != 0: + # Check the neighborhood + for dx in [-1, 0, 1]: + for dy in [-1, 0, 1]: + nx, ny = i + dx, j + dy + # If the neighbor coordinate is valid and there is no piece at (nx, ny) + if 0 <= nx < self.board_size and 0 <= ny < self.board_size and board[nx, ny] == 0: + neighbor_action = self.env.coord_to_action(nx, ny) + neighbor_actions.add(neighbor_action) + return list(neighbor_actions) + + def get_rule_bot_action(self, board: np.ndarray, player: int) -> int: + """ + Overview: + Determines the next action of the bot based on the current game board and player. + Arguments: + - board (:obj:`np.ndarray`): The current game board. + - player (:obj:`int`): The current player. + Returns: + - action (:obj:`int`): The next action of the bot. + """ + if self.search_only_in_neighbor: + # Get the legal actions in the neighborhood of existing pieces. + self.legal_actions = self.get_neighbor_actions(board) + else: + self.legal_actions = self.env.legal_actions.copy() + self.current_player = player + self.next_player = self.players[0] if self.current_player == self.players[1] else self.players[1] + self.board = np.array(copy.deepcopy(board)).reshape(self.board_size, self.board_size) + # Initialize dp array if it's None + if self.dp is None: + self.dp = np.zeros((self.board_size, self.board_size, 8), dtype=int) + self.update_dp(self.board) + + # Check if there is a winning move. + for action in self.legal_actions: + if self.is_winning_move(action): + return action + + # Check if there is a move to block opponent's winning move. + for action in self.legal_actions: + if self.is_blocking_move(action): + return action + + # Remove the actions which may lead to opponent to win. + self.remove_actions() + + # If all the actions are removed, then randomly select an action. + if len(self.legal_actions) == 0: + return np.random.choice(self.env.legal_actions) + + # Check if there is a move to form a sequence of 4. + for action in self.legal_actions: + if self.is_sequence_X_move(action, 4): + return action + + # Check if there is a move to form a sequence of 3. + for action in self.legal_actions: + if self.is_sequence_X_move(action, 3): + return action + + # Check if there is a move to form a sequence of 2. + for action in self.legal_actions: + if self.is_sequence_X_move(action, 2): + return action + + # Randomly select a legal move. + return np.random.choice(self.legal_actions) + + def is_winning_move(self, action: int) -> bool: + """ + Overview: + Checks if an action is a winning move. + Arguments: + - action (:obj:`int`): The action to be checked. + Returns: + - result (:obj:`bool`): True if the action is a winning move; False otherwise. + """ + piece = self.current_player + temp_board, dp_backup = self._place_piece(action, piece) + + result = self.check_five_in_a_row(temp_board, piece) + # Restore the dp array + self.dp = dp_backup + return result + + def is_blocking_move(self, action: int) -> bool: + """ + Overview: + Checks if an action can block the opponent's winning move. + Arguments: + - action (:obj:`int`): The action to be checked. + Returns: + - result (:obj:`bool`): True if the action can block the opponent's winning move; False otherwise. + """ + piece = 2 if self.current_player == 1 else 1 + temp_board, dp_backup = self._place_piece(action, piece) + + result = self.check_five_in_a_row(temp_board, piece) + # Restore the dp array + self.dp = dp_backup + return result + + def is_winning_move_in_two_steps(self, action: int) -> bool: + """ + Overview: + Checks if the specified action can lead to a win in two steps. + Arguments: + - action (:obj:`int`): The action to be checked. + Returns: + - result (:obj:`bool`): True if the action can lead to a win in two steps; False otherwise. + """ + # Simulate the action + piece = self.current_player + # player_current_1step (assessing_action_now) -> player_opponent_1step -> player_current_2step -> player_opponent_2step + # -- action is here -- + temp_board, dp_backup = self._place_piece(action, piece) + temp = [self.board.copy(), self.current_player] + + # Swap players + self.board = temp_board + self.current_player = 3 - self.current_player + + # Get legal actions + legal_actions = [ + action + for action in range(self.board_size * self.board_size) + if self.board[self.env.action_to_coord(action)] == 0 + ] + # player_current_1step (assessing_action_now) -> player_opponent_1step -> player_current_2step -> player_opponent_2step + # -- action is here -- + # Check if the player_current_2step has a winning move. + if any(self.is_winning_move(action) for action in legal_actions): + self.board, self.current_player = temp + return False + + # player_current_1step (assessing_action_now) -> player_opponent_1step -> player_current_2step -> player_opponent_2step + # -- action is here -- + # Count blocking moves. If player_current_2step has more than two blocking_move, which means that + # if player_current take assessing_action_now, then the player_opponent_2step will have at least one wining move + blocking_count = sum(self.is_blocking_move(action) for action in legal_actions) + + # Restore the original state + self.board, self.current_player = temp + + # Check if there are more than one blocking moves + return blocking_count >= 2 + + def remove_actions(self) -> None: + """ + Overview: + Removes the actions from `self.legal_actions` that could potentially lead to the opponent's win. + """ + temp_list = self.legal_actions.copy() + for action in temp_list: + temp = [self.board.copy(), self.current_player] + + piece = self.current_player + action_x, action_y = self.env.action_to_coord(action) + self.board[action_x][action_y] = piece + + self.current_player = self.next_player + # Get legal actions + legal_actions = [ + action + for action in range(self.board_size * self.board_size) + if self.board[self.env.action_to_coord(action)] == 0 + ] + # print(f'if we take action {action}, then the legal actions for opponent are {legal_actions}') + for a in legal_actions: + if self.is_winning_move(a) or self.is_winning_move_in_two_steps(a): + self.legal_actions.remove(action) + # print(f"if take action {action}, then opponent take{a} may win") + # print(f"so we should remove action from {self.legal_actions}") + break + + self.board, self.current_player = temp + + def is_sequence_X_move(self, action: int, X: int) -> bool: + """ + Overview: + Checks if the specified action can form a sequence of 'X' pieces for the bot. + Arguments: + - action (:obj:`int`): The action to be checked. + - X (:obj:`int`): The length of the sequence to be checked. + Returns: + - result (:obj:`bool`): True if the action can form a sequence of 'X' pieces; False otherwise. + """ + piece = self.current_player + + temp_board, dp_backup = self._place_piece(action, piece) + + result = self.check_sequence_in_neighbor_board(temp_board, piece, X, action) + # Restore the dp array + self.dp = dp_backup + return result + + def _place_piece(self, action: int, piece: int) -> Tuple[np.ndarray, np.ndarray]: + """ + Overview: + Places a piece on the board and updates the 'dp' array. + Arguments: + - action (:obj:`int`): The action indicating where to place the piece. + - piece (:obj:`int`): The piece to be placed. + Returns: + - temp_board (:obj:`np.ndarray`): The updated game board. + - dp_backup (:obj:`np.ndarray`): The backup of the 'dp' array before updating. + """ + action_x, action_y = self.env.action_to_coord(action) + temp_board = self.board.copy() + temp_board[action_x][action_y] = piece + + # Backup the dp array + dp_backup = copy.deepcopy(self.dp) + # Update dp array + self.update_dp(temp_board) + + return temp_board, dp_backup + + def check_sequence_in_neighbor_board(self, board: np.ndarray, piece: int, seq_len: int, action: int) -> bool: + """ + Overview: + Checks if a sequence of the bot's pieces of a given length can be formed in the neighborhood of a given action. + Arguments: + - board (:obj:`np.ndarray`): The current game board. + - piece (:obj:`int`): The piece of the bot. + - seq_len (:obj:`int`): The length of the sequence to be checked. + - action (:obj:`int`): The action to be checked. + Returns: + - result (:obj:`bool`): True if the sequence of the bot's pieces can be formed; False otherwise. + """ + # Convert action to coordinates + row, col = self.env.action_to_coord(action) + + # Check horizontal locations + for c in range(max(0, col - seq_len + 1), min(self.board_size - seq_len + 1, col + 1)): + window = list(board[row, c:c + seq_len]) + if window.count(piece) == seq_len: + return True + + # Check vertical locations + for r in range(max(0, row - seq_len + 1), min(self.board_size - seq_len + 1, row + 1)): + window = list(board[r:r + seq_len, col]) + if window.count(piece) == seq_len: + return True + + # Check positively sloped diagonals + for r in range(max(0, row - seq_len + 1), min(self.board_size - seq_len + 1, row + 1)): + for c in range(max(0, col - seq_len + 1), min(self.board_size - seq_len + 1, col + 1)): + if r - c == row - col: + window = [board[r + i][c + i] for i in range(seq_len)] + if window.count(piece) == seq_len: + return True + + # Check negatively sloped diagonals + for r in range(max(0, row - seq_len + 1), min(self.board_size - seq_len + 1, row + 1)): + for c in range(max(0, col - seq_len + 1), min(self.board_size - seq_len + 1, col + 1)): + if r + c == row + col: + window = [board[r + i][c - i] for i in range(seq_len)] + if window.count(piece) == seq_len: + return True + + return False + + def update_dp(self, board: np.ndarray = None) -> None: + """ + Overview: + Updates the dynamic programming (dp) array based on the current game board. + Arguments: + - board (:obj:`np.ndarray`): The current game board. Defaults to None. + """ + directions = [(0, 1), (1, 0), (-1, 1), (1, 1)] + for i in range(self.board_size): + for j in range(self.board_size): + if board[i, j]: + for d, (dx, dy) in enumerate(directions): + nx, ny = i + dx, j + dy + if 0 <= nx < self.board_size and 0 <= ny < self.board_size: + self.dp[nx, ny, d] = (self.dp[i, j, d] + 1) if board[i, j] == board[nx, ny] else 0 + + def check_five_in_a_row(self, board: np.ndarray, piece: int) -> bool: + """ + Overview: + Uses the dynamic programming (dp) array to check if there are five of the bot's pieces in a row. + Arguments: + - board (:obj:`np.ndarray`): The current game board. + - piece (:obj:`int`): The piece of the bot. + Returns: + - result (:obj:`bool`): True if there are five of the bot's pieces in a row; False otherwise. + """ + directions = [(0, 1), (1, 0), (-1, 1), (1, 1)] # Four possible directions: right, down, down-left, down-right + for i in range(self.board_size): + for j in range(self.board_size): + if board[i, j] == piece: # If the piece at this location is the same as the one we're checking + for d, (dx, dy) in enumerate(directions): + # Check if the previous location is within the board's range + if 0 <= i - dx < self.board_size and 0 <= j - dy < self.board_size: + if self.dp[ + i, j, d] + 1 >= 5: # If there are at least 4 more pieces of the same type in this direction + return True # We found five in a row + + return False # If we checked every location and didn't find five in a row + + def check_five_in_a_row_naive(self, board: np.ndarray, piece: int) -> bool: + """ + Overview: + Checks if there are five of the bot's pieces in a row on the current game board. + Arguments: + - board (:obj:`np.ndarray`): The current game board. + - piece (:obj:`int`): The piece of the bot. + Returns: + - result (:obj:`bool`): True if there are five of the bot's pieces in a row; False otherwise. + """ + # Check horizontal and vertical locations + for i in range(self.board_size): + for j in range(self.board_size - 5 + 1): + # Check horizontal + if np.all(board[i, j:j + 5] == piece): + return True + # Check vertical + if np.all(board[j:j + 5, i] == piece): + return True + + # Check diagonals + for i in range(self.board_size - 5 + 1): + for j in range(self.board_size - 5 + 1): + # Check positively sloped diagonals + if np.all(board[range(i, i + 5), range(j, j + 5)] == piece): + return True + # Check negatively sloped diagonals + if np.all(board[range(i, i + 5), range(j + 5 - 1, j - 1, -1)] == piece): + return True + + return False diff --git a/zoo/board_games/gomoku/envs/gomoku_rule_bot_v1.py b/zoo/board_games/gomoku/envs/gomoku_rule_bot_v1.py index 6c2f7d225..530342467 100644 --- a/zoo/board_games/gomoku/envs/gomoku_rule_bot_v1.py +++ b/zoo/board_games/gomoku/envs/gomoku_rule_bot_v1.py @@ -1,4 +1,5 @@ # Reference link: + # https://github.com/LouisCaixuran/gomoku/blob/c1b6d508522d9e8c78be827f326bbee54c4dfd8b/gomoku/expert.py """ Sometimes, when GomokuRuleBotV1 has 4-connect, and the opponent also have 4-connect, GomokuRuleBotV1 will block the opponent and don't diff --git a/zoo/board_games/gomoku/envs/test_gomoku_env.py b/zoo/board_games/gomoku/envs/test_gomoku_env.py index 254caaf6f..472c8874a 100644 --- a/zoo/board_games/gomoku/envs/test_gomoku_env.py +++ b/zoo/board_games/gomoku/envs/test_gomoku_env.py @@ -18,30 +18,32 @@ def test_self_play_mode(self): bot_action_type='v0', prob_random_action_in_bot=0., check_action_to_connect4_in_bot_v0=False, + # (str) The render mode. Options are 'None', 'state_realtime_mode', 'image_realtime_mode' or 'image_savefile_mode'. + # If None, then the game will not be rendered. + render_mode=None, + screen_scaling=9, ) env = GomokuEnv(cfg) obs = env.reset() print('init board state: ') - env.render() while True: action = env.random_action() # action = env.human_to_action() print('player 1: ' + env.action_to_string(action)) obs, reward, done, info = env.step(action) - env.render() + env.render(mode=cfg.render_mode) if done: if reward > 0: print('player 1 (human player) win') else: print('draw') break - # action = env.bot_action() action = env.random_action() # action = env.human_to_action() print('player 2 (computer player): ' + env.action_to_string(action)) obs, reward, done, info = env.step(action) - env.render() + env.render(mode=cfg.render_mode) if done: if reward > 0: print('player 2 (computer player) win') @@ -60,11 +62,15 @@ def test_play_with_bot_mode(self): bot_action_type='v0', prob_random_action_in_bot=0., check_action_to_connect4_in_bot_v0=False, + # (str) The render mode. Options are 'None', 'state_realtime_mode', 'image_realtime_mode' or 'image_savefile_mode'. + # If None, then the game will not be rendered. + render_mode='state_realtime_mode', # 'image_realtime_mode' # "state_realtime_mode", + screen_scaling=9, ) env = GomokuEnv(cfg) env.reset() print('init board state: ') - env.render() + env.render(mode=cfg.render_mode) while True: """player 1""" # action = env.human_to_action() @@ -73,7 +79,7 @@ def test_play_with_bot_mode(self): print('player 1: ' + env.action_to_string(action)) obs, reward, done, info = env.step(action) # reward is in the perspective of player1 - env.render() + env.render(mode=cfg.render_mode) if done: if reward != 0 and info['next player to play'] == 2: print('player 1 (human player) win') diff --git a/zoo/board_games/gomoku/envs/test_gomoku_rule_bot_v0.py b/zoo/board_games/gomoku/envs/test_gomoku_rule_bot_v0.py index 2dec03ca0..1e9ec0f4e 100644 --- a/zoo/board_games/gomoku/envs/test_gomoku_rule_bot_v0.py +++ b/zoo/board_games/gomoku/envs/test_gomoku_rule_bot_v0.py @@ -11,8 +11,10 @@ scale=True, agent_vs_human=False, bot_action_type='v0', # {'v0', 'v1', 'alpha_beta_pruning'} - prob_random_action_in_bot=0.5, + prob_random_action_in_bot=0., check_action_to_connect4_in_bot_v0=False, + screen_scaling=9, + render_mode=None, ) @@ -21,11 +23,11 @@ class TestExpertActionV0: def test_naive(self): env = GomokuEnv(cfg) - test_episodes = 100 + test_episodes = 1 for i in range(test_episodes): obs = env.reset() # print('init board state: ', obs) - env.render() + # env.render('image_realtime_mode') while True: action = env.bot_action() # action = env.random_action() @@ -33,7 +35,7 @@ def test_naive(self): print('action index of player 1 is:', action) print('player 1: ' + env.action_to_string(action)) obs, reward, done, info = env.step(action) - env.render() + # env.render('image_realtime_mode') if done: print('=' * 20) if reward > 0: @@ -43,13 +45,13 @@ def test_naive(self): print('=' * 20) break - action = env.bot_action() + # action = env.bot_action() # action = env.random_action() - # action = env.human_to_action() + action = env.human_to_action() print('action index of player 2 is:', action) print('player 2: ' + env.action_to_string(action)) obs, reward, done, info = env.step(action) - env.render() + # env.render('image_realtime_mode') if done: print('=' * 20) if reward > 0: @@ -61,28 +63,28 @@ def test_naive(self): def test_v0_vs_v1(self): """ - board_size=6, test_100_episodes: + board_size=6, test 10 episodes: ================================================= - v0 vs v1: 0 bot_v0 win, 16 bot_v1 win, 84 draw - v1 vs v0: 1 bot_v0 win, 35 bot_v1 win, 64 draw - v0 vs v0: 100 draw - v1 vs v1: 100 draw - v0 vs random: 93 bot_v0 win, 35 random win, 7 draw - v1 vs random: 100 bot_v1 win, 0 random win, 0 draw + v0 vs v1: 0 bot_v0 win, 5 bot_v1 win, 5 draw + v1 vs v0: 0 bot_v0 win, 4 bot_v1 win, 6 draw + v0 vs v0: 0 player1 win, 4 player2 win, 6 draw + v1 vs v1: 0 player1 win, 0 player2 win, 10 draw + v0 vs random: 10 bot_v1 win, 0 random win, 0 draw + v1 vs random: 10 bot_v1 win, 0 random win, 0 draw ================================================= - board_size=5, test_100_episodes: + board_size=9, test 3 episodes: ================================================= - v0 vs v1: 0 bot_v0 win, 0 bot_v1 win, 100 draw - v1 vs v0: 1 bot_v0 win, 35 bot_v1 win, 64 draw - v0 vs v0: 100 draw - v1 vs v1: 100 draw - v0 vs random: 68 bot_v0 win, 0 random win, 32 draw - v1 vs random: 98 bot_v1 win, 0 random win, 2 draw + v0 vs v1: 0 bot_v0 win, 3 bot_v1 win, 0 draw + v1 vs v0: 3 bot_v0 win, 0 bot_v1 win, 0 draw + v0 vs v0: 3 player1 win, 0 player2 win, 0 draw + v1 vs v1: 0 player1 win, 0 player2 win, 3 draw + v0 vs random: 3 bot_v1 win, 0 random win, 0 draw + v1 vs random: 3 bot_v1 win, 0 random win, 0 draw ================================================= """ env = GomokuEnv(cfg) - test_episodes = 10 + test_episodes = 1 for i in range(test_episodes): obs = env.reset() # print('init board state: ', obs) @@ -123,6 +125,6 @@ def test_v0_vs_v1(self): break -test = TestExpertActionV0() +# test = TestExpertActionV0() # test.test_v0_vs_v1() -test.test_naive() +# test.test_naive()