Skip to content

Commit

Permalink
sync code
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 authored and puyuan committed Jan 23, 2025
1 parent 911ed12 commit 557e8f9
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 49 deletions.
3 changes: 2 additions & 1 deletion lzero/model/unizero_world_models/world_model_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,14 +514,15 @@ def _initialize_projection_input_dim(self) -> None:
if self.num_observations_tokens == 16:
self.projection_input_dim = 128
elif self.num_observations_tokens == 1:
# self.projection_input_dim = self.config.embed_dim

if self.task_embed_option == "concat_task_embed":
self.projection_input_dim = self.config.embed_dim - 96
elif self.task_embed_option == "register_task_embed":
self.projection_input_dim = self.config.embed_dim
elif self.task_embed_option == "add_task_embed":
self.projection_input_dim = self.config.embed_dim
else:
self.projection_input_dim = self.config.embed_dim

def _initialize_statistics(self) -> None:
"""Initialize counters for hit count and query count statistics."""
Expand Down
62 changes: 32 additions & 30 deletions zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ def create_config(env_id, observation_shape_list, action_space_size_list, collec
obs_type='vector',
# use_shared_projection=True, # TODO
use_shared_projection=False,
# task_embed_option='concat_task_embed', # ==============TODO: none ==============
task_embed_option='concat_task_embed', # ==============TODO: none ==============
# task_embed_option='register_task_embed', # ==============TODO: none ==============
task_embed_option=None, # ==============TODO: none ==============
# task_embed_option=None, # ==============TODO: none ==============
register_token_num=2, # TODO: 修改kv_caching中的register_token_num
# use_task_embed=True, # TODO
use_task_embed=False, # ==============TODO==============
use_task_embed=True, # TODO
# use_task_embed=False, # ==============TODO==============
num_unroll_steps=num_unroll_steps,
policy_entropy_weight=5e-2,
continuous_action_space=True,
Expand All @@ -77,11 +77,11 @@ def create_config(env_id, observation_shape_list, action_space_size_list, collec
# num_layers=2,
# num_layers=4, # TODO

num_layers=8, # TODO
num_heads=8,
# num_layers=8, # TODO
# num_heads=8,

# num_layers=12, # TODO
# num_heads=12,
num_layers=12, # TODO
num_heads=12,

embed_dim=768,
env_num=max(collector_env_num, evaluator_env_num),
Expand Down Expand Up @@ -158,10 +158,12 @@ def generate_configs(env_id_list: List[str],
total_batch_size: int):
configs = []
# TODO: debug
exp_name_prefix = f'data_suz_mt_20250123/{len(env_id_list)}tasks_ddp_8gpu_nlayer12_upc200_no-taskweight_concat-task-embed_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/'

# exp_name_prefix = f'data_suz_mt_20250113/ddp_8gpu_nlayer8_upc200_no-taskweight-obsloss-temp1_register-task-embed-4_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/'

# exp_name_prefix = f'data_suz_mt_20250113/ddp_8gpu_nlayer8_upc200_no-taskweight-obsloss-temp1_register-task-embed-2-pos0_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/'
exp_name_prefix = f'data_suz_mt_20250113/ddp_8gpu_nlayer8_upc200_no-taskweight-obsloss-temp1_no-task-embed-2-pos0_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/'
# exp_name_prefix = f'data_suz_mt_20250113/ddp_8gpu_nlayer8_upc200_no-taskweight-obsloss-temp1_no-task-embed-2-pos0_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/'

# exp_name_prefix = f'data_suz_mt_20250113/ddp_8gpu_nlayer8_upc200_taskweight-eval1e3-10k-temp10-1_task-embed_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/'
# exp_name_prefix = f'data_suz_mt_20250113_debug/ddp_8gpu_nlayer8_upc200_taskweight-eval1e3-10k-temp10-1_no-task-embed_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/'
Expand Down Expand Up @@ -271,26 +273,26 @@ def create_env_manager():
]

# DMC 18games
# env_id_list = [
# 'acrobot-swingup',
# 'cartpole-balance',
# 'cartpole-balance_sparse',
# 'cartpole-swingup',
# 'cartpole-swingup_sparse',
# 'cheetah-run',
# "ball_in_cup-catch",
# "finger-spin",
# "finger-turn_easy",
# "finger-turn_hard",
# 'hopper-hop',
# 'hopper-stand',
# 'pendulum-swingup',
# 'reacher-easy',
# 'reacher-hard',
# 'walker-run',
# 'walker-stand',
# 'walker-walk',
# ]
env_id_list = [
'acrobot-swingup',
'cartpole-balance',
'cartpole-balance_sparse',
'cartpole-swingup',
'cartpole-swingup_sparse',
'cheetah-run',
"ball_in_cup-catch",
"finger-spin",
"finger-turn_easy",
"finger-turn_hard",
'hopper-hop',
'hopper-stand',
'pendulum-swingup',
'reacher-easy',
'reacher-hard',
'walker-run',
'walker-stand',
'walker-walk',
]

# 获取各环境的 action_space_size 和 observation_shape
action_space_size_list = [dmc_state_env_action_space_map[env_id] for env_id in env_id_list]
Expand All @@ -310,7 +312,7 @@ def create_env_manager():
batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))]

# nlayer=8/12
total_batch_size = 512
total_batch_size = 256
batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))]

num_unroll_steps = 5
Expand Down
36 changes: 18 additions & 18 deletions zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,24 +211,24 @@ def create_env_manager():

# DMC 18games
env_id_list = [
'acrobot-swingup',
'cartpole-balance',
'cartpole-balance_sparse',
'cartpole-swingup',
'cartpole-swingup_sparse',
'cheetah-run',
"ball_in_cup-catch",
"finger-spin",
"finger-turn_easy",
"finger-turn_hard",
'hopper-hop',
'hopper-stand',
'pendulum-swingup',
'reacher-easy',
'reacher-hard',
'walker-run',
'walker-stand',
'walker-walk',
'acrobot-swingup', # 0
'cartpole-balance', # 1
'cartpole-balance_sparse', # 2
'cartpole-swingup', # 3
'cartpole-swingup_sparse', # 4 bad
'cheetah-run', # 5 bad
"ball_in_cup-catch", # 6
"finger-spin", # 7 bad
"finger-turn_easy", # 8 波动
"finger-turn_hard", # 9 波动
'hopper-hop', # 10 bad
'hopper-stand', # 11
'pendulum-swingup', # 12 bad
'reacher-easy', # 13
'reacher-hard', # 14 波动
'walker-run', # 15 略差
'walker-stand', # 16
'walker-walk', # 17
]

# 获取各环境的 action_space_size 和 observation_shape
Expand Down

0 comments on commit 557e8f9

Please sign in to comment.