Skip to content

Commit

Permalink
feature(pu): add target_policy_entropy log
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Dec 1, 2023
1 parent 30e7882 commit ff6419f
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 22 deletions.
4 changes: 3 additions & 1 deletion lzero/model/gpt_models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ def __init__(self, **kwargs):
self.reward_loss_weight = 1.
self.value_loss_weight = 0.25
self.policy_loss_weight = 1.
self.ends_loss_weight = 1.
# self.ends_loss_weight = 1.
self.ends_loss_weight = 0.


# Initialize the total loss tensor on the correct device
self.loss_total = torch.tensor(0., device=device)
Expand Down
1 change: 1 addition & 0 deletions lzero/model/gpt_models/world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ def compute_loss(self, batch, tokenizer: Tokenizer, **kwargs: Any) -> LossWithIn
>>> loss.backward()
"""
logits_observations = rearrange(outputs.logits_observations[:, :-1], 'b t o -> (b t) o')
# TODO: 无效样本padding -100,为什么可以在这个loss中使得对应的loss被忽略掉
loss_obs = F.cross_entropy(logits_observations, labels_observations)
loss_ends = F.cross_entropy(rearrange(outputs.logits_ends, 'b t e -> (b t) e'), labels_ends)

Expand Down
17 changes: 14 additions & 3 deletions lzero/policy/muzero_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,21 +371,29 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in

# self._learn_model.world_model.train()

# get valid target_policy data
valid_target_policy = batch_for_gpt['target_policy'][batch_for_gpt['mask_padding']]
# compute entropy of each policy
target_policy_entropy = -torch.sum(valid_target_policy * torch.log(valid_target_policy + 1e-9), dim=-1)
# compute average entropy
average_target_policy_entropy = target_policy_entropy.mean().item()
# print(f'Average entropy: {average_entropy}')

intermediate_losses = defaultdict(float)
losses = self._learn_model.world_model.compute_loss(batch_for_gpt, self._learn_model.tokenizer)

# TODO: train tokenlizer
weighted_total_loss = losses.loss_total

for loss_name, loss_value in losses.intermediate_losses.items():
intermediate_losses[f"{loss_name}"] += loss_value
intermediate_losses[f"{loss_name}"] = loss_value

# print(intermediate_losses)
obs_loss = intermediate_losses['loss_obs']
reward_loss = intermediate_losses['loss_rewards']
policy_loss = intermediate_losses['loss_policy']
value_loss = intermediate_losses['loss_value']


# ==============================================================
# the core learn model update step.
# ==============================================================
Expand All @@ -395,6 +403,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
"""
gradient_scale = 1 / self._cfg.num_unroll_steps
weighted_total_loss.register_hook(lambda grad: grad * gradient_scale)

self._optimizer.zero_grad()
weighted_total_loss.backward()
if self._cfg.multi_gpu:
Expand All @@ -421,6 +430,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
'weighted_total_loss': weighted_total_loss.item(),
'obs_loss': obs_loss,
'policy_loss': policy_loss,
'target_policy_entropy': average_target_policy_entropy,
# 'policy_entropy': - policy_entropy_loss.mean().item() / (self._cfg.num_unroll_steps + 1),
'reward_loss': reward_loss,
'value_loss': value_loss,
Expand Down Expand Up @@ -683,7 +693,8 @@ def _monitor_vars_learn(self) -> List[str]:
# 'total_loss',
'obs_loss',
'policy_loss',
'policy_entropy',
# 'policy_entropy',
'target_policy_entropy',
'reward_loss',
'value_loss',
'consistency_loss',
Expand Down
36 changes: 18 additions & 18 deletions zoo/classic_control/cartpole/config/cartpole_muzero_gpt_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,37 +7,37 @@
# ==============================================================


collector_env_num = 1
n_episode = 1
evaluator_env_num = 1
num_simulations = 25
update_per_collect = 20
model_update_ratio = 1
batch_size = 64
max_env_step = int(1e5)
reanalyze_ratio = 0
# num_unroll_steps = 20
num_unroll_steps = 5


# debug
# collector_env_num = 1
# n_episode = 1
# evaluator_env_num = 1
# num_simulations = 25
# update_per_collect = 2
# num_simulations = 25
# update_per_collect = 20
# model_update_ratio = 1
# batch_size = 2
# batch_size = 64
# max_env_step = int(1e5)
# reanalyze_ratio = 0
# # num_unroll_steps = 20
# num_unroll_steps = 5


# debug
collector_env_num = 1
n_episode = 1
evaluator_env_num = 1
num_simulations = 25
update_per_collect = 2
model_update_ratio = 1
batch_size = 2
max_env_step = int(1e5)
reanalyze_ratio = 0
num_unroll_steps = 5

# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================

cartpole_muzero_gpt_config = dict(
exp_name=f'data_mz_gpt_ctree/cartpole_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_nlayers2_emd128_mediumnet_bs{batch_size}_mcs25_batch1_fixedtokenizer_fixloss_fixlatent_fixedslice_seed0',
exp_name=f'data_mz_gpt_ctree_debug/cartpole_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_nlayers2_emd128_mediumnet_bs{batch_size}_mcs25_batch1_fixedtokenizer_fixloss_fixlatent_fixedslice_seed0',
# exp_name=f'data_mz_gpt_ctree/cartpole_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_nlayers2_emd128_mediumnet_bs{batch_size}_mcs500_fixedtokenizer_fixloss_fixlatent_seed0',
# exp_name=f'data_mz_gpt_ctree/cartpole_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_nlayers2_emd128_mediumnet_bs{batch_size}_clear-25_fixedtokenizer_fixloss_fixlatent_seed0',
env=dict(
Expand Down

0 comments on commit ff6419f

Please sign in to comment.