Skip to content

Commit

Permalink
polish(pu): polish jericho_env and add jericho_ppo_config
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Jan 21, 2025
1 parent e18a169 commit bf095a4
Show file tree
Hide file tree
Showing 6 changed files with 213 additions and 50 deletions.
4 changes: 3 additions & 1 deletion lzero/entry/train_unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,9 @@ def train_unizero(
if cfg.policy.use_wandb:
policy.set_train_iter_env_step(learner.train_iter, collector.envstep)

train_data.append({'train_which_component': 'transformer'})
# train_data.append({'train_which_component': 'transformer'})
train_data.append(learner.train_iter)

log_vars = learner.train(train_data, collector.envstep)
if cfg.policy.use_priority:
replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig'])
Expand Down
4 changes: 2 additions & 2 deletions lzero/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,8 @@ def __init__(self, url: str = 'google-bert/bert-base-uncased', embedding_size: i

self.sim_norm = SimNorm(simnorm_dim=group_size)

# def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor:
def forward(self, x: torch.Tensor, no_grad: bool = False) -> torch.Tensor: # TODO ======
def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor: # TODO: train projection ======
# def forward(self, x: torch.Tensor, no_grad: bool = False) -> torch.Tensor: # TODO: train encoder ======

"""
前向传播,获取输入序列的语言表示。
Expand Down
59 changes: 42 additions & 17 deletions lzero/policy/unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ class UniZeroPolicy(MuZeroPolicy):
policy_loss_weight=1,
# (float) The weight of ssl (self-supervised learning) loss.
ssl_loss_weight=0,
# (bool) Whether to use the cosine learning rate decay.
cos_lr_scheduler=False,
# (bool) Whether to use piecewise constant learning rate decay.
# i.e. lr: 0.2 -> 0.02 -> 0.002
piecewise_decay_lr_scheduler=False,
Expand Down Expand Up @@ -300,6 +302,11 @@ def _init_learn(self) -> None:
betas=(0.9, 0.95),
)

if self._cfg.cos_lr_scheduler:
from torch.optim.lr_scheduler import CosineAnnealingLR
# TODO: check the total training steps
self.lr_scheduler = CosineAnnealingLR(self._optimizer_world_model, 1e5, eta_min=0, last_epoch=-1)

# use model_wrapper for specialized demands of different modes
self._target_model = copy.deepcopy(self._model)
# Ensure that the installed torch version is greater than or equal to 2.0
Expand Down Expand Up @@ -335,6 +342,9 @@ def _init_learn(self) -> None:
# TODO: add the model to wandb
wandb.watch(self._learn_model.representation_network, log="all")

# TODO: ========
self.accumulation_steps = 4 # 累积的步数

# @profile
def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]:
"""
Expand All @@ -352,7 +362,9 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
self._learn_model.train()
self._target_model.train()

current_batch, target_batch, _ = data
# current_batch, target_batch, _ = data
current_batch, target_batch, train_iter = data

obs_batch_ori, action_batch, target_action_batch, mask_batch, indices, weights, make_time = current_batch
target_reward, target_value, target_policy = target_batch

Expand Down Expand Up @@ -386,14 +398,14 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
# print(f'transformed_target_value:{transformed_target_value}')
# print("self.value_support:", self.value_support)

try:
target_value_categorical = phi_transform(self.value_support, transformed_target_value)
except Exception as e:
print('='*20)
print(e)
# print(f'transformed_target_value:{transformed_target_value}')
# print("self.value_support:", self.value_support)
print('='*20)
# try:
target_value_categorical = phi_transform(self.value_support, transformed_target_value)
# except Exception as e:
# print('='*20)
# print(e)
# # print(f'transformed_target_value:{transformed_target_value}')
# # print("self.value_support:", self.value_support)
# print('='*20)
# target_value_categorical = phi_transform(self.value_support, transformed_target_value)


Expand Down Expand Up @@ -455,7 +467,12 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
# assert not torch.isinf(losses.loss_total).any(), "Loss contains Inf values"

# Core learn model update step
self._optimizer_world_model.zero_grad()
if train_iter % self.accumulation_steps == 0: # 每 accumulation_steps 步更新一次参数
# print(f'train_iter:{train_iter}')
self._optimizer_world_model.zero_grad()

weighted_total_loss = weighted_total_loss / self.accumulation_steps # 累积梯度时对 loss 进行缩放

weighted_total_loss.backward()

# ========== for debugging ==========
Expand All @@ -471,15 +488,23 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in

total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_(self._learn_model.world_model.parameters(),
self._cfg.grad_clip_value)
if self._cfg.multi_gpu:
self.sync_gradients(self._learn_model)

self._optimizer_world_model.step()
if self._cfg.piecewise_decay_lr_scheduler:
self.lr_scheduler.step()
if train_iter % self.accumulation_steps == 0: # 每 accumulation_steps 步更新一次参数
# print(f'pos 2 train_iter:{train_iter}')

if self._cfg.multi_gpu:
self.sync_gradients(self._learn_model)

self._optimizer_world_model.step()

if self._cfg.cos_lr_scheduler or self._cfg.piecewise_decay_lr_scheduler:
self.lr_scheduler.step()

# Core target model update step
self._target_model.update(self._learn_model.state_dict())

# Core target model update step
self._target_model.update(self._learn_model.state_dict())
if self.accumulation_steps>1:
torch.cuda.empty_cache()

if torch.cuda.is_available():
torch.cuda.synchronize()
Expand Down
102 changes: 102 additions & 0 deletions zoo/jericho/configs/jericho_ppo_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from easydict import EasyDict
import torch.nn as nn

action_space_size = 10
max_steps = 50
model_name = 'BAAI/bge-base-en-v1.5'
env_id = 'detective.z5'
evaluator_env_num = 2

# proj train
collector_env_num = 4
batch_size = 32

# all train
# collector_env_num = 2
# n_episode = 2
# evaluator_env_num = 2
# batch_size = 4
# num_unroll_steps = 5
# infer_context_length = 2
jericho_ppo_config = dict(
# exp_name=f"data_ppo_detective/jericho_ppo_projtrain_bs{batch_size}_seed0",
exp_name=f"data_ppo_detective/jericho_add-loc-inv_ppo_projtrain_bs{batch_size}_seed0",
env=dict(
remove_stuck_actions=False,
# remove_stuck_actions=True,
# add_location_and_inventory=True,
add_location_and_inventory=False,

stop_value=int(1e6),
observation_shape=512,
max_steps=max_steps,
max_action_num=action_space_size,
tokenizer_path=model_name,
# tokenizer_path="/mnt/afs/zhangshenghan/.cache/huggingface/hub/models--google-bert--bert-base-uncased/snapshots/86b5e0934494bd15c9632b12f734a8a67f723594",
max_seq_len=512,
# game_path="z-machine-games-master/jericho-game-suite/" + env_id,
game_path="/mnt/afs/niuyazhe/code/LightZero/zoo/jericho/envs/z-machine-games-master/jericho-game-suite/" + env_id,
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=evaluator_env_num,
manager=dict(shared_memory=False, )
),
policy=dict(
cuda=True,
multi_agent=True,
action_space='discrete',
model=dict(
obs_shape=(26, 5, 4), # 没有起作用
action_shape=action_space_size,
action_space='discrete',
encoder_hidden_size_list = [512], # encoder_hidden_size_list[-1]是head的输入维度
actor_head_hidden_size= 512,
critic_head_hidden_size = 512,
),
learn=dict(
epoch_per_collect=4,
batch_size=batch_size,
learning_rate=0.0005,
# entropy_weight=0.01,
entropy_weight=0.05,
value_norm=True,
grad_clip_value=10,
),
collect=dict(
# n_sample=1024,
n_sample=320, # TODO: DEBUG
discount_factor=0.99,
gae_lambda=0.95,
),
eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=5000, )),
),
)
jericho_ppo_config = EasyDict(jericho_ppo_config)
main_config = jericho_ppo_config
cartpole_ppo_create_config = dict(
env=dict(
type='jericho',
import_names=['zoo.jericho.envs.jericho_env'],
),
# env_manager=dict(type='subprocess'),
env_manager=dict(type='base'),
policy=dict(type='ppo'),
)
cartpole_ppo_create_config = EasyDict(cartpole_ppo_create_config)
create_config = cartpole_ppo_create_config


