Skip to content

Commit

Permalink
optimization: expand is now almost twice as fast as before. Removed a…
Browse files Browse the repository at this point in the history
… function, and corresponding tests. Small change to select as well, which gives slight speedup.
  • Loading branch information
ChristianFredrikJohnsen committed Apr 27, 2024
1 parent e652b0a commit 7846aa5
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 54 deletions.
26 changes: 17 additions & 9 deletions src/alphazero/tree_search_methods/expand.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import torch
from src.alphazero.node import Node
from src.utils.game_context import GameContext
from src.utils.tensor_utils import normalize_policy_values, normalize_policy_values_with_noise
from src.utils.tensor_utils import normalize_policy_values_with_noise
from src.utils.random_utils import generate_dirichlet_noise


@profile
def expand(node: Node, nn_policy_values: torch.Tensor) -> None:
"""
Takes in a node, and adds all children.
Expand All @@ -14,13 +16,18 @@ def expand(node: Node, nn_policy_values: torch.Tensor) -> None:
Therefore, we normalize the policy values by applying the softmax normalization function to form a probability
distribution for action selection.
"""
legal_actions = node.state.legal_actions()
normalize_policy_values(nn_policy_values, legal_actions)
node.set_children_policy_values(nn_policy_values[legal_actions].to("cpu"))
for action in legal_actions: # Add the children with correct policy values
new_state = node.state.clone()
state = node.state
legal_actions = state.legal_actions()
nn_policy_values = nn_policy_values.cpu()
policy_values = torch.softmax(nn_policy_values[legal_actions], dim=0)
node.set_children_policy_values(policy_values)

children = node.children
for action, policy_value in zip(legal_actions, policy_values):
new_state = state.clone()
new_state.apply_action(action)
node.children.append(Node(node, new_state, action, nn_policy_values[action].item()))
children.append(Node(node, new_state, action, policy_value))


def dirichlet_expand(context: GameContext, node: Node, nn_policy_values: torch.Tensor, alpha: float, epsilon: float) -> None:
"""
Expand Down Expand Up @@ -51,11 +58,12 @@ def dirichlet_expand(context: GameContext, node: Node, nn_policy_values: torch.T
legal_actions = node.state.legal_actions()
noise = generate_dirichlet_noise(context, len(legal_actions), alpha)
normalize_policy_values_with_noise(nn_policy_values, legal_actions, noise, epsilon)
node.set_children_policy_values(nn_policy_values[legal_actions].to("cpu"))
policy_values = nn_policy_values.to("cpu")
node.set_children_policy_values(policy_values[legal_actions])

for action in legal_actions: # Add the children with correct policy values
new_state = node.state.clone()
new_state.apply_action(action)
node.children.append(
Node(node, new_state, action, nn_policy_values[action])
Node(node, new_state, action, policy_values[action])
)
11 changes: 6 additions & 5 deletions src/alphazero/tree_search_methods/select.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch
from src.alphazero.node import Node

def vectorized_select(node: Node, c: float) -> Node: # OPTIMIZATION for GPU, great speedup is expected when number of children is large.

def vectorized_select(node: Node, c: float) -> Node:
"""
Select stage of MCTS.
Go through the game tree, layer by layer.
Expand All @@ -15,11 +16,11 @@ def vectorized_select(node: Node, c: float) -> Node: # OPTIMIZATION for GPU, gre
return node.children[torch.argmax(node.children_policy_values).item()]

const = c * node.visits**0.5

kids_visits = node.children_visits

# Compute PUCT for all children in a vectorized manner
Q = torch.where(node.children_visits > 0, node.children_values / node.children_visits, torch.zeros_like(node.children_values))
Q.add_(const * (node.children_policy_values / node.children_visits.add(torch.ones_like(node.children_visits)) ))
PUCT = torch.where(kids_visits > 0, node.children_values / kids_visits, kids_visits).add_(node.children_policy_values.div(kids_visits.add(torch.ones_like(kids_visits))).mul(const))

node = node.children[torch.argmax(Q).item()] # Return the best child node based on PUCT value
node = node.children[torch.argmax(PUCT).item()] # Return the best child node based on PUCT value

return node
41 changes: 1 addition & 40 deletions test/utils/test_tensor_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from torch.distributions.dirichlet import Dirichlet
from torch.testing import assert_close
from src.utils.tensor_utils import normalize_policy_values, normalize_policy_values_with_noise
from src.utils.tensor_utils import normalize_policy_values_with_noise


"""
Expand All @@ -11,45 +11,6 @@
Additionally, the function is checking that both tensors have the same shape, dtype, and device.
"""

def test_normalize_policy_values_basic():
"""
Tests basic functionality of normalize_policy_values to ensure it normalizes
policy values correctly for specified legal actions.
"""
nn_policy_values = torch.tensor([-0.4, 0, 0.7, 0], dtype=torch.float)
legal_actions = torch.tensor([1, 3], dtype=torch.long)
expected_output = torch.tensor([-0.4, 0.5, 0.7, 0.5]) # Expected softmax values for indices 1 and 3

normalize_policy_values(nn_policy_values, legal_actions)

assert_close(nn_policy_values, expected_output)

def test_normalize_policy_values_all_legal():
"""
Tests the normalize_policy_values function when all actions are legal,
ensuring the entire tensor is normalized.
"""
nn_policy_values = torch.tensor([0.1, 0.2, 0.3, 0.4], dtype=torch.float)
legal_actions = torch.tensor([0, 1, 2, 3], dtype=torch.long)
expected_output = torch.softmax(nn_policy_values, dim=0)

normalize_policy_values(nn_policy_values, legal_actions)

assert_close(nn_policy_values, expected_output)

def test_normalize_policy_values_no_legal():
"""
Tests the normalize_policy_values function with an empty legal_actions tensor,
expecting the original nn_policy_values tensor to remain unchanged.
"""
nn_policy_values = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float)
legal_actions = torch.tensor([], dtype=torch.long)
expected_output = nn_policy_values.clone()

normalize_policy_values(nn_policy_values, legal_actions)

torch.testing.assert_close(nn_policy_values, expected_output)

alpha = 0.3
epsilon = 0.75

Expand Down

0 comments on commit 7846aa5

Please sign in to comment.