Skip to content

Commit

Permalink
polish(pu): polish slicer
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Nov 27, 2023
1 parent a20deba commit 641680c
Show file tree
Hide file tree
Showing 12 changed files with 243 additions and 69 deletions.
2 changes: 1 addition & 1 deletion lzero/mcts/tree_search/mcts_ctree.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def search(
last_latent_state = latent_state_roots
# TODO
# 你可能需要在每次搜索开始时清除past_keys_values_cache,以防止缓存过大:
# model.world_model.past_keys_values_cache.clear() # 清除缓存
model.world_model.past_keys_values_cache.clear() # 清除缓存
for simulation_index in range(self._cfg.num_simulations):
# In each simulation, we expanded a new node, so in one search, we have ``num_simulations`` num of nodes at most.

Expand Down
38 changes: 24 additions & 14 deletions lzero/model/gpt_models/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,44 @@
# 'embed_dim': 512,
# 'vocab_size': 256, # TODO: for atari debug
# 'embed_dim': 256,
'vocab_size': 128, # TODO: for cartpole
'vocab_size': 128, # TODO: for atari debug
'embed_dim': 128,
'encoder':
{'resolution': 1, 'in_channels': 4, 'z_channels': 128, 'ch': 64,
'ch_mult': [1, 1, 1, 1, 1], 'num_res_blocks': 2, 'attn_resolutions': [8, 16],
'out_ch': 3, 'dropout': 0.0},# TODO: for cartpole
'decoder':
{'resolution': 1, 'in_channels': 4, 'z_channels': 128, 'ch': 64,
'ch_mult': [1, 1, 1, 1, 1], 'num_res_blocks': 2, 'attn_resolutions': [8, 16],
'out_ch': 3, 'dropout': 0.0}} # TODO: for cartpole

{'resolution': 64, 'in_channels': 3, 'z_channels': 128, 'ch': 64,
'ch_mult': [1, 1, 1, 1, 1], 'num_res_blocks': 2, 'attn_resolutions': [8, 16],
'out_ch': 3, 'dropout': 0.0},# TODO:for atari debug
'decoder':
{'resolution': 64, 'in_channels': 3, 'z_channels': 128, 'ch': 64,
'ch_mult': [1, 1, 1, 1, 1], 'num_res_blocks': 2, 'attn_resolutions': [8, 16],
'out_ch': 3, 'dropout': 0.0}} # TODO:for atari debug
# {'resolution': 64, 'in_channels': 1, 'z_channels': 512, 'ch': 64,
# 'ch_mult': [1, 1, 1, 1, 1], 'num_res_blocks': 2, 'attn_resolutions': [8, 16],
# 'out_ch': 3, 'dropout': 0.0},# TODO:for atari
# 'decoder':
# {'resolution': 64, 'in_channels': 1, 'z_channels': 512, 'ch': 64,
# 'ch_mult': [1, 1, 1, 1, 1], 'num_res_blocks': 2, 'attn_resolutions': [8, 16],
# 'out_ch': 3, 'dropout': 0.0}} # TODO:for atari
cfg['world_model'] = {
'tokens_per_block': 17,
'max_blocks': 20,
"max_tokens": 17 * 20, # TODO: horizon
# 'max_blocks': 5,
# "max_tokens": 17 * 5, # TODO: horizon
'attention': 'causal',
# 'num_layers': 10,# TODO:for atari
'num_layers': 2, # TODO:for debug
'num_layers': 2, # TODO:for atari debug
'num_heads': 4,
'embed_dim': 128, # TODO: for cartpole
'embed_dim': 128, # TODO:for atari
# 'embed_dim': 64, # TODO:for atari debug
'embed_pdrop': 0.1,
'resid_pdrop': 0.1,
'attn_pdrop': 0.1,
"device": 'cuda:0',
# "device": 'cpu',
'support_size': 21,
'action_shape': 2,# TODO: for cartpole

'support_size': 601,
'action_shape': 6,# TODO:for atari
# 'max_cache_size':5000,
'max_cache_size':20,
}

