Skip to content

Commit

Permalink
fix(pu): latent state gradient times 0.2, set grad_clip_value to 0.5,…
Browse files Browse the repository at this point in the history
… add obs reconstruction loss, add load muzero representation net utils
  • Loading branch information
puyuan1996 committed Jan 11, 2024
1 parent 31efa2f commit f8d88e6
Show file tree
Hide file tree
Showing 15 changed files with 727 additions and 74 deletions.
4 changes: 2 additions & 2 deletions lzero/entry/train_muzero_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,11 +240,11 @@ def train_muzero_gpt(
# TODO: for batch world model ,to improve kv reuse, we could donot reset
policy._learn_model.world_model.past_keys_values_cache.clear()

# if collector.envstep > 10000:
# if collector.envstep > 0:
# # TODO: only for debug
# for param in policy._learn_model.world_model.tokenizer.parameters():
# param.requires_grad = False
# print("train some steps before collector.envstep > 10000, then fixed")
# print("train some steps before collector.envstep > 0, then fixed")

if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
break
Expand Down
23 changes: 21 additions & 2 deletions lzero/mcts/tree_search/mcts_ctree.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,13 @@ def search(
min_max_stats_lst = tree_muzero.MinMaxStatsList(batch_size)
min_max_stats_lst.set_delta(self._cfg.value_delta_max)

state_action_history = [] # 初始化 state_action_history 变量
last_latent_state = latent_state_roots
# NOTE: very important, from the right init key-value-cache
# forward_initial_inference()以及执行了下面的操作
# _ = model.world_model.refresh_keys_values_with_initial_obs_tokens(model.world_model.obs_tokens)

# model.world_model.past_keys_values_cache.clear() # 清除缓存
for simulation_index in range(self._cfg.num_simulations):
# In each simulation, we expanded a new node, so in one search, we have ``num_simulations`` num of nodes at most.

Expand All @@ -305,23 +312,35 @@ def search(
latent_states.append(latent_state_batch_in_search_path[ix][iy])

latent_states = torch.from_numpy(np.asarray(latent_states)).to(self._cfg.device).float()
# .long() is only for discrete action
# TODO: .long() is only for discrete action
last_actions = torch.from_numpy(np.asarray(last_actions)).to(self._cfg.device).long()

# TODO
# 在每次模拟后更新 state_action_history
# state_action_history.append((last_latent_state, last_actions.detach().cpu().numpy()))
state_action_history.append((latent_states.detach().cpu().numpy(), last_actions.detach().cpu().numpy()))

"""
MCTS stage 2: Expansion
At the final time-step l of the simulation, the next_latent_state and reward/value_prefix are computed by the dynamics function.
Then we calculate the policy_logits and value for the leaf node (next_latent_state) by the prediction function. (aka. evaluation)
MCTS stage 3: Backup
At the end of the simulation, the statistics along the trajectory are updated.
"""
network_output = model.recurrent_inference(latent_states, last_actions)
# network_output = model.recurrent_inference(latent_states, last_actions) # for classic muzero
# network_output = model.recurrent_inference(last_actions) # TODO: for muzero_gpt latent_states is not used in the model.
network_output = model.recurrent_inference(state_action_history) # TODO: latent_states is not used in the model.

network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state)
network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits)
network_output.value = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.value))
network_output.reward = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.reward))

latent_state_batch_in_search_path.append(network_output.latent_state)

# TODO
# last_latent_state = network_output.latent_state

# tolist() is to be compatible with cpp datatype.
reward_batch = network_output.reward.reshape(-1).tolist()
value_batch = network_output.value.reshape(-1).tolist()
Expand Down
478 changes: 478 additions & 0 deletions lzero/mcts/tree_search/mcts_ctree_orig.py

Large diffs are not rendered by default.

93 changes: 68 additions & 25 deletions lzero/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

return output

# EZ original
# def renormalize(inputs: torch.Tensor, first_dim: int = 1) -> torch.Tensor:
# """
# Overview:
Expand All @@ -158,17 +159,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

# return flat_input.view(*input.shape)

# def renormalize(x): # min-max
# # x is a 2D tensor of shape (batch_size, num_features)
# # Compute the min and max for each feature across the batch
# x_min = torch.min(x, dim=0, keepdim=True).values
# x_max = torch.max(x, dim=0, keepdim=True).values
def renormalize(x): # min-max
# x is a 2D tensor of shape (batch_size, num_features)
# Compute the min and max for each feature across the batch
x_min = torch.min(x, dim=0, keepdim=True).values
x_max = torch.max(x, dim=0, keepdim=True).values

