diff --git a/lzero/entry/train_muzero_gpt.py b/lzero/entry/train_muzero_gpt.py index f14312338..de8f8c8ee 100644 --- a/lzero/entry/train_muzero_gpt.py +++ b/lzero/entry/train_muzero_gpt.py @@ -239,6 +239,8 @@ def train_muzero_gpt( # NOTE: TODO # TODO: for batch world model ,to improve kv reuse, we could donot reset policy._learn_model.world_model.past_keys_values_cache.clear() + + torch.cuda.empty_cache() # TODO # if collector.envstep > 0: # # TODO: only for debug diff --git a/lzero/model/common.py b/lzero/model/common.py index 47f2ccbd2..362eaeca7 100644 --- a/lzero/model/common.py +++ b/lzero/model/common.py @@ -293,18 +293,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # print('cont embedings before last_linear', x.max(), x.min(), x.mean()) # NOTE: very important. for muzero_gpt atari 64,8,8 = 4096 -> 1024 - x = self.last_linear(x.contiguous().view(-1, 64*8*8)) + # x = self.last_linear(x.contiguous().view(-1, 64*8*8)) + x = self.last_linear(x.reshape(-1, 64*8*8)) # TODO x = x.view(-1, self.embedding_dim) - # print(x.max(), x.min()) - # x = renormalize(x) # print('cont embedings before renormalize', x.max(), x.min(), x.mean()) - # x = AvgL1Norm(x) - # print('after AvgL1Norm', x.max(), x.min()) # x = torch.tanh(x) x = renormalize(x) - # print('after renormalize', x.max(), x.min(),x.mean()) return x diff --git a/lzero/model/gpt_models/cfg_atari.py b/lzero/model/gpt_models/cfg_atari.py index 8b60f2ae6..a320ecda2 100644 --- a/lzero/model/gpt_models/cfg_atari.py +++ b/lzero/model/gpt_models/cfg_atari.py @@ -66,7 +66,7 @@ 'embed_pdrop': 0.1, 'resid_pdrop': 0.1, 'attn_pdrop': 0.1, - "device": 'cuda:2', + "device": 'cuda:1', # "device": 'cpu', 'support_size': 21, 'action_shape': 6,# TODO:for atari diff --git a/lzero/model/gpt_models/utils.py b/lzero/model/gpt_models/utils.py index f52f3fa17..f1d5770dc 100644 --- a/lzero/model/gpt_models/utils.py +++ b/lzero/model/gpt_models/utils.py @@ -126,8 +126,8 @@ def __init__(self, **kwargs): # self.latent_kl_loss_weight = 0.1 # for lunarlander self.latent_kl_loss_weight = 0. # for lunarlander - # self.latent_recon_loss_weight = 1 - self.latent_recon_loss_weight = 0.1 + self.latent_recon_loss_weight = 0. + # self.latent_recon_loss_weight = 0.1 # Initialize the total loss tensor on the correct device diff --git a/lzero/model/gpt_models/world_model.py b/lzero/model/gpt_models/world_model.py index 7f846f0ba..02115ef8f 100644 --- a/lzero/model/gpt_models/world_model.py +++ b/lzero/model/gpt_models/world_model.py @@ -639,72 +639,27 @@ def compute_loss(self, batch, tokenizer: Tokenizer=None, **kwargs: Any) -> LossW if len(batch['observations'][0, 0].shape) == 3: # obs is a 3-dimensional image pass - # elif len(batch['observations'][0, 0].shape) == 1: - # # print('obs is a 1-dimensional vector.') - # # TODO() - # # obs is a 1-dimensional vector - # original_shape = list(batch['observations'].shape) - # desired_shape = original_shape + [64, 64] - # expanded_observations = batch['observations'].unsqueeze(-1).unsqueeze(-1) - # expanded_observations = expanded_observations.expand(*desired_shape) - # batch['observations'] = expanded_observations - - # with torch.no_grad(): - # # 目前这里是没有梯度的 - # obs_tokens = tokenizer.encode(batch['observations'], should_preprocess=True).tokens # (BL, K) # NOTE: 这里是需要梯度的 - # obs_tokens = tokenizer.encode(batch['observations'], should_preprocess=True).tokens # (BL, K) - with torch.no_grad(): # TODO + with torch.no_grad(): # TODO: 非常重要 obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations'], should_preprocess=False) # (B, C, H, W) -> (B, K, E) - # obs_embeddings.register_hook(lambda grad: grad * 1/5) # TODO:只提供重建损失更新表征网络 + # obs_embeddings.register_hook(lambda grad: grad * 1) # TODO:只提供重建损失更新表征网络 # Assume that 'cont_embeddings' and 'original_images' are available from prior code # Decode the embeddings to reconstruct the images reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) # Calculate the reconstruction loss - latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].contiguous().view(-1, 4, 64, 64), reconstructed_images) - - - # 计算KL散度损失 - # 假设 obs_embeddings.shape = (160, 1, 256) - # 这里我们首先计算每个特征维度的均值和方差 - mean = obs_embeddings.mean(dim=0, keepdim=True) - std = obs_embeddings.std(dim=0, keepdim=True) - # 创建标准正态分布作为先验分布 - prior_dist = torch.distributions.Normal(torch.zeros_like(mean), torch.ones_like(std)) - - - # 创建模型输出的分布 - model_dist = torch.distributions.Normal(mean, std) - # 计算KL散度损失,对每个样本的每个特征进行计算 - kl_loss = torch.distributions.kl.kl_divergence(model_dist, prior_dist) - # 因为 kl_loss 的形状是 (1, 1, 256),我们可以对所有特征求平均来得到一个标量损失 - latent_kl_loss = kl_loss.mean() - # print(f'latent_kl_loss:, {latent_kl_loss}') - if torch.isnan(latent_kl_loss) or torch.isinf(latent_kl_loss): - print("NaN or inf detected in latent_kl_loss!") - # 使用 torch.tensor(0) 创建一个同设备,同数据类型的零张量,并确保不需要梯度 - latent_kl_loss = torch.tensor(0., device=latent_kl_loss.device, dtype=latent_kl_loss.dtype) + # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].contiguous().view(-1, 4, 64, 64), reconstructed_images) + latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 4, 64, 64), reconstructed_images) # TODO - # TODO - # obs_embeddings = AvgL1Norm(obs_embeddings) - # second to last 增加高斯噪声 TODO - # noise_std = 0.1 - # obs_embeddings = obs_embeddings.view(32, 5, -1) - # noise = torch.randn_like(obs_embeddings[:, 1:, :]) * noise_std - # # 修改后的代码,不使用原地操作 - # obs_embeddings = obs_embeddings.clone() # 克隆obs_embeddings来创建一个新的变量 - # obs_embeddings[:, 1:, :] = obs_embeddings[:, 1:, :] + noise + latent_kl_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) act_tokens = rearrange(batch['actions'], 'b l -> b l 1') - # tokens = rearrange(torch.cat((obs_tokens, act_tokens), dim=2), 'b l k1 -> b (l k1)') # (B, L(K+1)) - # outputs = self.forward(tokens, is_root=False) # TODO: 只提供重建损失更新表征网络 outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings.detach(), act_tokens)}, is_root=False) @@ -712,58 +667,23 @@ def compute_loss(self, batch, tokenizer: Tokenizer=None, **kwargs: Any) -> LossW batch['ends'], batch['mask_padding']) - """ - >>> # Example of target with class probabilities - >>> input = torch.randn(3, 5, requires_grad=True) - >>> target = torch.randn(3, 5).softmax(dim=1) - >>> loss = F.cross_entropy(input, target) - >>> loss.backward() - """ logits_observations = rearrange(outputs.logits_observations[:, :-1], 'b t o -> (b t) o') - labels_observations = labels_observations.contiguous().view(-1, self.projection_input_dim) # TODO: - # loss_obs = F.cross_entropy(logits_observations, labels_observations) + # labels_observations = labels_observations.contiguous().view(-1, self.projection_input_dim) # TODO: + labels_observations = labels_observations.reshape(-1, self.projection_input_dim) # TODO: - # TODO: EZ consistency loss; TWM loss - # loss_obs = self.negative_cosine_similarity(logits_observations, labels_observations.detach()) # 2528 = 32 * 79 = 32, 5*16-1 - - # obs_projection = self.projection(logits_observations) - # obs_prediction = self.prediction_head(obs_projection) - # obs_target = self.projection(labels_observations).detach() - # loss_obs = self.negative_cosine_similarity(obs_prediction, obs_target) - - loss_obs = torch.nn.functional.mse_loss(logits_observations, labels_observations.detach(), reduction='none').mean(-1) - - - # batch['mask_padding'] shape 32, 5 - # loss_obs = (loss_obs* batch['mask_padding']).mean() - - # Step 1: 扩展mask_padding - # 除去最后一个time step,每个time step 重复16次 NOTE检查shape是否reshape正确 - # mask_padding_expanded = batch['mask_padding'].unsqueeze(-1).repeat(1, 1, self.num_observations_tokens).reshape(32, -1)[:, :-1].contiguous().view(-1) - mask_padding_expanded = batch['mask_padding'][:, 1:].contiguous().view(-1) # TODO: - # mask_padding_expanded = batch['mask_padding'][:, :-1].contiguous().view(-1) + # mask_padding_expanded = batch['mask_padding'][:, 1:].reshape(-1) # 应用mask到loss_obs - # 使用inverted mask,因为我们想要保留非padding的loss loss_obs = (loss_obs * mask_padding_expanded).mean(-1) - # if loss_obs > 10: - # print('debug') - - # loss_ends = F.cross_entropy(rearrange(outputs.logits_ends, 'b t e -> (b t) e'), labels_ends) - - labels_policy, labels_value = self.compute_labels_world_model_value_policy(batch['target_value'], batch['target_policy'], batch['mask_padding']) loss_rewards = self.compute_cross_entropy_loss(outputs, labels_rewards, batch, element='rewards') loss_policy = self.compute_cross_entropy_loss(outputs, labels_policy, batch, element='policy') - """torch.eq(labels_observations, logits_observations.argmax(-1)).sum().item() / labels_observations.shape[0] - F.cross_entropy(logits_observations, logits_observations.argmax(-1)) - """ loss_value = self.compute_cross_entropy_loss(outputs, labels_value, batch, element='value') return LossWithIntermediateLosses(loss_obs=loss_obs, loss_rewards=loss_rewards, loss_value=loss_value, diff --git a/lzero/model/muzero_gpt_model.py b/lzero/model/muzero_gpt_model.py index 7cc214144..68a527aac 100644 --- a/lzero/model/muzero_gpt_model.py +++ b/lzero/model/muzero_gpt_model.py @@ -178,48 +178,6 @@ def __init__( print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') - # self.prediction_network = PredictionNetwork( - # observation_shape, - # action_space_size, - # num_res_blocks, - # num_channels, - # value_head_channels, - # policy_head_channels, - # fc_value_layers, - # fc_policy_layers, - # self.value_support_size, - # flatten_output_size_for_value_head, - # flatten_output_size_for_policy_head, - # downsample, - # last_linear_layer_init_zero=self.last_linear_layer_init_zero, - # activation=activation, - # norm_type=norm_type - # ) - - # if self.self_supervised_learning_loss: - # # projection used in EfficientZero - # if self.downsample: - # # In Atari, if the observation_shape is set to (12, 96, 96), which indicates the original shape of - # # (3,96,96), and frame_stack_num is 4. Due to downsample, the encoding of observation (latent_state) is - # # (64, 96/16, 96/16), where 64 is the number of channels, 96/16 is the size of the latent state. Thus, - # # self.projection_input_dim = 64 * 96/16 * 96/16 = 64*6*6 = 2304 - # ceil_size = math.ceil(observation_shape[1] / 16) * math.ceil(observation_shape[2] / 16) - # self.projection_input_dim = num_channels * ceil_size - # else: - # self.projection_input_dim = num_channels * observation_shape[1] * observation_shape[2] - - # self.projection = nn.Sequential( - # nn.Linear(self.projection_input_dim, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, - # nn.Linear(self.proj_hid, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, - # nn.Linear(self.proj_hid, self.proj_out), nn.BatchNorm1d(self.proj_out) - # ) - # self.prediction_head = nn.Sequential( - # nn.Linear(self.proj_out, self.pred_hid), - # nn.BatchNorm1d(self.pred_hid), - # activation, - # nn.Linear(self.pred_hid, self.pred_out), - # ) - def initial_inference(self, obs: torch.Tensor, action_batch=None) -> MZNetworkOutput: """ Overview: diff --git a/lzero/policy/muzero_gpt.py b/lzero/policy/muzero_gpt.py index 23b4de99a..ed00a71cc 100644 --- a/lzero/policy/muzero_gpt.py +++ b/lzero/policy/muzero_gpt.py @@ -265,52 +265,10 @@ def _init_learn(self) -> None: Overview: Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. """ - # assert self._cfg.optim_type in ['SGD', 'Adam', 'AdamW'], self._cfg.optim_type - # # NOTE: in board_games, for fixed lr 0.003, 'Adam' is better than 'SGD'. - # if self._cfg.optim_type == 'SGD': - # self._optimizer = optim.SGD( - # self._model.parameters(), - # lr=self._cfg.learning_rate, - # momentum=self._cfg.momentum, - # weight_decay=self._cfg.weight_decay, - # ) - # elif self._cfg.optim_type == 'Adam': - # self._optimizer = optim.Adam( - # self._model.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay - # ) - # elif self._cfg.optim_type == 'AdamW': - # self._optimizer = configure_optimizers( - # model=self._model, - # weight_decay=self._cfg.weight_decay, - # learning_rate=self._cfg.learning_rate, - # device_type=self._cfg.device - # ) - - # self._optimizer_tokenizer = optim.Adam( - # self._model.tokenizer.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay - # ) - - # self._optimizer_tokenizer = optim.Adam( - # self._model.tokenizer.parameters(), lr=self._cfg.learning_rate # weight_decay=0 - # ) - - # # TODO: nanoGPT optimizer - # self._optimizer_world_model = configure_optimizer( - # model=self._model.world_model, - # learning_rate=self._cfg.learning_rate, - # weight_decay=self._cfg.weight_decay, - # # weight_decay=0.01, - # exclude_submodules=['tokenizer'] - # ) - self._optimizer_tokenizer = optim.Adam( self._model.tokenizer.parameters(), lr=1e-4 # weight_decay=0 ) - # self._optimizer_tokenizer = optim.Adam( - # self._model.tokenizer.parameters(), lr=3e-3 # weight_decay=0 - # ) - # TODO: nanoGPT optimizer # self._optimizer_world_model = configure_optimizer( # model=self._model.world_model, @@ -322,26 +280,13 @@ def _init_learn(self) -> None: # ) self._optimizer_world_model = configure_optimizer( model=self._model.world_model, - # learning_rate=3e-3, - learning_rate=1e-4, + learning_rate=3e-3, + # learning_rate=1e-4, # NOTE: TODO weight_decay=self._cfg.weight_decay, # weight_decay=0.01, exclude_submodules=['none'] # NOTE ) - # self._optimizer_world_model = configure_optimizers( - # model=self._model.world_model, - # weight_decay=self._cfg.weight_decay, - # learning_rate=self._cfg.learning_rate, - # device_type=self._cfg.device - # ) - - # if self._cfg.lr_piecewise_constant_decay: - # from torch.optim.lr_scheduler import LambdaLR - # max_step = self._cfg.threshold_training_steps_for_final_lr - # # NOTE: the 1, 0.1, 0.01 is the decay rate, not the lr. - # lr_lambda = lambda step: 1 if step < max_step * 0.5 else (0.1 if step < max_step else 0.01) # noqa - # self.lr_scheduler = LambdaLR(self._optimizer, lr_lambda=lr_lambda) # use model_wrapper for specialized demands of different modes self._target_model = copy.deepcopy(self._model) @@ -359,7 +304,6 @@ def _init_learn(self) -> None: ) self._learn_model = self._model - # TODO: only for debug # for param in self._learn_model.tokenizer.parameters(): # param.requires_grad = False diff --git a/zoo/atari/config/atari_muzero_gpt_config_stack4.py b/zoo/atari/config/atari_muzero_gpt_config_stack4.py index d1e291f84..082275dac 100644 --- a/zoo/atari/config/atari_muzero_gpt_config_stack4.py +++ b/zoo/atari/config/atari_muzero_gpt_config_stack4.py @@ -1,6 +1,6 @@ from easydict import EasyDict import torch -torch.cuda.set_device(2) +torch.cuda.set_device(1) # options={'PongNoFrameskip-v4', 'QbertNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SpaceInvadersNoFrameskip-v4', 'BreakoutNoFrameskip-v4', ...} env_name = 'PongNoFrameskip-v4' @@ -25,6 +25,8 @@ n_episode = 8 evaluator_env_num = 1 update_per_collect = 1000 +# update_per_collect = 2000 + # update_per_collect = None # model_update_ratio = 0.25 model_update_ratio = 0.25 @@ -37,12 +39,8 @@ reanalyze_ratio = 0 batch_size = 32 # for num_head=2, emmbding_dim=128 -# batch_size = 8 # for num_head=4, emmbding_dim=256 - num_unroll_steps = 5 -# batch_size = 8 -# num_unroll_steps =10 # for debug # collector_env_num = 8 @@ -66,10 +64,12 @@ atari_muzero_config = dict( # TODO: world_model.py decode_obs_tokens # TODO: tokenizer.py: lpips loss - exp_name=f'data_mz_gpt_ctree_0111/{env_name[:-14]}_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_mcs500_contembdings_mz-repenet-lastlinear-lsd1024-6488-kaiming-lerelu_obsmseloss_obsloss2_kllw0-lr1e-4-gcv05-biasfalse-minmax-iter60k-fixed_stack4_seed0', - # exp_name=f'data_mz_gpt_ctree_0110/{env_name[:-14]}_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_mcs500_contembdings_mz-repenet-lastlinear-lsd1024-6488-kaiming-lerelu_obsmseloss_obsloss2_kllw0-lr1e-4-gcv05-onlyreconslossw1-biasfalse-minmax_stack4_seed0', + # exp_name=f'data_mz_gpt_ctree_0111/{env_name[:-14]}_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_mcs500_contembdings_mz-repenet-lastlinear-lsd1024-6488-kaiming-lerelu_obsmseloss_obsloss2_kllw0-lr3e-3-gcv05-reconslossw01-minmax-latentgrad02_stack4_upc1000_seed0', - # exp_name=f'data_mz_gpt_ctree_0110/{env_name[:-14]}_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_mcs500_contembdings_mz-repenet-lastlinear-lsd1024-6488-kaiming-lerelu_obsmseloss_obsloss2_kllw0-lr1e-4-gcv05-reconslossw1-minmax-latentgrad0.2_stack4_seed0', + exp_name=f'data_mz_gpt_ctree_0111/{env_name[:-14]}_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_mcs500_contembdings_mz-repenet-lastlinear-lsd1024-6488-kaiming-lerelu_obsmseloss_obsloss2_kllw0-lr3e-3-gcv05-biasfalse-minmax-iter60k-fixed_stack4_seed0', + + # exp_name=f'data_mz_gpt_ctree_0111/{env_name[:-14]}_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_mcs500_contembdings_mz-repenet-lastlinear-lsd1024-6488-kaiming-lerelu_obsmseloss_obsloss2_kllw0-lr1e-4-gcv05-reconslossw01-minmax-latentgrad0.2-fromm60ktrain_stack4_upc1000_seed0', + # exp_name=f'data_mz_gpt_ctree_0111/{env_name[:-14]}_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_mcs500_contembdings_mz-repenet-lastlinear-lsd1024-6488-kaiming-lerelu_obsmseloss_obsloss2_kllw0-lr1e-4-gcv05-onlyreconslossw1-biasfalse-minmax_stack4_seed0', # exp_name=f'data_mz_gpt_ctree_0110/{env_name[:-14]}_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_mcs500_contembdings_mz-repenet-lastlinear-lsd1024-6488-kaiming-lerelu_obsmseloss_obsloss2_kllw0-reconslossw01-tanh-lr1e-4-gcv05_stack4_seed0', @@ -184,6 +184,8 @@ target_update_freq=100, grad_clip_value = 0.5, # TODO + # grad_clip_value = 10, # TODO + num_simulations=num_simulations, reanalyze_ratio=reanalyze_ratio,