Skip to content

Commit

Permalink
polish(pu): polish unused debug code
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Jan 11, 2024
1 parent f8d88e6 commit c13fea7
Show file tree
Hide file tree
Showing 8 changed files with 27 additions and 205 deletions.
2 changes: 2 additions & 0 deletions lzero/entry/train_muzero_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,8 @@ def train_muzero_gpt(
# NOTE: TODO
# TODO: for batch world model ,to improve kv reuse, we could donot reset
policy._learn_model.world_model.past_keys_values_cache.clear()

torch.cuda.empty_cache() # TODO

# if collector.envstep > 0:
# # TODO: only for debug
Expand Down
8 changes: 2 additions & 6 deletions lzero/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,18 +293,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# print('cont embedings before last_linear', x.max(), x.min(), x.mean())

# NOTE: very important. for muzero_gpt atari 64,8,8 = 4096 -> 1024
x = self.last_linear(x.contiguous().view(-1, 64*8*8))
# x = self.last_linear(x.contiguous().view(-1, 64*8*8))
x = self.last_linear(x.reshape(-1, 64*8*8)) # TODO

x = x.view(-1, self.embedding_dim)
# print(x.max(), x.min())
# x = renormalize(x)

# print('cont embedings before renormalize', x.max(), x.min(), x.mean())
# x = AvgL1Norm(x)
# print('after AvgL1Norm', x.max(), x.min())
# x = torch.tanh(x)
x = renormalize(x)

# print('after renormalize', x.max(), x.min(),x.mean())

return x
Expand Down
2 changes: 1 addition & 1 deletion lzero/model/gpt_models/cfg_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
'embed_pdrop': 0.1,
'resid_pdrop': 0.1,
'attn_pdrop': 0.1,
"device": 'cuda:2',
"device": 'cuda:1',
# "device": 'cpu',
'support_size': 21,
'action_shape': 6,# TODO:for atari
Expand Down
4 changes: 2 additions & 2 deletions lzero/model/gpt_models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ def __init__(self, **kwargs):
# self.latent_kl_loss_weight = 0.1 # for lunarlander
self.latent_kl_loss_weight = 0. # for lunarlander

# self.latent_recon_loss_weight = 1
self.latent_recon_loss_weight = 0.1
self.latent_recon_loss_weight = 0.
# self.latent_recon_loss_weight = 0.1


# Initialize the total loss tensor on the correct device
Expand Down
96 changes: 8 additions & 88 deletions lzero/model/gpt_models/world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,131 +639,51 @@ def compute_loss(self, batch, tokenizer: Tokenizer=None, **kwargs: Any) -> LossW
if len(batch['observations'][0, 0].shape) == 3:
# obs is a 3-dimensional image
pass
# elif len(batch['observations'][0, 0].shape) == 1:
# # print('obs is a 1-dimensional vector.')
# # TODO()
# # obs is a 1-dimensional vector
# original_shape = list(batch['observations'].shape)
# desired_shape = original_shape + [64, 64]
# expanded_observations = batch['observations'].unsqueeze(-1).unsqueeze(-1)
# expanded_observations = expanded_observations.expand(*desired_shape)
# batch['observations'] = expanded_observations

# with torch.no_grad():
# # 目前这里是没有梯度的
# obs_tokens = tokenizer.encode(batch['observations'], should_preprocess=True).tokens # (BL, K)

# NOTE: 这里是需要梯度的
# obs_tokens = tokenizer.encode(batch['observations'], should_preprocess=True).tokens # (BL, K)
with torch.no_grad(): # TODO
with torch.no_grad(): # TODO: 非常重要
obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations'], should_preprocess=False) # (B, C, H, W) -> (B, K, E)

# obs_embeddings.register_hook(lambda grad: grad * 1/5) # TODO:只提供重建损失更新表征网络
# obs_embeddings.register_hook(lambda grad: grad * 1) # TODO:只提供重建损失更新表征网络

# Assume that 'cont_embeddings' and 'original_images' are available from prior code
# Decode the embeddings to reconstruct the images
reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings)

# Calculate the reconstruction loss
latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].contiguous().view(-1, 4, 64, 64), reconstructed_images)


# 计算KL散度损失
# 假设 obs_embeddings.shape = (160, 1, 256)
# 这里我们首先计算每个特征维度的均值和方差
mean = obs_embeddings.mean(dim=0, keepdim=True)
std = obs_embeddings.std(dim=0, keepdim=True)
# 创建标准正态分布作为先验分布
prior_dist = torch.distributions.Normal(torch.zeros_like(mean), torch.ones_like(std))


# 创建模型输出的分布
model_dist = torch.distributions.Normal(mean, std)
# 计算KL散度损失,对每个样本的每个特征进行计算
kl_loss = torch.distributions.kl.kl_divergence(model_dist, prior_dist)
# 因为 kl_loss 的形状是 (1, 1, 256),我们可以对所有特征求平均来得到一个标量损失
latent_kl_loss = kl_loss.mean()
# print(f'latent_kl_loss:, {latent_kl_loss}')
if torch.isnan(latent_kl_loss) or torch.isinf(latent_kl_loss):
print("NaN or inf detected in latent_kl_loss!")
# 使用 torch.tensor(0) 创建一个同设备,同数据类型的零张量,并确保不需要梯度
latent_kl_loss = torch.tensor(0., device=latent_kl_loss.device, dtype=latent_kl_loss.dtype)
# latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].contiguous().view(-1, 4, 64, 64), reconstructed_images)
latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 4, 64, 64), reconstructed_images) # TODO

