diff --git a/src/alphazero/tree_search_methods/expand.py b/src/alphazero/tree_search_methods/expand.py index 5ae460f..436a648 100644 --- a/src/alphazero/tree_search_methods/expand.py +++ b/src/alphazero/tree_search_methods/expand.py @@ -1,9 +1,11 @@ import torch from src.alphazero.node import Node from src.utils.game_context import GameContext -from src.utils.tensor_utils import normalize_policy_values, normalize_policy_values_with_noise +from src.utils.tensor_utils import normalize_policy_values_with_noise from src.utils.random_utils import generate_dirichlet_noise + +@profile def expand(node: Node, nn_policy_values: torch.Tensor) -> None: """ Takes in a node, and adds all children. @@ -14,13 +16,18 @@ def expand(node: Node, nn_policy_values: torch.Tensor) -> None: Therefore, we normalize the policy values by applying the softmax normalization function to form a probability distribution for action selection. """ - legal_actions = node.state.legal_actions() - normalize_policy_values(nn_policy_values, legal_actions) - node.set_children_policy_values(nn_policy_values[legal_actions].to("cpu")) - for action in legal_actions: # Add the children with correct policy values - new_state = node.state.clone() + state = node.state + legal_actions = state.legal_actions() + nn_policy_values = nn_policy_values.cpu() + policy_values = torch.softmax(nn_policy_values[legal_actions], dim=0) + node.set_children_policy_values(policy_values) + + children = node.children + for action, policy_value in zip(legal_actions, policy_values): + new_state = state.clone() new_state.apply_action(action) - node.children.append(Node(node, new_state, action, nn_policy_values[action].item())) + children.append(Node(node, new_state, action, policy_value)) + def dirichlet_expand(context: GameContext, node: Node, nn_policy_values: torch.Tensor, alpha: float, epsilon: float) -> None: """ @@ -51,11 +58,12 @@ def dirichlet_expand(context: GameContext, node: Node, nn_policy_values: torch.T legal_actions = node.state.legal_actions() noise = generate_dirichlet_noise(context, len(legal_actions), alpha) normalize_policy_values_with_noise(nn_policy_values, legal_actions, noise, epsilon) - node.set_children_policy_values(nn_policy_values[legal_actions].to("cpu")) + policy_values = nn_policy_values.to("cpu") + node.set_children_policy_values(policy_values[legal_actions]) for action in legal_actions: # Add the children with correct policy values new_state = node.state.clone() new_state.apply_action(action) node.children.append( - Node(node, new_state, action, nn_policy_values[action]) + Node(node, new_state, action, policy_values[action]) ) \ No newline at end of file diff --git a/src/alphazero/tree_search_methods/select.py b/src/alphazero/tree_search_methods/select.py index a5f0e01..af15b04 100644 --- a/src/alphazero/tree_search_methods/select.py +++ b/src/alphazero/tree_search_methods/select.py @@ -1,7 +1,8 @@ import torch from src.alphazero.node import Node -def vectorized_select(node: Node, c: float) -> Node: # OPTIMIZATION for GPU, great speedup is expected when number of children is large. + +def vectorized_select(node: Node, c: float) -> Node: """ Select stage of MCTS. Go through the game tree, layer by layer. @@ -15,11 +16,11 @@ def vectorized_select(node: Node, c: float) -> Node: # OPTIMIZATION for GPU, gre return node.children[torch.argmax(node.children_policy_values).item()] const = c * node.visits**0.5 - + kids_visits = node.children_visits + # Compute PUCT for all children in a vectorized manner - Q = torch.where(node.children_visits > 0, node.children_values / node.children_visits, torch.zeros_like(node.children_values)) - Q.add_(const * (node.children_policy_values / node.children_visits.add(torch.ones_like(node.children_visits)) )) + PUCT = torch.where(kids_visits > 0, node.children_values / kids_visits, kids_visits).add_(node.children_policy_values.div(kids_visits.add(torch.ones_like(kids_visits))).mul(const)) - node = node.children[torch.argmax(Q).item()] # Return the best child node based on PUCT value + node = node.children[torch.argmax(PUCT).item()] # Return the best child node based on PUCT value return node \ No newline at end of file diff --git a/test/utils/test_tensor_utils.py b/test/utils/test_tensor_utils.py index cda8db0..15bc8ee 100644 --- a/test/utils/test_tensor_utils.py +++ b/test/utils/test_tensor_utils.py @@ -1,7 +1,7 @@ import torch from torch.distributions.dirichlet import Dirichlet from torch.testing import assert_close -from src.utils.tensor_utils import normalize_policy_values, normalize_policy_values_with_noise +from src.utils.tensor_utils import normalize_policy_values_with_noise """ @@ -11,45 +11,6 @@ Additionally, the function is checking that both tensors have the same shape, dtype, and device. """ -def test_normalize_policy_values_basic(): - """ - Tests basic functionality of normalize_policy_values to ensure it normalizes - policy values correctly for specified legal actions. - """ - nn_policy_values = torch.tensor([-0.4, 0, 0.7, 0], dtype=torch.float) - legal_actions = torch.tensor([1, 3], dtype=torch.long) - expected_output = torch.tensor([-0.4, 0.5, 0.7, 0.5]) # Expected softmax values for indices 1 and 3 - - normalize_policy_values(nn_policy_values, legal_actions) - - assert_close(nn_policy_values, expected_output) - -def test_normalize_policy_values_all_legal(): - """ - Tests the normalize_policy_values function when all actions are legal, - ensuring the entire tensor is normalized. - """ - nn_policy_values = torch.tensor([0.1, 0.2, 0.3, 0.4], dtype=torch.float) - legal_actions = torch.tensor([0, 1, 2, 3], dtype=torch.long) - expected_output = torch.softmax(nn_policy_values, dim=0) - - normalize_policy_values(nn_policy_values, legal_actions) - - assert_close(nn_policy_values, expected_output) - -def test_normalize_policy_values_no_legal(): - """ - Tests the normalize_policy_values function with an empty legal_actions tensor, - expecting the original nn_policy_values tensor to remain unchanged. - """ - nn_policy_values = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float) - legal_actions = torch.tensor([], dtype=torch.long) - expected_output = nn_policy_values.clone() - - normalize_policy_values(nn_policy_values, legal_actions) - - torch.testing.assert_close(nn_policy_values, expected_output) - alpha = 0.3 epsilon = 0.75