Skip to content

Commit

Permalink
polish(pu): polish jericho uz config
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 authored and puyuan committed Jan 7, 2025
1 parent 99b361d commit 86a1f7a
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 43 deletions.
36 changes: 31 additions & 5 deletions lzero/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from ding.torch_utils import MLP, ResBlock
from ding.utils import SequenceType
from ditk import logging

from ding.utils import set_pkg_seed, get_rank, get_world_size
import torch

# use dataclass to make the output of network more convenient to use
@dataclass
Expand Down Expand Up @@ -278,7 +279,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
使用 google-bert/bert-base-uncased , 模型的输入为id
"""
class HFLanguageRepresentationNetwork(nn.Module):
def __init__(self, url: str = 'google-bert/bert-base-uncased', embedding_size: int = 768, group_size: int = 8):
def __init__(self, url: str = 'google-bert/bert-base-uncased', embedding_size: int = 768, group_size: int = 8, tokenizer=None):
"""
初始化语言表示网络
Expand All @@ -287,9 +288,33 @@ def __init__(self, url: str = 'google-bert/bert-base-uncased', embedding_size: i
- embedding_size (int): 输出嵌入的维度大小,默认为 768。
"""
super().__init__()
from transformers import AutoModel
from transformers import AutoModel, AutoTokenizer

print(f"="*20)
print(f"url:{url}")
print(f"="*20)

# 加载 Hugging Face 预训练模型
self.model = AutoModel.from_pretrained(url)

# 只让 rank 0 下载模型
if get_rank() == 0:
self.model = AutoModel.from_pretrained(url)
if get_world_size() > 1:
# 等待 rank 0 完成模型加载
torch.distributed.barrier()
if get_rank() != 0: # 非 rank 0 的进程从本地缓存加载
self.model = AutoModel.from_pretrained(url)


if tokenizer is None:
# 只让 rank 0 下载模型
if get_rank() == 0:
self.tokenizer = AutoTokenizer.from_pretrained(url)
if get_world_size() > 1:
# 等待 rank 0 完成模型加载
torch.distributed.barrier()
if get_rank() != 0: # 非 rank 0 的进程从本地缓存加载
self.tokenizer = AutoTokenizer.from_pretrained(url)

# 设置嵌入维度,如果目标维度不是 768,则添加一个线性变换层用于降维或升维
self.embedding_size = embedding_size
Expand All @@ -309,11 +334,12 @@ def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor:
返回:
- torch.Tensor: 经过处理的语言嵌入向量,形状为 [batch_size, embedding_size]。
"""
attention_mask = x!= self.tokenizer.pad_token_id
if no_grad:
# 在 no_grad 模式下禁用梯度计算以节省显存
with torch.no_grad():
x = x.long() # 确保输入张量为长整型
outputs = self.model(x) # 获取模型的输出
outputs = self.model(x, attention_mask= attention_mask) # 获取模型的输出

# 模型输出的 last_hidden_state 形状为 [batch_size, seq_len, hidden_size]
# 我们通常取 [CLS] 标记对应的向量,即 outputs.last_hidden_state[:, 0, :]
Expand Down
32 changes: 27 additions & 5 deletions lzero/model/common_noserve_input_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from ding.torch_utils import MLP, ResBlock
from ding.utils import SequenceType
from ditk import logging

from ding.utils import set_pkg_seed, get_rank, get_world_size
import torch

# use dataclass to make the output of network more convenient to use
@dataclass
Expand Down Expand Up @@ -278,7 +279,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
使用 google-bert/bert-base-uncased , 模型的输入为id
"""
class HFLanguageRepresentationNetwork(nn.Module):
def __init__(self, url: str = 'google-bert/bert-base-uncased', embedding_size: int = 768, group_size: int = 8):
def __init__(self, url: str = 'google-bert/bert-base-uncased', embedding_size: int = 768, group_size: int = 8, tokenizer=None):
"""
初始化语言表示网络
Expand All @@ -287,9 +288,29 @@ def __init__(self, url: str = 'google-bert/bert-base-uncased', embedding_size: i
- embedding_size (int): 输出嵌入的维度大小,默认为 768。
"""
super().__init__()
from transformers import AutoModel
from transformers import AutoModel, AutoTokenizer

# 加载 Hugging Face 预训练模型
self.model = AutoModel.from_pretrained(url)

# 只让 rank 0 下载模型
if get_rank() == 0:
self.model = AutoModel.from_pretrained(url)
if get_world_size() > 1:
# 等待 rank 0 完成模型加载
torch.distributed.barrier()
if get_rank() != 0: # 非 rank 0 的进程从本地缓存加载
self.model = AutoModel.from_pretrained(url)


if tokenizer is None:
# 只让 rank 0 下载模型
if get_rank() == 0:
self.tokenizer = AutoTokenizer.from_pretrained(url)
if get_world_size() > 1:
# 等待 rank 0 完成模型加载
torch.distributed.barrier()
if get_rank() != 0: # 非 rank 0 的进程从本地缓存加载
self.tokenizer = AutoTokenizer.from_pretrained(url)

# 设置嵌入维度,如果目标维度不是 768,则添加一个线性变换层用于降维或升维
self.embedding_size = embedding_size
Expand All @@ -309,11 +330,12 @@ def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor:
返回:
- torch.Tensor: 经过处理的语言嵌入向量,形状为 [batch_size, embedding_size]。
"""
attention_mask = x!= self.tokenizer.pad_token_id
if no_grad:
# 在 no_grad 模式下禁用梯度计算以节省显存
with torch.no_grad():
x = x.long() # 确保输入张量为长整型
outputs = self.model(x) # 获取模型的输出
outputs = self.model(x, attention_mask= attention_mask) # 获取模型的输出

# 模型输出的 last_hidden_state 形状为 [batch_size, seq_len, hidden_size]
# 我们通常取 [CLS] 标记对应的向量,即 outputs.last_hidden_state[:, 0, :]
Expand Down
4 changes: 4 additions & 0 deletions lzero/model/unizero_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ def __init__(
print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder')
print('==' * 20)
elif world_model_cfg.obs_type == 'text':
print(f"="*20)
print(f"kwargs['encoder_url']:{kwargs['encoder_url']}")
print(f"="*20)

self.representation_network = HFLanguageRepresentationNetwork(url=kwargs['encoder_url'], embedding_size=world_model_cfg.embed_dim)
self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=None, with_lpips=False,)
self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer)
Expand Down
11 changes: 6 additions & 5 deletions lzero/worker/muzero_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,19 +721,20 @@ def collect(self,

collected_duration = sum([d['time'] for d in self._episode_info])

# Before allreduce
self._logger.info(f"Rank {self._rank} before allreduce: collected_step={collected_step}, collected_episode={collected_episode}")


# reduce data when enables DDP
if self._world_size > 1:
# Before allreduce
self._logger.info(f"Rank {self._rank} before allreduce: collected_step={collected_step}, collected_episode={collected_episode}")

dist.barrier()
# print(f"Rank {dist.get_rank()} collected_step: {collected_step}, collected_episode: {collected_episode}, collected_duration: {collected_duration}")
collected_step = allreduce_data(collected_step, 'sum')
collected_episode = allreduce_data(collected_episode, 'sum')
collected_duration = allreduce_data(collected_duration, 'sum')

# After allreduce
self._logger.info(f"Rank {self._rank} after allreduce: collected_step={collected_step}, collected_episode={collected_episode}")
# After allreduce
self._logger.info(f"Rank {self._rank} after allreduce: collected_step={collected_step}, collected_episode={collected_episode}")

self._total_envstep_count += collected_step
self._total_episode_count += collected_episode
Expand Down
49 changes: 31 additions & 18 deletions zoo/jericho/configs/jericho_unizero_config.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,35 @@
import os
from easydict import EasyDict
import os
os.environ["HF_HOME"] = "/mnt/afs/zhangshenghan/.cache/huggingface/hub"
# import os
# os.environ["HF_HOME"] = "/mnt/afs/zhangshenghan/.cache/huggingface/hub"

def main(env_id='detective.z5', seed=0):
action_space_size = 50
max_steps = 51
# action_space_size = 50
action_space_size = 10
max_steps = 50

# ==============================================================
# begin of the most frequently changed config specified by the user
# ==============================================================
collector_env_num = 2
num_segments = 2
game_segment_length = 20
# collector_env_num = 8
# n_episode = 8
collector_env_num = 4
n_episode = 4
evaluator_env_num = 2
num_simulations = 50
max_env_step = int(10e6)
batch_size = 32
num_unroll_steps = 10
infer_context_length = 4

# batch_size = 8
# num_unroll_steps = 10
# infer_context_length = 4

batch_size = 16
num_unroll_steps = 5
infer_context_length = 2

num_layers = 2
replay_ratio = 0.25
update_per_collect = 20 # NOTE: very important for ddp
update_per_collect = None # NOTE: very important for ddp
embed_dim = 768
# Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs.
# buffer_reanalyze_freq = 1/10
Expand All @@ -30,8 +38,8 @@ def main(env_id='detective.z5', seed=0):
reanalyze_batch_size = 160
# The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer.
reanalyze_partition = 0.75
# model_name = 'BAAI/bge-base-en-v1.5'
model_name = 'google-bert/bert-base-uncased'
model_name = 'BAAI/bge-base-en-v1.5'
# model_name = 'google-bert/bert-base-uncased'
# =========== TODO: only for debug ===========
# collector_env_num = 2
# num_segments = 2
Expand Down Expand Up @@ -84,7 +92,7 @@ def main(env_id='detective.z5', seed=0):
model_type='mlp',
continuous_action_space=False,
world_model_cfg=dict(
policy_entropy_weight=5e-3,
policy_entropy_weight=5e-2,
continuous_action_space=False,
max_blocks=num_unroll_steps,
# NOTE: each timestep has 2 tokens: obs and action
Expand All @@ -107,11 +115,16 @@ def main(env_id='detective.z5', seed=0):
replay_ratio=replay_ratio,
batch_size=batch_size,
learning_rate=0.0001,
cos_lr_scheduler=True,
manual_temperature_decay=True,
threshold_training_steps_for_final_temperature=int(2.5e4),
num_simulations=num_simulations,
num_segments=num_segments,
# num_segments=num_segments,
n_episode=n_episode,
train_start_after_envsteps=0, # TODO
game_segment_length=game_segment_length,
replay_buffer_size=int(1e6),
# game_segment_length=game_segment_length,
# replay_buffer_size=int(1e6),
replay_buffer_size=int(1e5),
eval_freq=int(5e3),
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
Expand Down Expand Up @@ -143,7 +156,7 @@ def main(env_id='detective.z5', seed=0):
main_config = jericho_unizero_config
create_config = jericho_unizero_create_config

main_config.exp_name = f'data_unizero_detective_20241220/{env_id[:8]}_ms{max_steps}_uz_nlayer{num_layers}_gsl{game_segment_length}_rr{replay_ratio}-upc{update_per_collect}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
main_config.exp_name = f'data_unizero_detective_20250107/{model_name}/{env_id[:8]}_ms{max_steps}_action-space-{action_space_size}_uz_nlayer{num_layers}_rr{replay_ratio}-upc{update_per_collect}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
from lzero.entry import train_unizero
train_unizero([main_config, create_config], seed=seed,
model_path=main_config.policy.model_path, max_env_step=max_env_step)
Expand Down
36 changes: 26 additions & 10 deletions zoo/jericho/envs/jericho_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ding.envs import BaseEnv, BaseEnvTimestep
from jericho import FrotzEnv
from ding.utils import set_pkg_seed, get_rank, get_world_size

import torch

@ENV_REGISTRY.register('jericho')
class JerichoEnv(BaseEnv):
Expand All @@ -29,8 +29,20 @@ def __init__(self, cfg):
self.max_action_num = cfg.max_action_num
self.max_seq_len = cfg.max_seq_len


# 获取当前的 world_size 和 rank
self.world_size = get_world_size()
self.rank = get_rank()

if JerichoEnv.tokenizer is None:
JerichoEnv.tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer_path)
# 只让 rank 0 下载模型
if self.rank == 0:
JerichoEnv.tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer_path)
if self.world_size > 1:
# 等待 rank 0 完成模型加载
torch.distributed.barrier()
if self.rank != 0: # 非 rank 0 的进程从本地缓存加载
JerichoEnv.tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer_path)

