Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(pu): fix smz and sez config for pixel-based dmc #322

Merged
merged 3 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions lzero/mcts/buffer/game_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,14 +564,13 @@ def _push_game_segment(self, data: Any, meta: Optional[dict] = None) -> None:
# print(f'valid_len is {valid_len}')

if meta['priorities'] is None:
max_prio = self.game_pos_priorities.max() if self.game_segment_buffer else 1
if self.game_segment_buffer:
max_prio = self.game_pos_priorities.max() if len(self.game_pos_priorities) > 0 else 1
else:
max_prio = 1

# if no 'priorities' provided, set the valid part of the new-added game history the max_prio
self.game_pos_priorities = np.concatenate(
(
self.game_pos_priorities, [max_prio
for _ in range(valid_len)] + [0. for _ in range(valid_len, data_length)]
)
)
self.game_pos_priorities = np.concatenate((self.game_pos_priorities, [max_prio for _ in range(valid_len)] + [0. for _ in range(valid_len, data_length)]))
else:
assert data_length == len(meta['priorities']), " priorities should be of same length as the game steps"
priorities = meta['priorities'].copy().reshape(-1)
Expand Down
34 changes: 16 additions & 18 deletions lzero/mcts/buffer/game_buffer_sampled_unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,15 +479,14 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
m_output = model.initial_inference(batch_obs, action_batch[:self.reanalyze_num]) # NOTE: :self.reanalyze_num
# =======================================================================

if not model.training:
# if not in training, obtain the scalars of the value/reward
[m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
[
m_output.latent_state,
inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
m_output.policy_logits
]
)
# if not in training, obtain the scalars of the value/reward
[m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
[
m_output.latent_state,
inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
m_output.policy_logits
]
)

network_output.append(m_output)

Expand Down Expand Up @@ -638,15 +637,14 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
m_output = model.initial_inference(batch_obs, action_batch)
# ======================================================================

if not model.training:
# if not in training, obtain the scalars of the value/reward
[m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
[
m_output.latent_state,
inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
m_output.policy_logits
]
)
# if not in training, obtain the scalars of the value/reward
[m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
[
m_output.latent_state,
inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
m_output.policy_logits
]
)
network_output.append(m_output)

if self._cfg.use_root_value:
Expand Down
3 changes: 1 addition & 2 deletions lzero/model/sampled_efficientzero_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,7 @@ def __init__(
# (3,96,96), and frame_stack_num is 4. Due to downsample, the encoding of observation (latent_state) is
# (64, 96/16, 96/16), where 64 is the number of channels, 96/16 is the size of the latent state. Thus,
# self.projection_input_dim = 64 * 96/16 * 96/16 = 64*6*6 = 2304
self.projection_input_dim = num_channels * math.ceil(observation_shape[1] / 16
) * math.ceil(observation_shape[2] / 16)
self.projection_input_dim = num_channels * latent_size
else:
self.projection_input_dim = num_channels * observation_shape[1] * observation_shape[2]

Expand Down
Loading
Loading