# TODO
# obs_embeddings = AvgL1Norm(obs_embeddings)

# second to last 增加高斯噪声 TODO
# noise_std = 0.1
# obs_embeddings = obs_embeddings.view(32, 5, -1)
# noise = torch.randn_like(obs_embeddings[:, 1:, :]) * noise_std
# # 修改后的代码,不使用原地操作
# obs_embeddings = obs_embeddings.clone() # 克隆obs_embeddings来创建一个新的变量
# obs_embeddings[:, 1:, :] = obs_embeddings[:, 1:, :] + noise
latent_kl_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype)


act_tokens = rearrange(batch['actions'], 'b l -> b l 1')

# tokens = rearrange(torch.cat((obs_tokens, act_tokens), dim=2), 'b l k1 -> b (l k1)') # (B, L(K+1))
# outputs = self.forward(tokens, is_root=False)
# TODO: 只提供重建损失更新表征网络
outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings.detach(), act_tokens)}, is_root=False)

labels_observations, labels_rewards, labels_ends = self.compute_labels_world_model(obs_embeddings, batch['rewards'],
batch['ends'],
batch['mask_padding'])

"""
>>> # Example of target with class probabilities
>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.randn(3, 5).softmax(dim=1)
>>> loss = F.cross_entropy(input, target)
>>> loss.backward()
"""
logits_observations = rearrange(outputs.logits_observations[:, :-1], 'b t o -> (b t) o')
labels_observations = labels_observations.contiguous().view(-1, self.projection_input_dim) # TODO:
# loss_obs = F.cross_entropy(logits_observations, labels_observations)
# labels_observations = labels_observations.contiguous().view(-1, self.projection_input_dim) # TODO:
labels_observations = labels_observations.reshape(-1, self.projection_input_dim) # TODO:

# TODO: EZ consistency loss; TWM loss
# loss_obs = self.negative_cosine_similarity(logits_observations, labels_observations.detach()) # 2528 = 32 * 79 = 32, 5*16-1


# obs_projection = self.projection(logits_observations)
# obs_prediction = self.prediction_head(obs_projection)
# obs_target = self.projection(labels_observations).detach()
# loss_obs = self.negative_cosine_similarity(obs_prediction, obs_target)


loss_obs = torch.nn.functional.mse_loss(logits_observations, labels_observations.detach(), reduction='none').mean(-1)


# batch['mask_padding'] shape 32, 5
# loss_obs = (loss_obs* batch['mask_padding']).mean()

# Step 1: 扩展mask_padding
# 除去最后一个time step,每个time step 重复16次 NOTE检查shape是否reshape正确
# mask_padding_expanded = batch['mask_padding'].unsqueeze(-1).repeat(1, 1, self.num_observations_tokens).reshape(32, -1)[:, :-1].contiguous().view(-1)

mask_padding_expanded = batch['mask_padding'][:, 1:].contiguous().view(-1) # TODO:
# mask_padding_expanded = batch['mask_padding'][:, :-1].contiguous().view(-1)
# mask_padding_expanded = batch['mask_padding'][:, 1:].reshape(-1)

# 应用mask到loss_obs
# 使用inverted mask,因为我们想要保留非padding的loss
loss_obs = (loss_obs * mask_padding_expanded).mean(-1)
# if loss_obs > 10:
# print('debug')

# loss_ends = F.cross_entropy(rearrange(outputs.logits_ends, 'b t e -> (b t) e'), labels_ends)


labels_policy, labels_value = self.compute_labels_world_model_value_policy(batch['target_value'],
batch['target_policy'],
batch['mask_padding'])

loss_rewards = self.compute_cross_entropy_loss(outputs, labels_rewards, batch, element='rewards')
loss_policy = self.compute_cross_entropy_loss(outputs, labels_policy, batch, element='policy')
"""torch.eq(labels_observations, logits_observations.argmax(-1)).sum().item() / labels_observations.shape[0]
F.cross_entropy(logits_observations, logits_observations.argmax(-1))
"""
loss_value = self.compute_cross_entropy_loss(outputs, labels_value, batch, element='value')