if __name__ == "__main__":
from ding.entry import serial_pipeline_onpolicy
from ding.model.template import VAC
m = main_config.policy.model
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

from lzero.model.common import HFLanguageRepresentationNetwork
encoder = HFLanguageRepresentationNetwork(url=model_name, embedding_size=512)

model = VAC(obs_shape=m.obs_shape, action_shape=m.action_shape, action_space=m.action_space, encoder_hidden_size_list=m.encoder_hidden_size_list,
actor_head_hidden_size=m.actor_head_hidden_size,
critic_head_hidden_size =m.critic_head_hidden_size, encoder=encoder)
serial_pipeline_onpolicy([main_config, create_config], seed=0, model=model)
42 changes: 22 additions & 20 deletions zoo/jericho/configs/jericho_unizero_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,20 @@ def main(env_id='detective.z5', seed=0):
num_simulations = 50
max_env_step = int(10e6)

# collector_env_num = 4
# n_episode = 4
# batch_size = 16 # proj train
# num_unroll_steps = 10
# infer_context_length = 4

collector_env_num = 2
n_episode = 2
evaluator_env_num = 2
batch_size = 4 # all train
num_unroll_steps = 5
infer_context_length = 2
# proj train
collector_env_num = 4
n_episode = 4
batch_size = 16
num_unroll_steps = 10
infer_context_length = 4

