diff --git a/lzero/model/unizero_world_models/hf_transformer.py b/lzero/model/unizero_world_models/hf_transformer.py index d0ef762be..d15759073 100644 --- a/lzero/model/unizero_world_models/hf_transformer.py +++ b/lzero/model/unizero_world_models/hf_transformer.py @@ -1,13 +1,23 @@ -from typing import Optional - +from typing import Optional, List import torch -from transformers import LlamaForCausalLM -from transformers.cache_utils import DynamicCache +from torch import nn +from transformers import LlamaForCausalLM, LlamaConfig +from transformers.models.llama.modeling_llama import LlamaModel, LlamaPreTrainedModel from .kv_caching import KeysValues +from transformers.cache_utils import DynamicCache + + +def kv2dc(cache: KeysValues) -> DynamicCache: + """ + 将自定义的 KeysValues 缓存转换为 Huggingface 的 DynamicCache 格式。 + Args: + cache (KeysValues): 自定义的键值缓存。 -def kv2dc(cache: KeysValues): + Returns: + DynamicCache: Huggingface 的动态缓存对象。 + """ res = DynamicCache() for kv_cache in cache: k_tensor = kv_cache._k_cache.get() @@ -17,94 +27,158 @@ def kv2dc(cache: KeysValues): return res -def update_kv(cache: KeysValues, new_cache: DynamicCache): +def update_kv(cache: KeysValues, new_cache: DynamicCache) -> None: + """ + 更新自定义的 KeysValues 缓存。 + + Args: + cache (KeysValues): 自定义的键值缓存。 + new_cache (DynamicCache): Huggingface 的动态缓存对象。 + """ for i in range(len(new_cache.key_cache)): - cache[i].update(new_cache.key_cache[-1], new_cache.value_cache[-1]) + # 更新时使用当前最新的 key 和 value + cache[i].update(new_cache.key_cache[i], new_cache.value_cache[i]) class HuggingfaceLlamaTransformer(LlamaForCausalLM): + """ + 使用预训练的 Huggingface Llama 模型作为主干的 Transformer 类。 + + 继承自 LlamaForCausalLM,并扩展自定义的缓存与投影层。 + """ + + def __init__(self, config: LlamaConfig) -> None: + super().__init__(config) + # 假设需要一个自定义的投影层,如果不需要可以移除 + # self.projection_layer = nn.Linear(config.hidden_size, config.hidden_size) + self.projection_layer = nn.Linear(2048, 768) # TODO======= + @classmethod - def from_pretrained(cls, lzero_config, *args, **kwargs): - # Add custom logic here - model = super(HuggingfaceLlamaTransformer, cls).from_pretrained(*args, **kwargs) - model.lzero_config = lzero_config + def from_pretrained(cls, model_name_or_path: str, *args, **kwargs): + """ + 从预训练模型加载权重,并初始化自定义的 Transformer 类。 + + Args: + model_name_or_path (str): 预训练模型的名称或路径。 + + Returns: + HuggingfaceLlamaTransformer: 初始化后的模型实例。 + """ + model = super(HuggingfaceLlamaTransformer, cls).from_pretrained(model_name_or_path, *args, **kwargs) return model def generate_empty_keys_values(self, n: int, max_tokens: int) -> KeysValues: """ - Generate a placeholder for keys and values. + 生成键值缓存的占位符。 - Arguments: - - n (:obj:`int`): Batch size. - - max_tokens (:obj:`int`): Maximum number of tokens in the sequence. + Args: + n (int): 批量大小。 + max_tokens (int): 序列的最大长度。 Returns: - - KeysValues: An object containing empty keys and values. + KeysValues: 包含空键值的对象。 """ - device = self.lzero_config.device # Assumption: All submodules are on the same device - return KeysValues(n, self.lzero_config.num_heads, max_tokens, - self.lzero_config.embed_dim, self.lzero_config.num_layers, - device, self.lzero_config.hidden_size) + device = self.device # 使用模型所在的设备 + return KeysValues( + n=n, + num_heads=self.config.num_attention_heads, + max_tokens=max_tokens, + embed_dim=self.config.hidden_size, + num_layers=self.config.num_hidden_layers, + device=device + ) - def _get_positional_embedding(self, layer, attn_type, pos_emb) -> torch.Tensor: + def _get_positional_embedding(self, layer: int, attn_type: str, pos_emb) -> torch.Tensor: """ - Helper function to get positional embedding for a given layer and attention type. + 获取指定层和注意力类型的位置信息嵌入。 - Arguments: - - layer (:obj:`int`): Layer index. - - attn_type (:obj:`str`): Attention type, either 'key' or 'value'. + Args: + layer (int): 层索引。 + attn_type (str): 注意力类型,'key' 或 'value'。 + pos_emb: 位置信息嵌入对象。 - Returns: - - torch.Tensor: The positional embedding tensor. - """ + Returns: + torch.Tensor: 位置信息嵌入张量。 + """ if attn_type == 'key': module_name = 'k_proj' elif attn_type == 'value': module_name = 'v_proj' - elif attn_type == 'query': - module_name = 'q_proj' else: - assert False - attn_func = getattr(self.model.layers[layer].self_attn, module_name) + raise ValueError("attn_type 必须是 'key' 或 'value'") + + # 获取对应层的注意力投影模块 + attn_module = self.model.layers[layer].self_attn + attn_func = getattr(attn_module, module_name) return attn_func(pos_emb.weight) - def forward(self, sequences: torch.Tensor, past_keys_values: Optional[KeysValues] = None, - valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[KeysValues] = None, + valid_context_lengths: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs + ) -> torch.Tensor: """ - Forward pass of the Transformer model. + Transformer 模型的前向传播。 - Arguments: - - sequences (:obj:`torch.Tensor`): Input tensor of shape (batch_size, seq_length, embed_dim). - - past_keys_values (:obj:`Optional[KeysValues]`): Precomputed keys and values for faster generation (default: None). - - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Valid lengths of context for masking (default: None). + Args: + input_ids (Optional[torch.Tensor]): 输入的 token IDs,形状为 (batch_size, seq_length)。 + attention_mask (Optional[torch.Tensor]): 注意力掩码,形状为 (batch_size, seq_length)。 + past_key_values (Optional[KeysValues]): 预计算的键值缓存,用于加速生成。 + valid_context_lengths (Optional[torch.Tensor]): 有效的上下文长度,用于掩码。 + inputs_embeds (Optional[torch.Tensor]): 输入的嵌入,形状为 (batch_size, seq_length, embed_dim)。 Returns: - - torch.Tensor: Output tensor of shape (batch_size, seq_length, embed_dim). + torch.Tensor: 模型的输出。 """ - assert past_keys_values is None or len(past_keys_values) == len(self.model.layers) - if past_keys_values is not None: - kv_cache = kv2dc(past_keys_values) + # 将自定义的键值缓存转换为 Huggingface 的格式 + if past_key_values is not None: + kv_cache = kv2dc(past_key_values) use_cache = True else: kv_cache = None - use_cache = False + use_cache = True # 根据需求,可以设置为 False - B, T, _ = sequences.shape + # 如果提供了有效上下文长度,则构建 attention_mask if valid_context_lengths is not None: - attention_mask = torch.arange(T).expand(B, T) >= (T - valid_context_lengths.unsqueeze(1)) + B, T = input_ids.shape + # 创建一个全为 1 的 attention_mask + attention_mask = torch.ones((B, T), dtype=torch.long, device=self.device) + # 根据 valid_context_lengths 设置无效部分为 0 + for i in range(B): + attention_mask[i, :T - valid_context_lengths[i]] = 0 else: - attention_mask = torch.ones_like(sequences) - # print(valid_context_lengths.shape) - # print(attention_mask.shape) - # print(sequences.shape) - # assert False - - output = self.model.forward( + if attention_mask is None: + # 默认情况下,创建一个全为 1 的 attention_mask + if input_ids is not None: + attention_mask = torch.ones_like(input_ids, device=self.device) + elif inputs_embeds is not None: + attention_mask = torch.ones(inputs_embeds.size()[:2], device=self.device) + else: + raise ValueError("输入缺少 input_ids 或 inputs_embeds") + + # 调用 Huggingface 的前向方法 + outputs = self.model( + input_ids=input_ids, attention_mask=attention_mask, past_key_values=kv_cache, - inputs_embeds=sequences, - use_cache=use_cache + inputs_embeds=inputs_embeds, + use_cache=use_cache, + **kwargs ) - update_kv(past_keys_values, kv_cache) - return output.logits[:, -1, :] + # 更新自定义的 KeysValues 缓存 + if past_key_values is not None: + update_kv(past_key_values, outputs.past_key_values) + + # 如果需要,可以添加自定义的投影层 + if hasattr(self, 'projection_layer') and self.projection_layer is not None: + # 确保最后一个隐藏状态的形状正确 + last_hidden_state = outputs.last_hidden_state # (batch_size, seq_length, hidden_size) + output_projection = self.projection_layer(last_hidden_state) # (batch_size, seq_length, hidden_size) + return output_projection + else: + return outputs.last_hidden_state \ No newline at end of file diff --git a/lzero/model/unizero_world_models/hf_transformer_bkp.py b/lzero/model/unizero_world_models/hf_transformer_bkp.py new file mode 100644 index 000000000..2e0680549 --- /dev/null +++ b/lzero/model/unizero_world_models/hf_transformer_bkp.py @@ -0,0 +1,183 @@ +from typing import Optional, List +import torch +from torch import nn +from transformers import LlamaForCausalLM, LlamaConfig +from transformers.models.llama.modeling_llama import LlamaModel, LlamaPreTrainedModel + +from .kv_caching import KeysValues +from transformers.cache_utils import DynamicCache + + +def kv2dc(cache: KeysValues) -> DynamicCache: + """ + 将自定义的 KeysValues 缓存转换为 Huggingface 的 DynamicCache 格式。 + + Args: + cache (KeysValues): 自定义的键值缓存。 + + Returns: + DynamicCache: Huggingface 的动态缓存对象。 + """ + res = DynamicCache() + for kv_cache in cache: + k_tensor = kv_cache._k_cache.get() + v_tensor = kv_cache._v_cache.get() + res.key_cache.append(k_tensor) + res.value_cache.append(v_tensor) + return res + + +def update_kv(cache: KeysValues, new_cache: DynamicCache) -> None: + """ + 更新自定义的 KeysValues 缓存。 + + Args: + cache (KeysValues): 自定义的键值缓存。 + new_cache (DynamicCache): Huggingface 的动态缓存对象。 + """ + for i in range(len(new_cache.key_cache)): + # 更新时使用当前最新的 key 和 value + cache[i].update(new_cache.key_cache[i], new_cache.value_cache[i]) + + +class HuggingfaceLlamaTransformer(LlamaForCausalLM): + """ + 使用预训练的 Huggingface Llama 模型作为主干的 Transformer 类。 + + 继承自 LlamaForCausalLM,并扩展自定义的缓存与投影层。 + """ + + def __init__(self, config: LlamaConfig) -> None: + super().__init__(config) + # 假设需要一个自定义的投影层,如果不需要可以移除 + self.projection_layer = nn.Linear(config.hidden_size, config.hidden_size) + + @classmethod + def from_pretrained(cls, model_name_or_path: str, *args, **kwargs): + """ + 从预训练模型加载权重,并初始化自定义的 Transformer 类。 + + Args: + model_name_or_path (str): 预训练模型的名称或路径。 + + Returns: + HuggingfaceLlamaTransformer: 初始化后的模型实例。 + """ + model = super(HuggingfaceLlamaTransformer, cls).from_pretrained(model_name_or_path, *args, **kwargs) + return model + + def generate_empty_keys_values(self, n: int, max_tokens: int) -> KeysValues: + """ + 生成键值缓存的占位符。 + + Args: + n (int): 批量大小。 + max_tokens (int): 序列的最大长度。 + + Returns: + KeysValues: 包含空键值的对象。 + """ + device = self.device # 使用模型所在的设备 + return KeysValues( + n=n, + num_heads=self.config.num_attention_heads, + max_tokens=max_tokens, + embed_dim=self.config.hidden_size, + num_layers=self.config.num_hidden_layers, + device=device + ) + + def _get_positional_embedding(self, layer: int, attn_type: str, pos_emb) -> torch.Tensor: + """ + 获取指定层和注意力类型的位置信息嵌入。 + + Args: + layer (int): 层索引。 + attn_type (str): 注意力类型,'key' 或 'value'。 + pos_emb: 位置信息嵌入对象。 + + Returns: + torch.Tensor: 位置信息嵌入张量。 + """ + if attn_type == 'key': + module_name = 'k_proj' + elif attn_type == 'value': + module_name = 'v_proj' + else: + raise ValueError("attn_type 必须是 'key' 或 'value'") + + # 获取对应层的注意力投影模块 + attn_module = self.model.layers[layer].self_attn + attn_func = getattr(attn_module, module_name) + return attn_func(pos_emb.weight) + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[KeysValues] = None, + valid_context_lengths: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs + ) -> torch.Tensor: + """ + Transformer 模型的前向传播。 + + Args: + input_ids (Optional[torch.Tensor]): 输入的 token IDs,形状为 (batch_size, seq_length)。 + attention_mask (Optional[torch.Tensor]): 注意力掩码,形状为 (batch_size, seq_length)。 + past_key_values (Optional[KeysValues]): 预计算的键值缓存,用于加速生成。 + valid_context_lengths (Optional[torch.Tensor]): 有效的上下文长度,用于掩码。 + inputs_embeds (Optional[torch.Tensor]): 输入的嵌入,形状为 (batch_size, seq_length, embed_dim)。 + + Returns: + torch.Tensor: 模型的输出。 + """ + # 将自定义的键值缓存转换为 Huggingface 的格式 + if past_key_values is not None: + kv_cache = kv2dc(past_key_values) + use_cache = True + else: + kv_cache = None + use_cache = True # 根据需求,可以设置为 False + + # 如果提供了有效上下文长度,则构建 attention_mask + if valid_context_lengths is not None: + B, T = input_ids.shape + # 创建一个全为 1 的 attention_mask + attention_mask = torch.ones((B, T), dtype=torch.long, device=self.device) + # 根据 valid_context_lengths 设置无效部分为 0 + for i in range(B): + attention_mask[i, :T - valid_context_lengths[i]] = 0 + else: + if attention_mask is None: + # 默认情况下,创建一个全为 1 的 attention_mask + if input_ids is not None: + attention_mask = torch.ones_like(input_ids, device=self.device) + elif inputs_embeds is not None: + attention_mask = torch.ones(inputs_embeds.size()[:2], device=self.device) + else: + raise ValueError("输入缺少 input_ids 或 inputs_embeds") + + # 调用 Huggingface 的前向方法 + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=kv_cache, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + **kwargs + ) + + # 更新自定义的 KeysValues 缓存 + if past_key_values is not None: + update_kv(past_key_values, outputs.past_key_values) + + # 如果需要,可以添加自定义的投影层 + if hasattr(self, 'projection_layer') and self.projection_layer is not None: + # 确保最后一个隐藏状态的形状正确 + last_hidden_state = outputs.last_hidden_state # (batch_size, seq_length, hidden_size) + output_projection = self.projection_layer(last_hidden_state) # (batch_size, seq_length, hidden_size) + return output_projection + else: + return outputs.last_hidden_state \ No newline at end of file diff --git a/zoo/jericho/configs/jericho_unizero_pretrained_config.py b/zoo/jericho/configs/jericho_unizero_pretrained_config.py index c02d9072e..c64b1bbfd 100644 --- a/zoo/jericho/configs/jericho_unizero_pretrained_config.py +++ b/zoo/jericho/configs/jericho_unizero_pretrained_config.py @@ -33,19 +33,19 @@ def main(env_id='detective.z5', seed=0): # model_name = 'BAAI/bge-base-en-v1.5' model_name = 'google-bert/bert-base-uncased' # =========== TODO: only for debug =========== - # collector_env_num = 2 - # num_segments = 2 - # game_segment_length = 20 - # evaluator_env_num = 2 - # max_env_step = int(5e5) - # batch_size = 10 - # num_simulations = 5 - # num_unroll_steps = 5 - # infer_context_length = 2 - # max_steps = 10 - # num_layers = 1 - # replay_ratio = 0.05 - # embed_dim = 768 + collector_env_num = 2 + num_segments = 2 + game_segment_length = 20 + evaluator_env_num = 2 + max_env_step = int(5e5) + batch_size = 10 + num_simulations = 5 + num_unroll_steps = 5 + infer_context_length = 2 + max_steps = 10 + num_layers = 1 + replay_ratio = 0.05 + embed_dim = 768 # TODO: MCTS内部的action_space受限于root节点的legal action # ============================================================== @@ -60,7 +60,7 @@ def main(env_id='detective.z5', seed=0): tokenizer_path=model_name, # tokenizer_path="/mnt/afs/zhangshenghan/.cache/huggingface/hub/models--google-bert--bert-base-uncased/snapshots/86b5e0934494bd15c9632b12f734a8a67f723594", max_seq_len=512, - game_path="z-machine-games-master/jericho-game-suite/" + env_id, + game_path="/mnt/afs/niuyazhe/code/LightZero/zoo/jericho/envs/z-machine-games-master/jericho-game-suite/"+ env_id, # game_path="/mnt/afs/niuyazhe/code/LightZero/zoo/jericho/envs/z-machine-games-master/jericho-game-suite/"+ env_id, collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, @@ -93,7 +93,7 @@ def main(env_id='detective.z5', seed=0): device='cuda', action_space_size=action_space_size, use_hf=True, - pretrained_path='/data/share/Llama-3.2-1B', + pretrained_path='/mnt/afs/share/Llama-3.2-1B', # These parameters should be the same as the config file of original model. num_layers=num_layers, # Note: llama uses GQA, and the number of heads should equal to the number of key-value heads. @@ -148,7 +148,7 @@ def main(env_id='detective.z5', seed=0): main_config = jericho_unizero_config create_config = jericho_unizero_create_config - main_config.exp_name = f'data_unizero_detective_20241220/{env_id[:8]}_ms{max_steps}_uz_nlayer{num_layers}_gsl{game_segment_length}_rr{replay_ratio}-upc{update_per_collect}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}' + main_config.exp_name = f'data_unizero_detective_debug/{env_id[:8]}_ms{max_steps}_uz_nlayer{num_layers}_gsl{game_segment_length}_rr{replay_ratio}-upc{update_per_collect}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}' from lzero.entry import train_unizero train_unizero([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step)