From a4c436b35873487a5cb20b25723109855c5ebcc6 Mon Sep 17 00:00:00 2001 From: PaParaZz1 Date: Thu, 23 Jan 2025 18:30:20 +0800 Subject: [PATCH] fix(pu): fix smz and sez config for pixel-based dmc --- lzero/mcts/buffer/game_buffer.py | 18 +- .../buffer/game_buffer_sampled_unizero.py | 34 +- lzero/model/common.py | 1 + lzero/model/sampled_efficientzero_model.py | 3 +- lzero/model/sampled_muzero_model.py | 444 ++++++++++++++++++ lzero/policy/sampled_muzero.py | 2 +- .../config/dmc2gym_pixels_sez_config.py | 18 +- .../config/dmc2gym_pixels_smz_config.py | 33 +- 8 files changed, 513 insertions(+), 40 deletions(-) create mode 100644 lzero/model/sampled_muzero_model.py diff --git a/lzero/mcts/buffer/game_buffer.py b/lzero/mcts/buffer/game_buffer.py index 2678066e9..7291be21e 100644 --- a/lzero/mcts/buffer/game_buffer.py +++ b/lzero/mcts/buffer/game_buffer.py @@ -559,14 +559,18 @@ def _push_game_segment(self, data: Any, meta: Optional[dict] = None) -> None: # print(f'valid_len is {valid_len}') if meta['priorities'] is None: - max_prio = self.game_pos_priorities.max() if self.game_segment_buffer else 1 + # try: + if self.game_segment_buffer: + max_prio = self.game_pos_priorities.max() if len(self.game_pos_priorities) > 0 else 1 + else: + max_prio = 1 + # except Exception as e: + # print(e) + # print(f'self.game_pos_priorities:{self.game_pos_priorities}') + # print(f'self.game_segment_buffer:{self.game_segment_buffer}') + # if no 'priorities' provided, set the valid part of the new-added game history the max_prio - self.game_pos_priorities = np.concatenate( - ( - self.game_pos_priorities, [max_prio - for _ in range(valid_len)] + [0. for _ in range(valid_len, data_length)] - ) - ) + self.game_pos_priorities = np.concatenate((self.game_pos_priorities, [max_prio for _ in range(valid_len)] + [0. for _ in range(valid_len, data_length)])) else: assert data_length == len(meta['priorities']), " priorities should be of same length as the game steps" priorities = meta['priorities'].copy().reshape(-1) diff --git a/lzero/mcts/buffer/game_buffer_sampled_unizero.py b/lzero/mcts/buffer/game_buffer_sampled_unizero.py index abb7c92a8..c9937723f 100644 --- a/lzero/mcts/buffer/game_buffer_sampled_unizero.py +++ b/lzero/mcts/buffer/game_buffer_sampled_unizero.py @@ -479,15 +479,14 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: m_output = model.initial_inference(batch_obs, action_batch[:self.reanalyze_num]) # NOTE: :self.reanalyze_num # ======================================================================= - if not model.training: - # if not in training, obtain the scalars of the value/reward - [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( - [ - m_output.latent_state, - inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), - m_output.policy_logits - ] - ) + # if not in training, obtain the scalars of the value/reward + [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( + [ + m_output.latent_state, + inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), + m_output.policy_logits + ] + ) network_output.append(m_output) @@ -638,15 +637,14 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A m_output = model.initial_inference(batch_obs, action_batch) # ====================================================================== - if not model.training: - # if not in training, obtain the scalars of the value/reward - [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( - [ - m_output.latent_state, - inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), - m_output.policy_logits - ] - ) + # if not in training, obtain the scalars of the value/reward + [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( + [ + m_output.latent_state, + inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), + m_output.policy_logits + ] + ) network_output.append(m_output) if self._cfg.use_root_value: diff --git a/lzero/model/common.py b/lzero/model/common.py index 306fc9f22..33ef24b63 100644 --- a/lzero/model/common.py +++ b/lzero/model/common.py @@ -562,6 +562,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # for atari 64,8,8 = 4096 -> 768 x = self.sim_norm(x) + return x diff --git a/lzero/model/sampled_efficientzero_model.py b/lzero/model/sampled_efficientzero_model.py index d8f35655e..0bd14c6d2 100644 --- a/lzero/model/sampled_efficientzero_model.py +++ b/lzero/model/sampled_efficientzero_model.py @@ -224,8 +224,7 @@ def __init__( # (3,96,96), and frame_stack_num is 4. Due to downsample, the encoding of observation (latent_state) is # (64, 96/16, 96/16), where 64 is the number of channels, 96/16 is the size of the latent state. Thus, # self.projection_input_dim = 64 * 96/16 * 96/16 = 64*6*6 = 2304 - self.projection_input_dim = num_channels * math.ceil(observation_shape[1] / 16 - ) * math.ceil(observation_shape[2] / 16) + self.projection_input_dim = num_channels * latent_size else: self.projection_input_dim = num_channels * observation_shape[1] * observation_shape[2] diff --git a/lzero/model/sampled_muzero_model.py b/lzero/model/sampled_muzero_model.py new file mode 100644 index 000000000..6511a285b --- /dev/null +++ b/lzero/model/sampled_muzero_model.py @@ -0,0 +1,444 @@ +import math +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from ding.model.common import ReparameterizationHead +from ding.torch_utils import MLP +from ding.utils import MODEL_REGISTRY, SequenceType + +from .common import MZNetworkOutput, RepresentationNetwork +from .sampled_efficientzero_model import PredictionNetwork + +from .muzero_model import DynamicsNetwork +from .utils import renormalize + + +@MODEL_REGISTRY.register('SampledMuZeroModel') +class SampledMuZeroModel(nn.Module): + + def __init__( + self, + observation_shape: SequenceType = (12, 96, 96), + action_space_size: int = 6, + num_res_blocks: int = 1, + num_channels: int = 64, + latent_state_dim: int = 256, + reward_head_channels: int = 16, + value_head_channels: int = 16, + policy_head_channels: int = 16, + reward_head_hidden_channels: SequenceType = [256], + value_head_hidden_channels: SequenceType = [256], + policy_head_hidden_channels: SequenceType = [256], + reward_support_size: int = 601, + value_support_size: int = 601, + proj_hid: int = 1024, + proj_out: int = 1024, + pred_hid: int = 512, + pred_out: int = 1024, + self_supervised_learning_loss: bool = True, + categorical_distribution: bool = True, + activation: Optional[nn.Module] = nn.GELU(approximate='tanh'), + last_linear_layer_init_zero: bool = True, + state_norm: bool = False, + downsample: bool = False, + # ============================================================== + # specific sampled related config + # ============================================================== + continuous_action_space: bool = False, + num_of_sampled_actions: int = 6, + sigma_type='conditioned', + fixed_sigma_value: float = 0.3, + bound_type: str = None, + norm_type: str = 'LN', + discrete_action_encoding_type: str = 'one_hot', + res_connection_in_dynamics: bool = True, + use_sim_norm: bool = False, + *args, + **kwargs, + ): + """ + Overview: + The definition of the network model of Sampled MuZero for 3D-image obs. + Arguments: + - observation_shape (:obj:`int`): Observation space shape, e.g. 8 for Lunarlander. + - action_space_size: (:obj:`int`): Action space size, which is an integer number. For discrete action space, it is the num of discrete actions, \ + e.g. 4 for Lunarlander. For continuous action space, it is the dimension of the continuous action, e.g. 4 for bipedalwalker. + - latent_state_dim (:obj:`int`): The dimension of latent state, such as 256. + - reward_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). + - value_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). + - policy_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). + - reward_support_size (:obj:`int`): The size of categorical reward output + - value_support_size (:obj:`int`): The size of categorical value output. + - proj_hid (:obj:`int`): The size of projection hidden layer. + - proj_out (:obj:`int`): The size of projection output layer. + - pred_hid (:obj:`int`): The size of prediction hidden layer. + - pred_out (:obj:`int`): The size of prediction output layer. + - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks in Sampled MuZero model, default set it to False. + - categorical_distribution (:obj:`bool`): Whether to use discrete support to represent categorical distribution for value, reward/value_prefix. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of value/policy mlp, default sets it to True. + - state_norm (:obj:`bool`): Whether to use normalization for latent states, default sets it to True. + # ============================================================== + # specific sampled related config + # ============================================================== + - continuous_action_space (:obj:`bool`): The type of action space. default set it to False. + - num_of_sampled_actions (:obj:`int`): the number of sampled actions, i.e. the K in original Sampled MuZero paper. + # see ``ReparameterizationHead`` in ``ding.model.common.head`` for more details about the following arguments. + - sigma_type (:obj:`str`): the type of sigma in policy head of prediction network, options={'conditioned', 'fixed'}. + - fixed_sigma_value (:obj:`float`): the fixed sigma value in policy head of prediction network, + - bound_type (:obj:`str`): The type of bound in networks. Default sets it to None. + - norm_type (:obj:`str`): The type of normalization in networks. default set it to 'BN'. + - discrete_action_encoding_type (:obj:`str`): The type of encoding for discrete action. Default sets it to 'one_hot'. options = {'one_hot', 'not_one_hot'} + - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection for dynamics network, default set it to False. + """ + super(SampledMuZeroModel, self).__init__() + if not categorical_distribution: + self.reward_support_size = 1 + self.value_support_size = 1 + else: + self.reward_support_size = reward_support_size + self.value_support_size = value_support_size + + self.continuous_action_space = continuous_action_space + self.observation_shape = observation_shape + self.action_space_size = action_space_size + # The dim of action space. For discrete action space, it is 1. + # For continuous action space, it is the dimension of continuous action. + self.action_space_dim = action_space_size if self.continuous_action_space else 1 + assert discrete_action_encoding_type in ['one_hot', 'not_one_hot'], discrete_action_encoding_type + self.discrete_action_encoding_type = discrete_action_encoding_type + if self.continuous_action_space: + self.action_encoding_dim = action_space_size + else: + if self.discrete_action_encoding_type == 'one_hot': + self.action_encoding_dim = action_space_size + elif self.discrete_action_encoding_type == 'not_one_hot': + self.action_encoding_dim = 1 + + self.latent_state_dim = latent_state_dim + self.reward_head_hidden_channels = reward_head_hidden_channels + self.value_head_hidden_channels = value_head_hidden_channels + self.policy_head_hidden_channels = policy_head_hidden_channels + self.proj_hid = proj_hid + self.proj_out = proj_out + self.pred_hid = pred_hid + self.pred_out = pred_out + + self.last_linear_layer_init_zero = last_linear_layer_init_zero + self.state_norm = state_norm + self.downsample = downsample + + self.self_supervised_learning_loss = self_supervised_learning_loss + + self.sigma_type = sigma_type + self.fixed_sigma_value = fixed_sigma_value + self.bound_type = bound_type + self.norm_type = norm_type + self.num_of_sampled_actions = num_of_sampled_actions + self.res_connection_in_dynamics = res_connection_in_dynamics + self.activation = activation + + + if observation_shape[1] == 96: + latent_size = math.ceil(observation_shape[1] / 16) * math.ceil(observation_shape[2] / 16) + elif observation_shape[1] == 84: + latent_size = math.ceil(observation_shape[1] / 14) * math.ceil(observation_shape[2] / 14) + elif observation_shape[1] == 64: + latent_size = math.ceil(observation_shape[1] / 8) * math.ceil(observation_shape[2] / 8) + else: + raise ValueError("Invalid observation shape, only support 64, 84, 96 for now.") + + flatten_input_size_for_reward_head = ( + (reward_head_channels * latent_size) if downsample else + (reward_head_channels * observation_shape[1] * observation_shape[2]) + ) + flatten_input_size_for_value_head = ( + (value_head_channels * latent_size) if downsample else + (value_head_channels * observation_shape[1] * observation_shape[2]) + ) + flatten_input_size_for_policy_head = ( + (policy_head_channels * latent_size) if downsample else + (policy_head_channels * observation_shape[1] * observation_shape[2]) + ) + + self.representation_network = RepresentationNetwork( + observation_shape, + num_res_blocks, + num_channels, + downsample, + norm_type=self.norm_type, + use_sim_norm=use_sim_norm, + ) + + self.dynamics_network = DynamicsNetwork( + observation_shape, + self.action_encoding_dim, + num_res_blocks, + num_channels + self.action_encoding_dim, + reward_head_channels, + reward_head_hidden_channels, + self.reward_support_size, + flatten_input_size_for_reward_head, + downsample, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + activation=activation, + norm_type=norm_type + ) + + self.prediction_network = PredictionNetwork( + observation_shape, + self.continuous_action_space, + action_space_size, + num_res_blocks, + num_channels, + value_head_channels, + policy_head_channels, + value_head_hidden_channels, + policy_head_hidden_channels, + self.value_support_size, + flatten_input_size_for_value_head, + flatten_input_size_for_policy_head, + downsample, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + sigma_type=self.sigma_type, + fixed_sigma_value=self.fixed_sigma_value, + bound_type=self.bound_type, + norm_type=self.norm_type, + ) + + if self.self_supervised_learning_loss: + if self.downsample: + # In Atari, if the observation_shape is set to (12, 96, 96), which indicates the original shape of + # (3,96,96), and frame_stack_num is 4. Due to downsample, the encoding of observation (latent_state) is + # (64, 96/16, 96/16), where 64 is the number of channels, 96/16 is the size of the latent state. Thus, + # self.projection_input_dim = 64 * 96/16 * 96/16 = 64*6*6 = 2304 + self.projection_input_dim = num_channels * latent_size + else: + self.projection_input_dim = num_channels * observation_shape[1] * observation_shape[2] + + self.projection = nn.Sequential( + nn.Linear(self.projection_input_dim, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, + nn.Linear(self.proj_hid, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, + nn.Linear(self.proj_hid, self.proj_out), nn.BatchNorm1d(self.proj_out) + ) + self.prediction_head = nn.Sequential( + nn.Linear(self.proj_out, self.pred_hid), + nn.BatchNorm1d(self.pred_hid), + activation, + nn.Linear(self.pred_hid, self.pred_out), + ) + + def initial_inference(self, obs: torch.Tensor) -> MZNetworkOutput: + """ + Overview: + Initial inference of SampledMuZero model, which is the first step of the SampledMuZero model. + To perform the initial inference, we first use the representation network to obtain the "latent_state" of the observation. + Then we use the prediction network to predict the "value" and "policy_logits" of the "latent_state", and + also prepare the zeros-like ``reward_hidden_state`` for the next step of the Sampled MuZero model. + Arguments: + - obs (:obj:`torch.Tensor`): The 1D vector observation data. + Returns (MZNetworkOutput): + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. \ + In initial inference, we set it to zero vector. + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The hidden state of LSTM about reward. In initial inference, \ + we set it to the zeros-like hidden state (H and C). + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, obs_shape)`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + - value_prefix (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The shape of each element is :math:`(1, B, lstm_hidden_size)`, where B is batch_size. + """ + batch_size = obs.size(0) + latent_state = self._representation(obs) + policy_logits, value = self._prediction(latent_state) + return MZNetworkOutput( + value, + [0. for _ in range(batch_size)], + policy_logits, + latent_state, + ) + + def recurrent_inference(self, latent_state: torch.Tensor, action: torch.Tensor) -> MZNetworkOutput: + """ + Overview: + Recurrent inference of MuZero model, which is the rollout step of the MuZero model. + To perform the recurrent inference, we first use the dynamics network to predict ``next_latent_state``, + ``reward`` by the given current ``latent_state`` and ``action``. + We then use the prediction network to predict the ``value`` and ``policy_logits``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input obs. + - action (:obj:`torch.Tensor`): The predicted action to rollout. + Returns (MZNetworkOutput): + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - reward (:obj:`torch.Tensor`): The predicted reward for input state. + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - next_latent_state (:obj:`torch.Tensor`): The predicted next latent state. + Shapes: + - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + """ + next_latent_state, reward = self._dynamics(latent_state, action) + policy_logits, value = self._prediction(next_latent_state) + return MZNetworkOutput(value, reward, policy_logits, next_latent_state) + + def _representation(self, observation: torch.Tensor) -> Tuple[torch.Tensor]: + """ + Overview: + Use the representation network to encode the observations into latent state. + Arguments: + - obs (:obj:`torch.Tensor`): The 1D vector observation data. + Returns: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, obs_shape)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + """ + latent_state = self.representation_network(observation) + if self.state_norm: + latent_state = renormalize(latent_state) + return latent_state + + def _prediction(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Use the representation network to encode the observations into latent state. + Arguments: + - obs (:obj:`torch.Tensor`): The 1D vector observation data. + Returns: + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + """ + policy, value = self.prediction_network(latent_state) + return policy, value + + def _dynamics(self, latent_state: torch.Tensor, action: torch.Tensor) -> Tuple[ + torch.Tensor, Tuple[torch.Tensor], torch.Tensor]: + """ + Overview: + Concatenate ``latent_state`` and ``action`` and use the dynamics network to predict ``next_latent_state`` + ``value_prefix`` and ``next_reward_hidden_state``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The input hidden state of LSTM about reward. + - action (:obj:`torch.Tensor`): The predicted action to rollout. + Returns: + - next_latent_state (:obj:`torch.Tensor`): The predicted latent state of the next timestep. + - next_reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The output hidden state of LSTM about reward. + - value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - value_prefix (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + """ + # NOTE: the discrete action encoding type is important for some environments + + if not self.continuous_action_space: + # discrete action space + if self.discrete_action_encoding_type == 'one_hot': + # Stack latent_state with the one hot encoded action + if len(action.shape) == 1: + # (batch_size, ) -> (batch_size, 1) + # e.g., torch.Size([8]) -> torch.Size([8, 1]) + action = action.unsqueeze(-1) + + # transform action to one-hot encoding. + # action_one_hot shape: (batch_size, action_space_size), e.g., (8, 4) + action_one_hot = torch.zeros(action.shape[0], self.action_space_size, device=action.device) + # transform action to torch.int64 + action = action.long() + action_one_hot.scatter_(1, action, 1) + action_encoding = action_one_hot + elif self.discrete_action_encoding_type == 'not_one_hot': + action_encoding = action / self.action_space_size + if len(action_encoding.shape) == 1: + # (batch_size, ) -> (batch_size, 1) + # e.g., torch.Size([8]) -> torch.Size([8, 1]) + action_encoding = action_encoding.unsqueeze(-1) + else: + # continuous action space + if len(action.shape) == 1: + # (batch_size,) -> (batch_size, action_dim=1, 1, 1) + # e.g., torch.Size([8]) -> torch.Size([8, 1, 1, 1]) + action = action.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + elif len(action.shape) == 2: + # (batch_size, action_dim) -> (batch_size, action_dim, 1, 1) + # e.g., torch.Size([8, 2]) -> torch.Size([8, 2, 1, 1]) + action = action.unsqueeze(-1).unsqueeze(-1) + elif len(action.shape) == 3: + # (batch_size, action_dim, 1) -> (batch_size, action_dim) + # e.g., torch.Size([8, 2, 1]) -> torch.Size([8, 2, 1, 1]) + action = action.unsqueeze(-1) + + action_encoding_tmp = action + action_encoding = action_encoding_tmp.expand( + latent_state.shape[0], self.action_space_size, latent_state.shape[2], latent_state.shape[3] + ) + + action_encoding = action_encoding.to(latent_state.device).float() + # state_action_encoding shape: (batch_size, latent_state[1] + action_dim, latent_state[2], latent_state[3]) or + # (batch_size, latent_state[1] + action_space_size, latent_state[2], latent_state[3]) depending on the discrete_action_encoding_type. + state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) + + next_latent_state, reward = self.dynamics_network(state_action_encoding) + + if not self.state_norm: + return next_latent_state, reward + else: + next_latent_state_normalized = renormalize(next_latent_state) + return next_latent_state_normalized, reward + + def project(self, latent_state: torch.Tensor, with_grad: bool = True) -> torch.Tensor: + """ + Overview: + Project the latent state to a lower dimension to calculate the self-supervised loss, which is involved in + MuZero algorithm in EfficientZero. + For more details, please refer to the paper ``Exploring Simple Siamese Representation Learning``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - with_grad (:obj:`bool`): Whether to calculate gradient for the projection result. + Returns: + - proj (:obj:`torch.Tensor`): The result embedding vector of projection operation. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ + latent state, W_ is the width of latent state. + - proj (:obj:`torch.Tensor`): :math:`(B, projection_output_dim)`, where B is batch_size. + + Examples: + >>> latent_state = torch.randn(256, 64, 6, 6) + >>> output = self.project(latent_state) + >>> output.shape # (256, 1024) + + .. note:: + for Atari: + observation_shape = (12, 96, 96), # original shape is (3,96,96), frame_stack_num=4 + if downsample is True, latent_state.shape: (batch_size, num_channel, obs_shape[1] / 16, obs_shape[2] / 16) + i.e., (256, 64, 96 / 16, 96 / 16) = (256, 64, 6, 6) + latent_state reshape: (256, 64, 6, 6) -> (256,64*6*6) = (256, 2304) + # self.projection_input_dim = 64*6*6 = 2304 + # self.projection_output_dim = 1024 + """ + latent_state = latent_state.reshape(latent_state.shape[0], -1) + proj = self.projection(latent_state) + + if with_grad: + # with grad, use prediction_head + return self.prediction_head(proj) + else: + return proj.detach() + diff --git a/lzero/policy/sampled_muzero.py b/lzero/policy/sampled_muzero.py index 636683e1d..878f544a6 100644 --- a/lzero/policy/sampled_muzero.py +++ b/lzero/policy/sampled_muzero.py @@ -528,7 +528,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: 'total_loss': loss.mean().item(), 'policy_loss': policy_loss.mean().item(), 'policy_entropy': policy_entropy.item() / (self._cfg.num_unroll_steps + 1), - 'target_policy_entropy': target_policy_entropy.item() / (self._cfg.num_unroll_steps + 1), + 'target_policy_entropy': target_policy_entropy / (self._cfg.num_unroll_steps + 1), 'reward_loss': reward_loss.mean().item(), 'value_loss': value_loss.mean().item(), 'consistency_loss': consistency_loss.mean().item() / self._cfg.num_unroll_steps, diff --git a/zoo/dmc2gym/config/dmc2gym_pixels_sez_config.py b/zoo/dmc2gym/config/dmc2gym_pixels_sez_config.py index 62af52e15..5a3e16e90 100644 --- a/zoo/dmc2gym/config/dmc2gym_pixels_sez_config.py +++ b/zoo/dmc2gym/config/dmc2gym_pixels_sez_config.py @@ -14,18 +14,29 @@ batch_size = 256 max_env_step = int(5e6) reanalyze_ratio = 0. + +# ======== debug config ======== +collector_env_num = 2 +n_episode = 2 +evaluator_env_num = 2 +continuous_action_space = True +K = 2 # num_of_sampled_actions +num_simulations = 5 +replay_ratio = 0.05 +update_per_collect =2 +batch_size = 4 # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== dmc2gym_pixels_sampled_efficientzero_config = dict( - exp_name=f'data_sez/dmc2gym_pixels_sampled_efficientzero_k{K}_ns{num_simulations}_upc{update_per_collect}-rr{replay_ratio}_rer{reanalyze_ratio}_seed0', + exp_name=f'data_sez_debug/dmc2gym_pixels_sampled_efficientzero_k{K}_ns{num_simulations}_upc{update_per_collect}-rr{replay_ratio}_rer{reanalyze_ratio}_seed0', env=dict( env_id='dmc2gym-v0', domain_name="cartpole", task_name="swingup", from_pixels=True, # pixel/image obs - frame_skip=4, + frame_skip=8, warp_frame=True, scale=True, frame_stack_num=3, @@ -85,7 +96,8 @@ type='dmc2gym_lightzero', import_names=['zoo.dmc2gym.envs.dmc2gym_lightzero_env'], ), - env_manager=dict(type='subprocess'), + # env_manager=dict(type='subprocess'), + env_manager=dict(type='base'), policy=dict( type='sampled_efficientzero', import_names=['lzero.policy.sampled_efficientzero'], diff --git a/zoo/dmc2gym/config/dmc2gym_pixels_smz_config.py b/zoo/dmc2gym/config/dmc2gym_pixels_smz_config.py index 0fb901b82..3ffc0c69b 100644 --- a/zoo/dmc2gym/config/dmc2gym_pixels_smz_config.py +++ b/zoo/dmc2gym/config/dmc2gym_pixels_smz_config.py @@ -4,7 +4,7 @@ # ============================================================== from zoo.dmc2gym.config.dmc_state_env_space_map import dmc_state_env_action_space_map, dmc_state_env_obs_space_map -env_id = 'cartpole-swingup' # You can specify any DMC tasks here +env_id = 'cartpole-balance' # You can specify any DMC tasks here action_space_size = dmc_state_env_action_space_map[env_id] obs_space_size = dmc_state_env_obs_space_map[env_id] @@ -19,24 +19,34 @@ num_simulations = 50 update_per_collect = None replay_ratio = 0.25 -batch_size = 1024 +batch_size = 64 max_env_step = int(1e6) -reanalyze_ratio = 0. norm_type = 'LN' seed = 0 + +# ======== debug config ======== +# collector_env_num = 2 +# n_episode = 2 +# evaluator_env_num = 2 +# continuous_action_space = True +# K = 5 # num_of_sampled_actions +# num_simulations = 5 +# replay_ratio = 0.05 +# update_per_collect =2 +# batch_size = 4 # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== dmc2gym_pixels_cont_sampled_muzero_config = dict( - exp_name=f'data_smz/dmc2gym_{env_id}_state_cont_sampled_muzero_k{K}_ns{num_simulations}_upc{update_per_collect}-rr{replay_ratio}_rer{reanalyze_ratio}_{norm_type}_seed{seed}', + exp_name=f'data_smz/dmc2gym_{env_id}_pixel_cont_sampled_muzero_k{K}_ns{num_simulations}_upc{update_per_collect}-rr{replay_ratio}_{norm_type}_seed{seed}', env=dict( env_id='dmc2gym-v0', continuous=True, domain_name=domain_name, task_name=task_name, from_pixels=True, # pixel/image obs - frame_skip=2, + frame_skip=8, frame_stack_num=3, warp_frame=True, scale=True, @@ -50,6 +60,7 @@ model=dict( model_type='conv', observation_shape=(9, 84, 84), + downsample=True, image_channel=3, frame_stack_num=3, action_space_size=action_space_size, @@ -57,8 +68,8 @@ num_of_sampled_actions=K, sigma_type='conditioned', norm_type=norm_type, + self_supervised_learning_loss=True, ), - # (str) The path of the pretrained model. If None, the model will be initialized by the default model. model_path=None, cuda=True, env_type='not_board_games', @@ -66,15 +77,19 @@ update_per_collect=update_per_collect, batch_size=batch_size, optim_type='AdamW', + use_priority=False, cos_lr_scheduler=True, learning_rate=0.0001, num_simulations=num_simulations, - reanalyze_ratio=reanalyze_ratio, - policy_entropy_weight=5e-3, + reanalyze_ratio=0, + policy_entropy_weight=5e-2, + grad_clip_value=5, + manual_temperature_decay=True, + threshold_training_steps_for_final_temperature=int(2.5e4), n_episode=n_episode, eval_freq=int(2e3), replay_ratio=replay_ratio, - replay_buffer_size=int(1e6), + replay_buffer_size=int(1e5), collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, ),