# # Apply min-max normalization
# x_std = (x - x_min) / (x_max - x_min + 1e-8) # Add a small epsilon to avoid division by zero
# x_scaled = x_std * (1 - 0) + 0 # Assuming you want to scale between 0 and 1
# Apply min-max normalization
x_std = (x - x_min) / (x_max - x_min + 1e-8) # Add a small epsilon to avoid division by zero
x_scaled = x_std * (1 - 0) + 0 # Assuming you want to scale between 0 and 1

# return x_scaled
return x_scaled

# def renormalize(x): # z-score
# # x is a 2D tensor of shape (batch_size, num_features)
Expand All @@ -181,19 +182,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

# return x_normalized

def renormalize(x): # robust scaling
# x is a 2D tensor of shape (batch_size, num_features)
# Compute the 1st and 3rd quartile
q1 = torch.quantile(x, 0.25, dim=0, keepdim=True)
q3 = torch.quantile(x, 0.75, dim=0, keepdim=True)
# def renormalize(x): # robust scaling
# # x is a 2D tensor of shape (batch_size, num_features)
# # Compute the 1st and 3rd quartile
# q1 = torch.quantile(x, 0.25, dim=0, keepdim=True)
# q3 = torch.quantile(x, 0.75, dim=0, keepdim=True)

# Compute the interquartile range (IQR)
iqr = q3 - q1
# # Compute the interquartile range (IQR)
# iqr = q3 - q1

# Apply robust scaling
x_scaled = (x - q1) / (iqr + 1e-8) # Again, add epsilon to avoid division by zero
# # Apply robust scaling
# x_scaled = (x - q1) / (iqr + 1e-8) # Again, add epsilon to avoid division by zero

return x_scaled
# return x_scaled

