Skip to content

Commit

Permalink
feature(pu): add sampled alphazero and polish gomoku env (#141)
Browse files Browse the repository at this point in the history
* feature(pu): add sampled alphazero

* feature(pu): polish gomoku render method and add gomoku_human_vs_bot_ui

* polish(pu): polish gomoku render method

* polish(pu): polish gomoku_rule_bot_v0

* polish(pu): polish gomoku_human_vs_bot_UI

* polish(pu): polish gomoku_rule_bot_v0

* polish(pu): use dp to check_five_in_a_row in gomoku_rule_bot_v0

* fix(pu): fix emplace_back

* polish(pu): polish comments and eval num_sim
  • Loading branch information
puyuan1996 authored Nov 13, 2023
1 parent 8285db3 commit cee2849
Show file tree
Hide file tree
Showing 21 changed files with 2,495 additions and 383 deletions.
2 changes: 1 addition & 1 deletion lzero/entry/eval_alphazero.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def eval_alphazero(
if print_seed_details:
print("=" * 20)
print(f'In seed {seed}, returns: {returns}')
if cfg.policy.env_type == 'board_games':
if cfg.policy.simulation_env_name in ['tictactoe', 'connect4', 'gomoku', 'chess']:
print(
f'win rate: {len(np.where(returns == 1.)[0]) / num_episodes_each_seed}, draw rate: {len(np.where(returns == 0.)[0]) / num_episodes_each_seed}, lose rate: {len(np.where(returns == -1.)[0]) / num_episodes_each_seed}'
)
Expand Down
1 change: 1 addition & 0 deletions lzero/mcts/ctree/ctree_sampled_efficientzero/lib/cnode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ namespace tree
#else
disc_action_with_probs.emplace_back(std::make_pair(iter, disturbed_probs[iter]));
#endif
// disc_action_with_probs.emplace_back(std::make_pair(iter, disturbed_probs[iter]));
}

std::sort(disc_action_with_probs.begin(), disc_action_with_probs.end(), cmp);
Expand Down
521 changes: 521 additions & 0 deletions lzero/mcts/ptree/ptree_az_sampled.py

Large diffs are not rendered by default.

73 changes: 66 additions & 7 deletions lzero/mcts/ptree/ptree_sez.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ def expand(

for action_index in range(self.num_of_sampled_actions):
self.children[Action(sampled_actions[action_index].detach().cpu().numpy())] = Node(
# prob[action_index], # NOTE: this is a bug
prob[sampled_actions[action_index]], #
action_space_size=self.action_space_size,
num_of_sampled_actions=self.num_of_sampled_actions,
Expand Down Expand Up @@ -403,6 +402,7 @@ def get_sampled_actions(self) -> List[List[Union[int, float]]]:
- python_sampled_actions: a vector of sampled_actions for each root, e.g. the size of original action space is 6, the K=3,
python_sampled_actions = [[1,3,0], [2,4,0], [5,4,1]].
"""
# TODO(pu): root_sampled_actions bug in discere action space?
sampled_actions = []
for i in range(self.root_num):
sampled_actions.append(self.roots[i].legal_actions)
Expand Down Expand Up @@ -774,20 +774,79 @@ def batch_backpropagate(
)


from typing import Union
import numpy as np

class Action:
"""Class that represent an action of a game."""
"""
Class that represents an action of a game.
Attributes:
value (Union[int, np.ndarray]): The value of the action. Can be either an integer or a numpy array.
"""

def __init__(self, value: Union[int, np.ndarray]) -> None:
"""
Initializes the Action with the given value.
def __init__(self, value: float) -> None:
Args:
value (Union[int, np.ndarray]): The value of the action.
"""
self.value = value

def __hash__(self) -> hash:
return hash(self.value.tostring())
def __hash__(self) -> int:
"""
Returns a hash of the Action's value.
If the value is a numpy array, it is flattened to a tuple and then hashed.
If the value is a single integer, it is hashed directly.
Returns:
int: The hash of the Action's value.
"""
if isinstance(self.value, np.ndarray):
if self.value.ndim == 0:
return hash(self.value.item())
else:
return hash(tuple(self.value.flatten()))
else:
return hash(self.value)

def __eq__(self, other: "Action") -> bool:
return (self.value == other.value).all()
"""
Determines if this Action is equal to another Action.
If both values are numpy arrays, they are compared element-wise.
Otherwise, they are compared directly.
Args:
other (Action): The Action to compare with.
Returns:
bool: True if the two Actions are equal, False otherwise.
"""
if isinstance(self.value, np.ndarray) and isinstance(other.value, np.ndarray):
return np.array_equal(self.value, other.value)
else:
return self.value == other.value

def __gt__(self, other: "Action") -> bool:
return self.value[0] > other.value[0]
"""
Determines if this Action's value is greater than another Action's value.
Args:
other (Action): The Action to compare with.
Returns:
bool: True if this Action's value is greater, False otherwise.
"""
return self.value > other.value

def __repr__(self) -> str:
"""
Returns a string representation of this Action.
Returns:
str: A string representation of the Action's value.
"""
return str(self.value)
107 changes: 88 additions & 19 deletions lzero/model/alphazero_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from ding.model import ReparameterizationHead
from ding.torch_utils import MLP, ResBlock
from ding.utils import MODEL_REGISTRY, SequenceType

Expand All @@ -34,6 +35,16 @@ def __init__(
fc_value_layers: SequenceType = [32],
fc_policy_layers: SequenceType = [32],
value_support_size: int = 601,
# ==============================================================
# specific sampled related config
# ==============================================================
continuous_action_space: bool = False,
num_of_sampled_actions: int = 6,
sigma_type='conditioned',
fixed_sigma_value: float = 0.3,
bound_type: str = None,
norm_type: str = 'BN',
discrete_action_encoding_type: str = 'one_hot',
):
"""
Overview:
Expand Down Expand Up @@ -70,7 +81,26 @@ def __init__(
self.last_linear_layer_init_zero = last_linear_layer_init_zero
self.representation_network = representation_network

self.continuous_action_space = continuous_action_space
self.action_space_size = action_space_size
# The dim of action space. For discrete action space, it's 1.
# For continuous action space, it is the dim of action.
self.action_space_dim = action_space_size if self.continuous_action_space else 1
assert discrete_action_encoding_type in ['one_hot', 'not_one_hot'], discrete_action_encoding_type
self.discrete_action_encoding_type = discrete_action_encoding_type
if self.continuous_action_space:
self.action_encoding_dim = action_space_size
else:
if self.discrete_action_encoding_type == 'one_hot':
self.action_encoding_dim = action_space_size
elif self.discrete_action_encoding_type == 'not_one_hot':
self.action_encoding_dim = 1
self.sigma_type = sigma_type
self.fixed_sigma_value = fixed_sigma_value
self.bound_type = bound_type
self.norm_type = norm_type
self.num_of_sampled_actions = num_of_sampled_actions

# TODO use more adaptive way to get the flatten output size
flatten_output_size_for_value_head = (
(
Expand All @@ -88,6 +118,7 @@ def __init__(

self.prediction_network = PredictionNetwork(
action_space_size,
self.continuous_action_space,
num_res_blocks,
num_channels,
value_head_channels,
Expand All @@ -99,6 +130,10 @@ def __init__(
flatten_output_size_for_policy_head,
last_linear_layer_init_zero=self.last_linear_layer_init_zero,
activation=activation,
sigma_type=self.sigma_type,
fixed_sigma_value=self.fixed_sigma_value,
bound_type=self.bound_type,
norm_type=self.norm_type,
)

if self.representation_network is None:
Expand Down Expand Up @@ -131,7 +166,7 @@ def forward(self, state_batch: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor
logit, value = self.prediction_network(encoded_state)
return logit, value

def compute_prob_value(self, state_batch: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def compute_policy_value(self, state_batch: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Overview:
The computation graph of AlphaZero model to calculate action selection probability and value.
Expand All @@ -147,9 +182,7 @@ def compute_prob_value(self, state_batch: torch.Tensor) -> Tuple[torch.Tensor, t
- value (:obj:`torch.Tensor`): :math:`(B, 1)`, where B is batch size.
"""
logit, value = self.forward(state_batch)
# construct categorical distribution to calculate probability
dist = torch.distributions.Categorical(logits=logit)
prob = dist.probs
prob = torch.nn.functional.softmax(logit, dim=-1)
return prob, value

def compute_logp_value(self, state_batch: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -178,6 +211,7 @@ class PredictionNetwork(nn.Module):
def __init__(
self,
action_space_size: int,
continuous_action_space: bool,
num_res_blocks: int,
num_channels: int,
value_head_channels: int,
Expand All @@ -189,6 +223,13 @@ def __init__(
flatten_output_size_for_policy_head: int,
last_linear_layer_init_zero: bool = True,
activation: Optional[nn.Module] = nn.ReLU(inplace=True),
# ==============================================================
# specific sampled related config
# ==============================================================
sigma_type='conditioned',
fixed_sigma_value: float = 0.3,
bound_type: str = None,
norm_type: str = 'BN',
) -> None:
"""
Overview:
Expand All @@ -213,6 +254,15 @@ def __init__(
operation to speedup, e.g. ReLU(inplace=True).
"""
super().__init__()
self.continuous_action_space = continuous_action_space
self.flatten_output_size_for_value_head = flatten_output_size_for_value_head
self.flatten_output_size_for_policy_head = flatten_output_size_for_policy_head
self.norm_type = norm_type
self.sigma_type = sigma_type
self.fixed_sigma_value = fixed_sigma_value
self.bound_type = bound_type
self.activation = activation

self.resblocks = nn.ModuleList(
[
ResBlock(in_channels=num_channels, activation=activation, norm_type='BN', res_type='basic', bias=False)
Expand All @@ -226,7 +276,7 @@ def __init__(
self.norm_policy = nn.BatchNorm2d(policy_head_channels)
self.flatten_output_size_for_value_head = flatten_output_size_for_value_head
self.flatten_output_size_for_policy_head = flatten_output_size_for_policy_head
self.fc_value = MLP(
self.fc_value_head = MLP(
in_channels=self.flatten_output_size_for_value_head,
hidden_channels=fc_value_layers[0],
out_channels=output_support_size,
Expand All @@ -237,17 +287,31 @@ def __init__(
output_norm=False,
last_linear_layer_init_zero=last_linear_layer_init_zero
)
self.fc_policy = MLP(
in_channels=self.flatten_output_size_for_policy_head,
hidden_channels=fc_policy_layers[0],
out_channels=action_space_size,
layer_num=len(fc_policy_layers) + 1,
activation=activation,
norm_type='LN',
output_activation=False,
output_norm=False,
last_linear_layer_init_zero=last_linear_layer_init_zero
)

# sampled related core code
if self.continuous_action_space:
self.fc_policy_head = ReparameterizationHead(
input_size=self.flatten_output_size_for_policy_head,
output_size=action_space_size,
layer_num=len(fc_policy_layers) + 1,
sigma_type=self.sigma_type,
fixed_sigma_value=self.fixed_sigma_value,
activation=nn.ReLU(),
norm_type=None,
bound_type=self.bound_type
)
else:
self.fc_policy_head = MLP(
in_channels=self.flatten_output_size_for_policy_head,
hidden_channels=fc_policy_layers[0],
out_channels=action_space_size,
layer_num=len(fc_policy_layers) + 1,
activation=activation,
norm_type='LN',
output_activation=False,
output_norm=False,
last_linear_layer_init_zero=last_linear_layer_init_zero
)

self.activation = activation

Expand Down Expand Up @@ -279,6 +343,11 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
value = value.reshape(-1, self.flatten_output_size_for_value_head)
policy = policy.reshape(-1, self.flatten_output_size_for_policy_head)

value = self.fc_value(value)
logit = self.fc_policy(policy)
return logit, value
value = self.fc_value_head(value)

# sampled related core code
policy = self.fc_policy_head(policy)
if self.continuous_action_space:
policy = torch.cat([policy['mu'], policy['sigma']], dim=-1)

return policy, value
6 changes: 4 additions & 2 deletions lzero/model/tests/test_alphazero_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ def output_check(self, model, outputs):
prediction_network_args
)
def test_prediction_network(
self, action_space_size, batch_size, num_res_blocks, num_channels, value_head_channels, policy_head_channels,
fc_value_layers, fc_policy_layers, output_support_size
self, action_space_size, batch_size, num_res_blocks, num_channels, value_head_channels,
policy_head_channels,
fc_value_layers, fc_policy_layers, output_support_size
):
obs = torch.rand(batch_size, num_channels, 3, 3)
flatten_output_size_for_value_head = value_head_channels * observation_shape[1] * observation_shape[2]
Expand All @@ -64,6 +65,7 @@ def test_prediction_network(
# print('='*20)
prediction_network = PredictionNetwork(
action_space_size=action_space_size,
continuous_action_space=False,
num_res_blocks=num_res_blocks,
num_channels=num_channels,
value_head_channels=value_head_channels,
Expand Down
8 changes: 5 additions & 3 deletions lzero/policy/alphazero.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def _forward_learn(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, float]:
mcts_probs = mcts_probs.to(device=self._device, dtype=torch.float)
reward = reward.to(device=self._device, dtype=torch.float)

action_probs, values = self._learn_model.compute_prob_value(state_batch)
action_probs, values = self._learn_model.compute_policy_value(state_batch)
log_probs = torch.log(action_probs)

# calculate policy entropy, for monitoring only
Expand Down Expand Up @@ -258,7 +258,9 @@ def _init_eval(self) -> None:
self._get_simulation_env()
import copy
mcts_eval_config = copy.deepcopy(self._cfg.mcts)
mcts_eval_config.num_simulations = mcts_eval_config.num_simulations * 2
# TODO(pu): how to set proper num_simulations for evaluation
# mcts_eval_config.num_simulations = mcts_eval_config.num_simulations
mcts_eval_config.num_simulations = min(800, mcts_eval_config.num_simulations * 4)
self._eval_mcts = MCTS(mcts_eval_config, self.simulate_env)
self._eval_model = self._model

Expand Down Expand Up @@ -323,7 +325,7 @@ def _policy_value_fn(self, env: 'Env') -> Tuple[Dict[int, np.ndarray], float]:
device=self._device, dtype=torch.float
).unsqueeze(0)
with torch.no_grad():
action_probs, value = self._policy_model.compute_prob_value(current_state_scale)
action_probs, value = self._policy_model.compute_policy_value(current_state_scale)
action_probs_dict = dict(zip(legal_actions, action_probs.squeeze(0)[legal_actions].detach().cpu().numpy()))
return action_probs_dict, value.item()

Expand Down
Loading

0 comments on commit cee2849

Please sign in to comment.