Skip to content

Commit

Permalink
feature(pu): add muzero_gpt for atari, polish world_model.past_keys_v…
Browse files Browse the repository at this point in the history
…alues_cache
  • Loading branch information
puyuan1996 committed Nov 23, 2023
1 parent 8779039 commit a20deba
Show file tree
Hide file tree
Showing 13 changed files with 899 additions and 40 deletions.
3 changes: 3 additions & 0 deletions lzero/entry/train_muzero_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,9 @@ def train_muzero_gpt(

if cfg.policy.use_priority:
replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig'])

# TODO
policy._learn_model.world_model.past_keys_values_cache.clear()

if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
break
Expand Down
3 changes: 2 additions & 1 deletion lzero/mcts/tree_search/mcts_ctree.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,9 @@ def search(

state_action_history = [] # 初始化 state_action_history 变量
last_latent_state = latent_state_roots
# TODO
# 你可能需要在每次搜索开始时清除past_keys_values_cache,以防止缓存过大:
model.world_model.past_keys_values_cache.clear() # 清除缓存
# model.world_model.past_keys_values_cache.clear() # 清除缓存
for simulation_index in range(self._cfg.num_simulations):
# In each simulation, we expanded a new node, so in one search, we have ``num_simulations`` num of nodes at most.

Expand Down
38 changes: 17 additions & 21 deletions lzero/model/gpt_models/cfg.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,38 @@
cfg = {}

cfg['tokenizer'] = {'_target_': 'models.tokenizer.Tokenizer',
# 'vocab_size': 512,
# 'vocab_size': 512, # TODO: for atari
# 'embed_dim': 512,
'vocab_size': 128, # TODO
# 'vocab_size': 256, # TODO: for atari debug
# 'embed_dim': 256,
'vocab_size': 128, # TODO: for cartpole
'embed_dim': 128,
'encoder':
# {'resolution': 64, 'in_channels': 3, 'z_channels': 512, 'ch': 64,
# 'ch_mult': [1, 1, 1, 1, 1], 'num_res_blocks': 2, 'attn_resolutions': [8, 16],
# 'out_ch': 3, 'dropout': 0.0},
{'resolution': 1, 'in_channels': 4, 'z_channels': 128, 'ch': 64,
'ch_mult': [1, 1, 1, 1, 1], 'num_res_blocks': 2, 'attn_resolutions': [8, 16],
'out_ch': 3, 'dropout': 0.0},
'out_ch': 3, 'dropout': 0.0},# TODO: for cartpole
'decoder':
# {'resolution': 64, 'in_channels': 3, 'z_channels': 512, 'ch': 64,
# 'ch_mult': [1, 1, 1, 1, 1], 'num_res_blocks': 2, 'attn_resolutions': [8, 16],
# 'out_ch': 3, 'dropout': 0.0}} # TODO:
{'resolution': 1, 'in_channels': 4, 'z_channels': 128, 'ch': 64,
'ch_mult': [1, 1, 1, 1, 1], 'num_res_blocks': 2, 'attn_resolutions': [8, 16],
'out_ch': 3, 'dropout': 0.0}}
'out_ch': 3, 'dropout': 0.0}} # TODO: for cartpole