return LossWithIntermediateLosses(loss_obs=loss_obs, loss_rewards=loss_rewards, loss_value=loss_value,
Expand Down
42 changes: 0 additions & 42 deletions lzero/model/muzero_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,48 +178,6 @@ def __init__(
print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model')


# self.prediction_network = PredictionNetwork(
# observation_shape,
# action_space_size,
# num_res_blocks,
# num_channels,
# value_head_channels,
# policy_head_channels,
# fc_value_layers,
# fc_policy_layers,
# self.value_support_size,
# flatten_output_size_for_value_head,
# flatten_output_size_for_policy_head,
# downsample,
# last_linear_layer_init_zero=self.last_linear_layer_init_zero,
# activation=activation,
# norm_type=norm_type
# )

# if self.self_supervised_learning_loss:
# # projection used in EfficientZero
# if self.downsample:
# # In Atari, if the observation_shape is set to (12, 96, 96), which indicates the original shape of
# # (3,96,96), and frame_stack_num is 4. Due to downsample, the encoding of observation (latent_state) is
# # (64, 96/16, 96/16), where 64 is the number of channels, 96/16 is the size of the latent state. Thus,
# # self.projection_input_dim = 64 * 96/16 * 96/16 = 64*6*6 = 2304
# ceil_size = math.ceil(observation_shape[1] / 16) * math.ceil(observation_shape[2] / 16)
# self.projection_input_dim = num_channels * ceil_size
# else:
# self.projection_input_dim = num_channels * observation_shape[1] * observation_shape[2]

# self.projection = nn.Sequential(
# nn.Linear(self.projection_input_dim, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation,
# nn.Linear(self.proj_hid, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation,
# nn.Linear(self.proj_hid, self.proj_out), nn.BatchNorm1d(self.proj_out)
# )
# self.prediction_head = nn.Sequential(
# nn.Linear(self.proj_out, self.pred_hid),
# nn.BatchNorm1d(self.pred_hid),
# activation,
# nn.Linear(self.pred_hid, self.pred_out),
# )

def initial_inference(self, obs: torch.Tensor, action_batch=None) -> MZNetworkOutput:
"""
Overview:
Expand Down
60 changes: 2 additions & 58 deletions lzero/policy/muzero_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,52 +265,10 @@ def _init_learn(self) -> None:
Overview:
Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils.
"""
# assert self._cfg.optim_type in ['SGD', 'Adam', 'AdamW'], self._cfg.optim_type
# # NOTE: in board_games, for fixed lr 0.003, 'Adam' is better than 'SGD'.
# if self._cfg.optim_type == 'SGD':
# self._optimizer = optim.SGD(
# self._model.parameters(),
# lr=self._cfg.learning_rate,
# momentum=self._cfg.momentum,
# weight_decay=self._cfg.weight_decay,
# )
# elif self._cfg.optim_type == 'Adam':
# self._optimizer = optim.Adam(
# self._model.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay
# )
# elif self._cfg.optim_type == 'AdamW':
# self._optimizer = configure_optimizers(
# model=self._model,
# weight_decay=self._cfg.weight_decay,
# learning_rate=self._cfg.learning_rate,
# device_type=self._cfg.device
# )

# self._optimizer_tokenizer = optim.Adam(
# self._model.tokenizer.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay
# )

# self._optimizer_tokenizer = optim.Adam(
# self._model.tokenizer.parameters(), lr=self._cfg.learning_rate # weight_decay=0
# )

# # TODO: nanoGPT optimizer
# self._optimizer_world_model = configure_optimizer(
# model=self._model.world_model,
# learning_rate=self._cfg.learning_rate,
# weight_decay=self._cfg.weight_decay,
# # weight_decay=0.01,
# exclude_submodules=['tokenizer']
# )

self._optimizer_tokenizer = optim.Adam(
self._model.tokenizer.parameters(), lr=1e-4 # weight_decay=0
)

# self._optimizer_tokenizer = optim.Adam(
# self._model.tokenizer.parameters(), lr=3e-3 # weight_decay=0
# )

# TODO: nanoGPT optimizer
# self._optimizer_world_model = configure_optimizer(
# model=self._model.world_model,
Expand All @@ -322,26 +280,13 @@ def _init_learn(self) -> None:
# )
self._optimizer_world_model = configure_optimizer(
model=self._model.world_model,
# learning_rate=3e-3,
learning_rate=1e-4,
learning_rate=3e-3,
# learning_rate=1e-4, # NOTE: TODO
weight_decay=self._cfg.weight_decay,
# weight_decay=0.01,
exclude_submodules=['none'] # NOTE
)

# self._optimizer_world_model = configure_optimizers(
# model=self._model.world_model,
# weight_decay=self._cfg.weight_decay,
# learning_rate=self._cfg.learning_rate,
# device_type=self._cfg.device
# )

# if self._cfg.lr_piecewise_constant_decay:
# from torch.optim.lr_scheduler import LambdaLR
# max_step = self._cfg.threshold_training_steps_for_final_lr
# # NOTE: the 1, 0.1, 0.01 is the decay rate, not the lr.
# lr_lambda = lambda step: 1 if step < max_step * 0.5 else (0.1 if step < max_step else 0.01) # noqa
# self.lr_scheduler = LambdaLR(self._optimizer, lr_lambda=lr_lambda)

# use model_wrapper for specialized demands of different modes
self._target_model = copy.deepcopy(self._model)
Expand All @@ -359,7 +304,6 @@ def _init_learn(self) -> None:
)
self._learn_model = self._model


# TODO: only for debug
# for param in self._learn_model.tokenizer.parameters():
# param.requires_grad = False
Expand Down
Loading

0 comments on commit c13fea7

Please sign in to comment.