Skip to content

Commit

Permalink
[BugFix] Avoid reshape(-1) for inputs to DreamerActorLoss (#2496)
Browse files Browse the repository at this point in the history
  • Loading branch information
kurtamohler authored Oct 18, 2024
1 parent 30df6d9 commit a27514c
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 3 deletions.
4 changes: 3 additions & 1 deletion sota-implementations/dreamer/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,9 @@ def compile_rssms(module):
with torch.autocast(
device_type=device.type, dtype=torch.bfloat16
) if use_autocast else contextlib.nullcontext():
actor_loss_td, sampled_tensordict = actor_loss(sampled_tensordict)
actor_loss_td, sampled_tensordict = actor_loss(
sampled_tensordict.reshape(-1)
)

actor_opt.zero_grad()
if use_autocast:
Expand Down
2 changes: 1 addition & 1 deletion test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -10332,7 +10332,7 @@ def test_dreamer_actor(self, device, imagination_horizon, discount_loss, td_est)
return
if td_est is not None:
loss_module.make_value_estimator(td_est)
loss_td, fake_data = loss_module(tensordict)
loss_td, fake_data = loss_module(tensordict.reshape(-1))
assert not fake_data.requires_grad
assert fake_data.shape == torch.Size([tensordict.numel(), imagination_horizon])
if discount_loss:
Expand Down
1 change: 0 additions & 1 deletion torchrl/objectives/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,6 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:

def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]:
tensordict = tensordict.select("state", self.tensor_keys.belief).detach()
tensordict = tensordict.reshape(-1)

with timeit("actor_loss/time-rollout"), hold_out_net(
self.model_based_env
Expand Down

0 comments on commit a27514c

Please sign in to comment.