self._env = FrotzEnv(self.game_path)
self._action_list = None
Expand Down Expand Up @@ -93,14 +105,17 @@ def __repr__(self) -> str:
return "LightZero Jericho Env"

def step(self, action: int, return_str: bool = False):
try:
action_str = self._action_list[action]
except Exception as e:
# TODO: why exits illegal action
print('='*20)
print(e, f'rank {self.rank}, action {action} is illegal now we randomly choose a legal action from {self._action_list}!')
action = np.random.choice(len(self._action_list))
action_str = self._action_list[action]
if isinstance(action, str):
action_str = action
else:
try:
action_str = self._action_list[action]
except Exception as e:
# TODO: why exits illegal action
print('='*20)
print(e, f'rank {self.rank}, action {action} is illegal now we randomly choose a legal action from {self._action_list}!')
action = np.random.choice(len(self._action_list))
action_str = self._action_list[action]

observation, reward, done, info = self._env.step(action_str)
self.env_step += 1
Expand Down Expand Up @@ -165,4 +180,5 @@ def create_evaluator_env_cfg(cfg: dict) -> List[dict]:
obs, reward, done, info = env.step(action_id, return_str=True)
print(f'[OBS]:\n{obs["observation"]}')
if done:
action_id = input('Would you like to RESTART, RESTORE a saved game, give the FULL score for that game or QUIT?')
break

0 comments on commit 86a1f7a

Please sign in to comment.