def AvgL1Norm(x, eps=1e-8):
return x/x.abs().mean(-1,keepdim=True).clamp(min=eps)
Expand Down Expand Up @@ -261,14 +262,15 @@ def __init__(
# self.last_linear = nn.Linear(64*4*4, 64*4*4)

# self.last_linear = nn.Linear(64*4*4, 256)
self.last_linear = nn.Linear(64*8*8, self.embedding_dim)
# self.last_linear = nn.Linear(64*8*8, self.embedding_dim)
self.last_linear = nn.Linear(64*8*8, self.embedding_dim, bias=False)

# TODO
# Initialize weights using He initialization
init.kaiming_normal_(self.last_linear.weight, mode='fan_out', nonlinearity='relu')

# Initialize biases to zero
init.zeros_(self.last_linear.bias)
# init.zeros_(self.last_linear.bias)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Expand All @@ -284,23 +286,26 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv(x)
x = self.norm(x)
x = self.activation(x)
print('after downsample_net:', x.max(), x.min(), x.mean())
# print('after downsample_net:', x.max(), x.min(), x.mean())
for block in self.resblocks:
x = block(x)

# print('cont embedings before last_linear', x.max(), x.min(), x.mean())

# NOTE: very important. for muzero_gpt atari 64,8,8 = 4096 -> 1024
x = self.last_linear(x.contiguous().view(-1, 64*8*8))

x = x.view(-1, self.embedding_dim)
# print(x.max(), x.min())
# x = renormalize(x)

print('cont embedings', x.max(), x.min(), x.mean())

# print('cont embedings before renormalize', x.max(), x.min(), x.mean())
# x = AvgL1Norm(x)
# print('after AvgL1Norm', x.max(), x.min())
# x = torch.tanh(x)
# print('after tanh', x.max(), x.min(),x.mean())
x = renormalize(x)

# print('after renormalize', x.max(), x.min(),x.mean())

return x

Expand All @@ -318,6 +323,44 @@ def get_param_mean(self) -> float:
return mean


class LatentDecoder(nn.Module):
def __init__(self, embedding_dim: int, output_shape: SequenceType, num_channels: int = 64):
super().__init__()
self.embedding_dim = embedding_dim
self.output_shape = output_shape # (C, H, W)
self.num_channels = num_channels

# Assuming that the output shape is (C, H, W) = (12, 96, 96) and embedding_dim is 256
# We will reverse the process of the representation network
self.initial_size = (num_channels, output_shape[1] // 8, output_shape[2] // 8) # This should match the last layer of the encoder
self.fc = nn.Linear(self.embedding_dim, np.prod(self.initial_size))

# Upsampling blocks
self.conv_blocks = nn.ModuleList([
# Block 1: (num_channels, H/8, W/8) -> (num_channels//2, H/4, W/4)
nn.ConvTranspose2d(num_channels, num_channels // 2, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.BatchNorm2d(num_channels // 2),
# Block 2: (num_channels//2, H/4, W/4) -> (num_channels//4, H/2, W/2)
nn.ConvTranspose2d(num_channels // 2, num_channels // 4, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.BatchNorm2d(num_channels // 4),
# Block 3: (num_channels//4, H/2, W/2) -> (output_shape[0], H, W)
nn.ConvTranspose2d(num_channels // 4, output_shape[0], kernel_size=3, stride=2, padding=1, output_padding=1),
])

def forward(self, embeddings: torch.Tensor) -> torch.Tensor:
# Map embeddings back to the image space
x = self.fc(embeddings) # (B, embedding_dim) -> (B, C*H/8*W/8)
x = x.view(-1, *self.initial_size) # (B, C*H/8*W/8) -> (B, C, H/8, W/8)

# Apply conv blocks
for block in self.conv_blocks:
x = block(x) # Upsample progressively

# The output x should have the shape of (B, output_shape[0], output_shape[1], output_shape[2])
return x

class RepresentationNetworkMLP(nn.Module):

def __init__(
Expand Down
5 changes: 3 additions & 2 deletions lzero/model/gpt_models/plot_sequence_frame_grey.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
batch_observations = batch['observations'][:,:,0:1,:,:]
batch_observations = reconstructed_images.detach().view(32, 5, 4, 64, 64)[:,:,0:1,:,:]
# batch_observations = batch['observations'][:,:,0:1,:,:]
B, N, C, H, W = batch_observations.shape # 自动检测维度

# 分隔条的宽度(可以根据需要调整)
Expand Down Expand Up @@ -35,7 +36,7 @@
plt.show()

# 保存图像到文件
concat_image.save(f'sample_{i+1}_0110.png')
concat_image.save(f'sample_{i+1}_recs_0110.png')



Expand Down
2 changes: 0 additions & 2 deletions lzero/model/gpt_models/plot_weight_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@





x = torch.randn(192, 64, 8, 8).to('cuda:0')

def check_layer_output(model, x):
Expand Down
18 changes: 14 additions & 4 deletions lzero/model/gpt_models/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class TokenizerEncoderOutput:


class Tokenizer(nn.Module):
def __init__(self, vocab_size: int, embed_dim: int, encoder: Encoder, decoder: Decoder, with_lpips: bool = True, representation_network = None) -> None:
def __init__(self, vocab_size: int, embed_dim: int, encoder: Encoder, decoder: Decoder, with_lpips: bool = True, representation_network = None, decoder_network =None) -> None:
super().__init__()
self.vocab_size = vocab_size
self.encoder = encoder
Expand All @@ -46,6 +46,8 @@ def __init__(self, vocab_size: int, embed_dim: int, encoder: Encoder, decoder: D
self.embedding.weight.data.uniform_(-1.0 / vocab_size, 1.0 / vocab_size)
self.lpips = LPIPS().eval() if with_lpips else None
self.representation_network = representation_network
self.decoder_network = decoder_network


def __repr__(self) -> str:
return "tokenizer"
Expand Down Expand Up @@ -184,12 +186,20 @@ def encode_to_obs_embeddings(self, x: torch.Tensor, should_preprocess: bool = Fa
# obs_embeddings = rearrange(obs_embeddings, 'b c h w -> b 1 (c h w)') # (160,1,1024)
obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e') # (4,1,256) # TODO


#===============


return obs_embeddings

def decode_to_obs(self, embeddings: torch.Tensor) -> torch.Tensor:
return self.decoder_network(embeddings)


def reconstruction_loss(self, original_images: torch.Tensor, reconstructed_images: torch.Tensor) -> torch.Tensor:
# Mean Squared Error (MSE) is commonly used as a reconstruction loss
# loss = nn.MSELoss()(original_images, reconstructed_images) # L1 loss
loss = torch.abs(original_images - reconstructed_images).mean()
return loss


def decode(self, z_q: torch.Tensor, should_postprocess: bool = False) -> torch.Tensor:
shape = z_q.shape # (..., E, h, w)
z_q = z_q.view(-1, *shape[-3:])
Expand Down
13 changes: 9 additions & 4 deletions lzero/model/gpt_models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,12 @@ def __init__(self, **kwargs):
self.policy_loss_weight = 1.
# self.ends_loss_weight = 1.
self.ends_loss_weight = 0.
self.rep_kl_loss_weight = 0.1 # for lunarlander

# self.rep_kl_loss_weight = 0.5
# self.latent_kl_loss_weight = 0.1 # for lunarlander
self.latent_kl_loss_weight = 0. # for lunarlander

# self.latent_recon_loss_weight = 1
self.latent_recon_loss_weight = 0.1


# Initialize the total loss tensor on the correct device
Expand All @@ -140,8 +143,10 @@ def __init__(self, **kwargs):
self.loss_total += self.value_loss_weight * v
elif k == 'loss_ends':
self.loss_total += self.ends_loss_weight * v
elif k == 'rep_kl_loss':
self.loss_total += self.rep_kl_loss_weight * v
elif k == 'latent_kl_loss':
self.loss_total += self.latent_kl_loss_weight * v
elif k == 'latent_recon_loss':
self.loss_total += self.latent_recon_loss_weight * v
else:
raise ValueError(f"Unknown loss type : {k}")

Expand Down
33 changes: 22 additions & 11 deletions lzero/model/gpt_models/world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: Transformer
# nn.ReLU(),
# nn.Linear(config.embed_dim, obs_vocab_size)
nn.LeakyReLU(negative_slope=0.01), # TODO: 2
nn.Linear(config.embed_dim, self.obs_per_embdding_dim)
nn.Linear(config.embed_dim, self.obs_per_embdding_dim),
# nn.Tanh(), # TODO
)
)

Expand Down Expand Up @@ -654,7 +655,17 @@ def compute_loss(self, batch, tokenizer: Tokenizer=None, **kwargs: Any) -> LossW

# NOTE: 这里是需要梯度的
# obs_tokens = tokenizer.encode(batch['observations'], should_preprocess=True).tokens # (BL, K)
obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations'], should_preprocess=True) # (B, C, H, W) -> (B, K, E)
with torch.no_grad(): # TODO
obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations'], should_preprocess=False) # (B, C, H, W) -> (B, K, E)

# obs_embeddings.register_hook(lambda grad: grad * 1/5) # TODO:只提供重建损失更新表征网络

# Assume that 'cont_embeddings' and 'original_images' are available from prior code
# Decode the embeddings to reconstruct the images
reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings)

# Calculate the reconstruction loss
latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].contiguous().view(-1, 4, 64, 64), reconstructed_images)


# 计算KL散度损失
Expand All @@ -663,21 +674,20 @@ def compute_loss(self, batch, tokenizer: Tokenizer=None, **kwargs: Any) -> LossW
mean = obs_embeddings.mean(dim=0, keepdim=True)
std = obs_embeddings.std(dim=0, keepdim=True)
# 创建标准正态分布作为先验分布
# prior_dist = torch.distributions.Normal(torch.zeros_like(mean), torch.ones_like(std))
prior_dist = torch.distributions.Normal(torch.ones_like(mean)*0.1, torch.ones_like(std))
prior_dist = torch.distributions.Normal(torch.zeros_like(mean), torch.ones_like(std))


# 创建模型输出的分布
model_dist = torch.distributions.Normal(mean, std)
# 计算KL散度损失,对每个样本的每个特征进行计算
kl_loss = torch.distributions.kl.kl_divergence(model_dist, prior_dist)
# 因为 kl_loss 的形状是 (1, 1, 256),我们可以对所有特征求平均来得到一个标量损失
rep_kl_loss = kl_loss.mean()
print(f'rep_kl_loss:, {rep_kl_loss}')
if torch.isnan(rep_kl_loss) or torch.isinf(rep_kl_loss):
print("NaN or inf detected in rep_kl_loss!")
latent_kl_loss = kl_loss.mean()
# print(f'latent_kl_loss:, {latent_kl_loss}')
if torch.isnan(latent_kl_loss) or torch.isinf(latent_kl_loss):
print("NaN or inf detected in latent_kl_loss!")
# 使用 torch.tensor(0) 创建一个同设备,同数据类型的零张量,并确保不需要梯度
rep_kl_loss = torch.tensor(0., device=rep_kl_loss.device, dtype=rep_kl_loss.dtype)
latent_kl_loss = torch.tensor(0., device=latent_kl_loss.device, dtype=latent_kl_loss.dtype)

# TODO
# obs_embeddings = AvgL1Norm(obs_embeddings)
Expand All @@ -695,7 +705,8 @@ def compute_loss(self, batch, tokenizer: Tokenizer=None, **kwargs: Any) -> LossW

# tokens = rearrange(torch.cat((obs_tokens, act_tokens), dim=2), 'b l k1 -> b (l k1)') # (B, L(K+1))
# outputs = self.forward(tokens, is_root=False)
outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False)
# TODO: 只提供重建损失更新表征网络
outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings.detach(), act_tokens)}, is_root=False)

labels_observations, labels_rewards, labels_ends = self.compute_labels_world_model(obs_embeddings, batch['rewards'],
batch['ends'],
Expand Down Expand Up @@ -756,7 +767,7 @@ def compute_loss(self, batch, tokenizer: Tokenizer=None, **kwargs: Any) -> LossW
loss_value = self.compute_cross_entropy_loss(outputs, labels_value, batch, element='value')

return LossWithIntermediateLosses(loss_obs=loss_obs, loss_rewards=loss_rewards, loss_value=loss_value,
loss_policy=loss_policy, rep_kl_loss=rep_kl_loss)
loss_policy=loss_policy, latent_kl_loss=latent_kl_loss, latent_recon_loss=latent_recon_loss)

def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'):
# Assume outputs.logits_rewards and labels are your predictions and targets
Expand Down
Loading

0 comments on commit f8d88e6

Please sign in to comment.