Skip to content

Commit

Permalink
fix(pu): fix update_per_collect in ddp setting (#321)
Browse files Browse the repository at this point in the history
* fix(pu): fix ddp config when uptate_per_collect is None in config

* polish(pu): polish update_per_collect in ddp setting

* fix(pu): fix typo

---------

Co-authored-by: PaParaZz1 <niuyazhe314@outlook.com>
  • Loading branch information
puyuan1996 and PaParaZz1 authored Jan 27, 2025
1 parent 8a142a9 commit 8099be9
Show file tree
Hide file tree
Showing 15 changed files with 266 additions and 60 deletions.
7 changes: 5 additions & 2 deletions lzero/config/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ def lz_to_ddp_config(cfg: EasyDict) -> EasyDict:
- cfg (:obj:`EasyDict`): The converted config
"""
w = get_world_size()
cfg.policy.batch_size = int(np.ceil(cfg.policy.batch_size / w))
cfg.policy.n_episode = int(np.ceil(cfg.policy.n_episode) / w)
# Generalized handling for multiple keys
keys_to_scale = ['batch_size', 'n_episode', 'num_segments']
for key in keys_to_scale:
if key in cfg.policy:
cfg.policy[key] = int(np.ceil(cfg.policy[key] / w))
return cfg
11 changes: 4 additions & 7 deletions lzero/entry/train_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from lzero.policy.random_policy import LightZeroRandomPolicy
from lzero.worker import MuZeroCollector as Collector
from lzero.worker import MuZeroEvaluator as Evaluator
from .utils import random_collect
from .utils import random_collect, calculate_update_per_collect


def train_muzero(
Expand Down Expand Up @@ -186,12 +186,9 @@ def train_muzero(

# Collect data by default config n_sample/n_episode.
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
if cfg.policy.update_per_collect is None:
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
# The length of game_segment (i.e., len(game_segment.action_segment)) can be smaller than cfg.policy.game_segment_length if it represents the final segment of the game.
# On the other hand, its length will be less than cfg.policy.game_segment_length + padding_length when it is not the last game segment. Typically, padding_length is the sum of unroll_steps and td_steps.
collected_transitions_num = sum(min(len(game_segment), cfg.policy.game_segment_length) for game_segment in new_data[0])
update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio)

# Determine updates per collection
update_per_collect = calculate_update_per_collect(cfg, new_data)

# save returned new_data collected by the collector
replay_buffer.push_game_segments(new_data)
Expand Down
11 changes: 4 additions & 7 deletions lzero/entry/train_muzero_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from lzero.policy.random_policy import LightZeroRandomPolicy
from lzero.worker import MuZeroEvaluator as Evaluator
from lzero.worker import MuZeroSegmentCollector as Collector
from .utils import random_collect
from .utils import random_collect, calculate_update_per_collect

timer = EasyTimer()

Expand Down Expand Up @@ -180,13 +180,10 @@ def train_muzero_segment(

# Collect data by default config n_sample/n_episode.
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
if cfg.policy.update_per_collect is None:
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
# The length of game_segment (i.e., len(game_segment.action_segment)) can be smaller than cfg.policy.game_segment_length if it represents the final segment of the game.
# On the other hand, its length will be less than cfg.policy.game_segment_length + padding_length when it is not the last game segment. Typically, padding_length is the sum of unroll_steps and td_steps.
collected_transitions_num = sum(min(len(game_segment), cfg.policy.game_segment_length) for game_segment in new_data[0])
update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio)

# Determine updates per collection
update_per_collect = calculate_update_per_collect(cfg, new_data)

# save returned new_data collected by the collector
replay_buffer.push_game_segments(new_data)
# remove the oldest data if the replay buffer is full.
Expand Down
10 changes: 3 additions & 7 deletions lzero/entry/train_rezero.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from lzero.policy.random_policy import LightZeroRandomPolicy
from lzero.worker import MuZeroCollector as Collector
from lzero.worker import MuZeroEvaluator as Evaluator
from .utils import random_collect
from .utils import random_collect, calculate_update_per_collect


def train_rezero(
Expand Down Expand Up @@ -152,12 +152,8 @@ def train_rezero(
collect_with_pure_policy=cfg.policy.collect_with_pure_policy
)

if update_per_collect is None:
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
# The length of game_segment (i.e., len(game_segment.action_segment)) can be smaller than cfg.policy.game_segment_length if it represents the final segment of the game.
# On the other hand, its length will be less than cfg.policy.game_segment_length + padding_length when it is not the last game segment. Typically, padding_length is the sum of unroll_steps and td_steps.
collected_transitions_num = sum(min(len(game_segment), cfg.policy.game_segment_length) for game_segment in new_data[0])
update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio)
# Determine updates per collection
update_per_collect = calculate_update_per_collect(cfg, new_data)

# Update replay buffer
replay_buffer.push_game_segments(new_data)
Expand Down
10 changes: 2 additions & 8 deletions lzero/entry/train_unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from lzero.policy.random_policy import LightZeroRandomPolicy
from lzero.worker import MuZeroEvaluator as Evaluator
from lzero.worker import MuZeroCollector as Collector
from .utils import random_collect
from .utils import random_collect, calculate_update_per_collect


def train_unizero(
Expand Down Expand Up @@ -154,13 +154,7 @@ def train_unizero(
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)

# Determine updates per collection
update_per_collect = cfg.policy.update_per_collect
if update_per_collect is None:
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
# The length of game_segment (i.e., len(game_segment.action_segment)) can be smaller than cfg.policy.game_segment_length if it represents the final segment of the game.
# On the other hand, its length will be less than cfg.policy.game_segment_length + padding_length when it is not the last game segment. Typically, padding_length is the sum of unroll_steps and td_steps.
collected_transitions_num = sum(min(len(game_segment), cfg.policy.game_segment_length) for game_segment in new_data[0])
update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio)
update_per_collect = calculate_update_per_collect(cfg, new_data)

# Update replay buffer
replay_buffer.push_game_segments(new_data)
Expand Down
10 changes: 2 additions & 8 deletions lzero/entry/train_unizero_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from lzero.policy.random_policy import LightZeroRandomPolicy
from lzero.worker import MuZeroEvaluator as Evaluator
from lzero.worker import MuZeroSegmentCollector as Collector
from .utils import random_collect
from .utils import random_collect, calculate_update_per_collect

timer = EasyTimer()

Expand Down Expand Up @@ -151,13 +151,7 @@ def train_unizero_segment(
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)

# Determine updates per collection
update_per_collect = cfg.policy.update_per_collect
if update_per_collect is None:
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
# The length of game_segment (i.e., len(game_segment.action_segment)) can be smaller than cfg.policy.game_segment_length if it represents the final segment of the game.
# On the other hand, its length will be less than cfg.policy.game_segment_length + padding_length when it is not the last game segment. Typically, padding_length is the sum of unroll_steps and td_steps.
collected_transitions_num = sum(min(len(game_segment), cfg.policy.game_segment_length) for game_segment in new_data[0])
update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio)
update_per_collect = calculate_update_per_collect(cfg, new_data)

# Update replay buffer
replay_buffer.push_game_segments(new_data)
Expand Down
84 changes: 83 additions & 1 deletion lzero/entry/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,93 @@
from typing import Optional, Callable

import psutil
import torch
import torch.distributed as dist
from pympler.asizeof import asizeof
from tensorboardX import SummaryWriter
from typing import Optional, Callable


import torch
import torch.distributed as dist

def is_ddp_enabled():
"""
Check if Distributed Data Parallel (DDP) is enabled by verifying if
PyTorch's distributed package is available and initialized.
"""
return dist.is_available() and dist.is_initialized()

def ddp_synchronize():
"""
Perform a barrier synchronization across all processes in DDP mode.
Ensures all processes reach this point before continuing.
"""
if is_ddp_enabled():
dist.barrier()

def ddp_all_reduce_sum(tensor):
"""
Perform an all-reduce operation (sum) on the given tensor across
all processes in DDP mode. Returns the reduced tensor.
Arguments:
- tensor (:obj:`torch.Tensor`): The input tensor to be reduced.
Returns:
- torch.Tensor: The reduced tensor, summed across all processes.
"""
if is_ddp_enabled():
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
return tensor

def calculate_update_per_collect(cfg, new_data):
"""
Calculate the number of updates to perform per data collection in a
Distributed Data Parallel (DDP) setting. This ensures that all GPUs
compute the same `update_per_collect` value, synchronized across processes.
Arguments:
- cfg: Configuration object containing policy settings.
- new_data (list): The newly collected data segments.
Returns:
- int: The number of updates to perform per collection.
"""
# Retrieve the update_per_collect setting from the configuration
update_per_collect = cfg.policy.update_per_collect

if update_per_collect is None:
# If update_per_collect is not explicitly set, calculate it based on
# the number of collected transitions and the replay ratio.

# The length of game_segment (i.e., len(game_segment.action_segment)) can be smaller than cfg.policy.game_segment_length if it represents the final segment of the game.
# On the other hand, its length will be less than cfg.policy.game_segment_length + padding_length when it is not the last game segment. Typically, padding_length is the sum of unroll_steps and td_steps.
collected_transitions_num = sum(
min(len(game_segment), cfg.policy.game_segment_length)
for game_segment in new_data[0]
)

if torch.cuda.is_available():
# Convert the collected transitions count to a GPU tensor for DDP operations.
collected_transitions_tensor = torch.tensor(
collected_transitions_num, dtype=torch.int64, device='cuda'
)

# Synchronize the collected transitions count across all GPUs using all-reduce.
total_collected_transitions = ddp_all_reduce_sum(
collected_transitions_tensor
).item()

# Calculate update_per_collect based on the total synchronized transitions count.
update_per_collect = int(total_collected_transitions * cfg.policy.replay_ratio)

# Ensure the computed update_per_collect is positive.
assert update_per_collect > 0, "update_per_collect must be positive"
else:
# If not using DDP, calculate update_per_collect directly from the local count.
update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio)

return update_per_collect

def initialize_zeros_batch(observation_shape, batch_size, device):
"""
Expand Down
17 changes: 8 additions & 9 deletions lzero/mcts/buffer/game_buffer_unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,15 +522,14 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
m_output = model.initial_inference(batch_obs, batch_action)
# ======================================================================

if not model.training:
# if not in training, obtain the scalars of the value/reward
[m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
[
m_output.latent_state,
inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
m_output.policy_logits
]
)
# if not in training, obtain the scalars of the value/reward
[m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
[
m_output.latent_state,
inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
m_output.policy_logits
]
)

network_output.append(m_output)

Expand Down
9 changes: 4 additions & 5 deletions lzero/policy/unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,11 +728,10 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1
network_output = self._eval_model.initial_inference(self.last_batch_obs, self.last_batch_action, data)
latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output)

