Skip to content

Commit

Permalink
feat: try-except blocks for keyboard interrupt.
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristianFredrikJohnsen committed Apr 19, 2024
1 parent d88fc75 commit 8f551a8
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 88 deletions.
47 changes: 29 additions & 18 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,45 @@
from src.alphazero.alphazero_train_model import train_alphazero_model
from src.play_vs_alphazero import main as play_vs_alphazero

if __name__ == '__main__': # Needed for multiprocessing to work
test_overfit = False
train = True
self_play = False
play = False
mp.set_start_method('spawn')
if test_overfit:
train_alphazero_model(
num_games=1,
### Idea, make each game generation a longer task.
# Instead of running one function per game generation, run a function that generates multiple games.
# This will make the overhead of creating a new multiprocessing process less significant.


def test_overfit():
train_alphazero_model(
num_games=10,
num_simulations=1000,
epochs=1000,
batch_size=16,
model_path=None
)

if train:
for i in range(20):
def train():
try:
for i in range(int(1e6)):
train_alphazero_model(
num_games=24,
num_simulations=5,
num_simulations=300,
epochs=2,
batch_size=16,
model_path=None
model_path="./models/good_nn"
)
print(f'Training session {i} finished!')
print(f'Training session {i + 1} finished!')
print(torch.cuda.memory_summary())
torch.cuda.empty_cache()
except KeyboardInterrupt:
print('Training interrupted!')

if self_play:
play_alphazero("./models/good_nn")
def self_play():
play_alphazero("./models/good_nn")

if play:
play_vs_alphazero("./models/good_nn")
def play():
play_vs_alphazero("./models/good_nn")

if __name__ == '__main__': # Needed for multiprocessing to work
mp.set_start_method('spawn')
# test_overfit()
train()
# self_play()
# play()
84 changes: 59 additions & 25 deletions src/alphazero/alphazero_generate_training_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def play_alphazero_game(

# print(state, '\n~~~~~~~~~~~~~~~')
rewards = state.returns()
return [
training_data = [
(
state,
probability_visits,
Expand All @@ -53,6 +53,8 @@ def play_alphazero_game(
for i, (state, probability_visits) in enumerate(game_data)
]

return training_data


def generate_training_data(nn: NeuralNetwork, num_games: int, num_simulations: int = 100) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Expand All @@ -73,37 +75,69 @@ def generate_training_data(nn: NeuralNetwork, num_games: int, num_simulations: i
Instead of returning a list of tuples, we are just returning three huge tensors.
"""

alphazero_mcts = AlphaZero()
nn.to(alphazero_mcts.device)
training_data = []

# print(f"Generating training data with {mp.cpu_count()} threads...")

start_time = time.time()
with mp.Pool(24) as pool:
result_list = list(tqdm(pool.starmap(play_alphazero_game, [(alphazero_mcts, nn, num_simulations) for _ in range(num_games)])))
end_time = time.time()
print(f"Generated training data with {mp.cpu_count()} threads in {end_time - start_time:.2f} seconds.")

start_time = time.time()
with mp.Pool(1) as pool:
result_list = list(tqdm(pool.starmap(play_alphazero_game, [(alphazero_mcts, nn, num_simulations) for _ in range(num_games)])))
end_time = time.time()
print(f"Generated training data with 1 thread in {end_time - start_time:.2f} seconds.")

for i in range(len(result_list)):
try:
print(f"Generating training data with {mp.cpu_count()} threads...")
start_time = time.time()
with mp.Pool(mp.cpu_count()) as pool:
result_list = list(tqdm(pool.starmap(play_alphazero_game, [(alphazero_mcts, nn, num_simulations) for _ in range(num_games)])))
end_time = time.time()
print(f"Generated training data with {mp.cpu_count()} threads in {end_time - start_time:.2f} seconds.")

# Process results only if data generation was successful
for i in range(len(result_list)):
training_data.extend(result_list[i])

num_actions = alphazero_mcts.game.num_distinct_actions()
states = [item[0] for item in training_data]
probabilities = [item[1] for item in training_data]
rewards = [item[2] for item in training_data]
num_actions = alphazero_mcts.game.num_distinct_actions()
states = [item[0] for item in training_data]
probabilities = [item[1] for item in training_data]
rewards = [item[2] for item in training_data]

state_tensors = torch.cat(states, dim=0)
probability_tensors = torch.cat(probabilities, dim=0).reshape(-1, num_actions)
reward_tensors = torch.cat(rewards, dim=0).reshape(-1, 1)

return state_tensors, probability_tensors, reward_tensors

except KeyboardInterrupt:
print("KeyboardInterrupt: Terminating training data generation...")
raise
# pool = None # Initialize pool to None, so we can terminate it if KeyboardInterrupt is raised.
# alphazero_mcts = AlphaZero()
# nn.to(alphazero_mcts.device)
# training_data = []

# try:
# print(f"Generating training data with {mp.cpu_count()} threads...")
# start_time = time.time()
# pool = mp.Pool(mp.cpu_count())
# result_list = list(tqdm(pool.starmap(play_alphazero_game, [(alphazero_mcts, nn, num_simulations) for _ in range(num_games)])))
# end_time = time.time()
# print(f"Generated training data with {mp.cpu_count()} threads in {end_time - start_time:.2f} seconds.")

# except KeyboardInterrupt:
# print("KeyboardInterrupt: Terminating training data generation...")
# if pool is not None:
# pool.terminate() # Terminates all processes in the pool.
# pool.join() # Waits for all processes to finish.
# pool.close()
# raise # Reraise the KeyboardInterrupt, enables parent process to do its cleanup-process as well.

# for i in range(len(result_list)):
# training_data.extend(result_list[i])

# num_actions = alphazero_mcts.game.num_distinct_actions()
# states = [item[0] for item in training_data]
# probabilities = [item[1] for item in training_data]
# rewards = [item[2] for item in training_data]

state_tensors = torch.cat(states, dim=0)
probability_tensors = torch.cat(probabilities, dim=0).reshape(-1, num_actions)
reward_tensors = torch.cat(rewards, dim=0).reshape(-1, 1)
# state_tensors = torch.cat(states, dim=0)
# probability_tensors = torch.cat(probabilities, dim=0).reshape(-1, num_actions)
# reward_tensors = torch.cat(rewards, dim=0).reshape(-1, 1)

return state_tensors, probability_tensors, reward_tensors
# return state_tensors, probability_tensors, reward_tensors


11 changes: 2 additions & 9 deletions src/alphazero/alphazero_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,6 @@ def train_alphazero_model(num_games: int, num_simulations: int, epochs: int, bat

num_samples = state_tensors.size(0)

last_batch_size = num_samples % batch_size
num_steps = num_samples // batch_size + (1 if last_batch_size > 0 else 0)

try:

for epoch in range(epochs):
Expand Down Expand Up @@ -86,16 +83,12 @@ def train_alphazero_model(num_games: int, num_simulations: int, epochs: int, bat
# Track losses
policy_loss_tot += policy_loss.item(); value_loss_tot += value_loss.item(); total_loss += loss.item()


print(
f"Epoch {epoch+1}, (Per sample) Total Loss: {total_loss}, Policy Loss: {policy_loss_tot}, Value Loss: {value_loss_tot}"
)
print(
f"Epoch {epoch+1}\n(Per batch) Total Loss: {total_loss / num_steps}, Policy Loss: {policy_loss_tot / num_steps}, Value Loss: {value_loss_tot / num_steps}\n(Per sample) Total Loss: {total_loss / num_samples}, Policy Loss: {policy_loss_tot / num_samples}, Value Loss: {value_loss_tot / num_samples}"
f"Epoch {epoch+1}\n(Per sample) Total Loss: {total_loss / num_samples}, Policy Loss: {policy_loss_tot / num_samples}, Value Loss: {value_loss_tot / num_samples}"
)

nn.save(model_path)
print("\nModel saved!")
print(f"\nEpoch {epoch + 1}: Model saved!")

except KeyboardInterrupt:
nn.save(model_path)
Expand Down
76 changes: 40 additions & 36 deletions src/alphazero/alphazero_training_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import pyspiel
import torch
import os

from src.alphazero.node import Node
from src.neuralnet.neural_network import NeuralNetwork
Expand Down Expand Up @@ -193,39 +194,42 @@ def run_simulation(
Selection, expansion & evaluation, backpropagation.
Returns an action to be played.
"""

root_node = Node(
parent=None, state=state, action=None, policy_value=None
) # Initialize root node, and do dirichlet expand to get some exploration
policy, value = self.evaluate(root_node, neural_network)
self.dirichlet_expand(root_node, policy)

for _ in range(
num_simulations - 1
): # Do the selection, expansion & evaluation, backpropagation

node = self.vectorized_select(root_node) # Get desired child node
if not node.state.is_terminal() and not node.has_children():
policy, value = self.evaluate(
node, neural_network
) # Evaluate the node, using the neural network
self.expand(node, policy) # creates all its children
winner = value
else:
player = (
node.parent.state.current_player()
) # Here state is terminal, so we get the winning player
winner = node.state.returns()[player]
self.backpropagate(node, winner)

normalized_root_node_children_visits = generate_probabilty_target(root_node, self.num_actions, self.device)

# if move_number > self.temperature_moves:
return (
max(root_node.children, key=lambda node: node.visits).action,
normalized_root_node_children_visits,
) # The best action is the one with the most visits
# else:
# probabilities = torch.softmax(normalized_root_node_children_visits, dim=0) # Temperature-like exploration
# return root_node.children[torch.multinomial(probabilities, num_samples=1).item()].action, normalized_root_node_children_visits

try:
root_node = Node(
parent=None, state=state, action=None, policy_value=None
) # Initialize root node, and do dirichlet expand to get some exploration
policy, value = self.evaluate(root_node, neural_network)
self.dirichlet_expand(root_node, policy)

for _ in range(
num_simulations - 1
): # Do the selection, expansion & evaluation, backpropagation

node = self.vectorized_select(root_node) # Get desired child node
if not node.state.is_terminal() and not node.has_children():
policy, value = self.evaluate(
node, neural_network
) # Evaluate the node, using the neural network
self.expand(node, policy) # creates all its children
winner = value
else:
player = (
node.parent.state.current_player()
) # Here state is terminal, so we get the winning player
winner = node.state.returns()[player]
self.backpropagate(node, winner)

normalized_root_node_children_visits = generate_probabilty_target(root_node, self.num_actions, self.device)

# if move_number > self.temperature_moves:
return (
max(root_node.children, key=lambda node: node.visits).action,
normalized_root_node_children_visits,
) # The best action is the one with the most visits
# else:
# probabilities = torch.softmax(normalized_root_node_children_visits, dim=0) # Temperature-like exploration
# return root_node.children[torch.multinomial(probabilities, num_samples=1).item()].action, normalized_root_node_children_visits

except KeyboardInterrupt:
print(f'Simulation (run_simulation) interrupted! PID: {os.getpid()}')
raise

0 comments on commit 8f551a8

Please sign in to comment.