From f8d88e68efbe9ce736c06f967d84eaa18627fca3 Mon Sep 17 00:00:00 2001 From: puyuan1996 Date: Thu, 11 Jan 2024 12:33:54 +0800 Subject: [PATCH] fix(pu): latent state gradient times 0.2, set grad_clip_value to 0.5, add obs reconstruction loss, add load muzero representation net utils --- lzero/entry/train_muzero_gpt.py | 4 +- lzero/mcts/tree_search/mcts_ctree.py | 23 +- lzero/mcts/tree_search/mcts_ctree_orig.py | 478 ++++++++++++++++++ lzero/model/common.py | 93 +++- .../gpt_models/plot_sequence_frame_grey.py | 5 +- lzero/model/gpt_models/plot_weight_grad.py | 2 - lzero/model/gpt_models/tokenizer/tokenizer.py | 18 +- lzero/model/gpt_models/utils.py | 13 +- lzero/model/gpt_models/world_model.py | 33 +- .../gpt_models/world_model_bkp_20240105.py | 4 +- lzero/model/muzero_gpt_model.py | 8 +- lzero/policy/muzero_gpt.py | 92 +++- zoo/atari/config/atari_muzero_config_46464.py | 2 +- .../config/atari_muzero_gpt_config_stack4.py | 13 +- .../atari_muzero_gpt_config_stack4_debug.py | 13 +- 15 files changed, 727 insertions(+), 74 deletions(-) create mode 100644 lzero/mcts/tree_search/mcts_ctree_orig.py diff --git a/lzero/entry/train_muzero_gpt.py b/lzero/entry/train_muzero_gpt.py index 6eee4558e..f14312338 100644 --- a/lzero/entry/train_muzero_gpt.py +++ b/lzero/entry/train_muzero_gpt.py @@ -240,11 +240,11 @@ def train_muzero_gpt( # TODO: for batch world model ,to improve kv reuse, we could donot reset policy._learn_model.world_model.past_keys_values_cache.clear() - # if collector.envstep > 10000: + # if collector.envstep > 0: # # TODO: only for debug # for param in policy._learn_model.world_model.tokenizer.parameters(): # param.requires_grad = False - # print("train some steps before collector.envstep > 10000, then fixed") + # print("train some steps before collector.envstep > 0, then fixed") if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: break diff --git a/lzero/mcts/tree_search/mcts_ctree.py b/lzero/mcts/tree_search/mcts_ctree.py index 20336277c..553430674 100644 --- a/lzero/mcts/tree_search/mcts_ctree.py +++ b/lzero/mcts/tree_search/mcts_ctree.py @@ -279,6 +279,13 @@ def search( min_max_stats_lst = tree_muzero.MinMaxStatsList(batch_size) min_max_stats_lst.set_delta(self._cfg.value_delta_max) + state_action_history = [] # 初始化 state_action_history 变量 + last_latent_state = latent_state_roots + # NOTE: very important, from the right init key-value-cache + # forward_initial_inference()以及执行了下面的操作 + # _ = model.world_model.refresh_keys_values_with_initial_obs_tokens(model.world_model.obs_tokens) + + # model.world_model.past_keys_values_cache.clear() # 清除缓存 for simulation_index in range(self._cfg.num_simulations): # In each simulation, we expanded a new node, so in one search, we have ``num_simulations`` num of nodes at most. @@ -305,8 +312,14 @@ def search( latent_states.append(latent_state_batch_in_search_path[ix][iy]) latent_states = torch.from_numpy(np.asarray(latent_states)).to(self._cfg.device).float() - # .long() is only for discrete action + # TODO: .long() is only for discrete action last_actions = torch.from_numpy(np.asarray(last_actions)).to(self._cfg.device).long() + + # TODO + # 在每次模拟后更新 state_action_history + # state_action_history.append((last_latent_state, last_actions.detach().cpu().numpy())) + state_action_history.append((latent_states.detach().cpu().numpy(), last_actions.detach().cpu().numpy())) + """ MCTS stage 2: Expansion At the final time-step l of the simulation, the next_latent_state and reward/value_prefix are computed by the dynamics function. @@ -314,7 +327,9 @@ def search( MCTS stage 3: Backup At the end of the simulation, the statistics along the trajectory are updated. """ - network_output = model.recurrent_inference(latent_states, last_actions) + # network_output = model.recurrent_inference(latent_states, last_actions) # for classic muzero + # network_output = model.recurrent_inference(last_actions) # TODO: for muzero_gpt latent_states is not used in the model. + network_output = model.recurrent_inference(state_action_history) # TODO: latent_states is not used in the model. network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state) network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits) @@ -322,6 +337,10 @@ def search( network_output.reward = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.reward)) latent_state_batch_in_search_path.append(network_output.latent_state) + + # TODO + # last_latent_state = network_output.latent_state + # tolist() is to be compatible with cpp datatype. reward_batch = network_output.reward.reshape(-1).tolist() value_batch = network_output.value.reshape(-1).tolist() diff --git a/lzero/mcts/tree_search/mcts_ctree_orig.py b/lzero/mcts/tree_search/mcts_ctree_orig.py new file mode 100644 index 000000000..20336277c --- /dev/null +++ b/lzero/mcts/tree_search/mcts_ctree_orig.py @@ -0,0 +1,478 @@ +import copy +from typing import TYPE_CHECKING, List, Any, Union + +import numpy as np +import torch +from easydict import EasyDict + +from lzero.mcts.ctree.ctree_efficientzero import ez_tree as tree_efficientzero +from lzero.mcts.ctree.ctree_muzero import mz_tree as tree_muzero +from lzero.mcts.ctree.ctree_gumbel_muzero import gmz_tree as tree_gumbel_muzero +from lzero.policy import InverseScalarTransform, to_detach_cpu_numpy + +if TYPE_CHECKING: + from lzero.mcts.ctree.ctree_efficientzero import ez_tree as ez_ctree + from lzero.mcts.ctree.ctree_muzero import mz_tree as mz_ctree + from lzero.mcts.ctree.ctree_gumbel_muzero import gmz_tree as gmz_ctree + +# ============================================================== +# EfficientZero +# ============================================================== + + +class EfficientZeroMCTSCtree(object): + """ + Overview: + MCTSCtree for EfficientZero. The core ``batch_traverse`` and ``batch_backpropagate`` function is implemented in C++. + Interfaces: + __init__, roots, search + + """ + + config = dict( + # (float) The alpha value used in the Dirichlet distribution for exploration at the root node of the search tree. + root_dirichlet_alpha=0.3, + # (float) The noise weight at the root node of the search tree. + root_noise_weight=0.25, + # (int) The base constant used in the PUCT formula for balancing exploration and exploitation during tree search. + pb_c_base=19652, + # (float) The initialization constant used in the PUCT formula for balancing exploration and exploitation during tree search. + pb_c_init=1.25, + # (float) The maximum change in value allowed during the backup step of the search tree update. + value_delta_max=0.01, + ) + + @classmethod + def default_config(cls: type) -> EasyDict: + cfg = EasyDict(copy.deepcopy(cls.config)) + cfg.cfg_type = cls.__name__ + 'Dict' + return cfg + + def __init__(self, cfg: EasyDict = None) -> None: + """ + Overview: + Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key + in the default configuration, the user-provided value will override the default configuration. Otherwise, + the default configuration will be used. + """ + default_config = self.default_config() + default_config.update(cfg) + self._cfg = default_config + self.inverse_scalar_transform_handle = InverseScalarTransform( + self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution + ) + + @classmethod + def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "ez_ctree.Roots": + """ + Overview: + The initialization of CRoots with root num and legal action lists. + Arguments: + - root_num (:obj:'int'): the number of the current root. + - legal_action_list (:obj:'List'): the vector of the legal action of this root. + """ + from lzero.mcts.ctree.ctree_efficientzero import ez_tree as ctree + return ctree.Roots(active_collect_env_num, legal_actions) + + def search( + self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], + reward_hidden_state_roots: List[Any], to_play_batch: Union[int, List[Any]] + ) -> None: + """ + Overview: + Do MCTS for the roots (a batch of root nodes in parallel). Parallel in model inference. + Use the cpp ctree. + Arguments: + - roots (:obj:`Any`): a batch of expanded root nodes + - latent_state_roots (:obj:`list`): the hidden states of the roots + - reward_hidden_state_roots (:obj:`list`): the value prefix hidden states in LSTM of the roots + - to_play_batch (:obj:`list`): the to_play_batch list used in self-play-mode board games + """ + with torch.no_grad(): + model.eval() + + # preparation some constant + batch_size = roots.num + pb_c_base, pb_c_init, discount_factor = self._cfg.pb_c_base, self._cfg.pb_c_init, self._cfg.discount_factor + + # the data storage of latent states: storing the latent state of all the nodes in one search. + latent_state_batch_in_search_path = [latent_state_roots] + # the data storage of value prefix hidden states in LSTM + reward_hidden_state_c_batch = [reward_hidden_state_roots[0]] + reward_hidden_state_h_batch = [reward_hidden_state_roots[1]] + + # minimax value storage + min_max_stats_lst = tree_efficientzero.MinMaxStatsList(batch_size) + min_max_stats_lst.set_delta(self._cfg.value_delta_max) + + for simulation_index in range(self._cfg.num_simulations): + # In each simulation, we expanded a new node, so in one search, we have ``num_simulations`` num of nodes at most. + + latent_states = [] + hidden_states_c_reward = [] + hidden_states_h_reward = [] + + # prepare a result wrapper to transport results between python and c++ parts + results = tree_efficientzero.ResultsWrapper(num=batch_size) + + # latent_state_index_in_search_path: the first index of leaf node states in latent_state_batch_in_search_path, i.e. is current_latent_state_index in one the search. + # latent_state_index_in_batch: the second index of leaf node states in latent_state_batch_in_search_path, i.e. the index in the batch, whose maximum is ``batch_size``. + # e.g. the latent state of the leaf node in (x, y) is latent_state_batch_in_search_path[x, y], where x is current_latent_state_index, y is batch_index. + # The index of value prefix hidden state of the leaf node is in the same manner. + """ + MCTS stage 1: Selection + Each simulation starts from the internal root state s0, and finishes when the simulation reaches a leaf node s_l. + """ + latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, virtual_to_play_batch = tree_efficientzero.batch_traverse( + roots, pb_c_base, pb_c_init, discount_factor, min_max_stats_lst, results, + copy.deepcopy(to_play_batch) + ) + # obtain the search horizon for leaf nodes + search_lens = results.get_search_len() + + # obtain the latent state for leaf node + for ix, iy in zip(latent_state_index_in_search_path, latent_state_index_in_batch): + latent_states.append(latent_state_batch_in_search_path[ix][iy]) + hidden_states_c_reward.append(reward_hidden_state_c_batch[ix][0][iy]) + hidden_states_h_reward.append(reward_hidden_state_h_batch[ix][0][iy]) + + latent_states = torch.from_numpy(np.asarray(latent_states)).to(self._cfg.device).float() + hidden_states_c_reward = torch.from_numpy(np.asarray(hidden_states_c_reward)).to(self._cfg.device + ).unsqueeze(0) + hidden_states_h_reward = torch.from_numpy(np.asarray(hidden_states_h_reward)).to(self._cfg.device + ).unsqueeze(0) + # .long() is only for discrete action + last_actions = torch.from_numpy(np.asarray(last_actions)).to(self._cfg.device).long() + """ + MCTS stage 2: Expansion + At the final time-step l of the simulation, the next_latent_state and reward/value_prefix are computed by the dynamics function. + Then we calculate the policy_logits and value for the leaf node (next_latent_state) by the prediction function. (aka. evaluation) + MCTS stage 3: Backup + At the end of the simulation, the statistics along the trajectory are updated. + """ + network_output = model.recurrent_inference( + latent_states, (hidden_states_c_reward, hidden_states_h_reward), last_actions + ) + + network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state) + network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits) + network_output.value = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.value)) + network_output.value_prefix = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.value_prefix)) + + network_output.reward_hidden_state = ( + network_output.reward_hidden_state[0].detach().cpu().numpy(), + network_output.reward_hidden_state[1].detach().cpu().numpy() + ) + + latent_state_batch_in_search_path.append(network_output.latent_state) + # tolist() is to be compatible with cpp datatype. + value_prefix_batch = network_output.value_prefix.reshape(-1).tolist() + value_batch = network_output.value.reshape(-1).tolist() + policy_logits_batch = network_output.policy_logits.tolist() + + reward_latent_state_batch = network_output.reward_hidden_state + # reset the hidden states in LSTM every ``lstm_horizon_len`` steps in one search. + # which enable the model only need to predict the value prefix in a range (e.g.: [s0,...,s5]) + assert self._cfg.lstm_horizon_len > 0 + reset_idx = (np.array(search_lens) % self._cfg.lstm_horizon_len == 0) + assert len(reset_idx) == batch_size + reward_latent_state_batch[0][:, reset_idx, :] = 0 + reward_latent_state_batch[1][:, reset_idx, :] = 0 + is_reset_list = reset_idx.astype(np.int32).tolist() + reward_hidden_state_c_batch.append(reward_latent_state_batch[0]) + reward_hidden_state_h_batch.append(reward_latent_state_batch[1]) + + # In ``batch_backpropagate()``, we first expand the leaf node using ``the policy_logits`` and + # ``reward`` predicted by the model, then perform backpropagation along the search path to update the + # statistics. + + # NOTE: simulation_index + 1 is very important, which is the depth of the current leaf node. + current_latent_state_index = simulation_index + 1 + tree_efficientzero.batch_backpropagate( + current_latent_state_index, discount_factor, value_prefix_batch, value_batch, policy_logits_batch, + min_max_stats_lst, results, is_reset_list, virtual_to_play_batch + ) + + +# ============================================================== +# MuZero +# ============================================================== + + +class MuZeroMCTSCtree(object): + """ + Overview: + MCTSCtree for MuZero. The core ``batch_traverse`` and ``batch_backpropagate`` function is implemented in C++. + + Interfaces: + __init__, roots, search + """ + + config = dict( + # (float) The alpha value used in the Dirichlet distribution for exploration at the root node of the search tree. + root_dirichlet_alpha=0.3, + # (float) The noise weight at the root node of the search tree. + root_noise_weight=0.25, + # (int) The base constant used in the PUCT formula for balancing exploration and exploitation during tree search. + pb_c_base=19652, + # (float) The initialization constant used in the PUCT formula for balancing exploration and exploitation during tree search. + pb_c_init=1.25, + # (float) The maximum change in value allowed during the backup step of the search tree update. + value_delta_max=0.01, + ) + + @classmethod + def default_config(cls: type) -> EasyDict: + cfg = EasyDict(copy.deepcopy(cls.config)) + cfg.cfg_type = cls.__name__ + 'Dict' + return cfg + + def __init__(self, cfg: EasyDict = None) -> None: + """ + Overview: + Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key + in the default configuration, the user-provided value will override the default configuration. Otherwise, + the default configuration will be used. + """ + default_config = self.default_config() + default_config.update(cfg) + self._cfg = default_config + self.inverse_scalar_transform_handle = InverseScalarTransform( + self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution + ) + + @classmethod + def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "mz_ctree": + """ + Overview: + The initialization of CRoots with root num and legal action lists. + Arguments: + - root_num (:obj:`int`): the number of the current root. + - legal_action_list (:obj:`list`): the vector of the legal action of this root. + """ + from lzero.mcts.ctree.ctree_muzero import mz_tree as ctree + return ctree.Roots(active_collect_env_num, legal_actions) + + def search( + self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], to_play_batch: Union[int, + List[Any]] + ) -> None: + """ + Overview: + Do MCTS for the roots (a batch of root nodes in parallel). Parallel in model inference. + Use the cpp ctree. + Arguments: + - roots (:obj:`Any`): a batch of expanded root nodes + - latent_state_roots (:obj:`list`): the hidden states of the roots + - to_play_batch (:obj:`list`): the to_play_batch list used in in self-play-mode board games + """ + with torch.no_grad(): + model.eval() + + # preparation some constant + batch_size = roots.num + pb_c_base, pb_c_init, discount_factor = self._cfg.pb_c_base, self._cfg.pb_c_init, self._cfg.discount_factor + # the data storage of latent states: storing the latent state of all the nodes in the search. + latent_state_batch_in_search_path = [latent_state_roots] + + # minimax value storage + min_max_stats_lst = tree_muzero.MinMaxStatsList(batch_size) + min_max_stats_lst.set_delta(self._cfg.value_delta_max) + + for simulation_index in range(self._cfg.num_simulations): + # In each simulation, we expanded a new node, so in one search, we have ``num_simulations`` num of nodes at most. + + latent_states = [] + + # prepare a result wrapper to transport results between python and c++ parts + results = tree_muzero.ResultsWrapper(num=batch_size) + + # latent_state_index_in_search_path: the first index of leaf node states in latent_state_batch_in_search_path, i.e. is current_latent_state_index in one the search. + # latent_state_index_in_batch: the second index of leaf node states in latent_state_batch_in_search_path, i.e. the index in the batch, whose maximum is ``batch_size``. + # e.g. the latent state of the leaf node in (x, y) is latent_state_batch_in_search_path[x, y], where x is current_latent_state_index, y is batch_index. + # The index of value prefix hidden state of the leaf node are in the same manner. + """ + MCTS stage 1: Selection + Each simulation starts from the internal root state s0, and finishes when the simulation reaches a leaf node s_l. + """ + latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, virtual_to_play_batch = tree_muzero.batch_traverse( + roots, pb_c_base, pb_c_init, discount_factor, min_max_stats_lst, results, + copy.deepcopy(to_play_batch) + ) + + # obtain the latent state for leaf node + for ix, iy in zip(latent_state_index_in_search_path, latent_state_index_in_batch): + latent_states.append(latent_state_batch_in_search_path[ix][iy]) + + latent_states = torch.from_numpy(np.asarray(latent_states)).to(self._cfg.device).float() + # .long() is only for discrete action + last_actions = torch.from_numpy(np.asarray(last_actions)).to(self._cfg.device).long() + """ + MCTS stage 2: Expansion + At the final time-step l of the simulation, the next_latent_state and reward/value_prefix are computed by the dynamics function. + Then we calculate the policy_logits and value for the leaf node (next_latent_state) by the prediction function. (aka. evaluation) + MCTS stage 3: Backup + At the end of the simulation, the statistics along the trajectory are updated. + """ + network_output = model.recurrent_inference(latent_states, last_actions) + + network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state) + network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits) + network_output.value = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.value)) + network_output.reward = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.reward)) + + latent_state_batch_in_search_path.append(network_output.latent_state) + # tolist() is to be compatible with cpp datatype. + reward_batch = network_output.reward.reshape(-1).tolist() + value_batch = network_output.value.reshape(-1).tolist() + policy_logits_batch = network_output.policy_logits.tolist() + + # In ``batch_backpropagate()``, we first expand the leaf node using ``the policy_logits`` and + # ``reward`` predicted by the model, then perform backpropagation along the search path to update the + # statistics. + + # NOTE: simulation_index + 1 is very important, which is the depth of the current leaf node. + current_latent_state_index = simulation_index + 1 + tree_muzero.batch_backpropagate( + current_latent_state_index, discount_factor, reward_batch, value_batch, policy_logits_batch, + min_max_stats_lst, results, virtual_to_play_batch + ) + +class GumbelMuZeroMCTSCtree(object): + """ + Overview: + MCTSCtree for Gumbel MuZero. The core ``batch_traverse`` and ``batch_backpropagate`` function is implemented in C++. + Interfaces: + __init__, roots, search + + """ + config = dict( + # (int) The max limitation of simluation times during the simulation. + num_simulations=50, + # (float) The alpha value used in the Dirichlet distribution for exploration at the root node of the search tree. + root_dirichlet_alpha=0.3, + # (float) The noise weight at the root node of the search tree. + root_noise_weight=0.25, + # (float) The maximum change in value allowed during the backup step of the search tree update. + value_delta_max=0.01, + ) + + @classmethod + def default_config(cls: type) -> EasyDict: + cfg = EasyDict(copy.deepcopy(cls.config)) + cfg.cfg_type = cls.__name__ + 'Dict' + return cfg + + def __init__(self, cfg: EasyDict = None) -> None: + """ + Overview: + Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key + in the default configuration, the user-provided value will override the default configuration. Otherwise, + the default configuration will be used. + """ + default_config = self.default_config() + default_config.update(cfg) + self._cfg = default_config + self.inverse_scalar_transform_handle = InverseScalarTransform( + self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution + ) + + @classmethod + def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "gmz_ctree": + """ + Overview: + The initialization of CRoots with root num and legal action lists. + Arguments: + - root_num (:obj:`int`): the number of the current root. + - legal_action_list (:obj:`list`): the vector of the legal action of this root. + """ + from lzero.mcts.ctree.ctree_gumbel_muzero import gmz_tree as ctree + return ctree.Roots(active_collect_env_num, legal_actions) + + def search(self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], to_play_batch: Union[int, + List[Any]] + ) -> None: + """ + Overview: + Do MCTS for the roots (a batch of root nodes in parallel). Parallel in model inference. + Use the cpp tree. + Arguments: + - roots (:obj:`Any`): a batch of expanded root nodes + - latent_state_roots (:obj:`list`): the hidden states of the roots + - to_play_batch (:obj:`list`): the to_play_batch list used in two_player mode board games + """ + with torch.no_grad(): + model.eval() + + # preparation some constant + batch_size = roots.num + device = self._cfg.device + discount_factor = self._cfg.discount_factor + # the data storage of hidden states: storing the states of all the tree nodes + latent_state_batch_in_search_path = [latent_state_roots] + + # minimax value storage + min_max_stats_lst = tree_gumbel_muzero.MinMaxStatsList(batch_size) + min_max_stats_lst.set_delta(self._cfg.value_delta_max) + + for simulation_index in range(self._cfg.num_simulations): + # In each simulation, we expanded a new node, so in one search, we have ``num_simulations`` num of nodes at most. + + latent_states = [] + + # prepare a result wrapper to transport results between python and c++ parts + results = tree_gumbel_muzero.ResultsWrapper(num=batch_size) + + # traverse to select actions for each root + # hidden_state_index_x_lst: the first index of leaf node states in hidden_state_pool + # hidden_state_index_y_lst: the second index of leaf node states in hidden_state_pool + # the hidden state of the leaf node is hidden_state_pool[x, y]; value prefix states are the same + """ + MCTS stage 1: Selection + Each simulation starts from the internal root state s0, and finishes when the simulation reaches a leaf node s_l. + In gumbel muzero, the action at the root node is selected using the Sequential Halving algorithm, while the action + at the interier node is selected based on the completion of the action values. + """ + latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, virtual_to_play_batch = tree_gumbel_muzero.batch_traverse( + roots, self._cfg.num_simulations, self._cfg.max_num_considered_actions, discount_factor, results, copy.deepcopy(to_play_batch) + ) + + # obtain the states for leaf nodes + for ix, iy in zip(latent_state_index_in_search_path, latent_state_index_in_batch): + latent_states.append(latent_state_batch_in_search_path[ix][iy]) + + latent_states = torch.from_numpy(np.asarray(latent_states)).to(device).float() + # .long() is only for discrete action + last_actions = torch.from_numpy(np.asarray(last_actions)).to(device).unsqueeze(1).long() + """ + MCTS stage 2: Expansion + At the final time-step l of the simulation, the next_latent_state and reward/value_prefix are computed by the dynamics function. + Then we calculate the policy_logits and value for the leaf node (next_latent_state) by the prediction function. (aka. evaluation) + MCTS stage 3: Backup + At the end of the simulation, the statistics along the trajectory are updated. + """ + network_output = model.recurrent_inference(latent_states, last_actions) + + network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state) + network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits) + network_output.value = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.value)) + network_output.reward = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.reward)) + + latent_state_batch_in_search_path.append(network_output.latent_state) + # tolist() is to be compatible with cpp datatype. + reward_batch = network_output.reward.reshape(-1).tolist() + value_batch = network_output.value.reshape(-1).tolist() + policy_logits_batch = network_output.policy_logits.tolist() + + # In ``batch_backpropagate()``, we first expand the leaf node using ``the policy_logits`` and + # ``reward`` predicted by the model, then perform backpropagation along the search path to update the + # statistics. + + # NOTE: simulation_index + 1 is very important, which is the depth of the current leaf node. + current_latent_state_index = simulation_index + 1 + + # backpropagation along the search path to update the attributes + tree_gumbel_muzero.batch_back_propagate( + current_latent_state_index, discount_factor, reward_batch, value_batch, policy_logits_batch, + min_max_stats_lst, results, virtual_to_play_batch + ) diff --git a/lzero/model/common.py b/lzero/model/common.py index 991dbf0cf..47f2ccbd2 100644 --- a/lzero/model/common.py +++ b/lzero/model/common.py @@ -139,6 +139,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return output +# EZ original # def renormalize(inputs: torch.Tensor, first_dim: int = 1) -> torch.Tensor: # """ # Overview: @@ -158,17 +159,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # return flat_input.view(*input.shape) -# def renormalize(x): # min-max -# # x is a 2D tensor of shape (batch_size, num_features) -# # Compute the min and max for each feature across the batch -# x_min = torch.min(x, dim=0, keepdim=True).values -# x_max = torch.max(x, dim=0, keepdim=True).values +def renormalize(x): # min-max + # x is a 2D tensor of shape (batch_size, num_features) + # Compute the min and max for each feature across the batch + x_min = torch.min(x, dim=0, keepdim=True).values + x_max = torch.max(x, dim=0, keepdim=True).values -# # Apply min-max normalization -# x_std = (x - x_min) / (x_max - x_min + 1e-8) # Add a small epsilon to avoid division by zero -# x_scaled = x_std * (1 - 0) + 0 # Assuming you want to scale between 0 and 1 + # Apply min-max normalization + x_std = (x - x_min) / (x_max - x_min + 1e-8) # Add a small epsilon to avoid division by zero + x_scaled = x_std * (1 - 0) + 0 # Assuming you want to scale between 0 and 1 -# return x_scaled + return x_scaled # def renormalize(x): # z-score # # x is a 2D tensor of shape (batch_size, num_features) @@ -181,19 +182,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # return x_normalized -def renormalize(x): # robust scaling - # x is a 2D tensor of shape (batch_size, num_features) - # Compute the 1st and 3rd quartile - q1 = torch.quantile(x, 0.25, dim=0, keepdim=True) - q3 = torch.quantile(x, 0.75, dim=0, keepdim=True) +# def renormalize(x): # robust scaling +# # x is a 2D tensor of shape (batch_size, num_features) +# # Compute the 1st and 3rd quartile +# q1 = torch.quantile(x, 0.25, dim=0, keepdim=True) +# q3 = torch.quantile(x, 0.75, dim=0, keepdim=True) - # Compute the interquartile range (IQR) - iqr = q3 - q1 +# # Compute the interquartile range (IQR) +# iqr = q3 - q1 - # Apply robust scaling - x_scaled = (x - q1) / (iqr + 1e-8) # Again, add epsilon to avoid division by zero +# # Apply robust scaling +# x_scaled = (x - q1) / (iqr + 1e-8) # Again, add epsilon to avoid division by zero - return x_scaled +# return x_scaled def AvgL1Norm(x, eps=1e-8): return x/x.abs().mean(-1,keepdim=True).clamp(min=eps) @@ -261,14 +262,15 @@ def __init__( # self.last_linear = nn.Linear(64*4*4, 64*4*4) # self.last_linear = nn.Linear(64*4*4, 256) - self.last_linear = nn.Linear(64*8*8, self.embedding_dim) + # self.last_linear = nn.Linear(64*8*8, self.embedding_dim) + self.last_linear = nn.Linear(64*8*8, self.embedding_dim, bias=False) # TODO # Initialize weights using He initialization init.kaiming_normal_(self.last_linear.weight, mode='fan_out', nonlinearity='relu') # Initialize biases to zero - init.zeros_(self.last_linear.bias) + # init.zeros_(self.last_linear.bias) def forward(self, x: torch.Tensor) -> torch.Tensor: """ @@ -284,10 +286,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv(x) x = self.norm(x) x = self.activation(x) - print('after downsample_net:', x.max(), x.min(), x.mean()) + # print('after downsample_net:', x.max(), x.min(), x.mean()) for block in self.resblocks: x = block(x) + # print('cont embedings before last_linear', x.max(), x.min(), x.mean()) + # NOTE: very important. for muzero_gpt atari 64,8,8 = 4096 -> 1024 x = self.last_linear(x.contiguous().view(-1, 64*8*8)) @@ -295,12 +299,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # print(x.max(), x.min()) # x = renormalize(x) - print('cont embedings', x.max(), x.min(), x.mean()) - + # print('cont embedings before renormalize', x.max(), x.min(), x.mean()) # x = AvgL1Norm(x) # print('after AvgL1Norm', x.max(), x.min()) # x = torch.tanh(x) - # print('after tanh', x.max(), x.min(),x.mean()) + x = renormalize(x) + + # print('after renormalize', x.max(), x.min(),x.mean()) return x @@ -318,6 +323,44 @@ def get_param_mean(self) -> float: return mean +class LatentDecoder(nn.Module): + def __init__(self, embedding_dim: int, output_shape: SequenceType, num_channels: int = 64): + super().__init__() + self.embedding_dim = embedding_dim + self.output_shape = output_shape # (C, H, W) + self.num_channels = num_channels + + # Assuming that the output shape is (C, H, W) = (12, 96, 96) and embedding_dim is 256 + # We will reverse the process of the representation network + self.initial_size = (num_channels, output_shape[1] // 8, output_shape[2] // 8) # This should match the last layer of the encoder + self.fc = nn.Linear(self.embedding_dim, np.prod(self.initial_size)) + + # Upsampling blocks + self.conv_blocks = nn.ModuleList([ + # Block 1: (num_channels, H/8, W/8) -> (num_channels//2, H/4, W/4) + nn.ConvTranspose2d(num_channels, num_channels // 2, kernel_size=3, stride=2, padding=1, output_padding=1), + nn.ReLU(), + nn.BatchNorm2d(num_channels // 2), + # Block 2: (num_channels//2, H/4, W/4) -> (num_channels//4, H/2, W/2) + nn.ConvTranspose2d(num_channels // 2, num_channels // 4, kernel_size=3, stride=2, padding=1, output_padding=1), + nn.ReLU(), + nn.BatchNorm2d(num_channels // 4), + # Block 3: (num_channels//4, H/2, W/2) -> (output_shape[0], H, W) + nn.ConvTranspose2d(num_channels // 4, output_shape[0], kernel_size=3, stride=2, padding=1, output_padding=1), + ]) + + def forward(self, embeddings: torch.Tensor) -> torch.Tensor: + # Map embeddings back to the image space + x = self.fc(embeddings) # (B, embedding_dim) -> (B, C*H/8*W/8) + x = x.view(-1, *self.initial_size) # (B, C*H/8*W/8) -> (B, C, H/8, W/8) + + # Apply conv blocks + for block in self.conv_blocks: + x = block(x) # Upsample progressively + + # The output x should have the shape of (B, output_shape[0], output_shape[1], output_shape[2]) + return x + class RepresentationNetworkMLP(nn.Module): def __init__( diff --git a/lzero/model/gpt_models/plot_sequence_frame_grey.py b/lzero/model/gpt_models/plot_sequence_frame_grey.py index 995f0a26f..a8250e1b0 100644 --- a/lzero/model/gpt_models/plot_sequence_frame_grey.py +++ b/lzero/model/gpt_models/plot_sequence_frame_grey.py @@ -1,7 +1,8 @@ import matplotlib.pyplot as plt from PIL import Image import numpy as np -batch_observations = batch['observations'][:,:,0:1,:,:] +batch_observations = reconstructed_images.detach().view(32, 5, 4, 64, 64)[:,:,0:1,:,:] +# batch_observations = batch['observations'][:,:,0:1,:,:] B, N, C, H, W = batch_observations.shape # 自动检测维度 # 分隔条的宽度(可以根据需要调整) @@ -35,7 +36,7 @@ plt.show() # 保存图像到文件 - concat_image.save(f'sample_{i+1}_0110.png') + concat_image.save(f'sample_{i+1}_recs_0110.png') diff --git a/lzero/model/gpt_models/plot_weight_grad.py b/lzero/model/gpt_models/plot_weight_grad.py index 3691bd588..5fbedc953 100644 --- a/lzero/model/gpt_models/plot_weight_grad.py +++ b/lzero/model/gpt_models/plot_weight_grad.py @@ -30,8 +30,6 @@ - - x = torch.randn(192, 64, 8, 8).to('cuda:0') def check_layer_output(model, x): diff --git a/lzero/model/gpt_models/tokenizer/tokenizer.py b/lzero/model/gpt_models/tokenizer/tokenizer.py index 66e077574..0412406f9 100644 --- a/lzero/model/gpt_models/tokenizer/tokenizer.py +++ b/lzero/model/gpt_models/tokenizer/tokenizer.py @@ -35,7 +35,7 @@ class TokenizerEncoderOutput: class Tokenizer(nn.Module): - def __init__(self, vocab_size: int, embed_dim: int, encoder: Encoder, decoder: Decoder, with_lpips: bool = True, representation_network = None) -> None: + def __init__(self, vocab_size: int, embed_dim: int, encoder: Encoder, decoder: Decoder, with_lpips: bool = True, representation_network = None, decoder_network =None) -> None: super().__init__() self.vocab_size = vocab_size self.encoder = encoder @@ -46,6 +46,8 @@ def __init__(self, vocab_size: int, embed_dim: int, encoder: Encoder, decoder: D self.embedding.weight.data.uniform_(-1.0 / vocab_size, 1.0 / vocab_size) self.lpips = LPIPS().eval() if with_lpips else None self.representation_network = representation_network + self.decoder_network = decoder_network + def __repr__(self) -> str: return "tokenizer" @@ -184,12 +186,20 @@ def encode_to_obs_embeddings(self, x: torch.Tensor, should_preprocess: bool = Fa # obs_embeddings = rearrange(obs_embeddings, 'b c h w -> b 1 (c h w)') # (160,1,1024) obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e') # (4,1,256) # TODO - - #=============== - return obs_embeddings + def decode_to_obs(self, embeddings: torch.Tensor) -> torch.Tensor: + return self.decoder_network(embeddings) + + + def reconstruction_loss(self, original_images: torch.Tensor, reconstructed_images: torch.Tensor) -> torch.Tensor: + # Mean Squared Error (MSE) is commonly used as a reconstruction loss + # loss = nn.MSELoss()(original_images, reconstructed_images) # L1 loss + loss = torch.abs(original_images - reconstructed_images).mean() + return loss + + def decode(self, z_q: torch.Tensor, should_postprocess: bool = False) -> torch.Tensor: shape = z_q.shape # (..., E, h, w) z_q = z_q.view(-1, *shape[-3:]) diff --git a/lzero/model/gpt_models/utils.py b/lzero/model/gpt_models/utils.py index c769300f5..f52f3fa17 100644 --- a/lzero/model/gpt_models/utils.py +++ b/lzero/model/gpt_models/utils.py @@ -122,9 +122,12 @@ def __init__(self, **kwargs): self.policy_loss_weight = 1. # self.ends_loss_weight = 1. self.ends_loss_weight = 0. - self.rep_kl_loss_weight = 0.1 # for lunarlander - # self.rep_kl_loss_weight = 0.5 + # self.latent_kl_loss_weight = 0.1 # for lunarlander + self.latent_kl_loss_weight = 0. # for lunarlander + + # self.latent_recon_loss_weight = 1 + self.latent_recon_loss_weight = 0.1 # Initialize the total loss tensor on the correct device @@ -140,8 +143,10 @@ def __init__(self, **kwargs): self.loss_total += self.value_loss_weight * v elif k == 'loss_ends': self.loss_total += self.ends_loss_weight * v - elif k == 'rep_kl_loss': - self.loss_total += self.rep_kl_loss_weight * v + elif k == 'latent_kl_loss': + self.loss_total += self.latent_kl_loss_weight * v + elif k == 'latent_recon_loss': + self.loss_total += self.latent_recon_loss_weight * v else: raise ValueError(f"Unknown loss type : {k}") diff --git a/lzero/model/gpt_models/world_model.py b/lzero/model/gpt_models/world_model.py index 80f8aa4fa..7f846f0ba 100644 --- a/lzero/model/gpt_models/world_model.py +++ b/lzero/model/gpt_models/world_model.py @@ -107,7 +107,8 @@ def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: Transformer # nn.ReLU(), # nn.Linear(config.embed_dim, obs_vocab_size) nn.LeakyReLU(negative_slope=0.01), # TODO: 2 - nn.Linear(config.embed_dim, self.obs_per_embdding_dim) + nn.Linear(config.embed_dim, self.obs_per_embdding_dim), + # nn.Tanh(), # TODO ) ) @@ -654,7 +655,17 @@ def compute_loss(self, batch, tokenizer: Tokenizer=None, **kwargs: Any) -> LossW # NOTE: 这里是需要梯度的 # obs_tokens = tokenizer.encode(batch['observations'], should_preprocess=True).tokens # (BL, K) - obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations'], should_preprocess=True) # (B, C, H, W) -> (B, K, E) + with torch.no_grad(): # TODO + obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations'], should_preprocess=False) # (B, C, H, W) -> (B, K, E) + + # obs_embeddings.register_hook(lambda grad: grad * 1/5) # TODO:只提供重建损失更新表征网络 + + # Assume that 'cont_embeddings' and 'original_images' are available from prior code + # Decode the embeddings to reconstruct the images + reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) + + # Calculate the reconstruction loss + latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].contiguous().view(-1, 4, 64, 64), reconstructed_images) # 计算KL散度损失 @@ -663,8 +674,7 @@ def compute_loss(self, batch, tokenizer: Tokenizer=None, **kwargs: Any) -> LossW mean = obs_embeddings.mean(dim=0, keepdim=True) std = obs_embeddings.std(dim=0, keepdim=True) # 创建标准正态分布作为先验分布 - # prior_dist = torch.distributions.Normal(torch.zeros_like(mean), torch.ones_like(std)) - prior_dist = torch.distributions.Normal(torch.ones_like(mean)*0.1, torch.ones_like(std)) + prior_dist = torch.distributions.Normal(torch.zeros_like(mean), torch.ones_like(std)) # 创建模型输出的分布 @@ -672,12 +682,12 @@ def compute_loss(self, batch, tokenizer: Tokenizer=None, **kwargs: Any) -> LossW # 计算KL散度损失,对每个样本的每个特征进行计算 kl_loss = torch.distributions.kl.kl_divergence(model_dist, prior_dist) # 因为 kl_loss 的形状是 (1, 1, 256),我们可以对所有特征求平均来得到一个标量损失 - rep_kl_loss = kl_loss.mean() - print(f'rep_kl_loss:, {rep_kl_loss}') - if torch.isnan(rep_kl_loss) or torch.isinf(rep_kl_loss): - print("NaN or inf detected in rep_kl_loss!") + latent_kl_loss = kl_loss.mean() + # print(f'latent_kl_loss:, {latent_kl_loss}') + if torch.isnan(latent_kl_loss) or torch.isinf(latent_kl_loss): + print("NaN or inf detected in latent_kl_loss!") # 使用 torch.tensor(0) 创建一个同设备,同数据类型的零张量,并确保不需要梯度 - rep_kl_loss = torch.tensor(0., device=rep_kl_loss.device, dtype=rep_kl_loss.dtype) + latent_kl_loss = torch.tensor(0., device=latent_kl_loss.device, dtype=latent_kl_loss.dtype) # TODO # obs_embeddings = AvgL1Norm(obs_embeddings) @@ -695,7 +705,8 @@ def compute_loss(self, batch, tokenizer: Tokenizer=None, **kwargs: Any) -> LossW # tokens = rearrange(torch.cat((obs_tokens, act_tokens), dim=2), 'b l k1 -> b (l k1)') # (B, L(K+1)) # outputs = self.forward(tokens, is_root=False) - outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) + # TODO: 只提供重建损失更新表征网络 + outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings.detach(), act_tokens)}, is_root=False) labels_observations, labels_rewards, labels_ends = self.compute_labels_world_model(obs_embeddings, batch['rewards'], batch['ends'], @@ -756,7 +767,7 @@ def compute_loss(self, batch, tokenizer: Tokenizer=None, **kwargs: Any) -> LossW loss_value = self.compute_cross_entropy_loss(outputs, labels_value, batch, element='value') return LossWithIntermediateLosses(loss_obs=loss_obs, loss_rewards=loss_rewards, loss_value=loss_value, - loss_policy=loss_policy, rep_kl_loss=rep_kl_loss) + loss_policy=loss_policy, latent_kl_loss=latent_kl_loss, latent_recon_loss=latent_recon_loss) def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): # Assume outputs.logits_rewards and labels are your predictions and targets diff --git a/lzero/model/gpt_models/world_model_bkp_20240105.py b/lzero/model/gpt_models/world_model_bkp_20240105.py index 05fc32f15..8b1f64fc4 100644 --- a/lzero/model/gpt_models/world_model_bkp_20240105.py +++ b/lzero/model/gpt_models/world_model_bkp_20240105.py @@ -677,7 +677,7 @@ def compute_loss(self, batch, tokenizer: Tokenizer=None, **kwargs: Any) -> LossW # 计算KL散度损失,对每个样本的每个特征进行计算 kl_loss = torch.distributions.kl.kl_divergence(model_dist, prior_dist) # 因为 kl_loss 的形状是 (1, 1, 256),我们可以对所有特征求平均来得到一个标量损失 - rep_kl_loss = kl_loss.mean() + latent_kl_loss = kl_loss.mean() # second to last 增加高斯噪声 noise_std = 0.1 @@ -753,7 +753,7 @@ def compute_loss(self, batch, tokenizer: Tokenizer=None, **kwargs: Any) -> LossW loss_value = self.compute_cross_entropy_loss(outputs, labels_value, batch, element='value') return LossWithIntermediateLosses(loss_obs=loss_obs, loss_rewards=loss_rewards, loss_value=loss_value, - loss_policy=loss_policy, rep_kl_loss=rep_kl_loss) + loss_policy=loss_policy, latent_kl_loss=latent_kl_loss) def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): # Assume outputs.logits_rewards and labels are your predictions and targets diff --git a/lzero/model/muzero_gpt_model.py b/lzero/model/muzero_gpt_model.py index 26518a891..7cc214144 100644 --- a/lzero/model/muzero_gpt_model.py +++ b/lzero/model/muzero_gpt_model.py @@ -11,7 +11,7 @@ from ding.utils import MODEL_REGISTRY, SequenceType from numpy import ndarray -from .common import MZNetworkOutput, RepresentationNetworkGPT, PredictionNetwork +from .common import MZNetworkOutput, RepresentationNetworkGPT, PredictionNetwork, LatentDecoder from .utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean @@ -164,10 +164,14 @@ def __init__( # embedding_dim=cfg.embedding_dim, embedding_dim=cfg.world_model.embed_dim, ) + # Instantiate the decoder + decoder_network = LatentDecoder(embedding_dim=1024, output_shape=(4, 64, 64)) + Encoder = Encoder(cfg.tokenizer.encoder) Decoder = Decoder(cfg.tokenizer.decoder) - self.tokenizer = Tokenizer(cfg.tokenizer.vocab_size, cfg.tokenizer.embed_dim, Encoder, Decoder, with_lpips=True, representation_network=self.representation_network) + self.tokenizer = Tokenizer(cfg.tokenizer.vocab_size, cfg.tokenizer.embed_dim, Encoder, Decoder, with_lpips=True, representation_network=self.representation_network, + decoder_network=decoder_network) self.world_model = WorldModel(obs_vocab_size=self.tokenizer.vocab_size, act_vocab_size=self.action_space_size, config=cfg.world_model, tokenizer=self.tokenizer) print(f'{sum(p.numel() for p in self.tokenizer.parameters())} parameters in agent.tokenizer') diff --git a/lzero/policy/muzero_gpt.py b/lzero/policy/muzero_gpt.py index efe85ebf3..23b4de99a 100644 --- a/lzero/policy/muzero_gpt.py +++ b/lzero/policy/muzero_gpt.py @@ -322,8 +322,8 @@ def _init_learn(self) -> None: # ) self._optimizer_world_model = configure_optimizer( model=self._model.world_model, - learning_rate=3e-3, - # learning_rate=1e-4, + # learning_rate=3e-3, + learning_rate=1e-4, weight_decay=self._cfg.weight_decay, # weight_decay=0.01, exclude_submodules=['none'] # NOTE @@ -440,7 +440,9 @@ def _forward_learn_transformer(self, data: Tuple[torch.Tensor]) -> Dict[str, Uni self._learn_model.train() self._target_model.train() - self._learn_model.tokenizer.eval() + # self._learn_model.tokenizer.eval() + self._learn_model.tokenizer.train() + if self._cfg.use_rnd_model: self._target_model_for_intrinsic_reward.train() @@ -544,7 +546,9 @@ def _forward_learn_transformer(self, data: Tuple[torch.Tensor]) -> Dict[str, Uni reward_loss = intermediate_losses['loss_rewards'] policy_loss = intermediate_losses['loss_policy'] value_loss = intermediate_losses['loss_value'] - rep_kl_loss = intermediate_losses['rep_kl_loss'] + latent_kl_loss = intermediate_losses['latent_kl_loss'] + latent_recon_loss = intermediate_losses['latent_recon_loss'] + # ============================================================== # the core learn model update step. @@ -561,12 +565,18 @@ def _forward_learn_transformer(self, data: Tuple[torch.Tensor]) -> Dict[str, Uni # 在训练循环中使用 # self.monitor_weights_and_grads(self._learn_model.tokenizer.representation_network) + # print('torch.cuda.memory_summary():', torch.cuda.memory_summary()) if self._cfg.multi_gpu: self.sync_gradients(self._learn_model) total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_( self._learn_model.world_model.parameters(), self._cfg.grad_clip_value ) + + total_grad_norm_before_clip_rep_net = torch.nn.utils.clip_grad_norm_(self._learn_model.tokenizer.representation_network.parameters(), max_norm=1.0) + # print('total_grad_norm_before_clip_rep_net:', total_grad_norm_before_clip_rep_net) + + self._optimizer_world_model.step() if self._cfg.lr_piecewise_constant_decay: self.lr_scheduler.step() @@ -588,7 +598,8 @@ def _forward_learn_transformer(self, data: Tuple[torch.Tensor]) -> Dict[str, Uni 'weighted_total_loss': weighted_total_loss.item(), 'obs_loss': obs_loss, - 'rep_kl_loss': rep_kl_loss, + 'latent_kl_loss': latent_kl_loss, + 'latent_recon_loss':latent_recon_loss, 'policy_loss': policy_loss, 'target_policy_entropy': average_target_policy_entropy, # 'policy_entropy': - policy_entropy_loss.mean().item() / (self._cfg.num_unroll_steps + 1), @@ -609,6 +620,7 @@ def _forward_learn_transformer(self, data: Tuple[torch.Tensor]) -> Dict[str, Uni # 'predicted_rewards': predicted_rewards.detach().cpu().numpy().mean().item(), # 'predicted_values': predicted_values.detach().cpu().numpy().mean().item(), 'total_grad_norm_before_clip_wm': total_grad_norm_before_clip_wm.item(), + 'total_grad_norm_before_clip_rep_net': total_grad_norm_before_clip_rep_net.item(), } return return_loss_dict @@ -693,9 +705,12 @@ def _forward_learn_tokenizer(self, data: Tuple[torch.Tensor]) -> Dict[str, Union weighted_total_loss_tokenizer = losses_tokenizer.loss_total weighted_total_loss_tokenizer.backward() # losses_tokenizer.loss_total.backward() + total_grad_norm_before_clip_tokenizer = torch.nn.utils.clip_grad_norm_( self._learn_model.tokenizer.parameters(), self._cfg.grad_clip_value ) + + self._optimizer_tokenizer.step() intermediate_losses_tokenizer= defaultdict(float) @@ -999,7 +1014,8 @@ def _monitor_vars_learn(self) -> List[str]: # 'total_loss', 'obs_loss', 'policy_loss', - 'rep_kl_loss', + 'latent_kl_loss', + 'latent_recon_loss', # 'policy_entropy', 'target_policy_entropy', 'reward_loss', @@ -1014,6 +1030,7 @@ def _monitor_vars_learn(self) -> List[str]: # 'transformed_target_value', 'total_grad_norm_before_clip_tokenizer', 'total_grad_norm_before_clip_wm', + 'total_grad_norm_before_clip_rep_net', # tokenizer 'commitment_loss', 'reconstruction_loss', @@ -1034,18 +1051,71 @@ def _state_dict_learn(self) -> Dict[str, Any]: 'optimizer_tokenizer': self._optimizer_tokenizer.state_dict(), } + # TODO: + # def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + # """ + # Overview: + # Load the state_dict variable into policy learn mode. + # Arguments: + # - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. + # """ + # self._learn_model.load_state_dict(state_dict['model']) + # self._target_model.load_state_dict(state_dict['target_model']) + # self._optimizer_world_model.load_state_dict(state_dict['optimizer_world_model']) + # self._optimizer_tokenizer.load_state_dict(state_dict['optimizer_tokenizer']) def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: """ Overview: - Load the state_dict variable into policy learn mode. + Load the state_dict variable into policy learn mode, specifically loading only the + representation network of the tokenizer within model and target_model. Arguments: - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. """ - self._learn_model.load_state_dict(state_dict['model']) - self._target_model.load_state_dict(state_dict['target_model']) - self._optimizer_world_model.load_state_dict(state_dict['optimizer_world_model']) - self._optimizer_tokenizer.load_state_dict(state_dict['optimizer_tokenizer']) + # Extract the relevant sub-state-dicts for representation_network from the state_dict + # model_rep_network_state = state_dict['model']['tokenizer']['representation_network'] + # target_model_rep_network_state = state_dict['target_model']['tokenizer']['representation_network'] + + # # Load the state into the model's representation network + # self._learn_model.tokenizer.representation_network.load_state_dict(model_rep_network_state) + # self._target_model.tokenizer.representation_network.load_state_dict(target_model_rep_network_state) + + # Assuming self._learn_model and self._target_model have a 'representation_network' submodule + self._load_representation_network_state(state_dict['model'], self._learn_model.tokenizer.representation_network) + self._load_representation_network_state(state_dict['target_model'], self._target_model.tokenizer.representation_network) + + + def _load_representation_network_state(self, state_dict, model_submodule): + """ + This function filters the state_dict to only include the state of the representation_network + and loads it into the given model submodule. + """ + from collections import OrderedDict + + # Filter the state_dict to only include keys that start with 'representation_network' + representation_network_keys = {k: v for k, v in state_dict.items() if k.startswith('representation_network')} + + # Load the state into the model's representation_network submodule + # model_submodule.load_state_dict(OrderedDict(representation_network_keys)) + + # 去掉键名前缀 + new_state_dict = OrderedDict() + for key, value in representation_network_keys.items(): + new_key = key.replace('representation_network.', '') # 去掉前缀 + new_state_dict[new_key] = value + + # # 如果模型在特定的设备上,确保状态字典也在那个设备上 + # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # new_state_dict = {key: value.to(device) for key, value in new_state_dict.items()} + + # 尝试加载状态字典 + try: + # model_submodule.load_state_dict(new_state_dict) + # 使用 strict=False 参数忽略缺少的键 + model_submodule.load_state_dict(new_state_dict, strict=False) + except RuntimeError as e: + print("加载失败: ", e) + def _process_transition(self, obs, policy_output, timestep): # be compatible with DI-engine Policy class diff --git a/zoo/atari/config/atari_muzero_config_46464.py b/zoo/atari/config/atari_muzero_config_46464.py index 77a484a83..3e07f771d 100644 --- a/zoo/atari/config/atari_muzero_config_46464.py +++ b/zoo/atari/config/atari_muzero_config_46464.py @@ -1,6 +1,6 @@ from easydict import EasyDict import torch -torch.cuda.set_device(3) +torch.cuda.set_device(2) # options={'PongNoFrameskip-v4', 'QbertNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SpaceInvadersNoFrameskip-v4', 'BreakoutNoFrameskip-v4', ...} env_name = 'PongNoFrameskip-v4' diff --git a/zoo/atari/config/atari_muzero_gpt_config_stack4.py b/zoo/atari/config/atari_muzero_gpt_config_stack4.py index a0a2d9931..d1e291f84 100644 --- a/zoo/atari/config/atari_muzero_gpt_config_stack4.py +++ b/zoo/atari/config/atari_muzero_gpt_config_stack4.py @@ -66,7 +66,12 @@ atari_muzero_config = dict( # TODO: world_model.py decode_obs_tokens # TODO: tokenizer.py: lpips loss - exp_name=f'data_mz_gpt_ctree_0110/{env_name[:-14]}_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_mcs500_contembdings_mz-repenet-lastlinear-lsd1024-6488-kaiming-lerelu_obsmseloss_obsloss2_kll01-0.1_stack4_seed0', + exp_name=f'data_mz_gpt_ctree_0111/{env_name[:-14]}_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_mcs500_contembdings_mz-repenet-lastlinear-lsd1024-6488-kaiming-lerelu_obsmseloss_obsloss2_kllw0-lr1e-4-gcv05-biasfalse-minmax-iter60k-fixed_stack4_seed0', + # exp_name=f'data_mz_gpt_ctree_0110/{env_name[:-14]}_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_mcs500_contembdings_mz-repenet-lastlinear-lsd1024-6488-kaiming-lerelu_obsmseloss_obsloss2_kllw0-lr1e-4-gcv05-onlyreconslossw1-biasfalse-minmax_stack4_seed0', + + # exp_name=f'data_mz_gpt_ctree_0110/{env_name[:-14]}_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_mcs500_contembdings_mz-repenet-lastlinear-lsd1024-6488-kaiming-lerelu_obsmseloss_obsloss2_kllw0-lr1e-4-gcv05-reconslossw1-minmax-latentgrad0.2_stack4_seed0', + + # exp_name=f'data_mz_gpt_ctree_0110/{env_name[:-14]}_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_mcs500_contembdings_mz-repenet-lastlinear-lsd1024-6488-kaiming-lerelu_obsmseloss_obsloss2_kllw0-reconslossw01-tanh-lr1e-4-gcv05_stack4_seed0', # exp_name=f'data_mz_gpt_ctree_0110/{env_name[:-14]}_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_mcs500_contembdings_mz-repenet-lastlinear-lsd1024-10k-thenfixed_obsmseloss_obsloss2_rep-no-avgl1norm-no-neg1_kll01-inf2zero-noseclatstd01_stack4_seed0', @@ -100,7 +105,8 @@ clip_rewards=True, ), policy=dict( - model_path=None, + # model_path=None, + model_path='/mnt/afs/niuyazhe/code/LightZero/data_mz_ctree/Pong_muzero_ns50_upc1000_rr0.0_46464_seed0_240110_140819/ckpt/iteration_60000.pth.tar', # model_path='/mnt/afs/niuyazhe/code/LightZero/data_mz_gpt_ctree/Pong_muzero_gpt_ns5_upcNone-mur0.5_rr0_H5_orignet_tran-nlayers2-emd128-nh2_mcs500_batch8_bs16_lr1e-4_tokenizer-wd0_perl_tokenizer-only_seed0/ckpt/iteration_150000.pth.tar', # tokenizer_start_after_envsteps=int(9e9), # not train tokenizer tokenizer_start_after_envsteps=int(0), @@ -176,6 +182,9 @@ # learning_rate=0.003, learning_rate=0.0001, target_update_freq=100, + + grad_clip_value = 0.5, # TODO + num_simulations=num_simulations, reanalyze_ratio=reanalyze_ratio, ssl_loss_weight=2, # default is 0 diff --git a/zoo/atari/config/atari_muzero_gpt_config_stack4_debug.py b/zoo/atari/config/atari_muzero_gpt_config_stack4_debug.py index bda18b479..3b53441c3 100644 --- a/zoo/atari/config/atari_muzero_gpt_config_stack4_debug.py +++ b/zoo/atari/config/atari_muzero_gpt_config_stack4_debug.py @@ -28,11 +28,11 @@ # update_per_collect = None # model_update_ratio = 0.25 model_update_ratio = 0.25 -num_simulations = 50 +# num_simulations = 50 # num_simulations = 25 # TODO: debug -# num_simulations = 1 +num_simulations = 1 max_env_step = int(1e6) reanalyze_ratio = 0 @@ -66,7 +66,7 @@ atari_muzero_config = dict( # TODO: world_model.py decode_obs_tokens # TODO: tokenizer.py: lpips loss - exp_name=f'data_mz_gpt_ctree_0110_debug/{env_name[:-14]}_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_mcs500_contembdings_mz-repenet-lastlinear-6488-lsd1024-10k-thenfixed_obsmseloss_obsloss2_rep-no-avgl1norm-no-neg1_kll01-inf2zero-noseclatstd01_stack4_seed0', + exp_name=f'data_mz_gpt_ctree_0110_debug/{env_name[:-14]}_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_mcs500_contembdings_mz-repenet-lastlinear-6488-lsd1024_obsmseloss_obsloss2_kllw0_stack4_seed0', # exp_name=f'data_mz_gpt_ctree_0105/{env_name[:-14]}_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_mcs1000_contembdings_mz-repenet-lastlinear-lsd1024_obsmseloss_obsloss2_rep-no-avgl1norm-no-neg1_kll01-inf2zero-noseclatstd01_stack4_noapply_seed0', # exp_name=f'data_mz_gpt_ctree_0105_debug/{env_name[:-14]}_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_mcs1000_contembdings_mz-repenet-lastlinear-lsd256_obsmseloss_rep-no-avgl1norm-no-neg1_obsloss2_kll01-inf2zero-noseclatstd01_stack4_aug_seed0', @@ -98,7 +98,8 @@ clip_rewards=True, ), policy=dict( - model_path=None, + # model_path=None, + model_path='/mnt/afs/niuyazhe/code/LightZero/data_mz_ctree/Pong_muzero_ns50_upc1000_rr0.0_46464_seed0_240110_140819/ckpt/iteration_20000.pth.tar', # model_path='/mnt/afs/niuyazhe/code/LightZero/data_mz_gpt_ctree/Pong_muzero_gpt_ns5_upcNone-mur0.5_rr0_H5_orignet_tran-nlayers2-emd128-nh2_mcs500_batch8_bs16_lr1e-4_tokenizer-wd0_perl_tokenizer-only_seed0/ckpt/iteration_150000.pth.tar', # tokenizer_start_after_envsteps=int(9e9), # not train tokenizer tokenizer_start_after_envsteps=int(0), @@ -174,6 +175,10 @@ # learning_rate=0.003, learning_rate=0.0001, target_update_freq=100, + + grad_clip_value = 0.5, # TODO + + num_simulations=num_simulations, reanalyze_ratio=reanalyze_ratio, ssl_loss_weight=2, # default is 0