Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/opendilab/LightZero into de…
Browse files Browse the repository at this point in the history
…v-az-ctree
  • Loading branch information
puyuan1996 committed Nov 13, 2023
2 parents 4e08396 + cee2849 commit c01ad9f
Show file tree
Hide file tree
Showing 25 changed files with 1,332 additions and 430 deletions.
2 changes: 1 addition & 1 deletion lzero/entry/eval_alphazero.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def eval_alphazero(
if print_seed_details:
print("=" * 20)
print(f'In seed {seed}, returns: {returns}')
if cfg.policy.env_type == 'board_games':
if cfg.policy.simulation_env_name in ['tictactoe', 'connect4', 'gomoku', 'chess']:
print(
f'win rate: {len(np.where(returns == 1.)[0]) / num_episodes_each_seed}, draw rate: {len(np.where(returns == 0.)[0]) / num_episodes_each_seed}, lose rate: {len(np.where(returns == -1.)[0]) / num_episodes_each_seed}'
)
Expand Down
8 changes: 4 additions & 4 deletions lzero/mcts/buffer/game_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
from easydict import EasyDict

if TYPE_CHECKING:
from lzero.policy import MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy, GumeblMuZeroPolicy
from lzero.policy import MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy, GumbelMuZeroPolicy


@BUFFER_REGISTRY.register('game_buffer')
class GameBuffer(ABC, object):
"""
Overview:
The base game buffer class for MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy, GumeblMuZeroPolicy.
The base game buffer class for MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy, GumbelMuZeroPolicy.
"""

@classmethod
Expand Down Expand Up @@ -69,14 +69,14 @@ def __init__(self, cfg: dict):

@abstractmethod
def sample(
self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy", "GumeblMuZeroPolicy"]
self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy", "GumbelMuZeroPolicy"]
) -> List[Any]:
"""
Overview:
sample data from ``GameBuffer`` and prepare the current and target batch for training.
Arguments:
- batch_size (:obj:`int`): batch size.
- policy (:obj:`Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy", "GumeblMuZeroPolicy"]`): policy.
- policy (:obj:`Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy", "GumbelMuZeroPolicy"]`): policy.
Returns:
- train_data (:obj:`List`): List of train data, including current_batch and target_batch.
"""
Expand Down
6 changes: 3 additions & 3 deletions lzero/mcts/buffer/game_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,13 @@ def pad_over(
"""
assert len(next_segment_observations) <= self.num_unroll_steps
assert len(next_segment_child_visits) <= self.num_unroll_steps
assert len(next_segment_root_values) <= self.num_unroll_steps + self.num_unroll_steps
assert len(next_segment_rewards) <= self.num_unroll_steps + self.num_unroll_steps - 1
assert len(next_segment_root_values) <= self.num_unroll_steps + self.td_steps
assert len(next_segment_rewards) <= self.num_unroll_steps + self.td_steps - 1
# ==============================================================
# The core difference between GumbelMuZero and MuZero
# ==============================================================
if self.gumbel_algo:
assert len(next_segment_improved_policy) <= self.num_unroll_steps + self.num_unroll_steps
assert len(next_segment_improved_policy) <= self.num_unroll_steps + self.td_steps

# NOTE: next block observation should start from (stacked_observation - 1) in next trajectory
for observation in next_segment_observations:
Expand Down
1 change: 1 addition & 0 deletions lzero/mcts/ctree/ctree_sampled_efficientzero/lib/cnode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ namespace tree
#else
disc_action_with_probs.emplace_back(std::make_pair(iter, disturbed_probs[iter]));
#endif
// disc_action_with_probs.emplace_back(std::make_pair(iter, disturbed_probs[iter]));
}

std::sort(disc_action_with_probs.begin(), disc_action_with_probs.end(), cmp);
Expand Down
2 changes: 1 addition & 1 deletion lzero/model/alphazero_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ class PredictionNetwork(nn.Module):
def __init__(
self,
action_space_size: int,
continuous_action_space,
continuous_action_space: bool,
num_res_blocks: int,
num_channels: int,
value_head_channels: int,
Expand Down
6 changes: 4 additions & 2 deletions lzero/model/tests/test_alphazero_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ def output_check(self, model, outputs):
prediction_network_args
)
def test_prediction_network(
self, action_space_size, batch_size, num_res_blocks, num_channels, value_head_channels, policy_head_channels,
fc_value_layers, fc_policy_layers, output_support_size
self, action_space_size, batch_size, num_res_blocks, num_channels, value_head_channels,
policy_head_channels,
fc_value_layers, fc_policy_layers, output_support_size
):
obs = torch.rand(batch_size, num_channels, 3, 3)
flatten_output_size_for_value_head = value_head_channels * observation_shape[1] * observation_shape[2]
Expand All @@ -64,6 +65,7 @@ def test_prediction_network(
# print('='*20)
prediction_network = PredictionNetwork(
action_space_size=action_space_size,
continuous_action_space=False,
num_res_blocks=num_res_blocks,
num_channels=num_channels,
value_head_channels=value_head_channels,
Expand Down
6 changes: 5 additions & 1 deletion lzero/policy/gumbel_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@


@POLICY_REGISTRY.register('gumbel_muzero')
class GumeblMuZeroPolicy(MuZeroPolicy):
class GumbelMuZeroPolicy(MuZeroPolicy):
"""
Overview:
The policy class for Gumbel MuZero proposed in the paper https://openreview.net/forum?id=bERaNdoegnO.
Expand Down Expand Up @@ -486,6 +486,7 @@ def _forward_collect(
action_mask: list = None,
temperature: float = 1,
to_play: List = [-1],
epsilon: float = 0.25,
ready_env_id: np.array = None,
) -> Dict:
"""
Expand Down Expand Up @@ -514,6 +515,7 @@ def _forward_collect(
"""
self._collect_model.eval()
self._collect_mcts_temperature = temperature
self.collect_epsilon = epsilon
active_collect_env_num = data.shape[0]
with torch.no_grad():
# data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)}
Expand Down Expand Up @@ -560,6 +562,7 @@ def _forward_collect(

for i, env_id in enumerate(ready_env_id):
distributions, value, improved_policy_probs = roots_visit_count_distributions[i], roots_values[i], roots_improved_policy_probs[i]

roots_completed_value = roots_completed_values[i]
# NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents
# the index within the legal action set, rather than the index in the entire action set.
Expand All @@ -570,6 +573,7 @@ def _forward_collect(
# entire action set.
valid_value = np.where(action_mask[i] == 1.0, improved_policy_probs, 0.0)
action = np.argmax([v for v in valid_value])

output[env_id] = {
'action': action,
'visit_count_distributions': distributions,
Expand Down
2 changes: 2 additions & 0 deletions lzero/policy/muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,7 @@ def _forward_collect(
- action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected.
- temperature (:obj:`float`): The temperature of the policy.
- to_play (:obj:`int`): The player to play.
- epsilon (:obj:`float`): The epsilon of the eps greedy exploration.
- ready_env_id (:obj:`list`): The id of the env that is ready to collect.
Shape:
- data (:obj:`torch.Tensor`):
Expand All @@ -535,6 +536,7 @@ def _forward_collect(
- action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env.
- temperature: :math:`(1, )`.
- to_play: :math:`(N, 1)`, where N is the number of collect_env.
- epsilon: :math:`(1, )`.
- ready_env_id: None
Returns:
- output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \
Expand Down
41 changes: 21 additions & 20 deletions lzero/policy/sampled_alphazero.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def _init_eval(self) -> None:
self._get_simulation_env()
# TODO(pu): use double num_simulations for evaluation
if self._cfg.mcts_ctree:
self._eval_mcts = mcts_alphazero.MCTS(self._cfg.mcts.max_moves, min(4 * self._cfg.mcts.num_simulations, 800),
self._eval_mcts = mcts_alphazero.MCTS(self._cfg.mcts.max_moves, min(800, mcts_eval_config.num_simulations * 4),
self._cfg.mcts.pb_c_base,
self._cfg.mcts.pb_c_init, self._cfg.mcts.root_dirichlet_alpha,
self._cfg.mcts.root_noise_weight, self.simulate_env)
Expand All @@ -409,8 +409,9 @@ def _init_eval(self) -> None:
else:
from lzero.mcts.ptree.ptree_az import MCTS
mcts_eval_config = copy.deepcopy(self._cfg.mcts)
# The number of simulations for evaluation should be larger than that for collecting data.
mcts_eval_config.num_simulations = min(mcts_eval_config.num_simulations * 4, 800)
# TODO(pu): how to set proper num_simulations for evaluation
# mcts_eval_config.num_simulations = mcts_eval_config.num_simulations
mcts_eval_config.num_simulations = min(800, mcts_eval_config.num_simulations * 4)

self._eval_mcts = MCTS(mcts_eval_config, self.simulate_env)

Expand Down Expand Up @@ -464,46 +465,46 @@ def _forward_eval(self, obs: Dict) -> Dict[str, torch.Tensor]:
return output

def _get_simulation_env(self):
assert self._cfg.simulate_env_name in ['tictactoe', 'gomoku', 'go'], self._cfg.simulate_env_name
assert self._cfg.simulate_env_config_type in ['play_with_bot', 'self_play', 'league', 'sampled_play_with_bot'], self._cfg.simulate_env_config_type
if self._cfg.simulate_env_name == 'tictactoe':
assert self._cfg.simulation_env_name in ['tictactoe', 'gomoku', 'go'], self._cfg.simulation_env_name
assert self._cfg.simulation_env_config_type in ['play_with_bot', 'self_play', 'league', 'sampled_play_with_bot'], self._cfg.simulation_env_config_type
if self._cfg.simulation_env_name == 'tictactoe':
from zoo.board_games.tictactoe.envs.tictactoe_env import TicTacToeEnv
if self._cfg.simulate_env_config_type == 'play_with_bot':
if self._cfg.simulation_env_config_type == 'play_with_bot':
from zoo.board_games.tictactoe.config.tictactoe_alphazero_bot_mode_config import \
tictactoe_alphazero_config
elif self._cfg.simulate_env_config_type == 'self_play':
elif self._cfg.simulation_env_config_type == 'self_play':
from zoo.board_games.tictactoe.config.tictactoe_alphazero_sp_mode_config import \
tictactoe_alphazero_config
elif self._cfg.simulate_env_config_type == 'league':
elif self._cfg.simulation_env_config_type == 'league':
from zoo.board_games.tictactoe.config.tictactoe_alphazero_league_config import \
tictactoe_alphazero_config
elif self._cfg.simulate_env_config_type == 'sampled_play_with_bot':
elif self._cfg.simulation_env_config_type == 'sampled_play_with_bot':
from zoo.board_games.tictactoe.config.tictactoe_sampled_alphazero_bot_mode_config import \
tictactoe_sampled_alphazero_config as tictactoe_alphazero_config

self.simulate_env = TicTacToeEnv(tictactoe_alphazero_config.env)

elif self._cfg.simulate_env_name == 'gomoku':
elif self._cfg.simulation_env_name == 'gomoku':
from zoo.board_games.gomoku.envs.gomoku_env import GomokuEnv
if self._cfg.simulate_env_config_type == 'play_with_bot':
if self._cfg.simulation_env_config_type == 'play_with_bot':
from zoo.board_games.gomoku.config.gomoku_alphazero_bot_mode_config import gomoku_alphazero_config
elif self._cfg.simulate_env_config_type == 'self_play':
elif self._cfg.simulation_env_config_type == 'self_play':
from zoo.board_games.gomoku.config.gomoku_alphazero_sp_mode_config import gomoku_alphazero_config
elif self._cfg.simulate_env_config_type == 'league':
elif self._cfg.simulation_env_config_type == 'league':
from zoo.board_games.gomoku.config.gomoku_alphazero_league_config import gomoku_alphazero_config
elif self._cfg.simulate_env_config_type == 'sampled_play_with_bot':
elif self._cfg.simulation_env_config_type == 'sampled_play_with_bot':
from zoo.board_games.gomoku.config.gomoku_sampled_alphazero_bot_mode_config import gomoku_sampled_alphazero_config as gomoku_alphazero_config

self.simulate_env = GomokuEnv(gomoku_alphazero_config.env)
elif self._cfg.simulate_env_name == 'go':
elif self._cfg.simulation_env_name == 'go':
from zoo.board_games.go.envs.go_env import GoEnv
if self._cfg.simulate_env_config_type == 'play_with_bot':
if self._cfg.simulation_env_config_type == 'play_with_bot':
from zoo.board_games.go.config.go_alphazero_bot_mode_config import go_alphazero_config
elif self._cfg.simulate_env_config_type == 'self_play':
elif self._cfg.simulation_env_config_type == 'self_play':
from zoo.board_games.go.config.go_alphazero_sp_mode_config import go_alphazero_config
elif self._cfg.simulate_env_config_type == 'league':
elif self._cfg.simulation_env_config_type == 'league':
from zoo.board_games.go.config.go_alphazero_league_config import go_alphazero_config
elif self._cfg.simulate_env_config_type == 'sampled_play_with_bot':
elif self._cfg.simulation_env_config_type == 'sampled_play_with_bot':
from zoo.board_games.go.config.go_sampled_alphazero_bot_mode_config import \
go_sampled_alphazero_config as go_alphazero_config

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
screen_scaling=9,
render_mode=None,
# ==============================================================
screen_scaling=9,
render_mode=None,
),
policy=dict(
mcts_ctree=mcts_ctree,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,6 @@
type='gumbel_muzero',
import_names=['lzero.policy.gumbel_muzero'],
),
collector=dict(
type='gumbel_muzero',
import_names=['lzero.worker.gumbel_muzero_collector'],
)
)
gomoku_gumbel_muzero_create_config = EasyDict(gomoku_gumbel_muzero_create_config)
create_config = gomoku_gumbel_muzero_create_config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,6 @@
max_env_step = int(10e6)
prob_random_action_in_bot = 0.5
mcts_ctree = False

# board_size = 6
# collector_env_num = 2
# n_episode = 2
# evaluator_env_num = 2
# num_simulations = 5
# update_per_collect = 2
# batch_size = 2
# max_env_step = int(5e5)
# prob_random_action_in_bot = 0.5
# mcts_ctree = False
# num_of_sampled_actions = 5
# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================
Expand All @@ -41,20 +29,31 @@
battle_mode='play_with_bot_mode',
bot_action_type='v0',
prob_random_action_in_bot=prob_random_action_in_bot,
channel_last=False, # NOTE
channel_last=False,
use_katago_bot=False,
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=evaluator_env_num,
manager=dict(shared_memory=False, ),
env_name="gomoku",
# ==============================================================
# for the creation of simulation env
agent_vs_human=False,
prob_random_agent=0,
prob_expert_agent=0,
scale=True,
agent_vs_human=False,
check_action_to_connect4_in_bot_v0=False,
simulation_env_name="gomoku",
# ==============================================================
mcts_ctree=mcts_ctree,
screen_scaling=9,
render_mode=None,
),
policy=dict(
# ==============================================================
# for the creation of simulation env
simulation_env_name='gomoku',
simulation_env_config_type='sampled_play_with_bot',
# ==============================================================
torch_compile=False,
tensor_float_32=False,
model=dict(
Expand All @@ -65,8 +64,6 @@
),
sampled_algo=True,
mcts_ctree=mcts_ctree,
simulate_env_config_type='sampled_play_with_bot',
simulate_env_name="gomoku",
policy_loss_type='KL',
cuda=True,
board_size=board_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,25 @@
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=evaluator_env_num,
manager=dict(shared_memory=False, ),
env_name="Gomoku",
# ==============================================================
# for the creation of simulation env
agent_vs_human=False,
prob_random_agent=0,
prob_expert_agent=0,
scale=True,
agent_vs_human=False,
check_action_to_connect4_in_bot_v0=False,
simulation_env_name="gomoku",
# ==============================================================
mcts_ctree=mcts_ctree,
screen_scaling=9,
render_mode=None,
),
policy=dict(
# ==============================================================
# for the creation of simulation env
simulation_env_name='gomoku',
simulation_env_config_type='sampled_self_play',
# ==============================================================
torch_compile=False,
tensor_float_32=False,
model=dict(
Expand All @@ -48,9 +59,7 @@
num_channels=32,
),
sampled_algo=True,
simulate_env_name="gomoku",
mcts_ctree=mcts_ctree,
simulate_env_config_type='sampled_self_play',
policy_loss_type='KL',
cuda=True,
board_size=board_size,
Expand Down
3 changes: 2 additions & 1 deletion zoo/board_games/gomoku/entry/gomoku_alphazero_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
seeds = [0]
num_episodes_each_seed = 5
# If True, you can play with the agent.
main_config.env.agent_vs_human = False
main_config.env.agent_vs_human = True
main_config.env.render_mode = 'image_realtime_mode'
create_config.env_manager.type = 'base'
main_config.env.evaluator_env_num = 1
main_config.env.n_evaluator_episode = 1
Expand Down
Loading

0 comments on commit c01ad9f

Please sign in to comment.