diff --git a/lzero/model/unizero_world_models/kv_caching.py b/lzero/model/unizero_world_models/kv_caching.py index 5ec466209..f373739c6 100644 --- a/lzero/model/unizero_world_models/kv_caching.py +++ b/lzero/model/unizero_world_models/kv_caching.py @@ -140,11 +140,11 @@ def __init__(self, n: int, num_heads: int, max_tokens: int, embed_dim: int, devi self._k_cache = Cache(n, num_heads, max_tokens, embed_dim, device) self._v_cache = Cache(n, num_heads, max_tokens, embed_dim, device) - self.register_token_num = 2 # Number of register tokens TODO====== + # self.register_token_num = 2 # Number of register tokens TODO====== - def set_register_token_num(self, num: int) -> None: - """Set the number of register tokens.""" - self.register_token_num = num + # def set_register_token_num(self, num: int) -> None: + # """Set the number of register tokens.""" + # self.register_token_num = num @property def shape(self) -> Tuple[int, int, int, int]: diff --git a/lzero/model/unizero_world_models/transformer.py b/lzero/model/unizero_world_models/transformer.py index abc2e5a10..7397df45b 100644 --- a/lzero/model/unizero_world_models/transformer.py +++ b/lzero/model/unizero_world_models/transformer.py @@ -60,11 +60,29 @@ def __init__(self, config: TransformerConfig, task_embed=None) -> None: self.task_embed = task_embed self.task_embed_option = self.config.task_embed_option # Strategy for task embeddings + self.register_token_shared = True + + # TODO: 共享模式下,所有任务使用同一参数 + if self.task_embed_option == "register_task_embed": self.use_register_token = True # TODO # Register token setup self.register_token_num = config.register_token_num if hasattr(config, "register_token_num") else 4 - self.sim_norm = SimNorm(simnorm_dim=config.embed_dim) # Normalization for task embeddings + + # 判断是否采用共享模式 + self.register_token_shared = getattr(config, "register_token_shared", True) + if self.register_token_shared: + # print(f'self.register_token_shared:{self.register_token_shared}') + # print(f'='*20) + # 共享模式:所有任务使用同一个 register_tokens 参数,形状为 (register_token_num, embed_dim) + self.register_tokens = nn.Parameter(torch.empty(self.register_token_num, config.embed_dim)) + nn.init.xavier_uniform_(self.register_tokens) + else: + # 非共享模式:依赖外部传入的 task_embed 模块来生成 task embedding, + # 并通过 SimNorm 归一化后复制出 register token + self.task_embed = task_embed # 外部传入的模块,如 nn.Embedding + self.sim_norm = SimNorm(simnorm_dim=config.embed_dim) # Normalization for task embeddings + else: self.use_register_token = False # TODO @@ -83,16 +101,19 @@ def add_register_tokens(self, sequences: torch.Tensor, task_id: int) -> torch.Te B = sequences.size(0) device = sequences.device - # 生成一个可学习的 task embedding - # 并进行 SimNorm - task_embedding = self.task_embed(torch.tensor([task_id], device=device)) # (1, C) - task_embedding = self.sim_norm(task_embedding.view(1, -1)).view(-1) # (C, ) - # 扩展出 register_token_num - register_tokens = task_embedding.unsqueeze(0).expand(self.register_token_num, -1) # (register_token_num, C) - register_tokens = register_tokens.unsqueeze(0).expand(B, -1, -1) # (B, register_token_num, C) + if self.register_token_shared: + # 共享模式:直接使用同一组 register_tokens 参数 + # register_tokens 形状为 (register_token_num, embed_dim) + register_tokens = self.register_tokens + register_tokens = register_tokens.unsqueeze(0).expand(B, -1, -1) # 形状 (B, register_token_num, embed_dim) + else: + # 非共享模式:依靠 task_embed 动态生成 task embedding,然后复制出 register tokens + task_embedding = self.task_embed(torch.tensor([task_id], device=device)) # (1, embed_dim) + task_embedding = self.sim_norm(task_embedding.view(1, -1)).view(-1) # (embed_dim,) + register_tokens = task_embedding.unsqueeze(0).expand(self.register_token_num, -1) # (register_token_num, embed_dim) + register_tokens = register_tokens.unsqueeze(0).expand(B, -1, -1) # (B, register_token_num, embed_dim) - # 拼接:将 Register Token 拼到最后面 - new_sequences = torch.cat([sequences, register_tokens], dim=1) # (B, register_token_num + T, C) + new_sequences = torch.cat([sequences, register_tokens], dim=1) # 在序列末尾拼接 register tokens (B, register_token_num + T, C) return new_sequences def remove_register_tokens_from_kv(self, past_keys_values: KeysValues) -> None: diff --git a/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_config.py b/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_config.py index da866e44c..ff23a4968 100644 --- a/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_config.py +++ b/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_config.py @@ -54,10 +54,13 @@ def create_config(env_id, observation_shape_list, action_space_size_list, collec obs_type='vector', # use_shared_projection=True, # TODO use_shared_projection=False, - task_embed_option='concat_task_embed', # ==============TODO: none ============== - # task_embed_option='register_task_embed', # ==============TODO: none ============== + # task_embed_option='concat_task_embed', # ==============TODO: none ============== + task_embed_option='register_task_embed', # ==============TODO: none ============== + # task_embed_option=None, # ==============TODO: none ============== - register_token_num=2, # TODO: 修改kv_caching中的register_token_num + # register_token_num=4, + register_token_num=2, + use_task_embed=True, # TODO # use_task_embed=False, # ==============TODO============== num_unroll_steps=num_unroll_steps, @@ -77,11 +80,11 @@ def create_config(env_id, observation_shape_list, action_space_size_list, collec # num_layers=2, # num_layers=4, # TODO - # num_layers=8, # TODO - # num_heads=8, + num_layers=8, # TODO + num_heads=8, - num_layers=12, # TODO - num_heads=12, + # num_layers=12, # TODO + # num_heads=12, embed_dim=768, env_num=max(collector_env_num, evaluator_env_num), @@ -158,9 +161,9 @@ def generate_configs(env_id_list: List[str], total_batch_size: int): configs = [] # TODO: debug - exp_name_prefix = f'data_suz_mt_20250123/{len(env_id_list)}tasks_ddp_8gpu_nlayer12_upc200_no-taskweight_concat-task-embed_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/' + # exp_name_prefix = f'data_suz_mt_20250123/{len(env_id_list)}tasks_ddp_8gpu_nlayer12_upc200_no-taskweight_concat-task-embed_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/' - # exp_name_prefix = f'data_suz_mt_20250113/ddp_8gpu_nlayer8_upc200_no-taskweight-obsloss-temp1_register-task-embed-4_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/' + exp_name_prefix = f'data_suz_mt_20250207/ddp_8gpu_nlayer8_{len(env_id_list)}tasks_upc200_no-taskweight-obsloss-temp1_register-task-embed-2-shared_infer{infer_context_length}_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/' # exp_name_prefix = f'data_suz_mt_20250113/ddp_8gpu_nlayer8_upc200_no-taskweight-obsloss-temp1_register-task-embed-2-pos0_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/' # exp_name_prefix = f'data_suz_mt_20250113/ddp_8gpu_nlayer8_upc200_no-taskweight-obsloss-temp1_no-task-embed-2-pos0_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/' @@ -273,26 +276,26 @@ def create_env_manager(): ] # DMC 18games - env_id_list = [ - 'acrobot-swingup', - 'cartpole-balance', - 'cartpole-balance_sparse', - 'cartpole-swingup', - 'cartpole-swingup_sparse', - 'cheetah-run', - "ball_in_cup-catch", - "finger-spin", - "finger-turn_easy", - "finger-turn_hard", - 'hopper-hop', - 'hopper-stand', - 'pendulum-swingup', - 'reacher-easy', - 'reacher-hard', - 'walker-run', - 'walker-stand', - 'walker-walk', - ] + # env_id_list = [ + # 'acrobot-swingup', + # 'cartpole-balance', + # 'cartpole-balance_sparse', + # 'cartpole-swingup', + # 'cartpole-swingup_sparse', + # 'cheetah-run', + # "ball_in_cup-catch", + # "finger-spin", + # "finger-turn_easy", + # "finger-turn_hard", + # 'hopper-hop', + # 'hopper-stand', + # 'pendulum-swingup', + # 'reacher-easy', + # 'reacher-hard', + # 'walker-run', + # 'walker-stand', + # 'walker-walk', + # ] # 获取各环境的 action_space_size 和 observation_shape action_space_size_list = [dmc_state_env_action_space_map[env_id] for env_id in env_id_list] @@ -303,20 +306,25 @@ def create_env_manager(): n_episode = 8 evaluator_env_num = 3 num_simulations = 50 - # max_env_step = int(5e5) - max_env_step = int(1e6) + max_env_step = int(5e5) + # max_env_step = int(1e6) reanalyze_ratio = 0.0 - # nlayer=4 + # nlayer=4/8 total_batch_size = 512 batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] - # nlayer=8/12 - total_batch_size = 256 - batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] + # nlayer=12 + # total_batch_size = 256 + # batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] num_unroll_steps = 5 - infer_context_length = 4 # 尾部有4个register token, 相当于infer_context_length还是2 + infer_context_length = 5 # 尾部有4个register token, kv_cache里面已经去掉了 + + # 原始设置 + # num_unroll_steps = 5 + # infer_context_length = 2 + norm_type = 'LN' buffer_reanalyze_freq = 1 / 100000 reanalyze_batch_size = 160