Skip to content

Commit

Permalink
feature(pu): add register_token_shared option
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan committed Feb 7, 2025
1 parent 557e8f9 commit 8fe1a6d
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 50 deletions.
8 changes: 4 additions & 4 deletions lzero/model/unizero_world_models/kv_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,11 @@ def __init__(self, n: int, num_heads: int, max_tokens: int, embed_dim: int, devi
self._k_cache = Cache(n, num_heads, max_tokens, embed_dim, device)
self._v_cache = Cache(n, num_heads, max_tokens, embed_dim, device)

self.register_token_num = 2 # Number of register tokens TODO======
# self.register_token_num = 2 # Number of register tokens TODO======

def set_register_token_num(self, num: int) -> None:
"""Set the number of register tokens."""
self.register_token_num = num
# def set_register_token_num(self, num: int) -> None:
# """Set the number of register tokens."""
# self.register_token_num = num

@property
def shape(self) -> Tuple[int, int, int, int]:
Expand Down
41 changes: 31 additions & 10 deletions lzero/model/unizero_world_models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,29 @@ def __init__(self, config: TransformerConfig, task_embed=None) -> None:

self.task_embed = task_embed
self.task_embed_option = self.config.task_embed_option # Strategy for task embeddings
self.register_token_shared = True

# TODO: 共享模式下,所有任务使用同一参数

if self.task_embed_option == "register_task_embed":
self.use_register_token = True # TODO
# Register token setup
self.register_token_num = config.register_token_num if hasattr(config, "register_token_num") else 4
self.sim_norm = SimNorm(simnorm_dim=config.embed_dim) # Normalization for task embeddings

# 判断是否采用共享模式
self.register_token_shared = getattr(config, "register_token_shared", True)
if self.register_token_shared:
# print(f'self.register_token_shared:{self.register_token_shared}')
# print(f'='*20)
# 共享模式:所有任务使用同一个 register_tokens 参数,形状为 (register_token_num, embed_dim)
self.register_tokens = nn.Parameter(torch.empty(self.register_token_num, config.embed_dim))
nn.init.xavier_uniform_(self.register_tokens)
else:
# 非共享模式:依赖外部传入的 task_embed 模块来生成 task embedding,
# 并通过 SimNorm 归一化后复制出 register token
self.task_embed = task_embed # 外部传入的模块,如 nn.Embedding
self.sim_norm = SimNorm(simnorm_dim=config.embed_dim) # Normalization for task embeddings

else:
self.use_register_token = False # TODO

Expand All @@ -83,16 +101,19 @@ def add_register_tokens(self, sequences: torch.Tensor, task_id: int) -> torch.Te
B = sequences.size(0)
device = sequences.device

# 生成一个可学习的 task embedding
# 并进行 SimNorm
task_embedding = self.task_embed(torch.tensor([task_id], device=device)) # (1, C)
task_embedding = self.sim_norm(task_embedding.view(1, -1)).view(-1) # (C, )
# 扩展出 register_token_num
register_tokens = task_embedding.unsqueeze(0).expand(self.register_token_num, -1) # (register_token_num, C)
register_tokens = register_tokens.unsqueeze(0).expand(B, -1, -1) # (B, register_token_num, C)
if self.register_token_shared:
# 共享模式:直接使用同一组 register_tokens 参数
# register_tokens 形状为 (register_token_num, embed_dim)
register_tokens = self.register_tokens
register_tokens = register_tokens.unsqueeze(0).expand(B, -1, -1) # 形状 (B, register_token_num, embed_dim)
else:
# 非共享模式:依靠 task_embed 动态生成 task embedding,然后复制出 register tokens
task_embedding = self.task_embed(torch.tensor([task_id], device=device)) # (1, embed_dim)
task_embedding = self.sim_norm(task_embedding.view(1, -1)).view(-1) # (embed_dim,)
register_tokens = task_embedding.unsqueeze(0).expand(self.register_token_num, -1) # (register_token_num, embed_dim)
register_tokens = register_tokens.unsqueeze(0).expand(B, -1, -1) # (B, register_token_num, embed_dim)

# 拼接:将 Register Token 拼到最后面
new_sequences = torch.cat([sequences, register_tokens], dim=1) # (B, register_token_num + T, C)
new_sequences = torch.cat([sequences, register_tokens], dim=1) # 在序列末尾拼接 register tokens (B, register_token_num + T, C)
return new_sequences

def remove_register_tokens_from_kv(self, past_keys_values: KeysValues) -> None:
Expand Down
80 changes: 44 additions & 36 deletions zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,13 @@ def create_config(env_id, observation_shape_list, action_space_size_list, collec
obs_type='vector',
# use_shared_projection=True, # TODO
use_shared_projection=False,
task_embed_option='concat_task_embed', # ==============TODO: none ==============
# task_embed_option='register_task_embed', # ==============TODO: none ==============
# task_embed_option='concat_task_embed', # ==============TODO: none ==============
task_embed_option='register_task_embed', # ==============TODO: none ==============

# task_embed_option=None, # ==============TODO: none ==============
register_token_num=2, # TODO: 修改kv_caching中的register_token_num
# register_token_num=4,
register_token_num=2,

use_task_embed=True, # TODO
# use_task_embed=False, # ==============TODO==============
num_unroll_steps=num_unroll_steps,
Expand All @@ -77,11 +80,11 @@ def create_config(env_id, observation_shape_list, action_space_size_list, collec
# num_layers=2,
# num_layers=4, # TODO

# num_layers=8, # TODO
# num_heads=8,
num_layers=8, # TODO
num_heads=8,

num_layers=12, # TODO
num_heads=12,
# num_layers=12, # TODO
# num_heads=12,

embed_dim=768,
env_num=max(collector_env_num, evaluator_env_num),
Expand Down Expand Up @@ -158,9 +161,9 @@ def generate_configs(env_id_list: List[str],
total_batch_size: int):
configs = []
# TODO: debug
exp_name_prefix = f'data_suz_mt_20250123/{len(env_id_list)}tasks_ddp_8gpu_nlayer12_upc200_no-taskweight_concat-task-embed_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/'
# exp_name_prefix = f'data_suz_mt_20250123/{len(env_id_list)}tasks_ddp_8gpu_nlayer12_upc200_no-taskweight_concat-task-embed_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/'

# exp_name_prefix = f'data_suz_mt_20250113/ddp_8gpu_nlayer8_upc200_no-taskweight-obsloss-temp1_register-task-embed-4_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/'
exp_name_prefix = f'data_suz_mt_20250207/ddp_8gpu_nlayer8_{len(env_id_list)}tasks_upc200_no-taskweight-obsloss-temp1_register-task-embed-2-shared_infer{infer_context_length}_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/'

# exp_name_prefix = f'data_suz_mt_20250113/ddp_8gpu_nlayer8_upc200_no-taskweight-obsloss-temp1_register-task-embed-2-pos0_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/'
# exp_name_prefix = f'data_suz_mt_20250113/ddp_8gpu_nlayer8_upc200_no-taskweight-obsloss-temp1_no-task-embed-2-pos0_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/'
Expand Down Expand Up @@ -273,26 +276,26 @@ def create_env_manager():
]

# DMC 18games
env_id_list = [
'acrobot-swingup',
'cartpole-balance',
'cartpole-balance_sparse',
'cartpole-swingup',
'cartpole-swingup_sparse',
'cheetah-run',
"ball_in_cup-catch",
"finger-spin",
"finger-turn_easy",
"finger-turn_hard",
'hopper-hop',
'hopper-stand',
'pendulum-swingup',
'reacher-easy',
'reacher-hard',
'walker-run',
'walker-stand',
'walker-walk',
]
# env_id_list = [
# 'acrobot-swingup',
# 'cartpole-balance',
# 'cartpole-balance_sparse',
# 'cartpole-swingup',
# 'cartpole-swingup_sparse',
# 'cheetah-run',
# "ball_in_cup-catch",
# "finger-spin",
# "finger-turn_easy",
# "finger-turn_hard",
# 'hopper-hop',
# 'hopper-stand',
# 'pendulum-swingup',
# 'reacher-easy',
# 'reacher-hard',
# 'walker-run',
# 'walker-stand',
# 'walker-walk',
# ]

# 获取各环境的 action_space_size 和 observation_shape
action_space_size_list = [dmc_state_env_action_space_map[env_id] for env_id in env_id_list]
Expand All @@ -303,20 +306,25 @@ def create_env_manager():
n_episode = 8
evaluator_env_num = 3
num_simulations = 50
# max_env_step = int(5e5)
max_env_step = int(1e6)
max_env_step = int(5e5)
# max_env_step = int(1e6)
reanalyze_ratio = 0.0

# nlayer=4
# nlayer=4/8
total_batch_size = 512
batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))]

# nlayer=8/12
total_batch_size = 256
batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))]
# nlayer=12
# total_batch_size = 256
# batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))]

num_unroll_steps = 5
infer_context_length = 4 # 尾部有4个register token, 相当于infer_context_length还是2
infer_context_length = 5 # 尾部有4个register token, kv_cache里面已经去掉了

# 原始设置
# num_unroll_steps = 5
# infer_context_length = 2

norm_type = 'LN'
buffer_reanalyze_freq = 1 / 100000
reanalyze_batch_size = 160
Expand Down

0 comments on commit 8fe1a6d

Please sign in to comment.