Skip to content

Commit

Permalink
fix(pu): fix device bug in sampled efficientzero
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Dec 12, 2023
1 parent 784a1c2 commit 9b00a00
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions lzero/policy/sampled_efficientzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,9 +650,9 @@ def _calculate_policy_loss_cont(
y = 1 - target_sampled_actions[:, k, :].pow(2)

# NOTE: for numerical stability.
target_sampled_actions_clamped = torch.clamp(
target_sampled_actions[:, k, :], torch.tensor(-1 + 1e-6), torch.tensor(1 - 1e-6)
)
min_val = torch.tensor(-1 + 1e-6).to(target_sampled_actions.device)
max_val = torch.tensor(1 - 1e-6).to(target_sampled_actions.device)
target_sampled_actions_clamped = torch.clamp(target_sampled_actions[:, k, :], min_val, max_val)
target_sampled_actions_before_tanh = torch.arctanh(target_sampled_actions_clamped)

# keep dimension for loss computation (usually for action space is 1 env. e.g. pendulum)
Expand Down

0 comments on commit 9b00a00

Please sign in to comment.