from easydict import EasyDict
Expand Down
15 changes: 7 additions & 8 deletions lzero/model/gpt_models/cfg_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
# 'vocab_size': 256, # TODO: for atari debug
# 'embed_dim': 256,
'vocab_size': 128, # TODO: for atari debug
'embed_dim': 64,
'embed_dim': 128,
'encoder':
{'resolution': 64, 'in_channels': 3, 'z_channels': 64, 'ch': 32,
{'resolution': 64, 'in_channels': 3, 'z_channels': 128, 'ch': 64,
'ch_mult': [1, 1, 1, 1, 1], 'num_res_blocks': 2, 'attn_resolutions': [8, 16],
'out_ch': 3, 'dropout': 0.0},# TODO:for atari debug
'decoder':
{'resolution': 64, 'in_channels': 3, 'z_channels': 64, 'ch': 32,
{'resolution': 64, 'in_channels': 3, 'z_channels': 128, 'ch': 64,
'ch_mult': [1, 1, 1, 1, 1], 'num_res_blocks': 2, 'attn_resolutions': [8, 16],
'out_ch': 3, 'dropout': 0.0}} # TODO:for atari debug
# {'resolution': 64, 'in_channels': 1, 'z_channels': 512, 'ch': 64,
Expand All @@ -30,19 +30,18 @@
# "max_tokens": 17 * 5, # TODO: horizon
'attention': 'causal',
# 'num_layers': 10,# TODO:for atari
# 'num_heads': 4,
'num_layers': 2, # TODO:for atari debug
'num_heads': 2,
# 'embed_dim': 256, # TODO:for atari
'embed_dim': 64, # TODO:for atari debug
'num_heads': 4,
'embed_dim': 128, # TODO:for atari
# 'embed_dim': 64, # TODO:for atari debug
'embed_pdrop': 0.1,
'resid_pdrop': 0.1,
'attn_pdrop': 0.1,
"device": 'cuda:0',
# "device": 'cpu',
'support_size': 601,
'action_shape': 6,# TODO:for atari

'max_cache_size':5000,
}

