Skip to content

Commit

Permalink
fix(pu): world_model.past_keys_values_cache.clear() per 200 env steps…
Browse files Browse the repository at this point in the history
… to fix cuda oom bug
  • Loading branch information
puyuan1996 committed Jan 13, 2024
1 parent 01c8385 commit a538a80
Show file tree
Hide file tree
Showing 9 changed files with 115 additions and 63 deletions.
4 changes: 4 additions & 0 deletions lzero/entry/train_muzero_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,10 @@ def train_muzero_gpt(
# TODO: for batch world model ,to improve kv reuse, we could donot reset
policy._learn_model.world_model.past_keys_values_cache.clear() # very important
del policy._learn_model.world_model.keys_values_wm

policy._collect_model.world_model.past_keys_values_cache.clear() # very important
# del policy._collect_model.world_model.keys_values_wm

torch.cuda.empty_cache() # TODO: NOTE

# if collector.envstep > 0:
Expand Down
3 changes: 2 additions & 1 deletion lzero/mcts/tree_search/mcts_ctree.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,6 @@ def search(
# _ = model.world_model.refresh_keys_values_with_initial_obs_tokens(model.world_model.obs_tokens)

# model.world_model.past_keys_values_cache.clear() # 清除缓存
# del model.world_model.keys_values_wm # TODO: 清除缓存
for simulation_index in range(self._cfg.num_simulations):
# In each simulation, we expanded a new node, so in one search, we have ``num_simulations`` num of nodes at most.

Expand Down Expand Up @@ -357,6 +356,8 @@ def search(
current_latent_state_index, discount_factor, reward_batch, value_batch, policy_logits_batch,
min_max_stats_lst, results, virtual_to_play_batch
)
# del model.world_model.keys_values_wm # TODO: 清除缓存


class GumbelMuZeroMCTSCtree(object):
"""
Expand Down
10 changes: 5 additions & 5 deletions lzero/model/gpt_models/cfg_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@
# 'embed_dim': 128, # TODO:for atari

# 'tokens_per_block': 2,
# 'max_blocks': 10,
# "max_tokens": 2 * 10, # TODO: horizon
# 'embed_dim': 1024, # TODO:for atari
# 'max_blocks': 20,
# "max_tokens": 2 * 20, # TODO: horizon

'tokens_per_block': 2,
'max_blocks': 5,
"max_tokens": 2 * 5, # TODO: horizon

# 'embed_dim':2048, # TODO:for atari
'embed_dim':1024, # TODO:for atari
# 'embed_dim':256, # TODO:for atari
Expand All @@ -71,9 +71,9 @@
'support_size': 21,
'action_shape': 6,# TODO:for atari
'max_cache_size':500,
# 'max_cache_size':100,
# 'max_cache_size':1000,
# 'max_cache_size':5000,
"env_num":8,
'latent_recon_loss_weight':0,
}
from easydict import EasyDict
cfg = EasyDict(cfg)
Expand Down
3 changes: 3 additions & 0 deletions lzero/model/gpt_models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch.nn import functional as F

from .kv_caching import KeysValues, KVCache
from line_profiler import line_profiler


@dataclass
Expand Down Expand Up @@ -93,6 +94,8 @@ def __init__(self, config: TransformerConfig) -> None:
block_causal_mask = torch.max(causal_mask, torch.block_diag(*[torch.ones(config.tokens_per_block, config.tokens_per_block) for _ in range(config.max_blocks)]))
self.register_buffer('mask', causal_mask if config.attention == 'causal' else block_causal_mask)


# @profile
def forward(self, x: torch.Tensor, kv_cache: Optional[KVCache] = None) -> torch.Tensor:
B, T, C = x.size()
if kv_cache is not None:
Expand Down
4 changes: 2 additions & 2 deletions lzero/model/gpt_models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def compute_lambda_returns(rewards, values, ends, gamma, lambda_):


class LossWithIntermediateLosses:
def __init__(self, **kwargs):
def __init__(self, latent_recon_loss_weight=0, **kwargs):
# self.loss_total = sum(kwargs.values())

# Ensure that kwargs is not empty
Expand Down Expand Up @@ -126,7 +126,7 @@ def __init__(self, **kwargs):
# self.latent_kl_loss_weight = 0.1 # for lunarlander
self.latent_kl_loss_weight = 0. # for lunarlander

self.latent_recon_loss_weight = 0.
self.latent_recon_loss_weight = latent_recon_loss_weight
# self.latent_recon_loss_weight = 0.1


Expand Down
90 changes: 57 additions & 33 deletions lzero/model/gpt_models/world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
from .tokenizer import Tokenizer
from .transformer import Transformer, TransformerConfig
from .utils import LossWithIntermediateLosses, init_weights

from ding.torch_utils import to_device
# from memory_profiler import profile
from line_profiler import line_profiler

@dataclass
class WorldModelOutput:
Expand All @@ -47,6 +48,7 @@ def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: Transformer
# self.num_observations_tokens = 16
self.num_observations_tokens = config.tokens_per_block -1

self.latent_recon_loss_weight = config.latent_recon_loss_weight
self.device = config.device
self.support_size = config.support_size
self.action_shape = config.action_shape
Expand Down Expand Up @@ -233,6 +235,7 @@ def __repr__(self) -> str:
# is_root=False) -> WorldModelOutput:
# def forward(self, obs_embeddings, act_tokens, past_keys_values: Optional[KeysValues] = None,
# is_root=False) -> WorldModelOutput:
# @profile
def forward(self, obs_embeddings_or_act_tokens, past_keys_values: Optional[KeysValues] = None,
is_root=False) -> WorldModelOutput:

Expand Down Expand Up @@ -344,6 +347,8 @@ def reset(self) -> torch.FloatTensor:
0) # (1, C, H, W) in [0., 1.]
return self.reset_from_initial_observations(obs)


# @profile
@torch.no_grad()
def reset_from_initial_observations(self, obs_act_dict: torch.FloatTensor) -> torch.FloatTensor:
if isinstance(obs_act_dict, dict):
Expand Down Expand Up @@ -373,6 +378,7 @@ def reset_from_initial_observations(self, obs_act_dict: torch.FloatTensor) -> to
return outputs_wm, self.obs_tokens


# @profile
@torch.no_grad()
def refresh_keys_values_with_initial_obs_tokens_for_init_infer(self, obs_tokens: torch.LongTensor, buffer_action=None) -> torch.FloatTensor:
n, num_observations_tokens, _ = obs_tokens.shape
Expand Down Expand Up @@ -443,6 +449,8 @@ def refresh_keys_values_with_initial_obs_tokens_for_init_infer(self, obs_tokens:
# return WorldModelOutput(x, logits_observations, logits_rewards, logits_ends, logits_policy, logits_value)
return outputs_wm


# @profile
@torch.no_grad()
def refresh_keys_values_with_initial_obs_tokens(self, obs_tokens: torch.LongTensor) -> torch.FloatTensor:
n, num_observations_tokens, _ = obs_tokens.shape
Expand All @@ -455,6 +463,7 @@ def refresh_keys_values_with_initial_obs_tokens(self, obs_tokens: torch.LongTens
# return outputs_wm.output_sequence # (B, K, E)
return outputs_wm

# @profile
@torch.no_grad()
def forward_initial_inference(self, obs_act_dict: torch.LongTensor, should_predict_next_obs: bool = True):

Expand Down Expand Up @@ -495,14 +504,14 @@ def forward_initial_inference(self, obs_act_dict: torch.LongTensor, should_predi
# cache_key = hash(obs_tokens.squeeze(0).detach().cpu().numpy())
cache_key = hash(obs_tokens.detach().cpu().numpy())

self.past_keys_values_cache[cache_key] = copy.deepcopy(self.keys_values_wm)
self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm, 'cpu'))
# elif obs_tokens.shape[0] == self.env_num:
# elif obs_tokens.shape[0] > self.env_num:
elif obs_tokens.shape[0] > 1 and obs_tokens.shape[0] <= self.env_num:
# This branch will be executed only when env_num=8
cache_key = hash(obs_tokens.detach().cpu().numpy())
# Store the KV_cache for all 8 samples together
self.past_keys_values_cache[cache_key] = copy.deepcopy(self.keys_values_wm)
self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm, 'cpu'))

# return outputs_wm.output_sequence, outputs_wm.logits_observations, outputs_wm.logits_rewards, outputs_wm.logits_policy, outputs_wm.logits_value
return outputs_wm.output_sequence, obs_tokens, outputs_wm.logits_rewards, outputs_wm.logits_policy, outputs_wm.logits_value
Expand All @@ -518,6 +527,9 @@ def forward_initial_inference(self, obs_act_dict: torch.LongTensor, should_predi

# @profile
# TODO: only for inference, not for training


# @profile
@torch.no_grad()
def forward_recurrent_inference(self, state_action_history, should_predict_next_obs: bool = True):
# 一般来讲,在一次 MCTS search中,我们需要维护H长度的context来使用transformer进行推理。
Expand All @@ -534,23 +546,20 @@ def forward_recurrent_inference(self, state_action_history, should_predict_next_
matched_value = self.past_keys_values_cache.get(hash_latest_state)
if matched_value is not None:
# If a matching value is found, do something with it
self.keys_values_wm = copy.deepcopy(matched_value)

# self.keys_values_wm = copy.deepcopy(matched_value)
self.keys_values_wm = copy.deepcopy(self.to_device_for_kvcache(matched_value, 'cuda') )

# print('recurrent_inference:find matched_value!')
else:
# If no matching value is found, handle the case accordingly
# NOTE: very important
_ = self.refresh_keys_values_with_initial_obs_tokens(torch.tensor(latest_state, dtype=torch.float32).to(self.device))
# Depending on the shape of obs_tokens, create a cache key and store a deep copy of keys_values_wm
if latest_state.shape[0] == 1:
# This branch will be executed only when env_num=1
# cache_key = hash(latest_state.squeeze(0))
cache_key = hash(latest_state)
self.past_keys_values_cache[cache_key] = copy.deepcopy(self.keys_values_wm)
elif latest_state.shape[0] > 1 and latest_state.shape[0] <= self.env_num:
# This branch will be executed only when env_num=8
cache_key = hash(latest_state)
# Store the KV_cache for all 8 samples together
self.past_keys_values_cache[cache_key] = copy.deepcopy(self.keys_values_wm)
# This branch will be executed only when env_num=1
# cache_key = hash(latest_state.squeeze(0))
cache_key = hash(latest_state)
self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm, 'cpu'))
# print('recurrent_inference:not find matched_value!')


Expand All @@ -561,20 +570,14 @@ def forward_recurrent_inference(self, state_action_history, should_predict_next_
output_sequence, obs_tokens = [], []

if self.keys_values_wm.size + num_passes > self.config.max_tokens:
del self.keys_values_wm # TODO
# TODO: the impact
# _ = self.refresh_keys_values_with_initial_obs_tokens(self.obs_tokens)
_ = self.refresh_keys_values_with_initial_obs_tokens(torch.tensor(latest_state, dtype=torch.float32).to(self.device))
# Depending on the shape of obs_tokens, create a cache key and store a deep copy of keys_values_wm
if latest_state.shape[0] == 1:
# This branch will be executed only when env_num=1
# cache_key = hash(latest_state.squeeze(0))
cache_key = hash(latest_state)
self.past_keys_values_cache[cache_key] = copy.deepcopy(self.keys_values_wm)
elif latest_state.shape[0] > 1 and latest_state.shape[0] <= self.env_num:
# This branch will be executed only when env_num=8
cache_key = hash(latest_state)
# Store the KV_cache for all 8 samples together
self.past_keys_values_cache[cache_key] = copy.deepcopy(self.keys_values_wm)
cache_key = hash(latest_state)
self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm, 'cpu'))


# TODO
action = state_action_history[-1][-1]
Expand Down Expand Up @@ -625,16 +628,16 @@ def forward_recurrent_inference(self, state_action_history, should_predict_next_
# cache_key = hash(self.obs_tokens.detach().cpu().numpy())
cache_key = hash(self.obs_tokens.detach().cpu().numpy())

# TODO: 在计算结束后,更新缓存. 是否需要deepcopy
self.past_keys_values_cache[cache_key] = copy.deepcopy(self.keys_values_wm)
# TODO: 在计算结束后,是否需要更新最新的缓存. 是否需要deepcopy
self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm, 'cpu'))
if len(self.past_keys_values_cache) > self.max_cache_size:
# TODO: lru_cache
self.past_keys_values_cache.popitem(last=False) # Removes the earliest inserted item
# self.past_keys_values_cache.popitem(last=False) # Removes the earliest inserted item
# popitem返回一个键值对,其中第二个元素是值
# _, popped_kv_cache = self.past_keys_values_cache.popitem(last=False)
_, popped_kv_cache = self.past_keys_values_cache.popitem(last=False)
# 如果popped_kv_cache是一个包含张量或复杂对象的容器,您可能需要进一步删除这些对象
# 例如:
# del popped_kv_cache # 不要这一行
del popped_kv_cache # 不要这一行
# torch.cuda.empty_cache() # 请注意,频繁调用可能会影响性能, 先del反而清除不掉占用的2MB缓存

# Example usage:
Expand All @@ -646,6 +649,26 @@ def forward_recurrent_inference(self, state_action_history, should_predict_next_
return outputs_wm.output_sequence, self.obs_tokens, reward, outputs_wm.logits_policy, outputs_wm.logits_value


def to_device_for_kvcache(self, keys_values: KeysValues, device: str) -> KeysValues:
"""
Transfer all KVCache objects within the KeysValues object to a certain device.
Arguments:
- keys_values (KeysValues): The KeysValues object to be transferred.
- device (str): The device to transfer to.
Returns:
- keys_values (KeysValues): The KeysValues object with its caches transferred to the specified device.
"""
# Check if CUDA is available and select the first available CUDA device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

for kv_cache in keys_values:
kv_cache._k_cache._cache = kv_cache._k_cache._cache.to(device)
kv_cache._v_cache._cache = kv_cache._v_cache._cache.to(device)
return keys_values


# 计算显存使用量的函数
def calculate_cuda_memory_gb(self, past_keys_values_cache, num_layers: int):
total_memory_bytes = 0
Expand All @@ -672,16 +695,17 @@ def calculate_cuda_memory_gb(self, past_keys_values_cache, num_layers: int):



# @profile
def compute_loss(self, batch, tokenizer: Tokenizer=None, **kwargs: Any) -> LossWithIntermediateLosses:

if len(batch['observations'][0, 0].shape) == 3:
# obs is a 3-dimensional image
pass

# NOTE: 这里是需要梯度的
with torch.no_grad(): # TODO: 非常重要
obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations'], should_preprocess=False) # (B, C, H, W) -> (B, K, E)
# obs_embeddings.register_hook(lambda grad: grad * 1/5) # TODO:只提供重建损失更新表征网络
# with torch.no_grad(): # TODO: 非常重要
obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations'], should_preprocess=False) # (B, C, H, W) -> (B, K, E)
obs_embeddings.register_hook(lambda grad: grad * 1/5) # TODO:只提供重建损失更新表征网络
# obs_embeddings.register_hook(lambda grad: grad * 1) # TODO:只提供重建损失更新表征网络

# Assume that 'cont_embeddings' and 'original_images' are available from prior code
Expand Down Expand Up @@ -724,7 +748,7 @@ def compute_loss(self, batch, tokenizer: Tokenizer=None, **kwargs: Any) -> LossW
loss_policy = self.compute_cross_entropy_loss(outputs, labels_policy, batch, element='policy')
loss_value = self.compute_cross_entropy_loss(outputs, labels_value, batch, element='value')

return LossWithIntermediateLosses(loss_obs=loss_obs, loss_rewards=loss_rewards, loss_value=loss_value,
return LossWithIntermediateLosses(latent_recon_loss_weight=self.latent_recon_loss_weight, loss_obs=loss_obs, loss_rewards=loss_rewards, loss_value=loss_value,
loss_policy=loss_policy, latent_kl_loss=latent_kl_loss, latent_recon_loss=latent_recon_loss)

def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'):
Expand Down
7 changes: 6 additions & 1 deletion lzero/policy/muzero_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,8 +938,13 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1
# the index within the legal action set, rather than the index in the entire action set.
# Setting deterministic=True implies choosing the action with the highest value (argmax) rather than
# sampling during the evaluation phase.

# action_index_in_legal_action_set, visit_count_distribution_entropy = select_action(
# distributions, temperature=1, deterministic=True
# )
# TODO: eval
action_index_in_legal_action_set, visit_count_distribution_entropy = select_action(
distributions, temperature=1, deterministic=True
distributions, temperature=self._collect_mcts_temperature, deterministic=False
)
# NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the
# entire action set.
Expand Down
5 changes: 4 additions & 1 deletion lzero/worker/muzero_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,10 @@ def collect(self,

eps_steps_lst[env_id] += 1

if eps_steps_lst[env_id] % 200==0:
if eps_steps_lst[env_id] % 200 == 0:
self._policy.get_attribute('collect_model').world_model.past_keys_values_cache.clear()
# self._policy._learn_model.world_model.past_keys_values_cache.clear() # very important
# del self._policy.get_attribute('collect_model').world_model.keys_values_wm
torch.cuda.empty_cache() # TODO: NOTE
print('torch.cuda.empty_cache()')
# print(f'eps_steps_lst[{env_id}]:{eps_steps_lst[env_id]}')
Expand Down
Loading

0 comments on commit a538a80

Please sign in to comment.