cfg['world_model'] = {'tokens_per_block': 17,
# 'max_blocks': 20,
# "max_tokens": 17 * 20, # TODO: horizon
'max_blocks': 5,
"max_tokens": 17 * 5, # TODO: horizon
cfg['world_model'] = {
'tokens_per_block': 17,
'max_blocks': 20,
"max_tokens": 17 * 20, # TODO: horizon
'attention': 'causal',
# 'num_layers': 10,
'num_layers': 2, # TODO:
# 'num_layers': 10,# TODO:for atari
'num_layers': 2, # TODO:for debug
'num_heads': 4,
# 'embed_dim': 256, # TODO:
'embed_dim': 128,
'embed_dim': 128, # TODO: for cartpole
'embed_pdrop': 0.1,
'resid_pdrop': 0.1,
'attn_pdrop': 0.1,
# "device": 'cuda:0',
"device": 'cpu',
"device": 'cuda:0',
# "device": 'cpu',
'support_size': 21,
'action_shape': 2,# TODO: for cartpole

}

from easydict import EasyDict
Expand Down
50 changes: 50 additions & 0 deletions lzero/model/gpt_models/cfg_atari.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
cfg = {}

cfg['tokenizer'] = {'_target_': 'models.tokenizer.Tokenizer',
# 'vocab_size': 512, # TODO: for atari
# 'embed_dim': 512,
# 'vocab_size': 256, # TODO: for atari debug
# 'embed_dim': 256,
'vocab_size': 128, # TODO: for atari debug
'embed_dim': 64,
'encoder':
{'resolution': 64, 'in_channels': 3, 'z_channels': 64, 'ch': 32,
'ch_mult': [1, 1, 1, 1, 1], 'num_res_blocks': 2, 'attn_resolutions': [8, 16],
'out_ch': 3, 'dropout': 0.0},# TODO:for atari debug
'decoder':
{'resolution': 64, 'in_channels': 3, 'z_channels': 64, 'ch': 32,
'ch_mult': [1, 1, 1, 1, 1], 'num_res_blocks': 2, 'attn_resolutions': [8, 16],
'out_ch': 3, 'dropout': 0.0}} # TODO:for atari debug
# {'resolution': 64, 'in_channels': 1, 'z_channels': 512, 'ch': 64,
# 'ch_mult': [1, 1, 1, 1, 1], 'num_res_blocks': 2, 'attn_resolutions': [8, 16],
# 'out_ch': 3, 'dropout': 0.0},# TODO:for atari
# 'decoder':
# {'resolution': 64, 'in_channels': 1, 'z_channels': 512, 'ch': 64,
# 'ch_mult': [1, 1, 1, 1, 1], 'num_res_blocks': 2, 'attn_resolutions': [8, 16],
# 'out_ch': 3, 'dropout': 0.0}} # TODO:for atari
cfg['world_model'] = {
'tokens_per_block': 17,
'max_blocks': 20,
"max_tokens": 17 * 20, # TODO: horizon
# 'max_blocks': 5,
# "max_tokens": 17 * 5, # TODO: horizon
'attention': 'causal',
# 'num_layers': 10,# TODO:for atari
# 'num_heads': 4,
'num_layers': 2, # TODO:for atari debug
'num_heads': 2,
# 'embed_dim': 256, # TODO:for atari
'embed_dim': 64, # TODO:for atari debug
'embed_pdrop': 0.1,
'resid_pdrop': 0.1,
'attn_pdrop': 0.1,
"device": 'cuda:0',
# "device": 'cpu',
'support_size': 601,
'action_shape': 6,# TODO:for atari

}

from easydict import EasyDict

cfg = EasyDict(cfg)
40 changes: 40 additions & 0 deletions lzero/model/gpt_models/cfg_cartpole.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
cfg = {}

cfg['tokenizer'] = {'_target_': 'models.tokenizer.Tokenizer',
# 'vocab_size': 512, # TODO: for atari
# 'embed_dim': 512,
# 'vocab_size': 256, # TODO: for atari debug
# 'embed_dim': 256,
'vocab_size': 128, # TODO: for cartpole
'embed_dim': 128,
'encoder':
{'resolution': 1, 'in_channels': 4, 'z_channels': 128, 'ch': 64,
'ch_mult': [1, 1, 1, 1, 1], 'num_res_blocks': 2, 'attn_resolutions': [8, 16],
'out_ch': 3, 'dropout': 0.0},# TODO: for cartpole
'decoder':
{'resolution': 1, 'in_channels': 4, 'z_channels': 128, 'ch': 64,
'ch_mult': [1, 1, 1, 1, 1], 'num_res_blocks': 2, 'attn_resolutions': [8, 16],
'out_ch': 3, 'dropout': 0.0}} # TODO: for cartpole

cfg['world_model'] = {
'tokens_per_block': 17,
'max_blocks': 20,
"max_tokens": 17 * 20, # TODO: horizon
'attention': 'causal',
# 'num_layers': 10,# TODO:for atari
'num_layers': 2, # TODO:for debug
'num_heads': 4,
'embed_dim': 128, # TODO: for cartpole
'embed_pdrop': 0.1,
'resid_pdrop': 0.1,
'attn_pdrop': 0.1,
"device": 'cuda:0',
# "device": 'cpu',
'support_size': 21,
'action_shape': 2,# TODO: for cartpole

}

from easydict import EasyDict

cfg = EasyDict(cfg)
27 changes: 19 additions & 8 deletions lzero/model/gpt_models/world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import torch
from torch.distributions.categorical import Categorical
import torchvision
from joblib import hash


@dataclass
Expand All @@ -44,9 +45,10 @@ def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: Transformer

self.transformer = Transformer(config)
self.num_observations_tokens = 16
# self.device = 'cpu'
self.device = config.device
self.support_size = config.support_size
self.action_shape = config.action_shape


all_but_last_obs_tokens_pattern = torch.ones(config.tokens_per_block)
all_but_last_obs_tokens_pattern[-2] = 0
Expand Down Expand Up @@ -119,7 +121,7 @@ def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: Transformer
head_module=nn.Sequential(
nn.Linear(config.embed_dim, config.embed_dim),
nn.ReLU(),
nn.Linear(config.embed_dim, 2) # TODO(pu); action shape
nn.Linear(config.embed_dim, self.action_shape) # TODO(pu); action shape
)
)
self.head_value = Head(
Expand Down Expand Up @@ -149,7 +151,11 @@ def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: Transformer
nn.init.zeros_(layer.bias)
break

self.past_keys_values_cache = {}
# self.past_keys_values_cache = {}
from collections import deque
# TODO: Transformer更新后应该清除缓存
self.max_cache_size = 10000
self.past_keys_values_cache = deque(maxlen=self.max_cache_size)

def __repr__(self) -> str:
return "world_model"
Expand Down Expand Up @@ -296,8 +302,6 @@ def forward_recurrent_inference(self, state_action_history, should_predict_next_
# 从这个位置开始的 action_history
action_history_from_last_root = state_action_history[last_root_position:]
cache_key = tuple(action_history_from_last_root)

from joblib import hash
cache_key = hash(action_history_from_last_root)

if cache_key in self.past_keys_values_cache:
Expand Down Expand Up @@ -344,8 +348,14 @@ def forward_recurrent_inference(self, state_action_history, should_predict_next_
obs = self.decode_obs_tokens() if should_predict_next_obs else None
# return outputs_wm.output_sequence, outputs_wm.logits_observations, outputs_wm.logits_rewards, outputs_wm.logits_policy, outputs_wm.logits_value

# TODO: 在计算结束后,更新缓存
self.past_keys_values_cache[cache_key] = copy.deepcopy(self.keys_values_wm)
# TODO: 在计算结束后,更新缓存. 是否需要deepcopy
# self.past_keys_values_cache[cache_key] = copy.deepcopy(self.keys_values_wm)

# 每次需要添加新的键值对时,检查缓存大小并根据需要弹出最旧的缓存项
if len(self.past_keys_values_cache) >= self.max_cache_size:
self.past_keys_values_cache.popleft()
# 这样可以在固定的内存空间中保持缓存,并自动清理旧的缓存项。
self.past_keys_values_cache.append((cache_key, copy.deepcopy(self.keys_values_wm)))

return outputs_wm.output_sequence, self.obs_tokens, reward, outputs_wm.logits_policy, outputs_wm.logits_value

Expand Down Expand Up @@ -395,6 +405,7 @@ def compute_loss(self, batch, tokenizer: Tokenizer, **kwargs: Any) -> LossWithIn
# obs is a 3-dimensional image
pass
elif len(batch['observations'][0, 0].shape) == 1:
print('obs is a 1-dimensional vector.')
# TODO()
# obs is a 1-dimensional vector
original_shape = list(batch['observations'].shape)
Expand Down Expand Up @@ -498,5 +509,5 @@ def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, ta

mask_fill_value = mask_fill.unsqueeze(-1).expand_as(target_value)
labels_value = target_value.masked_fill(mask_fill_value, -100)
return labels_policy.reshape(-1, 2), labels_value.reshape(-1, self.support_size) # TODO(pu)
return labels_policy.reshape(-1, self.action_shape), labels_value.reshape(-1, self.support_size) # TODO(pu)
# return labels_policy.reshape(-1, ), labels_value.reshape(-1)
Loading

0 comments on commit a20deba

Please sign in to comment.