diff --git a/lzero/entry/train_muzero_gpt.py b/lzero/entry/train_muzero_gpt.py index 1732db27e..d6c3aada7 100644 --- a/lzero/entry/train_muzero_gpt.py +++ b/lzero/entry/train_muzero_gpt.py @@ -167,30 +167,51 @@ def train_muzero_gpt( # remove the oldest data if the replay buffer is full. replay_buffer.remove_oldest_data_to_fit() - # Learn policy from collected data. - for i in range(update_per_collect): - # Learner will train ``update_per_collect`` times in one iteration. - if replay_buffer.get_num_of_transitions() > batch_size: - train_data = replay_buffer.sample(batch_size, policy) - else: - logging.warning( - f'The data in replay_buffer is not sufficient to sample a mini-batch: ' - f'batch_size: {batch_size}, ' - f'{replay_buffer} ' - f'continue to collect now ....' - ) - break - - # The core train steps for MCTS+RL algorithms. - log_vars = learner.train(train_data, collector.envstep) - if cfg.policy.use_priority: - replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig']) + if collector.envstep > cfg.policy.transformer_start_after_envsteps: + # TODO:transformer tokenizer交替更新 + # Learn policy from collected data. + for i in range(cfg.policy.update_per_collect_transformer): + # Learner will train ``update_per_collect`` times in one iteration. + if replay_buffer.get_num_of_transitions() > batch_size: + train_data = replay_buffer.sample(batch_size, policy) + train_data.append({'train_which_component':'transformer'}) + else: + logging.warning( + f'The data in replay_buffer is not sufficient to sample a mini-batch: ' + f'batch_size: {batch_size}, ' + f'{replay_buffer} ' + f'continue to collect now ....' + ) + break + # The core train steps for MCTS+RL algorithms. + log_vars = learner.train(train_data, collector.envstep) + if cfg.policy.use_priority: + replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig']) # NOTE: TODO # TODO: for batch world model ,to improve kv reuse, we could donot reset policy._learn_model.world_model.past_keys_values_cache.clear() + if collector.envstep > cfg.policy.tokenizer_start_after_envsteps: + # Learn policy from collected data. + for i in range(cfg.policy.update_per_collect_tokenizer): + # Learner will train ``update_per_collect`` times in one iteration. + if replay_buffer.get_num_of_transitions() > batch_size: + train_data = replay_buffer.sample(batch_size, policy) + train_data.append({'train_which_component':'tokenizer'}) + else: + logging.warning( + f'The data in replay_buffer is not sufficient to sample a mini-batch: ' + f'batch_size: {batch_size}, ' + f'{replay_buffer} ' + f'continue to collect now ....' + ) + break + # The core train steps for MCTS+RL algorithms. + log_vars = learner.train(train_data, collector.envstep) + + if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: break diff --git a/lzero/entry/train_muzero_gpt_same-data.py b/lzero/entry/train_muzero_gpt_same-data.py new file mode 100644 index 000000000..b7b2d9390 --- /dev/null +++ b/lzero/entry/train_muzero_gpt_same-data.py @@ -0,0 +1,200 @@ +import logging +import os +from functools import partial +from typing import Optional, Tuple + +import torch +from ding.config import compile_config +from ding.envs import create_env_manager +from ding.envs import get_vec_env_setting +from ding.policy import create_policy +from ding.utils import set_pkg_seed, get_rank +from ding.rl_utils import get_epsilon_greedy_fn +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage +from lzero.policy import visit_count_temperature +from lzero.policy.random_policy import LightZeroRandomPolicy +from lzero.worker import MuZeroCollector as Collector +from lzero.worker import MuZeroEvaluator as Evaluator +from .utils import random_collect + + +def train_muzero_gpt( + input_cfg: Tuple[dict, dict], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': # noqa + """ + Overview: + The train entry for MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero, Gumbel Muzero. + Arguments: + - input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type. + ``Tuple[dict, dict]`` type means [user_config, create_cfg]. + - seed (:obj:`int`): Random seed. + - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. + - model_path (:obj:`Optional[str]`): The pretrained model path, which should + point to the ckpt file of the pretrained model, and an absolute path is recommended. + In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. + - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training. + - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps. + Returns: + - policy (:obj:`Policy`): Converged policy. + """ + + cfg, create_cfg = input_cfg + assert create_cfg.policy.type in ['efficientzero', 'muzero_gpt', 'sampled_efficientzero', 'gumbel_muzero', 'stochastic_muzero'], \ + "train_muzero_gpt entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero'" + + if create_cfg.policy.type == 'muzero_gpt': + from lzero.mcts import MuZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'efficientzero': + from lzero.mcts import EfficientZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'sampled_efficientzero': + from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'gumbel_muzero': + from lzero.mcts import GumbelMuZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'stochastic_muzero': + from lzero.mcts import StochasticMuZeroGameBuffer as GameBuffer + + if cfg.policy.cuda and torch.cuda.is_available(): + cfg.policy.device = 'cuda' + else: + cfg.policy.device = 'cpu' + + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + # Create main components: env, policy + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + + collector_env.seed(cfg.seed) + evaluator_env.seed(cfg.seed, dynamic_seed=False) + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + # load pretrained model + if model_path is not None: + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + + # Create worker components: learner, collector, evaluator, replay buffer, commander. + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + + # ============================================================== + # MCTS+RL algorithms related core code + # ============================================================== + policy_config = cfg.policy + batch_size = policy_config.batch_size + # specific game buffer for MCTS+RL algorithms + replay_buffer = GameBuffer(policy_config) + collector = Collector( + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config + ) + evaluator = Evaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config + ) + + # ============================================================== + # Main loop + # ============================================================== + # Learner's before_run hook. + learner.call_hook('before_run') + + if cfg.policy.update_per_collect is not None: + update_per_collect = cfg.policy.update_per_collect + + # The purpose of collecting random data before training: + # Exploration: Collecting random data helps the agent explore the environment and avoid getting stuck in a suboptimal policy prematurely. + # Comparison: By observing the agent's performance during random action-taking, we can establish a baseline to evaluate the effectiveness of reinforcement learning algorithms. + if cfg.policy.random_collect_episode_num > 0: + random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer) + + while True: + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) + collect_kwargs = {} + # set temperature for visit count distributions according to the train_iter, + # please refer to Appendix D in MuZero paper for details. + collect_kwargs['temperature'] = visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ) + + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + else: + collect_kwargs['epsilon'] = 0.0 + + # Evaluate policy performance. + if evaluator.should_eval(learner.train_iter): + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + + # Collect data by default config n_sample/n_episode. + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + if cfg.policy.update_per_collect is None: + # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio. + collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]]) + update_per_collect = int(collected_transitions_num * cfg.policy.model_update_ratio) + # save returned new_data collected by the collector + replay_buffer.push_game_segments(new_data) + # remove the oldest data if the replay buffer is full. + replay_buffer.remove_oldest_data_to_fit() + + + # Learn policy from collected data. + for i in range(update_per_collect): + # Learner will train ``update_per_collect`` times in one iteration. + if replay_buffer.get_num_of_transitions() > batch_size: + train_data = replay_buffer.sample(batch_size, policy) + else: + logging.warning( + f'The data in replay_buffer is not sufficient to sample a mini-batch: ' + f'batch_size: {batch_size}, ' + f'{replay_buffer} ' + f'continue to collect now ....' + ) + break + + # The core train steps for MCTS+RL algorithms. + log_vars = learner.train(train_data, collector.envstep) + + if cfg.policy.use_priority: + replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig']) + + # NOTE: TODO + # TODO: for batch world model ,to improve kv reuse, we could donot reset + policy._learn_model.world_model.past_keys_values_cache.clear() + + if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: + break + + # Learner's after_run hook. + learner.call_hook('after_run') + return policy diff --git a/lzero/model/gpt_models/cfg.py b/lzero/model/gpt_models/cfg.py index 5f2429038..42247d5a7 100644 --- a/lzero/model/gpt_models/cfg.py +++ b/lzero/model/gpt_models/cfg.py @@ -1,19 +1,20 @@ -# large net +# small net cfg = {} - cfg['tokenizer'] = {'_target_': 'models.tokenizer.Tokenizer', - 'vocab_size': 512, # TODO: for atari - 'embed_dim': 512, + # 'vocab_size': 512, # TODO: for atari + # 'embed_dim': 512, + # 'vocab_size': 256, # TODO: for atari debug + # 'embed_dim': 256, # 'vocab_size': 128, # TODO: for atari debug # 'embed_dim': 128, - # 'vocab_size': 64, # TODO: for atari debug - # 'embed_dim': 64, + 'vocab_size': 64, # TODO: for atari debug + 'embed_dim': 64, 'encoder': - {'resolution': 64, 'in_channels': 3, 'z_channels': 512, 'ch': 64, + {'resolution': 64, 'in_channels': 3, 'z_channels': 64, 'ch': 64, 'ch_mult': [1, 1, 1, 1, 1], 'num_res_blocks': 2, 'attn_resolutions': [8, 16], 'out_ch': 3, 'dropout': 0.0},# TODO:for atari debug 'decoder': - {'resolution': 64, 'in_channels': 3, 'z_channels': 512, 'ch': 64, + {'resolution': 64, 'in_channels': 3, 'z_channels': 64, 'ch': 64, 'ch_mult': [1, 1, 1, 1, 1], 'num_res_blocks': 2, 'attn_resolutions': [8, 16], 'out_ch': 3, 'dropout': 0.0}} # TODO:for atari debug # {'resolution': 64, 'in_channels': 1, 'z_channels': 512, 'ch': 64, @@ -33,22 +34,17 @@ # 'num_layers': 10,# TODO:for atari 'num_layers': 2, # TODO:for atari debug 'num_heads': 4, - 'embed_dim': 256, # TODO:for atari # 'embed_dim': 128, # TODO:for atari - # 'embed_dim': 64, # TODO:for atari debug + 'embed_dim': 64, # TODO:for atari debug 'embed_pdrop': 0.1, 'resid_pdrop': 0.1, 'attn_pdrop': 0.1, - "device": 'cuda:3', + "device": 'cuda:0', # "device": 'cpu', 'support_size': 601, 'action_shape': 6,# TODO:for atari - # 'max_cache_size':5000, 'max_cache_size':500, - "env_num":8, - } - from easydict import EasyDict -cfg = EasyDict(cfg) \ No newline at end of file +cfg = EasyDict(cfg) diff --git a/lzero/model/gpt_models/cfg_atari.py b/lzero/model/gpt_models/cfg_atari.py index 3ca25fa10..a83687544 100644 --- a/lzero/model/gpt_models/cfg_atari.py +++ b/lzero/model/gpt_models/cfg_atari.py @@ -53,7 +53,6 @@ # # mediumnet # cfg = {} - # cfg['tokenizer'] = {'_target_': 'models.tokenizer.Tokenizer', # # 'vocab_size': 512, # TODO: for atari # # 'embed_dim': 512, @@ -148,9 +147,6 @@ 'action_shape': 6,# TODO:for atari 'max_cache_size':5000, "env_num":8, - } - from easydict import EasyDict - cfg = EasyDict(cfg) diff --git a/lzero/model/gpt_models/cfg_cartpole.py b/lzero/model/gpt_models/cfg_cartpole.py index 99cedf3cf..10928395e 100644 --- a/lzero/model/gpt_models/cfg_cartpole.py +++ b/lzero/model/gpt_models/cfg_cartpole.py @@ -42,6 +42,8 @@ from easydict import EasyDict cfg = EasyDict(cfg) + + # # mediumnet # cfg = {} # cfg['tokenizer'] = {'_target_': 'models.tokenizer.Tokenizer', diff --git a/lzero/model/gpt_models/tokenizer/tokenizer.py b/lzero/model/gpt_models/tokenizer/tokenizer.py index f57b68829..672c71626 100644 --- a/lzero/model/gpt_models/tokenizer/tokenizer.py +++ b/lzero/model/gpt_models/tokenizer/tokenizer.py @@ -56,8 +56,25 @@ def forward(self, x: torch.Tensor, should_preprocess: bool = False, should_postp return outputs.z, outputs.z_quantized, reconstructions def compute_loss(self, batch, **kwargs: Any) -> LossWithIntermediateLosses: + if len(batch['observations'][0, 0].shape) == 3: + # obs is a 3-dimensional image + pass + elif len(batch['observations'][0, 0].shape) == 1: + # print('obs is a 1-dimensional vector.') + # TODO() + # obs is a 1-dimensional vector + original_shape = list(batch['observations'].shape) + desired_shape = original_shape + [64, 64] + expanded_observations = batch['observations'].unsqueeze(-1).unsqueeze(-1) + expanded_observations = expanded_observations.expand(*desired_shape) + batch['observations'] = expanded_observations + assert self.lpips is not None + # IRIS original code observations = self.preprocess_input(rearrange(batch['observations'], 'b t c h w -> (b t) c h w')) + # TODO + # observations = rearrange(batch['observations'], 'b t c h w -> (b t) c h w') + z, z_quantized, reconstructions = self(observations, should_preprocess=False, should_postprocess=False) # Codebook loss. Notes: @@ -65,13 +82,13 @@ def compute_loss(self, batch, **kwargs: Any) -> LossWithIntermediateLosses: # - VQVAE uses 0.25 by default beta = 1.0 commitment_loss = (z.detach() - z_quantized).pow(2).mean() + beta * (z - z_quantized.detach()).pow(2).mean() - + # L1 loss reconstruction_loss = torch.abs(observations - reconstructions).mean() - # TODO: cartpole + # TODO: for atari pong perceptual_loss = torch.mean(self.lpips(observations, reconstructions)) + # TODO: only for cartpole # perceptual_loss = torch.zeros_like(reconstruction_loss) - return LossWithIntermediateLosses(commitment_loss=commitment_loss, reconstruction_loss=reconstruction_loss, perceptual_loss=perceptual_loss) def encode(self, x: torch.Tensor, should_preprocess: bool = False) -> TokenizerEncoderOutput: @@ -112,8 +129,10 @@ def encode_decode(self, x: torch.Tensor, should_preprocess: bool = False, should def preprocess_input(self, x: torch.Tensor) -> torch.Tensor: """x is supposed to be channels first and in [0, 1]""" + # [0,1] -> [-1, 1] return x.mul(2).sub(1) def postprocess_output(self, y: torch.Tensor) -> torch.Tensor: """y is supposed to be channels first and in [-1, 1]""" + # [-1, 1] -> [0,1] return y.add(1).div(2) diff --git a/lzero/model/gpt_models/transformer.py b/lzero/model/gpt_models/transformer.py index c2509dc2f..fde55c7e3 100644 --- a/lzero/model/gpt_models/transformer.py +++ b/lzero/model/gpt_models/transformer.py @@ -110,21 +110,21 @@ def forward(self, x: torch.Tensor, kv_cache: Optional[KVCache] = None) -> torch. k, v = kv_cache.get() # TODO - self.flash=False - # self.flash=True - - # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) - if self.flash: - # TODO - # efficient attention using Flash Attention CUDA kernels - y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.config.attn_drop if self.training else 0, is_causal=True) - else: - # manual implementation of attention - att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) - att = att.masked_fill(self.mask[L:L + T, :L + T] == 0, float('-inf')) - att = F.softmax(att, dim=-1) - att = self.attn_drop(att) - y = att @ v + # self.flash=False + # # self.flash=True + # # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + # if self.flash: + # # TODO + # # efficient attention using Flash Attention CUDA kernels + # with torch.backends.cuda.sdp_kernel(enable_flash=True): + # y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.config.attn_drop if self.training else 0, is_causal=True) + # else: + # manual implementation of attention + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.mask[L:L + T, :L + T] == 0, float('-inf')) + att = F.softmax(att, dim=-1) + att = self.attn_drop(att) + y = att @ v y = rearrange(y, 'b h t e -> b t (h e)') diff --git a/lzero/model/gpt_models/world_model.py b/lzero/model/gpt_models/world_model.py index 4541d9899..f606f5a7a 100644 --- a/lzero/model/gpt_models/world_model.py +++ b/lzero/model/gpt_models/world_model.py @@ -194,6 +194,7 @@ def render_batch(self) -> List[Image.Image]: frames = rearrange(frames, 'b c h w -> b h w c').mul(255).numpy().astype(np.uint8) return [Image.fromarray(frame) for frame in frames] + # only foe inference now, now is invalid @torch.no_grad() def decode_obs_tokens(self) -> List[Image.Image]: embedded_tokens = self.tokenizer.embedding(self.obs_tokens) # (B, K, E) @@ -219,6 +220,7 @@ def reset(self) -> torch.FloatTensor: @torch.no_grad() def reset_from_initial_observations(self, observations: torch.FloatTensor) -> torch.FloatTensor: + # NOTE: should_preprocess=True is important obs_tokens = self.tokenizer.encode(observations, should_preprocess=True).tokens # (B, C, H, W) -> (B, K) _, num_observations_tokens = obs_tokens.shape if self.num_observations_tokens is None: @@ -379,7 +381,7 @@ def forward_recurrent_inference(self, state_action_history, should_predict_next_ del self.obs_tokens self.obs_tokens = torch.cat(obs_tokens, dim=1) # (B, K) - obs = self.decode_obs_tokens() if should_predict_next_obs else None + # obs = self.decode_obs_tokens() if should_predict_next_obs else None # cache_key = hash(self.obs_tokens.detach().cpu().numpy().astype('int')) cache_key = hash(self.obs_tokens.detach().cpu().numpy()) diff --git a/lzero/policy/muzero_gpt.py b/lzero/policy/muzero_gpt.py index 085849020..e692a5a55 100644 --- a/lzero/policy/muzero_gpt.py +++ b/lzero/policy/muzero_gpt.py @@ -294,6 +294,7 @@ def _init_learn(self) -> None: model=self._model.world_model, learning_rate=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay, + # weight_decay=0.01, exclude_submodules=['tokenizer'] ) @@ -354,6 +355,30 @@ def _init_learn(self) -> None: ) def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: + """ + Overview: + The forward function for learning policy in learn mode, which is the core of the learning process. + The data is sampled from replay buffer. + The loss is calculated by the loss function and the loss is backpropagated to update the model. + Arguments: + - data (:obj:`Tuple[torch.Tensor]`): The data sampled from replay buffer, which is a tuple of tensors. + The first tensor is the current_batch, the second tensor is the target_batch. + Returns: + - info_dict (:obj:`Dict[str, Union[float, int]]`): The information dict to be logged, which contains \ + current learning loss and learning statistics. + """ + # current_batch, target_batch, train_which_component_dict = data + if data[-1]['train_which_component'] == 'transformer': + return_loss_dict = self._forward_learn_transformer(data) + elif data[-1]['train_which_component'] == 'tokenizer': + return_loss_dict = self._forward_learn_tokenizer(data) + else: + ValueError('Unknown component type') + + return return_loss_dict + + + def _forward_learn_transformer(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: """ Overview: The forward function for learning policy in learn mode, which is the core of the learning process. @@ -371,7 +396,10 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in if self._cfg.use_rnd_model: self._target_model_for_intrinsic_reward.train() - current_batch, target_batch = data + # current_batch, target_batch = data + current_batch, target_batch, train_which_component_dict = data + + obs_batch_ori, action_batch, mask_batch, indices, weights, make_time = current_batch target_reward, target_value, target_policy = target_batch @@ -447,8 +475,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # print(f'Average entropy: {average_entropy}') - - + # if train_which_component_dict['train_which_component'] == 'transformer': # ============================================================== # update world model # ============================================================== @@ -481,18 +508,121 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in ) self._optimizer_world_model.step() if self._cfg.lr_piecewise_constant_decay: - self.lr_scheduler.step() + self.lr_scheduler.step() + # ============================================================== + # the core target model update step. + # ============================================================== + self._target_model.update(self._learn_model.state_dict()) + if self._cfg.use_rnd_model: + self._target_model_for_intrinsic_reward.update(self._learn_model.state_dict()) + + + return_loss_dict = { + 'collect_mcts_temperature': self._collect_mcts_temperature, + 'collect_epsilon': self.collect_epsilon, + 'cur_lr_world_model': self._optimizer_world_model.param_groups[0]['lr'], + 'cur_lr_tokenizer': self._optimizer_tokenizer.param_groups[0]['lr'], + + 'weighted_total_loss': weighted_total_loss.item(), + 'obs_loss': obs_loss, + 'policy_loss': policy_loss, + 'target_policy_entropy': average_target_policy_entropy, + # 'policy_entropy': - policy_entropy_loss.mean().item() / (self._cfg.num_unroll_steps + 1), + 'reward_loss': reward_loss, + 'value_loss': value_loss, + # 'consistency_loss': consistency_loss.mean().item() / self._cfg.num_unroll_steps, + + # ============================================================== + # priority related + # ============================================================== + # 'value_priority_orig': value_priority, + 'value_priority_orig': np.zeros(self._cfg.batch_size), # TODO + # 'value_priority': value_priority.mean().item(), + 'target_reward': target_reward.detach().cpu().numpy().mean().item(), + 'target_value': target_value.detach().cpu().numpy().mean().item(), + 'transformed_target_reward': transformed_target_reward.detach().cpu().numpy().mean().item(), + 'transformed_target_value': transformed_target_value.detach().cpu().numpy().mean().item(), + # 'predicted_rewards': predicted_rewards.detach().cpu().numpy().mean().item(), + # 'predicted_values': predicted_values.detach().cpu().numpy().mean().item(), + 'total_grad_norm_before_clip': total_grad_norm_before_clip.item(), + } + + return return_loss_dict + + + def _forward_learn_tokenizer(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: + """ + Overview: + The forward function for learning policy in learn mode, which is the core of the learning process. + The data is sampled from replay buffer. + The loss is calculated by the loss function and the loss is backpropagated to update the model. + Arguments: + - data (:obj:`Tuple[torch.Tensor]`): The data sampled from replay buffer, which is a tuple of tensors. + The first tensor is the current_batch, the second tensor is the target_batch. + Returns: + - info_dict (:obj:`Dict[str, Union[float, int]]`): The information dict to be logged, which contains \ + current learning loss and learning statistics. + """ + self._learn_model.train() + self._target_model.train() + if self._cfg.use_rnd_model: + self._target_model_for_intrinsic_reward.train() + + # current_batch, target_batch = data + current_batch, target_batch, train_which_component_dict = data + + + obs_batch_ori, action_batch, mask_batch, indices, weights, make_time = current_batch + target_reward, target_value, target_policy = target_batch + + obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) + + # do augmentations + if self._cfg.use_augmentation: + obs_batch = self.image_transforms.transform(obs_batch) + if self._cfg.model.self_supervised_learning_loss: + obs_target_batch = self.image_transforms.transform(obs_target_batch) + + # shape: (batch_size, num_unroll_steps, action_dim) + # NOTE: .long(), in discrete action space. + action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze(-1).long() + data_list = [ + mask_batch, + target_reward.astype('float32'), + target_value.astype('float32'), target_policy, weights + ] + [mask_batch, target_reward, target_value, target_policy, + weights] = to_torch_float_tensor(data_list, self._cfg.device) + + target_reward = target_reward.view(self._cfg.batch_size, -1) + target_value = target_value.view(self._cfg.batch_size, -1) + + assert obs_batch.size(0) == self._cfg.batch_size == target_reward.size(0) + + + batch_for_gpt = {} + # TODO: for cartpole self._cfg.model.observation_shape + if isinstance(self._cfg.model.observation_shape, int) or len(self._cfg.model.observation_shape)==1: + batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( self._cfg.batch_size, -1, self._cfg.model.observation_shape) # (B, T, O) or (B, T, C, H, W) + elif len(self._cfg.model.observation_shape)==3: + batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( self._cfg.batch_size, -1, *self._cfg.model.observation_shape) # (B, T, O) or (B, T, C, H, W) + + batch_for_gpt['observations'] = batch_for_gpt['observations'][:, :-1] # (B, T-1, O) or (B, T-1, C, H, W) + + # if train_which_component_dict['train_which_component'] == 'tokenizer': + # ============================================================== # update tokenizer # ============================================================== # TODO: train tokenlizer + self._learn_model.tokenizer.train() losses_tokenizer = self._learn_model.tokenizer.compute_loss(batch_for_gpt) self._optimizer_tokenizer.zero_grad() losses_tokenizer.loss_total.backward() - total_grad_norm_before_clip = torch.nn.utils.clip_grad_norm_( + total_grad_norm_before_clip_tokenizer = torch.nn.utils.clip_grad_norm_( self._learn_model.tokenizer.parameters(), self._cfg.grad_clip_value ) self._optimizer_tokenizer.step() @@ -505,48 +635,51 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in reconstruction_loss = intermediate_losses_tokenizer['reconstruction_loss'] perceptual_loss = intermediate_losses_tokenizer['perceptual_loss'] - # ============================================================== - # the core target model update step. - # ============================================================== - self._target_model.update(self._learn_model.state_dict()) - if self._cfg.use_rnd_model: - self._target_model_for_intrinsic_reward.update(self._learn_model.state_dict()) + + # # ============================================================== + # # the core target model update step. + # # ============================================================== + # self._target_model.update(self._learn_model.state_dict()) + # if self._cfg.use_rnd_model: + # self._target_model_for_intrinsic_reward.update(self._learn_model.state_dict()) - return { + + return_loss_dict = { 'collect_mcts_temperature': self._collect_mcts_temperature, 'collect_epsilon': self.collect_epsilon, 'cur_lr_world_model': self._optimizer_world_model.param_groups[0]['lr'], 'cur_lr_tokenizer': self._optimizer_tokenizer.param_groups[0]['lr'], - 'weighted_total_loss': weighted_total_loss.item(), - 'obs_loss': obs_loss, - 'policy_loss': policy_loss, - 'target_policy_entropy': average_target_policy_entropy, + # 'weighted_total_loss': weighted_total_loss.item(), + # 'obs_loss': obs_loss, + # 'policy_loss': policy_loss, + # 'target_policy_entropy': average_target_policy_entropy, # 'policy_entropy': - policy_entropy_loss.mean().item() / (self._cfg.num_unroll_steps + 1), - 'reward_loss': reward_loss, - 'value_loss': value_loss, + # 'reward_loss': reward_loss, + # 'value_loss': value_loss, # 'consistency_loss': consistency_loss.mean().item() / self._cfg.num_unroll_steps, # ============================================================== # priority related # ============================================================== # 'value_priority_orig': value_priority, - 'value_priority_orig': np.zeros(self._cfg.batch_size), # TODO + # 'value_priority_orig': np.zeros(self._cfg.batch_size), # TODO # 'value_priority': value_priority.mean().item(), - 'target_reward': target_reward.detach().cpu().numpy().mean().item(), - 'target_value': target_value.detach().cpu().numpy().mean().item(), - 'transformed_target_reward': transformed_target_reward.detach().cpu().numpy().mean().item(), - 'transformed_target_value': transformed_target_value.detach().cpu().numpy().mean().item(), + # 'target_reward': target_reward.detach().cpu().numpy().mean().item(), + # 'target_value': target_value.detach().cpu().numpy().mean().item(), + # 'transformed_target_reward': transformed_target_reward.detach().cpu().numpy().mean().item(), + # 'transformed_target_value': transformed_target_value.detach().cpu().numpy().mean().item(), # 'predicted_rewards': predicted_rewards.detach().cpu().numpy().mean().item(), # 'predicted_values': predicted_values.detach().cpu().numpy().mean().item(), - 'total_grad_norm_before_clip': total_grad_norm_before_clip.item(), - + 'total_grad_norm_before_clip': total_grad_norm_before_clip_tokenizer.item(), 'commitment_loss':commitment_loss, 'reconstruction_loss':reconstruction_loss, 'perceptual_loss': perceptual_loss, } + return return_loss_dict + def _init_collect(self) -> None: """ Overview: diff --git a/lzero/policy/muzero_gpt_same-data.py b/lzero/policy/muzero_gpt_same-data.py new file mode 100644 index 000000000..eb89e6baa --- /dev/null +++ b/lzero/policy/muzero_gpt_same-data.py @@ -0,0 +1,847 @@ +import copy +from collections import defaultdict +from typing import List, Dict, Any, Tuple, Union + +import numpy as np +import torch +import torch.optim as optim +from ding.model import model_wrap +from ding.policy.base_policy import Policy +from ding.torch_utils import to_tensor +from ding.utils import POLICY_REGISTRY +from torch.distributions import Categorical +from torch.nn import L1Loss + +from lzero.mcts import MuZeroMCTSCtree as MCTSCtree +from lzero.mcts import MuZeroMCTSPtree as MCTSPtree +from lzero.model import ImageTransforms +from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ + DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, negative_cosine_similarity, \ + prepare_obs, configure_optimizers + + +def configure_optimizer(model, learning_rate, weight_decay, exclude_submodules, *blacklist_module_names): + """Credits to https://github.com/karpathy/minGPT""" + # separate out all parameters to those that will and won't experience regularizing weight decay + decay = set() + no_decay = set() + whitelist_weight_modules = [torch.nn.Linear, torch.nn.Conv1d] + blacklist_weight_modules = [torch.nn.LayerNorm, torch.nn.Embedding] + + # Here, we make sure to exclude parameters from specified submodules when creating param_dict + param_dict = {} + for mn, m in model.named_modules(): + if any(mn.startswith(module_name) for module_name in exclude_submodules): + continue # skip parameters from excluded submodules + for pn, p in m.named_parameters(recurse=False): + fpn = f'{mn}.{pn}' if mn else pn # full param name + if not any(fpn.startswith(bl_module_name) for bl_module_name in blacklist_module_names): + param_dict[fpn] = p + if 'bias' in pn: + no_decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, tuple(whitelist_weight_modules)): + decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, tuple(blacklist_weight_modules)): + no_decay.add(fpn) + else: + decay.add(fpn) # Default behavior is to add to decay + + # Validate that we considered every parameter + inter_params = decay & no_decay + union_params = decay | no_decay + assert len(inter_params) == 0, f"parameters {str(inter_params)} made it into both decay/no_decay sets!" + assert len(param_dict.keys() - union_params) == 0, f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!" + + # Create the PyTorch optimizer object + optim_groups = [ + {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay}, + {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, + ] + optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate) + return optimizer + + + +@POLICY_REGISTRY.register('muzero_gpt') +class MuZeroGPTPolicy(Policy): + """ + Overview: + The policy class for MuZero. + """ + + # The default_config for MuZero policy. + config = dict( + model=dict( + # (str) The model type. For 1-dimensional vector obs, we use mlp model. For the image obs, we use conv model. + model_type='conv', # options={'mlp', 'conv'} + # (bool) If True, the action space of the environment is continuous, otherwise discrete. + continuous_action_space=False, + # (tuple) The stacked obs shape. + # observation_shape=(1, 96, 96), # if frame_stack_num=1 + observation_shape=(4, 96, 96), # if frame_stack_num=4 + # (bool) Whether to use the self-supervised learning loss. + self_supervised_learning_loss=False, + # (bool) Whether to use discrete support to represent categorical distribution for value/reward/value_prefix. + categorical_distribution=True, + # (int) The image channel in image observation. + image_channel=1, + # (int) The number of frames to stack together. + frame_stack_num=1, + # (int) The number of res blocks in MuZero model. + num_res_blocks=1, + # (int) The number of channels of hidden states in MuZero model. + num_channels=64, + # (int) The scale of supports used in categorical distribution. + # This variable is only effective when ``categorical_distribution=True``. + support_scale=300, + # (bool) whether to learn bias in the last linear layer in value and policy head. + bias=True, + # (str) The type of action encoding. Options are ['one_hot', 'not_one_hot']. Default to 'one_hot'. + discrete_action_encoding_type='one_hot', + # (bool) whether to use res connection in dynamics. + res_connection_in_dynamics=True, + # (str) The type of normalization in MuZero model. Options are ['BN', 'LN']. Default to 'LN'. + norm_type='BN', + ), + # ****** common ****** + # (bool) whether to use rnd model. + use_rnd_model=False, + # (bool) Whether to use multi-gpu training. + multi_gpu=False, + # (bool) Whether to enable the sampled-based algorithm (e.g. Sampled EfficientZero) + # this variable is used in ``collector``. + sampled_algo=False, + # (bool) Whether to enable the gumbel-based algorithm (e.g. Gumbel Muzero) + gumbel_algo=False, + # (bool) Whether to use C++ MCTS in policy. If False, use Python implementation. + mcts_ctree=True, + # (bool) Whether to use cuda for network. + cuda=True, + # (int) The number of environments used in collecting data. + collector_env_num=8, + # (int) The number of environments used in evaluating policy. + evaluator_env_num=3, + # (str) The type of environment. Options are ['not_board_games', 'board_games']. + env_type='not_board_games', + # (str) The type of battle mode. Options are ['play_with_bot_mode', 'self_play_mode']. + battle_mode='play_with_bot_mode', + # (bool) Whether to monitor extra statistics in tensorboard. + monitor_extra_statistics=True, + # (int) The transition number of one ``GameSegment``. + game_segment_length=200, + + # ****** observation ****** + # (bool) Whether to transform image to string to save memory. + transform2string=False, + # (bool) Whether to use gray scale image. + gray_scale=False, + # (bool) Whether to use data augmentation. + use_augmentation=False, + # (list) The style of augmentation. + augmentation=['shift', 'intensity'], + + # ******* learn ****** + # (bool) Whether to ignore the done flag in the training data. Typically, this value is set to False. + # However, for some environments with a fixed episode length, to ensure the accuracy of Q-value calculations, + # we should set it to True to avoid the influence of the done flag. + ignore_done=False, + # (int) How many updates(iterations) to train after collector's one collection. + # Bigger "update_per_collect" means bigger off-policy. + # collect data -> update policy-> collect data -> ... + # For different env, we have different episode_length, + # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor. + # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.model_update_ratio automatically. + update_per_collect=None, + # (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None. + model_update_ratio=0.1, + # (int) Minibatch size for one gradient descent. + batch_size=256, + # (str) Optimizer for training policy network. ['SGD', 'Adam'] + optim_type='SGD', + # (float) Learning rate for training policy network. Initial lr for manually decay schedule. + learning_rate=0.2, + # (int) Frequency of target network update. + target_update_freq=100, + # (int) Frequency of target network update. + target_update_freq_for_intrinsic_reward=1000, + # (float) Weight decay for training policy network. + weight_decay=1e-4, + # (float) One-order Momentum in optimizer, which stabilizes the training process (gradient direction). + momentum=0.9, + # (float) The maximum constraint value of gradient norm clipping. + grad_clip_value=10, + # (int) The number of episodes in each collecting stage. + n_episode=8, + # (int) the number of simulations in MCTS. + num_simulations=50, + # (float) Discount factor (gamma) for returns. + discount_factor=0.997, + # (int) The number of steps for calculating target q_value. + td_steps=5, + # (int) The number of unroll steps in dynamics network. + num_unroll_steps=5, + # (float) The weight of reward loss. + reward_loss_weight=1, + # (float) The weight of value loss. + value_loss_weight=0.25, + # (float) The weight of policy loss. + policy_loss_weight=1, + # (float) The weight of policy entropy loss. + policy_entropy_loss_weight=0, + # (float) The weight of ssl (self-supervised learning) loss. + ssl_loss_weight=0, + # (bool) Whether to use piecewise constant learning rate decay. + # i.e. lr: 0.2 -> 0.02 -> 0.002 + lr_piecewise_constant_decay=True, + # (int) The number of final training iterations to control lr decay, which is only used for manually decay. + threshold_training_steps_for_final_lr=int(5e4), + # (bool) Whether to use manually decayed temperature. + manual_temperature_decay=False, + # (int) The number of final training iterations to control temperature, which is only used for manually decay. + threshold_training_steps_for_final_temperature=int(1e5), + # (float) The fixed temperature value for MCTS action selection, which is used to control the exploration. + # The larger the value, the more exploration. This value is only used when manual_temperature_decay=False. + fixed_temperature_value=0.25, + # (bool) Whether to use the true chance in MCTS in some environments with stochastic dynamics, such as 2048. + use_ture_chance_label_in_chance_encoder=False, + + # ****** Priority ****** + # (bool) Whether to use priority when sampling training data from the buffer. + use_priority=True, + # (float) The degree of prioritization to use. A value of 0 means no prioritization, + # while a value of 1 means full prioritization. + priority_prob_alpha=0.6, + # (float) The degree of correction to use. A value of 0 means no correction, + # while a value of 1 means full correction. + priority_prob_beta=0.4, + + # ****** UCB ****** + # (float) The alpha value used in the Dirichlet distribution for exploration at the root node of search tree. + root_dirichlet_alpha=0.3, + # (float) The noise weight at the root node of the search tree. + root_noise_weight=0.25, + + # ****** Explore by random collect ****** + # (int) The number of episodes to collect data randomly before training. + random_collect_episode_num=0, + + # ****** Explore by eps greedy ****** + eps=dict( + # (bool) Whether to use eps greedy exploration in collecting data. + eps_greedy_exploration_in_collect=False, + # (str) The type of decaying epsilon. Options are 'linear', 'exp'. + type='linear', + # (float) The start value of eps. + start=1., + # (float) The end value of eps. + end=0.05, + # (int) The decay steps from start to end eps. + decay=int(1e5), + ), + ) + + def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm default model setting for demonstration. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): model name and model import_names. + - model_type (:obj:`str`): The model type used in this algorithm, which is registered in ModelRegistry. + - import_names (:obj:`List[str]`): The model class path list used in this algorithm. + .. note:: + The user can define and use customized network model but must obey the same interface definition indicated \ + by import_names path. For MuZero, ``lzero.model.muzero_gpt_model.MuZeroModel`` + """ + if self._cfg.model.model_type == "conv": + # return 'MuZeroModel', ['lzero.model.muzero_gpt_model'] + return 'MuZeroModelGPT', ['lzero.model.muzero_gpt_model'] + elif self._cfg.model.model_type == "mlp": + return 'MuZeroModelGPT', ['lzero.model.muzero_gpt_model_vector_obs'] + else: + raise ValueError("model type {} is not supported".format(self._cfg.model.model_type)) + + def _init_learn(self) -> None: + """ + Overview: + Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. + """ + # assert self._cfg.optim_type in ['SGD', 'Adam', 'AdamW'], self._cfg.optim_type + # # NOTE: in board_games, for fixed lr 0.003, 'Adam' is better than 'SGD'. + # if self._cfg.optim_type == 'SGD': + # self._optimizer = optim.SGD( + # self._model.parameters(), + # lr=self._cfg.learning_rate, + # momentum=self._cfg.momentum, + # weight_decay=self._cfg.weight_decay, + # ) + # elif self._cfg.optim_type == 'Adam': + # self._optimizer = optim.Adam( + # self._model.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay + # ) + # elif self._cfg.optim_type == 'AdamW': + # self._optimizer = configure_optimizers( + # model=self._model, + # weight_decay=self._cfg.weight_decay, + # learning_rate=self._cfg.learning_rate, + # device_type=self._cfg.device + # ) + + self._optimizer_tokenizer = optim.Adam( + self._model.tokenizer.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay + ) + # TODO: nanoGPT optimizer + self._optimizer_world_model = configure_optimizer( + model=self._model.world_model, + learning_rate=self._cfg.learning_rate, + weight_decay=self._cfg.weight_decay, + exclude_submodules=['tokenizer'] + ) + + # self._optimizer_world_model = configure_optimizers( + # model=self._model.world_model, + # weight_decay=self._cfg.weight_decay, + # learning_rate=self._cfg.learning_rate, + # device_type=self._cfg.device + # ) + + # if self._cfg.lr_piecewise_constant_decay: + # from torch.optim.lr_scheduler import LambdaLR + # max_step = self._cfg.threshold_training_steps_for_final_lr + # # NOTE: the 1, 0.1, 0.01 is the decay rate, not the lr. + # lr_lambda = lambda step: 1 if step < max_step * 0.5 else (0.1 if step < max_step else 0.01) # noqa + # self.lr_scheduler = LambdaLR(self._optimizer, lr_lambda=lr_lambda) + + # use model_wrapper for specialized demands of different modes + self._target_model = copy.deepcopy(self._model) + + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='assign', + update_kwargs={'freq': self._cfg.target_update_freq} + ) + self._learn_model = self._model + + # TODO: only for debug + # for param in self._learn_model.tokenizer.parameters(): + # param.requires_grad = False + + if self._cfg.use_augmentation: + self.image_transforms = ImageTransforms( + self._cfg.augmentation, + image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) + ) + self.value_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) + self.reward_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) + self.inverse_scalar_transform_handle = InverseScalarTransform( + self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution + ) + + if self._cfg.use_rnd_model: + if self._cfg.target_model_for_intrinsic_reward_update_type == 'assign': + self._target_model_for_intrinsic_reward = model_wrap( + self._target_model, + wrapper_name='target', + update_type='assign', + update_kwargs={'freq': self._cfg.target_update_freq_for_intrinsic_reward} + ) + elif self._cfg.target_model_for_intrinsic_reward_update_type == 'momentum': + self._target_model_for_intrinsic_reward = model_wrap( + self._target_model, + wrapper_name='target', + update_type='momentum', + update_kwargs={'theta': self._cfg.target_update_theta_for_intrinsic_reward} + ) + + def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: + """ + Overview: + The forward function for learning policy in learn mode, which is the core of the learning process. + The data is sampled from replay buffer. + The loss is calculated by the loss function and the loss is backpropagated to update the model. + Arguments: + - data (:obj:`Tuple[torch.Tensor]`): The data sampled from replay buffer, which is a tuple of tensors. + The first tensor is the current_batch, the second tensor is the target_batch. + Returns: + - info_dict (:obj:`Dict[str, Union[float, int]]`): The information dict to be logged, which contains \ + current learning loss and learning statistics. + """ + self._learn_model.train() + self._target_model.train() + if self._cfg.use_rnd_model: + self._target_model_for_intrinsic_reward.train() + + current_batch, target_batch = data + + obs_batch_ori, action_batch, mask_batch, indices, weights, make_time = current_batch + target_reward, target_value, target_policy = target_batch + + obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) + + # do augmentations + if self._cfg.use_augmentation: + obs_batch = self.image_transforms.transform(obs_batch) + if self._cfg.model.self_supervised_learning_loss: + obs_target_batch = self.image_transforms.transform(obs_target_batch) + + # shape: (batch_size, num_unroll_steps, action_dim) + # NOTE: .long(), in discrete action space. + action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze(-1).long() + data_list = [ + mask_batch, + target_reward.astype('float32'), + target_value.astype('float32'), target_policy, weights + ] + [mask_batch, target_reward, target_value, target_policy, + weights] = to_torch_float_tensor(data_list, self._cfg.device) + + target_reward = target_reward.view(self._cfg.batch_size, -1) + target_value = target_value.view(self._cfg.batch_size, -1) + + assert obs_batch.size(0) == self._cfg.batch_size == target_reward.size(0) + + # ``scalar_transform`` to transform the original value to the scaled value, + # i.e. h(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. + transformed_target_reward = scalar_transform(target_reward) + transformed_target_value = scalar_transform(target_value) + + # transform a scalar to its categorical_distribution. After this transformation, each scalar is + # represented as the linear combination of its two adjacent supports. + target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) + target_value_categorical = phi_transform(self.value_support, transformed_target_value) + + # compute_loss(self, batch: Batch, tokenizer: Tokenizer, ** kwargs: Any) + + batch_for_gpt = {} + # TODO: for cartpole self._cfg.model.observation_shape + if isinstance(self._cfg.model.observation_shape, int) or len(self._cfg.model.observation_shape)==1: + batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( self._cfg.batch_size, -1, self._cfg.model.observation_shape) # (B, T, O) or (B, T, C, H, W) + elif len(self._cfg.model.observation_shape)==3: + batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( self._cfg.batch_size, -1, *self._cfg.model.observation_shape) # (B, T, O) or (B, T, C, H, W) + + + batch_for_gpt['actions'] = action_batch.squeeze(-1) # (B, T-1, A) -> (B, T-1) + + batch_for_gpt['rewards'] = target_reward_categorical[:, :-1] # (B, T, R) -> (B, T-1, R) + + batch_for_gpt['mask_padding'] = mask_batch == 1.0 # (B, T) NOTE: 0 means invalid padding data + batch_for_gpt['mask_padding'] = batch_for_gpt['mask_padding'][:, :-1] # (B, T-1) TODO + + + batch_for_gpt['observations'] = batch_for_gpt['observations'][:, :-1] # (B, T-1, O) or (B, T-1, C, H, W) + batch_for_gpt['ends'] = torch.zeros(batch_for_gpt['mask_padding'].shape, dtype=torch.long, device=self._cfg.device) # (B, T-1) + + batch_for_gpt['target_value'] = target_value_categorical[:, :-1] # (B, T-1, V) + batch_for_gpt['target_policy'] = target_policy[:, :-1] # (B, T-1, A) + # NOTE: TODO: next latent state's policy value + # batch_for_gpt['target_value'] = target_value_categorical[:, 1:] # (B, T-1, V) + # batch_for_gpt['target_policy'] = target_policy[:, 1:] # (B, T-1, A) + + # self._learn_model.world_model.train() + + # get valid target_policy data + valid_target_policy = batch_for_gpt['target_policy'][batch_for_gpt['mask_padding']] + # compute entropy of each policy + target_policy_entropy = -torch.sum(valid_target_policy * torch.log(valid_target_policy + 1e-9), dim=-1) + # compute average entropy + average_target_policy_entropy = target_policy_entropy.mean().item() + # print(f'Average entropy: {average_entropy}') + + + + # ============================================================== + # update world model + # ============================================================== + intermediate_losses = defaultdict(float) + losses = self._learn_model.world_model.compute_loss(batch_for_gpt, self._learn_model.tokenizer) + weighted_total_loss = losses.loss_total + for loss_name, loss_value in losses.intermediate_losses.items(): + intermediate_losses[f"{loss_name}"] = loss_value + # print(intermediate_losses) + obs_loss = intermediate_losses['loss_obs'] + reward_loss = intermediate_losses['loss_rewards'] + policy_loss = intermediate_losses['loss_policy'] + value_loss = intermediate_losses['loss_value'] + + # ============================================================== + # the core learn model update step. + # ============================================================== + """ + for name, parameter in self._learn_model.named_parameters(): + print(name) + """ + gradient_scale = 1 / self._cfg.num_unroll_steps + weighted_total_loss.register_hook(lambda grad: grad * gradient_scale) + self._optimizer_world_model.zero_grad() + weighted_total_loss.backward() + if self._cfg.multi_gpu: + self.sync_gradients(self._learn_model) + total_grad_norm_before_clip = torch.nn.utils.clip_grad_norm_( + self._learn_model.world_model.parameters(), self._cfg.grad_clip_value + ) + self._optimizer_world_model.step() + if self._cfg.lr_piecewise_constant_decay: + self.lr_scheduler.step() + + + # ============================================================== + # update tokenizer + # ============================================================== + # TODO: train tokenlizer + losses_tokenizer = self._learn_model.tokenizer.compute_loss(batch_for_gpt) + + self._optimizer_tokenizer.zero_grad() + losses_tokenizer.loss_total.backward() + total_grad_norm_before_clip = torch.nn.utils.clip_grad_norm_( + self._learn_model.tokenizer.parameters(), self._cfg.grad_clip_value + ) + self._optimizer_tokenizer.step() + + intermediate_losses_tokenizer= defaultdict(float) + for loss_name, loss_value in losses_tokenizer.intermediate_losses.items(): + intermediate_losses_tokenizer[f"{loss_name}"] = loss_value + # print(intermediate_losses) + commitment_loss= intermediate_losses_tokenizer['commitment_loss'] + reconstruction_loss = intermediate_losses_tokenizer['reconstruction_loss'] + perceptual_loss = intermediate_losses_tokenizer['perceptual_loss'] + + # ============================================================== + # the core target model update step. + # ============================================================== + self._target_model.update(self._learn_model.state_dict()) + if self._cfg.use_rnd_model: + self._target_model_for_intrinsic_reward.update(self._learn_model.state_dict()) + + + return { + 'collect_mcts_temperature': self._collect_mcts_temperature, + 'collect_epsilon': self.collect_epsilon, + 'cur_lr_world_model': self._optimizer_world_model.param_groups[0]['lr'], + 'cur_lr_tokenizer': self._optimizer_tokenizer.param_groups[0]['lr'], + + 'weighted_total_loss': weighted_total_loss.item(), + 'obs_loss': obs_loss, + 'policy_loss': policy_loss, + 'target_policy_entropy': average_target_policy_entropy, + # 'policy_entropy': - policy_entropy_loss.mean().item() / (self._cfg.num_unroll_steps + 1), + 'reward_loss': reward_loss, + 'value_loss': value_loss, + # 'consistency_loss': consistency_loss.mean().item() / self._cfg.num_unroll_steps, + + # ============================================================== + # priority related + # ============================================================== + # 'value_priority_orig': value_priority, + 'value_priority_orig': np.zeros(self._cfg.batch_size), # TODO + # 'value_priority': value_priority.mean().item(), + 'target_reward': target_reward.detach().cpu().numpy().mean().item(), + 'target_value': target_value.detach().cpu().numpy().mean().item(), + 'transformed_target_reward': transformed_target_reward.detach().cpu().numpy().mean().item(), + 'transformed_target_value': transformed_target_value.detach().cpu().numpy().mean().item(), + # 'predicted_rewards': predicted_rewards.detach().cpu().numpy().mean().item(), + # 'predicted_values': predicted_values.detach().cpu().numpy().mean().item(), + 'total_grad_norm_before_clip': total_grad_norm_before_clip.item(), + + 'commitment_loss':commitment_loss, + 'reconstruction_loss':reconstruction_loss, + 'perceptual_loss': perceptual_loss, + } + + def _init_collect(self) -> None: + """ + Overview: + Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. + """ + self._collect_model = self._model + if self._cfg.mcts_ctree: + self._mcts_collect = MCTSCtree(self._cfg) + else: + self._mcts_collect = MCTSPtree(self._cfg) + self._collect_mcts_temperature = 1. + self.collect_epsilon = 0.0 + + def _forward_collect( + self, + data: torch.Tensor, + action_mask: list = None, + temperature: float = 1, + to_play: List = [-1], + epsilon: float = 0.25, + ready_env_id: np.array = None, + ) -> Dict: + """ + Overview: + The forward function for collecting data in collect mode. Use model to execute MCTS search. + Choosing the action through sampling during the collect mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - temperature (:obj:`float`): The temperature of the policy. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - temperature: :math:`(1, )`. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._collect_model.eval() + self._collect_mcts_temperature = temperature + self.collect_epsilon = epsilon + active_collect_env_num = data.shape[0] + with torch.no_grad(): + network_output = self._collect_model.initial_inference(data) + # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)] + # the only difference between collect and eval is the dirichlet noise + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) + ).astype(np.float32).tolist() for j in range(active_collect_env_num) + ] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(active_collect_env_num, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(active_collect_env_num, legal_actions) + + roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) + self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play) + + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() # shape: {list: batch_size} + + data_id = [i for i in range(active_collect_env_num)] + output = {i: None for i in data_id} + + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + if self._cfg.eps.eps_greedy_exploration_in_collect: + # eps greedy collect + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=True + ) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + if np.random.rand() < self.collect_epsilon: + action = np.random.choice(legal_actions[i]) + else: + # normal collect + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=False + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + + return output + + + def _init_eval(self) -> None: + """ + Overview: + Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. + """ + self._eval_model = self._model + if self._cfg.mcts_ctree: + self._mcts_eval = MCTSCtree(self._cfg) + else: + self._mcts_eval = MCTSPtree(self._cfg) + + def _get_target_obs_index_in_step_k(self, step): + """ + Overview: + Get the begin index and end index of the target obs in step k. + Arguments: + - step (:obj:`int`): The current step k. + Returns: + - beg_index (:obj:`int`): The begin index of the target obs in step k. + - end_index (:obj:`int`): The end index of the target obs in step k. + Examples: + >>> self._cfg.model.model_type = 'conv' + >>> self._cfg.model.image_channel = 3 + >>> self._cfg.model.frame_stack_num = 4 + >>> self._get_target_obs_index_in_step_k(0) + >>> (0, 12) + """ + if self._cfg.model.model_type == 'conv': + beg_index = self._cfg.model.image_channel * step + end_index = self._cfg.model.image_channel * (step + self._cfg.model.frame_stack_num) + elif self._cfg.model.model_type == 'mlp': + beg_index = self._cfg.model.observation_shape * step + end_index = self._cfg.model.observation_shape * (step + self._cfg.model.frame_stack_num) + return beg_index, end_index + + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id: np.array = None,) -> Dict: + """ + Overview: + The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. + Choosing the action with the highest value (argmax) rather than sampling during the eval mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._eval_model.eval() + active_eval_env_num = data.shape[0] + with torch.no_grad(): + # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} + network_output = self._collect_model.initial_inference(data) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + if not self._eval_model.training: + # if not in training, obtain the scalars of the value/reward + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num)] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(active_eval_env_num, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(active_eval_env_num, legal_actions) + roots.prepare_no_noise(reward_roots, policy_logits, to_play) + self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play) + + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() # shape: {list: batch_size} + + data_id = [i for i in range(active_eval_env_num)] + output = {i: None for i in data_id} + + if ready_env_id is None: + ready_env_id = np.arange(active_eval_env_num) + + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + # Setting deterministic=True implies choosing the action with the highest value (argmax) rather than + # sampling during the evaluation phase. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=1, deterministic=True + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the + # entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + + return output + + def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Register the variables to be monitored in learn mode. The registered variables will be logged in + tensorboard according to the return value ``_forward_learn``. + """ + return [ + 'collect_mcts_temperature', + # 'cur_lr', + 'cur_lr_world_model', + 'cur_lr_tokenizer', + + 'weighted_total_loss', + # 'total_loss', + 'obs_loss', + 'policy_loss', + # 'policy_entropy', + 'target_policy_entropy', + 'reward_loss', + 'value_loss', + 'consistency_loss', + 'value_priority', + 'target_reward', + 'target_value', + # 'predicted_rewards', + # 'predicted_values', + # 'transformed_target_reward', + # 'transformed_target_value', + 'total_grad_norm_before_clip', + # tokenizer + 'commitment_loss', + 'reconstruction_loss', + 'perceptual_loss', + ] + + def _state_dict_learn(self) -> Dict[str, Any]: + """ + Overview: + Return the state_dict of learn mode, usually including model, target_model and optimizer. + Returns: + - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. + """ + return { + 'model': self._learn_model.state_dict(), + 'target_model': self._target_model.state_dict(), + 'optimizer_world_model': self._optimizer_world_model.state_dict(), + 'optimizer_tokenizer': self._optimizer_tokenizer.state_dict(), + + } + + def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + """ + Overview: + Load the state_dict variable into policy learn mode. + Arguments: + - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. + """ + self._learn_model.load_state_dict(state_dict['model']) + self._target_model.load_state_dict(state_dict['target_model']) + self._optimizer_world_model.load_state_dict(state_dict['optimizer_world_model']) + self._optimizer_tokenizer.load_state_dict(state_dict['optimizer_tokenizer']) + + def _process_transition(self, obs, policy_output, timestep): + # be compatible with DI-engine Policy class + pass + + def _get_train_sample(self, data): + # be compatible with DI-engine Policy class + pass diff --git a/zoo/atari/config/atari_muzero_gpt_config.py b/zoo/atari/config/atari_muzero_gpt_config.py index 37207edcf..17282be64 100644 --- a/zoo/atari/config/atari_muzero_gpt_config.py +++ b/zoo/atari/config/atari_muzero_gpt_config.py @@ -1,6 +1,6 @@ from easydict import EasyDict import torch -torch.cuda.set_device(3) +torch.cuda.set_device(0) # options={'PongNoFrameskip-v4', 'QbertNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SpaceInvadersNoFrameskip-v4', 'BreakoutNoFrameskip-v4', ...} env_name = 'PongNoFrameskip-v4' @@ -19,38 +19,34 @@ # ============================================================== # begin of the most frequently changed config specified by the user # ============================================================== -collector_env_num = 8 -n_episode = 8 -evaluator_env_num = 1 -num_simulations = 50 -# update_per_collect = 2000 -update_per_collect = None -model_update_ratio = 1 -# batch_size = 32 -batch_size = 16 - -max_env_step = int(1e6) -reanalyze_ratio = 0 -num_unroll_steps = 5 -# num_unroll_steps = 20 - -# collector_env_num = 1 -# n_episode = 1 +# collector_env_num = 8 +# n_episode = 8 # evaluator_env_num = 1 # num_simulations = 50 -# update_per_collect = 100 +# update_per_collect = 1000 +# # update_per_collect = None # model_update_ratio = 1 -# batch_size = 64 +# batch_size = 32 +# # batch_size = 16 # max_env_step = int(1e6) # reanalyze_ratio = 0 # num_unroll_steps = 5 # # num_unroll_steps = 20 +# for debug +collector_env_num = 2 +n_episode = 2 +evaluator_env_num = 1 +num_simulations = 2 +update_per_collect = 1 +model_update_ratio = 1 +batch_size = 2 +max_env_step = int(1e6) +reanalyze_ratio = 0 +num_unroll_steps = 5 + # eps_greedy_exploration_in_collect = False eps_greedy_exploration_in_collect = True - - - # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== @@ -58,9 +54,12 @@ atari_muzero_config = dict( # TODO: world_model.py decode_obs_tokens # TODO: tokenizer.py: lpips loss - exp_name=f'data_mz_gpt_ctree/{env_name[:-14]}_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_nlayers2_emd256_largenet_mcs500_batch8_obs-token-lw2_recons-obs-noaug_bs{batch_size}_adamw3e-3_seed0', - # exp_name=f'data_mz_gpt_ctree/{env_name[:-14]}_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_nlayers2_emd128_mediumnet_mcs500_batch8_obs-token-lw10_recons-obs-noaug_bs{batch_size}_adamw3e-3_seed0', - # exp_name=f'data_mz_gpt_ctree/{env_name[:-14]}_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_nlayers2_emd64_smallnet_mcs5000_batch8_recons-obs-noaug_bs{batch_size}_adamw3e-3_seed0', + # exp_name=f'data_mz_gpt_ctree/{env_name[:-14]}_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_nlayers2_emd128_mediumnet_mcs1e4_batch8_obs-token-lw2_recons-obs-noaug_bs{batch_size}_adamw3e-3_indep-10k_seed0', + # exp_name=f'data_mz_gpt_ctree/{env_name[:-14]}_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_nlayers2_emd256_largenet_mcs500_batch8_obs-token-lw2_recons-obs-noaug_bs{batch_size}_adamw3e-3_seed0', + # exp_name=f'data_mz_gpt_ctree/{env_name[:-14]}_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_nlayers2_emd64_smallnet_mcs500_batch8_recons-obs-noaug_bs{batch_size}_adamw3e-3_indep-10k_seed0', + + exp_name=f'data_mz_gpt_ctree_debug/{env_name[:-14]}_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_nlayers2_emd64_smallnet_mcs500_batch8_recons-obs-noaug_bs{batch_size}_adamw3e-3_indep-0_seed0', + env=dict( stop_value=int(1e6), env_name=env_name, @@ -73,13 +72,19 @@ n_evaluator_episode=evaluator_env_num, manager=dict(shared_memory=False, ), # TODO: debug - # collect_max_episode_steps=int(20), - # eval_max_episode_steps=int(20), + collect_max_episode_steps=int(20), + eval_max_episode_steps=int(20), # # TODO - collect_max_episode_steps=int(2e3), - eval_max_episode_steps=int(5e3), + # collect_max_episode_steps=int(2e3), + # eval_max_episode_steps=int(5e3), ), policy=dict( + tokenizer_start_after_envsteps=int(0), + transformer_start_after_envsteps=int(0), + # transformer_start_after_envsteps=int(1e4), # 10K + update_per_collect_transformer=update_per_collect, + update_per_collect_tokenizer=update_per_collect, + # transformer_start_after_envsteps=int(5e3), num_unroll_steps=num_unroll_steps, model=dict( # observation_shape=(4, 96, 96), @@ -115,12 +120,11 @@ end=0.01, # decay=int(1e5), decay=int(1e4), # 10k - ), # TODO: NOTE # use_augmentation=True, use_augmentation=False, - update_per_collect=update_per_collect, + # update_per_collect=update_per_collect, model_update_ratio = model_update_ratio, batch_size=batch_size, # optim_type='SGD', diff --git a/zoo/classic_control/cartpole/config/cartpole_muzero_gpt_config.py b/zoo/classic_control/cartpole/config/cartpole_muzero_gpt_config.py index 67c491b78..b39466814 100644 --- a/zoo/classic_control/cartpole/config/cartpole_muzero_gpt_config.py +++ b/zoo/classic_control/cartpole/config/cartpole_muzero_gpt_config.py @@ -37,10 +37,10 @@ # ============================================================== cartpole_muzero_gpt_config = dict( - # TODO: decode decode_obs_tokens + # TODO: world_model.py decode_obs_tokens # TODO: tokenizer: lpips loss - exp_name=f'data_mz_gpt_ctree_debug/cartpole_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_nlayers2_emd64_smallnet_bs{batch_size}_mcs500_batch8_obs-token-lw10_recons-obs_bs{batch_size}_seed0', - # exp_name=f'data_mz_gpt_ctree/cartpole_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_nlayers2_emd64_smallnet_bs{batch_size}_mcs500_batch8_not-fixedtokenizer_bs{batch_size}_seed0', + exp_name=f'data_mz_gpt_ctree/cartpole_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_nlayers2_emd64_smallnet_bs{batch_size}_mcs50_batch8_obs-token-lw2_recons-obs_bs{batch_size}_indep0_trans-wd0.01_seed0', + # exp_name=f'data_mz_gpt_ctree_debug/cartpole_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_nlayers2_emd64_smallnet_bs{batch_size}_mcs500_batch8_obs-token-lw2_recons-obs_bs{batch_size}_indep_seed0', env=dict( env_name='CartPole-v0', continuous=False, @@ -51,6 +51,11 @@ manager=dict(shared_memory=False, ), ), policy=dict( + tokenizer_start_after_envsteps=int(0), + transformer_start_after_envsteps=int(0), + update_per_collect_transformer=update_per_collect, + update_per_collect_tokenizer=update_per_collect, + # transformer_start_after_envsteps=int(5e3), # model_path='/mnt/afs/niuyazhe/code/LightZero/data_mz_gpt_ctree/cartpole_muzero_gpt_ns25_upc20-mur1_rr0_H5_nlayers2_emd128_mediumnet_bs64_mcs25_fixedtokenizer_fixloss_fixlatent_seed0/ckpt/ckpt_best.pth.tar', num_unroll_steps=num_unroll_steps, model=dict( @@ -73,7 +78,6 @@ # cuda=False, env_type='not_board_games', game_segment_length=50, - update_per_collect=update_per_collect, model_update_ratio=model_update_ratio, batch_size=batch_size, optim_type='Adam',