Skip to content

Commit

Permalink
fix(pu): fix smz and sez config for pixel-based dmc
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Jan 23, 2025
1 parent 5143f08 commit a4c436b
Show file tree
Hide file tree
Showing 8 changed files with 513 additions and 40 deletions.
18 changes: 11 additions & 7 deletions lzero/mcts/buffer/game_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,14 +559,18 @@ def _push_game_segment(self, data: Any, meta: Optional[dict] = None) -> None:
# print(f'valid_len is {valid_len}')

if meta['priorities'] is None:
max_prio = self.game_pos_priorities.max() if self.game_segment_buffer else 1
# try:
if self.game_segment_buffer:
max_prio = self.game_pos_priorities.max() if len(self.game_pos_priorities) > 0 else 1
else:
max_prio = 1
# except Exception as e:
# print(e)
# print(f'self.game_pos_priorities:{self.game_pos_priorities}')
# print(f'self.game_segment_buffer:{self.game_segment_buffer}')

# if no 'priorities' provided, set the valid part of the new-added game history the max_prio
self.game_pos_priorities = np.concatenate(
(
self.game_pos_priorities, [max_prio
for _ in range(valid_len)] + [0. for _ in range(valid_len, data_length)]
)
)
self.game_pos_priorities = np.concatenate((self.game_pos_priorities, [max_prio for _ in range(valid_len)] + [0. for _ in range(valid_len, data_length)]))
else:
assert data_length == len(meta['priorities']), " priorities should be of same length as the game steps"
priorities = meta['priorities'].copy().reshape(-1)
Expand Down
34 changes: 16 additions & 18 deletions lzero/mcts/buffer/game_buffer_sampled_unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,15 +479,14 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
m_output = model.initial_inference(batch_obs, action_batch[:self.reanalyze_num]) # NOTE: :self.reanalyze_num
# =======================================================================

if not model.training:
# if not in training, obtain the scalars of the value/reward
[m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
[
m_output.latent_state,
inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
m_output.policy_logits
]
)
# if not in training, obtain the scalars of the value/reward
[m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
[
m_output.latent_state,
inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
m_output.policy_logits
]
)

network_output.append(m_output)

Expand Down Expand Up @@ -638,15 +637,14 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
m_output = model.initial_inference(batch_obs, action_batch)
# ======================================================================

if not model.training:
# if not in training, obtain the scalars of the value/reward
[m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
[
m_output.latent_state,
inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
m_output.policy_logits
]
)
# if not in training, obtain the scalars of the value/reward
[m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
[
m_output.latent_state,
inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
m_output.policy_logits
]
)
network_output.append(m_output)

if self._cfg.use_root_value:
Expand Down
1 change: 1 addition & 0 deletions lzero/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# for atari 64,8,8 = 4096 -> 768
x = self.sim_norm(x)


return x


Expand Down
3 changes: 1 addition & 2 deletions lzero/model/sampled_efficientzero_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,7 @@ def __init__(
# (3,96,96), and frame_stack_num is 4. Due to downsample, the encoding of observation (latent_state) is
# (64, 96/16, 96/16), where 64 is the number of channels, 96/16 is the size of the latent state. Thus,
# self.projection_input_dim = 64 * 96/16 * 96/16 = 64*6*6 = 2304
self.projection_input_dim = num_channels * math.ceil(observation_shape[1] / 16
) * math.ceil(observation_shape[2] / 16)
self.projection_input_dim = num_channels * latent_size
else:
self.projection_input_dim = num_channels * observation_shape[1] * observation_shape[2]

Expand Down
Loading

0 comments on commit a4c436b

Please sign in to comment.