diff --git a/lzero/entry/train_muzero_multitask_segment_ddp.py b/lzero/entry/train_muzero_multitask_segment_ddp.py index e63e659a2..375dbd291 100644 --- a/lzero/entry/train_muzero_multitask_segment_ddp.py +++ b/lzero/entry/train_muzero_multitask_segment_ddp.py @@ -380,8 +380,7 @@ def train_muzero_multitask_segment_ddp( collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) # if learner.train_iter == 0 or evaluator.should_eval(learner.train_iter): - if learner.train_iter > 0 and evaluator.should_eval(learner.train_iter): - + if learner.train_iter > 1 and evaluator.should_eval(learner.train_iter): print('=' * 20) print(f'Rank {rank} 评估 task_id: {cfg.policy.task_id}...') diff --git a/lzero/mcts/buffer/game_buffer_muzero.py b/lzero/mcts/buffer/game_buffer_muzero.py index d062bd865..a83aacf85 100644 --- a/lzero/mcts/buffer/game_buffer_muzero.py +++ b/lzero/mcts/buffer/game_buffer_muzero.py @@ -64,7 +64,10 @@ def __init__(self, cfg: dict): if hasattr(self._cfg, 'task_id'): self.task_id = self._cfg.task_id print(f"Task ID is set to {self.task_id}.") - self.action_space_size = self._cfg.model.action_space_size_list[self.task_id] + try: + self.action_space_size = self._cfg.model.action_space_size_list[self.task_id] + except Exception as e: + self.action_space_size = self._cfg.model.action_space_size else: self.task_id = None diff --git a/lzero/mcts/buffer/game_buffer_unizero.py b/lzero/mcts/buffer/game_buffer_unizero.py index 3ad7c9ae7..585cc5b27 100644 --- a/lzero/mcts/buffer/game_buffer_unizero.py +++ b/lzero/mcts/buffer/game_buffer_unizero.py @@ -52,8 +52,10 @@ def __init__(self, cfg: dict): if hasattr(self._cfg, 'task_id'): self.task_id = self._cfg.task_id print(f"Task ID is set to {self.task_id}.") - self.action_space_size = self._cfg.model.action_space_size_list[self.task_id] - + try: + self.action_space_size = self._cfg.model.action_space_size_list[self.task_id] + except Exception as e: + self.action_space_size = self._cfg.model.action_space_size else: self.task_id = None print("No task_id found in configuration. Task ID is set to None.") diff --git a/lzero/mcts/buffer/game_segment.py b/lzero/mcts/buffer/game_segment.py index 004e664ed..57cb47af0 100644 --- a/lzero/mcts/buffer/game_segment.py +++ b/lzero/mcts/buffer/game_segment.py @@ -60,12 +60,15 @@ def __init__(self, action_space: int, game_segment_length: int = 200, config: Ea # image obs input, e.g. atari environments self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape[-2], config.model.observation_shape[-1]) else: - if isinstance(config.model.observation_shape_list[task_id], int) or len(config.model.observation_shape_list[task_id]) == 1: - # for vector obs input, e.g. classical control and box2d environments - self.zero_obs_shape = config.model.observation_shape_list[task_id] - elif len(config.model.observation_shape_list[task_id]) == 3: - # image obs input, e.g. atari environments - self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape_list[task_id][-2], config.model.observation_shape_list[task_id][-1]) + if hasattr(config.model, "observation_shape_list"): + if isinstance(config.model.observation_shape_list[task_id], int) or len(config.model.observation_shape_list[task_id]) == 1: + # for vector obs input, e.g. classical control and box2d environments + self.zero_obs_shape = config.model.observation_shape_list[task_id] + elif len(config.model.observation_shape_list[task_id]) == 3: + # image obs input, e.g. atari environments + self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape_list[task_id][-2], config.model.observation_shape_list[task_id][-1]) + else: + self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape[-2], config.model.observation_shape[-1]) self.obs_segment = [] self.action_segment = [] diff --git a/lzero/model/unizero_model_multitask.py b/lzero/model/unizero_model_multitask.py index 8f2e9f300..2bdcc7b7e 100644 --- a/lzero/model/unizero_model_multitask.py +++ b/lzero/model/unizero_model_multitask.py @@ -112,6 +112,16 @@ def __init__( embedding_dim=world_model_cfg.embed_dim, group_size=world_model_cfg.group_size, )) + # self.representation_network = RepresentationNetworkUniZero( + # observation_shape, + # num_res_blocks, + # num_channels, + # self.downsample, + # activation=self.activation, + # norm_type=norm_type, + # embedding_dim=world_model_cfg.embed_dim, + # group_size=world_model_cfg.group_size, + # ) # TODO: we should change the output_shape to the real observation shape # self.decoder_network = LatentDecoder(embedding_dim=world_model_cfg.embed_dim, output_shape=(3, 64, 64)) @@ -187,8 +197,8 @@ def initial_inference(self, obs_batch: torch.Tensor, action_batch=None, current_ latent state, W_ is the width of latent state. """ batch_size = obs_batch.size(0) - print('=here 5='*20) - import ipdb; ipdb.set_trace() + # print('=here 5='*20) + # import ipdb; ipdb.set_trace() obs_act_dict = {'obs': obs_batch, 'action': action_batch, 'current_obs': current_obs_batch} _, obs_token, logits_rewards, logits_policy, logits_value = self.world_model.forward_initial_inference(obs_act_dict, task_id=task_id) latent_state, reward, policy_logits, value = obs_token, logits_rewards, logits_policy, logits_value diff --git a/lzero/model/unizero_world_models/tokenizer.py b/lzero/model/unizero_world_models/tokenizer.py index d0f5e0483..dd45da6bf 100644 --- a/lzero/model/unizero_world_models/tokenizer.py +++ b/lzero/model/unizero_world_models/tokenizer.py @@ -96,8 +96,8 @@ def encode_to_obs_embeddings(self, x: torch.Tensor, task_id = None) -> torch.Ten obs_embeddings = self.encoder(x, task_id=task_id) # TODO: for dmc multitask # obs_embeddings = self.encoder[task_id](x) except Exception as e: - print(e) - obs_embeddings = self.encoder(x) # TODO: for memory env + # print(e) + obs_embeddings = self.encoder[0](x) # TODO: for atari/memory env obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e') elif len(shape) == 5: @@ -106,7 +106,7 @@ def encode_to_obs_embeddings(self, x: torch.Tensor, task_id = None) -> torch.Ten try: obs_embeddings = self.encoder[task_id](x) except Exception as e: - obs_embeddings = self.encoder(x) # TODO: for memory env + obs_embeddings = self.encoder[0](x) # TODO: for atari/memory env obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e') else: raise ValueError(f"Invalid input shape: {shape}") diff --git a/lzero/model/unizero_world_models/world_model_multitask.py b/lzero/model/unizero_world_models/world_model_multitask.py index ee9b62677..481c989d4 100644 --- a/lzero/model/unizero_world_models/world_model_multitask.py +++ b/lzero/model/unizero_world_models/world_model_multitask.py @@ -681,7 +681,7 @@ def forward(self, obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tu if len(act_tokens.shape) == 3: act_tokens = act_tokens.squeeze(1) num_steps = act_tokens.size(1) - if self.task_num >= 1: + if self.task_num >= 1 and self.continuous_action_space: act_embeddings = self.act_embedding_table[task_id](act_tokens) else: act_embeddings = self.act_embedding_table(act_tokens) @@ -862,7 +862,8 @@ def _process_obs_act_combined(self, obs_embeddings_or_act_tokens, prev_steps, ta -1) num_steps = int(obs_embeddings.size(1) * (obs_embeddings.size(2) + 1)) - act_embeddings = self.act_embedding_table[task_id](act_tokens) + # act_embeddings = self.act_embedding_table[task_id](act_tokens) + act_embeddings = self.act_embedding_table(act_tokens) B, L, K, E = obs_embeddings.size() obs_act_embeddings = torch.empty(B, L * (K + 1), E, device=self.device) diff --git a/lzero/policy/unizero_multitask.py b/lzero/policy/unizero_multitask.py index 880ab67b7..9cfc71ff3 100644 --- a/lzero/policy/unizero_multitask.py +++ b/lzero/policy/unizero_multitask.py @@ -892,7 +892,7 @@ def _forward_collect( if active_collect_env_num < self.collector_env_num: print('==========collect_forward============') print(f'len(self.last_batch_obs) < self.collector_env_num, {active_collect_env_num}<{self.collector_env_num}') - self._reset_collect(reset_init_data=True) + self._reset_collect(reset_init_data=True, task_id=task_id) return output @@ -1001,7 +1001,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 return output #@profile - def _reset_collect(self, env_id: int = None, current_steps: int = 0, reset_init_data: bool = True) -> None: + def _reset_collect(self, env_id: int = None, current_steps: int = 0, reset_init_data: bool = True, task_id: int = None) -> None: """ Overview: This method resets the collection process for a specific environment. It clears caches and memory @@ -1085,21 +1085,21 @@ def _reset_eval(self, env_id: int = None, current_steps: int = 0, reset_init_dat - reset_init_data (:obj:`bool`, optional): Whether to reset the initial data. If True, the initial data will be reset. """ if reset_init_data: - if task_id is not None: - self.last_batch_obs_eval = initialize_zeros_batch( - self._cfg.model.observation_shape_list[task_id], - self._cfg.evaluator_env_num, - self._cfg.device - ) - print('unizero_multitask.py task_id is not None after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape) - - else: - self.last_batch_obs_eval = initialize_zeros_batch( - self._cfg.model.observation_shape, - self._cfg.evaluator_env_num, - self._cfg.device - ) - print('unizero_multitask.py task_id is None after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape) + # if task_id is not None: + # self.last_batch_obs_eval = initialize_zeros_batch( + # self._cfg.model.observation_shape_list[task_id], + # self._cfg.evaluator_env_num, + # self._cfg.device + # ) + # print('unizero_multitask.py task_id is not None after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape) + + # else: + self.last_batch_obs_eval = initialize_zeros_batch( + self._cfg.model.observation_shape, + self._cfg.evaluator_env_num, + self._cfg.device + ) + print('unizero_multitask.py task_id is None after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape) self.last_batch_action = [-1 for _ in range(self._cfg.evaluator_env_num)] diff --git a/zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config.py b/zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config.py index cdddabc14..fb35e2f03 100644 --- a/zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config.py +++ b/zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config.py @@ -22,6 +22,7 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu # eval_max_episode_steps=int(30), ), policy=dict( + use_moco=False, # ==============TODO============== multi_gpu=True, # Very important for ddp learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=200000))), grad_correct_params=dict( @@ -37,24 +38,39 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu num_res_blocks=2, num_channels=256, world_model_cfg=dict( + + task_embed_option=None, # ==============TODO: none ============== + use_task_embed=False, # ==============TODO============== + use_shared_projection=False, + + max_blocks=num_unroll_steps, max_tokens=2 * num_unroll_steps, context_length=2 * infer_context_length, device='cuda', action_space_size=action_space_size, # batch_size=64 8games训练时,每张卡大约占 12*3=36G cuda显存 - num_layers=12, - num_heads=24, + # num_layers=12, + # num_heads=24, + + num_layers=8, + num_heads=8, + embed_dim=768, obs_type='image', env_num=8, task_num=len(env_id_list), use_normal_head=True, use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, moe_in_transformer=False, + multiplication_moe_in_transformer=False, num_experts_of_moe_in_transformer=4, ), ), + use_task_exploitation_weight=False, # TODO + task_complexity_weight=False, # TODO total_batch_size=total_batch_size, allocated_batch_sizes=False, train_start_after_envsteps=int(0), @@ -87,7 +103,7 @@ def generate_configs(env_id_list, action_space_size, collector_env_num, n_episod norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size): configs = [] - exp_name_prefix = f'data_unizero_mt_ddp-8gpu/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_seed{seed}/' + exp_name_prefix = f'data_unizero_atari_mt_20250212/atari_{len(env_id_list)}games_brf{buffer_reanalyze_freq}_seed{seed}/' for task_id, env_id in enumerate(env_id_list): config = create_config( @@ -118,7 +134,7 @@ def create_env_manager(): Overview: This script should be executed with GPUs. Run the following command to launch the script: - python -m torch.distributed.launch --nproc_per_node=8 --master_port=29501 ./zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config.py + python -m torch.distributed.launch --nproc_per_node=5 --master_port=29501 ./zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config.py torchrun --nproc_per_node=8 ./zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config.py """ diff --git a/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_moco_config.py b/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_moco_config.py index 90193b781..5f23c2660 100644 --- a/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_moco_config.py +++ b/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_moco_config.py @@ -17,19 +17,18 @@ def create_config(env_id, observation_shape_list, action_space_size_list, collec action_space_size_list=action_space_size_list, from_pixels=False, # ===== only for debug ===== - # frame_skip=50, # 100 - frame_skip=2, + # frame_skip=50, # episode_length:20 + # ===== only for debug ===== + frame_skip=2, # episode_length:500 continuous=True, # Assuming all DMC tasks use continuous action spaces collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, n_evaluator_episode=evaluator_env_num, manager=dict(shared_memory=False), game_segment_length=100, # As per single-task config - # ===== only for debug ===== - # collect_max_episode_steps=int(20), - # eval_max_episode_steps=int(20), ), policy=dict( + # multi_gpu=False, # TODO: nable multi-GPU for DDP multi_gpu=True, # TODO: nable multi-GPU for DDP learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000))), grad_correct_params=dict( @@ -54,9 +53,13 @@ def create_config(env_id, observation_shape_list, action_space_size_list, collec obs_type='vector', # use_shared_projection=True, # TODO use_shared_projection=False, + + # task_embed_option='concat_task_embed', # ==============TODO: none ============== # use_task_embed=True, # TODO + task_embed_option=None, # ==============TODO: none ============== use_task_embed=False, # ==============TODO============== + num_unroll_steps=num_unroll_steps, policy_entropy_weight=5e-2, continuous_action_space=True, @@ -104,7 +107,7 @@ def create_config(env_id, observation_shape_list, action_space_size_list, collec cuda=True, model_path=None, num_unroll_steps=num_unroll_steps, - # update_per_collect=2, # TODO: 80 + # update_per_collect=2, # TODO: debug update_per_collect=200, # TODO: 8*100*0.25=200 # update_per_collect=80, # TODO: 8*100*0.1=80 replay_ratio=reanalyze_ratio, @@ -154,10 +157,12 @@ def generate_configs(env_id_list: List[str], total_batch_size: int): configs = [] # TODO: debug - # exp_name_prefix = f'data_suz_mt_20250113/ddp_8gpu_nlayer8_upc200_taskweight-eval1e3-10k-temp10-1_task-embed_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/' + # exp_name_prefix = f'data_suz_mt_20250207/ddp_moco-fix-paramv2_nlayer8_upc200_notaskweight_concat-task-embed_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/' + exp_name_prefix = f'data_suz_mt_20250207/ddp_moco-fix-paramv2_nlayer8_upc200_notaskweight_no-task-embed_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/' - exp_name_prefix = f'data_suz_mt_20250207/ddp_8gpu-moco_nlayer8_upc200_notaskweight_no-task-embed_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/' - + # exp_name_prefix = f'data_suz_mt_20250113/ddp_8gpu_nlayer8_upc200_taskweight-eval1e3-10k-temp10-1_task-embed_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/' + # exp_name_prefix = f'data_suz_mt_20250207/ddp_moco-fix-paramv2_nlayer8_upc200_notaskweight_no-task-embed_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/' + # exp_name_prefix = f'data_suz_mt_20250207/ddp_1gpu-moco_multigpu_nlayer8_upc200_notaskweight_no-task-embed_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/' # exp_name_prefix = f'data_suz_mt_20250113/ddp_3gpu_3games_nlayer8_upc200_notusp_notaskweight-symlog-01-05-eval1e3_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/' action_space_size_list = [dmc_state_env_action_space_map[env_id] for env_id in env_id_list] @@ -208,7 +213,7 @@ def create_env_manager(): Overview: This script should be executed with GPUs. Run the following command to launch the script: - python -m torch.distributed.launch --nproc_per_node=2 --master_port=29501 ./zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_moco_config.py + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29503 ./zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_moco_config.py torchrun --nproc_per_node=8 ./zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_config.py """ @@ -226,7 +231,6 @@ def create_env_manager(): ] - # env_id_list = [ # # 'acrobot-swingup', # # 'cartpole-balance', @@ -251,26 +255,26 @@ def create_env_manager(): ] # DMC 18games - # env_id_list = [ - # 'acrobot-swingup', - # 'cartpole-balance', - # 'cartpole-balance_sparse', - # 'cartpole-swingup', - # 'cartpole-swingup_sparse', - # 'cheetah-run', - # "ball_in_cup-catch", - # "finger-spin", - # "finger-turn_easy", - # "finger-turn_hard", - # 'hopper-hop', - # 'hopper-stand', - # 'pendulum-swingup', - # 'reacher-easy', - # 'reacher-hard', - # 'walker-run', - # 'walker-stand', - # 'walker-walk', - # ] + env_id_list = [ + 'acrobot-swingup', + 'cartpole-balance', + 'cartpole-balance_sparse', + 'cartpole-swingup', + 'cartpole-swingup_sparse', + 'cheetah-run', + "ball_in_cup-catch", + "finger-spin", + "finger-turn_easy", + "finger-turn_hard", + 'hopper-hop', + 'hopper-stand', + 'pendulum-swingup', + 'reacher-easy', + 'reacher-hard', + 'walker-run', + 'walker-stand', + 'walker-walk', + ] # 获取各环境的 action_space_size 和 observation_shape action_space_size_list = [dmc_state_env_action_space_map[env_id] for env_id in env_id_list] @@ -286,13 +290,13 @@ def create_env_manager(): reanalyze_ratio = 0.0 - # nlayer=4/8 + # nlayer=4/8 8games total_batch_size = 512 batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] - # # nlayer=12 - # total_batch_size = 256 - # batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] + # # # nlayer=12 18games + total_batch_size = 256 + batch_size = [int(min(16, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] num_unroll_steps = 5 infer_context_length = 2