Skip to content

Commit

Permalink
sync code
Browse files Browse the repository at this point in the history
  • Loading branch information
dyyoungg authored and puyuan1996 committed Dec 19, 2024
1 parent 0540c20 commit 5346860
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 22 deletions.
3 changes: 2 additions & 1 deletion lzero/entry/train_unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ def train_unizero(
collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep)

# Evaluate policy performance
if learner.train_iter == 0 or evaluator.should_eval(learner.train_iter):
# if learner.train_iter == 0 or evaluator.should_eval(learner.train_iter):
if evaluator.should_eval(learner.train_iter):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break
Expand Down
2 changes: 1 addition & 1 deletion lzero/mcts/buffer/game_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def _preprocess_to_play_and_action_mask(
# list(np.ones(self._cfg.model.action_space_size, dtype=np.int8))
# for _ in range(unroll_steps + 1 - len(action_mask_tmp))
# ]
# TODO
# TODO: padded data
action_mask_tmp += [
list(np.zeros(self._cfg.model.action_space_size, dtype=np.int8))
for _ in range(unroll_steps + 1 - len(action_mask_tmp))
Expand Down
77 changes: 77 additions & 0 deletions lzero/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,83 @@ def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor:
return cls_embedding



class HFLanguageRepresentationNetwork_backup(nn.Module):
def __init__(self, url: str = 'google-bert/bert-base-uncased', embedding_size: int = 768, group_size: int = 8):
"""
初始化语言表示网络
参数:
- url (str): 预训练 Hugging Face 模型的地址,默认为 'google-bert/bert-base-uncased'。
- embedding_size (int): 输出嵌入的维度大小,默认为 768。
"""
super().__init__()
from transformers import AutoModel
# 加载 Hugging Face 预训练模型
self.model = AutoModel.from_pretrained(url)

# 设置嵌入维度,如果目标维度不是 768,则添加一个线性变换层用于降维或升维
self.embedding_size = embedding_size
if self.embedding_size != 768:
self.embed_head = nn.Linear(768, self.embedding_size)

self.sim_norm = SimNorm(simnorm_dim=group_size)

def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor:
"""
前向传播,获取输入序列的语言表示。
参数:
- x (torch.Tensor): 输入的张量,通常是序列的 token 索引,形状为 [batch_size, seq_len]。
- no_grad (bool): 是否在无梯度模式下运行,默认为 True。
返回:
- torch.Tensor: 经过处理的语言嵌入向量,形状为 [batch_size, embedding_size]。
"""
if no_grad:
# 在 no_grad 模式下禁用梯度计算以节省显存
with torch.no_grad():
x = x.long() # 确保输入张量为长整型
outputs = self.model(x) # 获取模型的输出

# 模型输出的 last_hidden_state 形状为 [batch_size, seq_len, hidden_size]
# 我们通常取 [CLS] 标记对应的向量,即 outputs.last_hidden_state[:, 0, :]
cls_embedding = outputs.last_hidden_state[:, 0, :]

# 如果目标的 embedding_size 不是 768,则应用线性变换
if self.embedding_size == 768:
# NOTE: very important for training stability.
cls_embedding = self.sim_norm(cls_embedding)

return cls_embedding
else:
cls_embedding = self.embed_head(cls_embedding)

# NOTE: very important for training stability.
cls_embedding = self.sim_norm(cls_embedding)

return cls_embedding
else:
# 非 no_grad 模式下,启用梯度计算
x = x.long() # 确保输入张量为长整型
outputs = self.model(x)
cls_embedding = outputs.last_hidden_state[:, 0, :]

# 如果目标的 embedding_size 不是 768,则应用线性变换
if self.embedding_size == 768:
# NOTE: very important for training stability.
cls_embedding = self.sim_norm(cls_embedding)

return cls_embedding
else:
cls_embedding = self.embed_head(cls_embedding)

# NOTE: very important for training stability.
cls_embedding = self.sim_norm(cls_embedding)

return cls_embedding


class RepresentationNetworkUniZero(nn.Module):

def __init__(
Expand Down
88 changes: 75 additions & 13 deletions zoo/dmc2gym/config/dmc2gym_state_suz_segment_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,56 @@ def main(env_id, seed):

continuous_action_space = True
K = 20 # num_of_sampled_actions
# K = 16 # num_of_sampled_actions

collector_env_num = 8
n_episode = 8
num_segments = 8
game_segment_length = 100
game_segment_length = 125
# game_segment_length = 500

collector_env_num = 16
n_episode = 16
num_segments = 16
game_segment_length = 125

# collector_env_num = 16
# n_episode = 16
# num_segments = 16
# game_segment_length = 125


evaluator_env_num = 3
num_simulations = 50
replay_ratio = 0.1
max_env_step = int(5e5)
num_simulations = 50 # TODO

# max_env_step = int(5e5)
max_env_step = int(1e6)
# max_env_step = int(3e6) # TODO

reanalyze_ratio = 0
batch_size = 64
num_layers = 2
# num_layers = 4

num_unroll_steps = 5
# num_unroll_steps = 10
infer_context_length = 2

# replay_ratio = 0.25
# num_unroll_steps = 10
# infer_context_length = 4

norm_type = 'LN'

# Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs.
buffer_reanalyze_freq = 1/100000
buffer_reanalyze_freq = 1/1000000000 # TODO
# replay_ratio = 0.1
replay_ratio = 0.25


# buffer_reanalyze_freq = 1/10
# replay_ratio = 0.1

# Each reanalyze process will reanalyze <reanalyze_batch_size> sequences (<cfg.policy.num_unroll_steps> transitions per sequence)
reanalyze_batch_size = 160
# The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer.
Expand All @@ -54,7 +88,9 @@ def main(env_id, seed):
domain_name=domain_name,
task_name=task_name,
from_pixels=False, # vector/state obs
frame_skip=2,
# from_pixels=True, # vector/state obs
# frame_skip=2,
frame_skip=8,
continuous=True,
save_replay_gif=False,
replay_path_gif='./replay_gif',
Expand All @@ -75,13 +111,18 @@ def main(env_id, seed):
num_of_sampled_actions=K,
model_type='mlp',
world_model_cfg=dict(
policy_loss_type='kl',
num_simulations=num_simulations,
policy_loss_type='kl', # 'simple'
# policy_loss_type='simple', # 'simple'
obs_type='vector',
num_unroll_steps=num_unroll_steps,
# policy_entropy_weight=0,
# policy_entropy_weight=5e-3,
policy_entropy_weight=5e-2,
continuous_action_space=continuous_action_space,
num_of_sampled_actions=K,
sigma_type='conditioned',
# sigma_type='fixed',
fixed_sigma_value=0.5,
bound_type=None,
model_type='mlp',
Expand All @@ -93,7 +134,8 @@ def main(env_id, seed):
action_space_size=action_space_size,
num_layers=num_layers,
num_heads=8,
embed_dim=768,
embed_dim=768, # original
# embed_dim=512,
env_num=max(collector_env_num, evaluator_env_num),
),
),
Expand All @@ -108,21 +150,31 @@ def main(env_id, seed):
replay_ratio=replay_ratio,
batch_size=batch_size,
discount_factor=0.99,
td_steps=5,
piecewise_decay_lr_scheduler=False,
# discount_factor=1,
# td_steps=5,
# td_steps=10,
td_steps=game_segment_length, # TODO

lr_piecewise_constant_decay=False,
learning_rate=1e-4,
grad_clip_value=5,
# grad_clip_value=0.3, # TODO
# manual_temperature_decay=False,
manual_temperature_decay=True,
threshold_training_steps_for_final_temperature=int(2.5e4),
cos_lr_scheduler=True,

# cos_lr_scheduler=True,
cos_lr_scheduler=False,

num_segments=num_segments,
train_start_after_envsteps=2000,
game_segment_length=game_segment_length,
num_simulations=num_simulations,
reanalyze_ratio=0,
reanalyze_ratio=reanalyze_ratio,
n_episode=n_episode,
eval_freq=int(5e3),
replay_buffer_size=int(1e6),
# replay_buffer_size=int(5e4),
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
# ============= The key different params for ReZero =============
Expand Down Expand Up @@ -151,7 +203,8 @@ def main(env_id, seed):

# ============ use muzero_segment_collector instead of muzero_collector =============
from lzero.entry import train_unizero_segment
main_config.exp_name=f'data_sampled_unizero/dmc2gym_{env_id}_brf{buffer_reanalyze_freq}_state_cont_suz_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_K{K}_ns{num_simulations}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_{norm_type}_seed{seed}_learnsigma'
main_config.exp_name=f'data_suz_1216/dmc2gym_{env_id}_state_cont_suz_fs8_act-simnorm_td{game_segment_length}_dc099_learn-sigma_gcv5_rbs1e6_no-corlr_embed768_temp2.5e4_pew5e-2_19prior1flatten_obs10value01_clamp4_brf{buffer_reanalyze_freq}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_K{K}_ns{num_simulations}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_{norm_type}_seed{seed}'

train_unizero_segment([main_config, create_config], model_path=main_config.policy.model_path, seed=seed, max_env_step=max_env_step)


Expand All @@ -161,6 +214,15 @@ def main(env_id, seed):

parser.add_argument('--env', type=str, help='The environment to use', default='cartpole-swingup')
parser.add_argument('--seed', type=int, help='The seed to use', default=0)

args = parser.parse_args()

# args.env = 'cheetah-run'
# args.env = 'walker-walk'
# args.env = 'finger-spin'
# args.env = 'pendulum-swingup'

# args.env = 'hopper-hop'
# args.env = 'acrobot-swingup'

main(args.env, args.seed)
12 changes: 8 additions & 4 deletions zoo/jericho/configs/jericho_unizero_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@ def main(env_id='detective.z5', seed=0):
game_segment_length = 20
evaluator_env_num = 2
num_simulations = 50
max_env_step = int(5e5)
batch_size = 64
max_env_step = int(10e6)
# batch_size = 64
batch_size = 32
num_unroll_steps = 10
infer_context_length = 4
num_layers = 2
replay_ratio = 0.1
replay_ratio = 0.25
embed_dim = 768
# Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs.
# buffer_reanalyze_freq = 1/10
Expand Down Expand Up @@ -65,7 +66,10 @@ def main(env_id='detective.z5', seed=0):
manager=dict(shared_memory=False, )
),
policy=dict(
multi_gpu=False, # ======== Very important for ddp =============
# multi_gpu=True, # ======== Very important for ddp =============
# default is 10000
use_wandb=False,
learn=dict(learner=dict(
hook=dict(save_ckpt_after_iter=1000000, ), ), ),
model=dict(
Expand Down Expand Up @@ -135,7 +139,7 @@ def main(env_id='detective.z5', seed=0):
main_config = jericho_unizero_config
create_config = jericho_unizero_create_config

main_config.exp_name = f'data_unizero/{env_id[:-14]}/{env_id[:-14]}_uz_nlayer{num_layers}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
main_config.exp_name = f'data_unizero_detective_20241219/{env_id[:8]}_uz_nlayer{num_layers}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
from lzero.entry import train_unizero
train_unizero([main_config, create_config], seed=seed,
model_path=main_config.policy.model_path, max_env_step=max_env_step)
Expand Down
6 changes: 3 additions & 3 deletions zoo/jericho/envs/jericho_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def prepare_obs(self, obs, return_str: bool = False):
if not return_str:
full_obs = JerichoEnv.tokenizer(
[full_obs], truncation=True, padding="max_length", max_length=self.max_seq_len)
# obs_attn_mask = full_obs['attention_mask']
obs_attn_mask = full_obs['attention_mask']
full_obs = np.array(full_obs['input_ids'][0], dtype=np.int32) # TODO: attn_mask
if len(self._action_list) <= self.max_action_num:
action_mask = [1] * len(self._action_list) + [0] * \
Expand All @@ -56,8 +56,8 @@ def prepare_obs(self, obs, return_str: bool = False):
action_mask = [1] * len(self._action_list)

action_mask = np.array(action_mask, dtype=np.int8)
return {'observation': full_obs, 'action_mask': action_mask, 'to_play': -1}
# return {'observation': full_obs, 'obs_attn_mask': obs_attn_mask,'action_mask': action_mask, 'to_play': -1}
# return {'observation': full_obs, 'action_mask': action_mask, 'to_play': -1}
return {'observation': full_obs, 'obs_attn_mask': obs_attn_mask,'action_mask': action_mask, 'to_play': -1}

def reset(self, return_str: bool = False):
initial_observation, info = self._env.reset()
Expand Down

0 comments on commit 5346860

Please sign in to comment.