if not self._eval_model.training:
# if not in training, obtain the scalars of the value/reward
pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1)
latent_state_roots = latent_state_roots.detach().cpu().numpy()
policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A)
# if not in training, obtain the scalars of the value/reward
pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1)
latent_state_roots = latent_state_roots.detach().cpu().numpy()
policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A)

legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num)]
if self._cfg.mcts_ctree:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
Overview:
This script should be executed with <nproc_per_node> GPUs.
Run the following command to launch the script:
python -m torch.distributed.launch --nproc_per_node=2 ./LightZero/zoo/atari/config/atari_efficientzero_multigpu_ddp_config.py
torchrun --nproc_per_node=2 ./LightZero/zoo/atari/config/atari_efficientzero_ddp_config.py
"""
from ding.utils import DDPContext
from lzero.entry import train_muzero
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
Overview:
This script should be executed with <nproc_per_node> GPUs.
Run the following command to launch the script:
python -m torch.distributed.launch --nproc_per_node=2 ./zoo/atari/config/atari_muzero_multigpu_ddp_config.py
torchrun --nproc_per_node=2 ./zoo/atari/config/atari_muzero_ddp_config.py
"""
from ding.utils import DDPContext
from lzero.entry import train_muzero
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,7 @@
Overview:
This script should be executed with <nproc_per_node> GPUs.
Run the following command to launch the script:
python -m torch.distributed.launch --nproc_per_node=2 ./zoo/atari/config/atari_unizero_multigpu_ddp_config.py
torchrun --nproc_per_node=2 ./zoo/atari/config/atari_unizero_multigpu_ddp_config.py
torchrun --nproc_per_node=2 ./zoo/atari/config/atari_unizero_ddp_config.py
"""
from ding.utils import DDPContext
Expand Down
Loading

0 comments on commit 8099be9

Please sign in to comment.