from easydict import EasyDict
Expand Down
15 changes: 9 additions & 6 deletions lzero/model/gpt_models/cfg_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
cfg['tokenizer'] = {'_target_': 'models.tokenizer.Tokenizer',
# 'vocab_size': 512, # TODO: for atari
# 'embed_dim': 512,
# 'vocab_size': 256, # TODO: for atari debug
# 'embed_dim': 256,
'vocab_size': 128, # TODO: for cartpole
'vocab_size': 128, # TODO: for atari debug
'embed_dim': 128,
# 'vocab_size': 64, # TODO: for cartpole
# 'embed_dim': 64,
'encoder':
{'resolution': 1, 'in_channels': 4, 'z_channels': 128, 'ch': 64,
'ch_mult': [1, 1, 1, 1, 1], 'num_res_blocks': 2, 'attn_resolutions': [8, 16],
Expand All @@ -21,18 +21,21 @@
'max_blocks': 20,
"max_tokens": 17 * 20, # TODO: horizon
'attention': 'causal',
'num_heads': 4,
# 'num_layers': 10,# TODO:for atari
'num_layers': 2, # TODO:for debug
'num_heads': 4,
'embed_dim': 128, # TODO: for cartpole
# 'embed_dim': 64, # TODO: for cartpole
'embed_pdrop': 0.1,
'resid_pdrop': 0.1,
'attn_pdrop': 0.1,
"device": 'cuda:0',
# "device": 'cpu',
'support_size': 21,
'support_size': 601,
# 'support_size': 21,
'action_shape': 2,# TODO: for cartpole

'max_cache_size':5000,
# 'max_cache_size':20,
}

from easydict import EasyDict
Expand Down
15 changes: 11 additions & 4 deletions lzero/model/gpt_models/slicer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import math
from typing import List

from typing import List, Dict
import torch
import torch.nn as nn

Expand All @@ -13,12 +12,20 @@ def __init__(self, max_blocks: int, block_mask: torch.Tensor) -> None:
kept_indices = torch.where(block_mask)[0].repeat(max_blocks)
offsets = torch.arange(max_blocks).repeat_interleave(self.num_kept_tokens)
self.register_buffer('indices', kept_indices + block_mask.size(0) * offsets)
self.cache: Dict[str, torch.Tensor] = {}

def compute_slice(self, num_steps: int, prev_steps: int = 0) -> torch.Tensor:
cache_key = f"{num_steps}_{prev_steps}"
if cache_key in self.cache:
return self.cache[cache_key]

total_steps = num_steps + prev_steps
num_blocks = math.ceil(total_steps / self.block_size)
indices = self.indices[:num_blocks * self.num_kept_tokens]
return indices[torch.logical_and(prev_steps <= indices, indices < total_steps)] - prev_steps
result = indices[torch.logical_and(prev_steps <= indices, indices < total_steps)] - prev_steps

self.cache[cache_key] = result
return result

def forward(self, *args, **kwargs):
raise NotImplementedError
Expand Down Expand Up @@ -51,4 +58,4 @@ def forward(self, tokens: torch.Tensor, num_steps: int, prev_steps: int) -> torc
for slicer, emb in zip(self.slicers, self.embedding_tables):
s = slicer.compute_slice(num_steps, prev_steps)
output[:, s] = emb(tokens[:, s])
return output
return output
54 changes: 54 additions & 0 deletions lzero/model/gpt_models/slicer_bkp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import math
from typing import List

import torch
import torch.nn as nn


class Slicer(nn.Module):
def __init__(self, max_blocks: int, block_mask: torch.Tensor) -> None:
super().__init__()
self.block_size = block_mask.size(0)
self.num_kept_tokens = block_mask.sum().long().item()
kept_indices = torch.where(block_mask)[0].repeat(max_blocks)
offsets = torch.arange(max_blocks).repeat_interleave(self.num_kept_tokens)
self.register_buffer('indices', kept_indices + block_mask.size(0) * offsets)

def compute_slice(self, num_steps: int, prev_steps: int = 0) -> torch.Tensor:
total_steps = num_steps + prev_steps
num_blocks = math.ceil(total_steps / self.block_size)
indices = self.indices[:num_blocks * self.num_kept_tokens]
return indices[torch.logical_and(prev_steps <= indices, indices < total_steps)] - prev_steps

def forward(self, *args, **kwargs):
raise NotImplementedError


class Head(Slicer):
def __init__(self, max_blocks: int, block_mask: torch.Tensor, head_module: nn.Module) -> None:
super().__init__(max_blocks, block_mask)
assert isinstance(head_module, nn.Module)
self.head_module = head_module

def forward(self, x: torch.Tensor, num_steps: int, prev_steps: int) -> torch.Tensor:
x_sliced = x[:, self.compute_slice(num_steps, prev_steps)] # x is (B, T, E)
return self.head_module(x_sliced)


class Embedder(nn.Module):
def __init__(self, max_blocks: int, block_masks: List[torch.Tensor], embedding_tables: List[nn.Embedding]) -> None:
super().__init__()
assert len(block_masks) == len(embedding_tables)
assert (sum(block_masks) == 1).all() # block mask are a partition of a block
self.embedding_dim = embedding_tables[0].embedding_dim
assert all([e.embedding_dim == self.embedding_dim for e in embedding_tables])
self.embedding_tables = embedding_tables
self.slicers = [Slicer(max_blocks, block_mask) for block_mask in block_masks]

def forward(self, tokens: torch.Tensor, num_steps: int, prev_steps: int) -> torch.Tensor:
assert tokens.ndim == 2 # x is (B, T)
output = torch.zeros(*tokens.size(), self.embedding_dim, device=tokens.device)
for slicer, emb in zip(self.slicers, self.embedding_tables):
s = slicer.compute_slice(num_steps, prev_steps)
output[:, s] = emb(tokens[:, s])
return output
61 changes: 61 additions & 0 deletions lzero/model/gpt_models/slicer_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import math
from typing import List, Dict
import torch
import torch.nn as nn


class Slicer(nn.Module):
def __init__(self, max_blocks: int, block_mask: torch.Tensor) -> None:
super().__init__()
self.block_size = block_mask.size(0)
self.num_kept_tokens = block_mask.sum().long().item()
kept_indices = torch.where(block_mask)[0].repeat(max_blocks)
offsets = torch.arange(max_blocks).repeat_interleave(self.num_kept_tokens)
self.register_buffer('indices', kept_indices + block_mask.size(0) * offsets)
self.cache: Dict[str, torch.Tensor] = {}

def compute_slice(self, num_steps: int, prev_steps: int = 0) -> torch.Tensor:
cache_key = f"{num_steps}_{prev_steps}"
if cache_key in self.cache:
return self.cache[cache_key]

total_steps = num_steps + prev_steps
num_blocks = math.ceil(total_steps / self.block_size)
indices = self.indices[:num_blocks * self.num_kept_tokens]
result = indices[torch.logical_and(prev_steps <= indices, indices < total_steps)] - prev_steps

self.cache[cache_key] = result
return result

def forward(self, *args, **kwargs):
raise NotImplementedError


class Head(Slicer):
def __init__(self, max_blocks: int, block_mask: torch.Tensor, head_module: nn.Module) -> None:
super().__init__(max_blocks, block_mask)
assert isinstance(head_module, nn.Module)
self.head_module = head_module

def forward(self, x: torch.Tensor, num_steps: int, prev_steps: int) -> torch.Tensor:
x_sliced = x[:, self.compute_slice(num_steps, prev_steps)] # x is (B, T, E)
return self.head_module(x_sliced)


class Embedder(nn.Module):
def __init__(self, max_blocks: int, block_masks: List[torch.Tensor], embedding_tables: List[nn.Embedding]) -> None:
super().__init__()
assert len(block_masks) == len(embedding_tables)
assert (sum(block_masks) == 1).all() # block mask are a partition of a block
self.embedding_dim = embedding_tables[0].embedding_dim
assert all([e.embedding_dim == self.embedding_dim for e in embedding_tables])
self.embedding_tables = embedding_tables
self.slicers = [Slicer(max_blocks, block_mask) for block_mask in block_masks]

def forward(self, tokens: torch.Tensor, num_steps: int, prev_steps: int) -> torch.Tensor:
assert tokens.ndim == 2 # x is (B, T)
output = torch.zeros(*tokens.size(), self.embedding_dim, device=tokens.device)
for slicer, emb in zip(self.slicers, self.embedding_tables):
s = slicer.compute_slice(num_steps, prev_steps)
output[:, s] = emb(tokens[:, s])
return output
25 changes: 13 additions & 12 deletions lzero/model/gpt_models/world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: Transformer
self.device = config.device
self.support_size = config.support_size
self.action_shape = config.action_shape

self.max_cache_size = config.max_cache_size

all_but_last_obs_tokens_pattern = torch.ones(config.tokens_per_block)
all_but_last_obs_tokens_pattern[-2] = 0
Expand Down Expand Up @@ -151,11 +151,12 @@ def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: Transformer
nn.init.zeros_(layer.bias)
break

# self.past_keys_values_cache = {}
from collections import deque
# TODO: Transformer更新后应该清除缓存
self.max_cache_size = 10000
self.past_keys_values_cache = deque(maxlen=self.max_cache_size)
self.past_keys_values_cache = {}
# from collections import deque
# # TODO: Transformer更新后应该清除缓存
# # self.max_cache_size = 10000
# # self.max_cache_size = 20*200
# self.past_keys_values_cache = deque(maxlen=self.max_cache_size)

def __repr__(self) -> str:
return "world_model"
Expand Down Expand Up @@ -349,13 +350,13 @@ def forward_recurrent_inference(self, state_action_history, should_predict_next_
# return outputs_wm.output_sequence, outputs_wm.logits_observations, outputs_wm.logits_rewards, outputs_wm.logits_policy, outputs_wm.logits_value

# TODO: 在计算结束后,更新缓存. 是否需要deepcopy
# self.past_keys_values_cache[cache_key] = copy.deepcopy(self.keys_values_wm)
self.past_keys_values_cache[cache_key] = copy.deepcopy(self.keys_values_wm)

# 每次需要添加新的键值对时,检查缓存大小并根据需要弹出最旧的缓存项
if len(self.past_keys_values_cache) >= self.max_cache_size:
self.past_keys_values_cache.popleft()
# 这样可以在固定的内存空间中保持缓存,并自动清理旧的缓存项。
self.past_keys_values_cache.append((cache_key, copy.deepcopy(self.keys_values_wm)))
# if len(self.past_keys_values_cache) >= self.max_cache_size:
# self.past_keys_values_cache.popleft()
# # 这样可以在固定的内存空间中保持缓存,并自动清理旧的缓存项。
# self.past_keys_values_cache.append((cache_key, copy.deepcopy(self.keys_values_wm)))

return outputs_wm.output_sequence, self.obs_tokens, reward, outputs_wm.logits_policy, outputs_wm.logits_value

Expand Down Expand Up @@ -405,7 +406,7 @@ def compute_loss(self, batch, tokenizer: Tokenizer, **kwargs: Any) -> LossWithIn
# obs is a 3-dimensional image
pass
elif len(batch['observations'][0, 0].shape) == 1:
print('obs is a 1-dimensional vector.')
# print('obs is a 1-dimensional vector.')
# TODO()
# obs is a 1-dimensional vector
original_shape = list(batch['observations'].shape)
Expand Down
11 changes: 7 additions & 4 deletions lzero/policy/muzero_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,8 @@ def _init_learn(self) -> None:
self._learn_model = self._model

# TODO: only for debug
for param in self._learn_model.tokenizer.parameters():
param.requires_grad = False
# for param in self._learn_model.tokenizer.parameters():
# param.requires_grad = False

if self._cfg.use_augmentation:
self.image_transforms = ImageTransforms(
Expand Down Expand Up @@ -345,8 +345,10 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in

batch_for_gpt = {}
# TODO: for cartpole self._cfg.model.observation_shape
# batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( self._cfg.batch_size, -1, self._cfg.model.observation_shape) # (B, T, O) or (B, T, C, H, W)
batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( self._cfg.batch_size, -1, *self._cfg.model.observation_shape) # (B, T, O) or (B, T, C, H, W)
if isinstance(self._cfg.model.observation_shape, int) or len(self._cfg.model.observation_shape)==1:
batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( self._cfg.batch_size, -1, self._cfg.model.observation_shape) # (B, T, O) or (B, T, C, H, W)
elif len(self._cfg.model.observation_shape)==3:
batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( self._cfg.batch_size, -1, *self._cfg.model.observation_shape) # (B, T, O) or (B, T, C, H, W)


batch_for_gpt['actions'] = action_batch.squeeze(-1) # (B, T-1, A) -> (B, T-1)
Expand Down Expand Up @@ -548,6 +550,7 @@ def _forward_collect(

return output


def _init_eval(self) -> None:
"""
Overview:
Expand Down
Loading

0 comments on commit 641680c

Please sign in to comment.