# all train
# collector_env_num = 2
# n_episode = 2
# evaluator_env_num = 2
# batch_size = 4
# num_unroll_steps = 5
# infer_context_length = 2

# batch_size = 16
# num_unroll_steps = 5
Expand All @@ -52,27 +54,24 @@ def main(env_id='detective.z5', seed=0):
# model_name = 'google-bert/bert-base-uncased'

# =========== TODO: only for debug ===========
# collector_env_num = 2
# num_segments = 2
# game_segment_length = 20
# evaluator_env_num = 2
# max_env_step = int(5e5)
# batch_size = 10
# num_simulations = 5
# num_simulations = 2
# num_unroll_steps = 5
# infer_context_length = 2
# max_steps = 10
# num_layers = 1
# replay_ratio = 0.05
# embed_dim = 768
# TODO: MCTS内部的action_space受限于root节点的legal action

# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================
jericho_unizero_config = dict(
env=dict(
remove_stuck_actions=False,
# remove_stuck_actions=False,
remove_stuck_actions=True,

stop_value=int(1e6),
observation_shape=512,
max_steps=max_steps,
Expand Down Expand Up @@ -168,7 +167,10 @@ def main(env_id='detective.z5', seed=0):
main_config = jericho_unizero_config
create_config = jericho_unizero_create_config

main_config.exp_name = f'data_unizero_detective_20250107/{model_name}/{env_id[:8]}_ms{max_steps}_action-space-{action_space_size}_all-train_uz_nlayer{num_layers}_embed512_rr{replay_ratio}-upc{update_per_collect}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
main_config.exp_name = f'data_unizero_detective_20250107/{model_name}/{env_id[:8]}_ms{max_steps}_action-space-{action_space_size}-remove-novalid_proj-train-accstep4_uz_nlayer{num_layers}_embed512_rr{replay_ratio}-upc{update_per_collect}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
# main_config.exp_name = f'data_unizero_detective_20250107/{model_name}/{env_id[:8]}_ms{max_steps}_action-space-{action_space_size}_proj-train-accstep4_uz_nlayer{num_layers}_embed512_rr{replay_ratio}-upc{update_per_collect}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'

# main_config.exp_name = f'data_unizero_detective_20250107/{model_name}/{env_id[:8]}_ms{max_steps}_action-space-{action_space_size}_all-train_uz_nlayer{num_layers}_embed512_rr{replay_ratio}-upc{update_per_collect}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
# main_config.exp_name = f'data_unizero_detective_20250107/{model_name}/{env_id[:8]}_ms{max_steps}_action-space-{action_space_size}_remove-novalid-action_uz_nlayer{num_layers}_embed512_rr{replay_ratio}-upc{update_per_collect}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
from lzero.entry import train_unizero
train_unizero([main_config, create_config], seed=seed,
Expand Down
Loading

0 comments on commit bf095a4

Please sign in to comment.