Skip to content

Commit

Permalink
fix(pu): fix state_action_history bug in mcts_ctree for muzero_gpt
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Nov 30, 2023
1 parent 03dc66f commit 26b6255
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 12 deletions.
4 changes: 3 additions & 1 deletion lzero/mcts/tree_search/mcts_ctree.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,9 @@ def search(

# TODO
# 在每次模拟后更新 state_action_history
state_action_history.append((last_latent_state, last_actions.detach().cpu().numpy()))
# state_action_history.append((last_latent_state, last_actions.detach().cpu().numpy()))
state_action_history.append((latent_states.detach().cpu().numpy(), last_actions.detach().cpu().numpy()))

# state_action_history.append(last_latent_state)
# state_action_history.append(last_actions)

Expand Down
4 changes: 2 additions & 2 deletions lzero/model/gpt_models/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
'support_size': 601,
# 'support_size': 21,
'action_shape': 2,# TODO: for cartpole
# 'max_cache_size': 5000,
'max_cache_size':25,
'max_cache_size': 500,
# 'max_cache_size':25,
}

from easydict import EasyDict
Expand Down
8 changes: 4 additions & 4 deletions lzero/model/gpt_models/world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,10 +286,10 @@ def forward_recurrent_inference(self, state_action_history, should_predict_next_
# [(s0,a0)] -> [s0]
# [(s0,a0),(s1,a1)] -> [(s0,a0),s1]
state_action_history_from_last_root = state_action_history[last_root_position:-1] + [state_action_history[-1][0]]
if last_root_position>0:
print('='*20)
print('last_root_position>0')
print('='*20)
# if last_root_position>0:
# print('='*20)
# print('last_root_position>0')
# print('='*20)

# cache_key = tuple(state_action_history_from_last_root)
cache_key = hash(state_action_history_from_last_root)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@
# num_unroll_steps = 5 # debug

# debug
# collector_env_num = 2
# n_episode = 2
# evaluator_env_num = 2
# collector_env_num = 1
# n_episode = 1
# evaluator_env_num = 1
# num_simulations = 25
# update_per_collect = 2
# model_update_ratio = 1
Expand Down Expand Up @@ -91,8 +91,7 @@
num_simulations=num_simulations,
reanalyze_ratio=reanalyze_ratio,
n_episode=n_episode,
eval_freq=int(2e2),
# eval_freq=int(2),
eval_freq=int(1e3),
replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions.
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
Expand Down

0 comments on commit 26b6255

Please sign in to comment.