Skip to content

Commit

Permalink
[BugFix] Avoid reshape(-1) for inputs to DreamerActorLoss
Browse files Browse the repository at this point in the history
  • Loading branch information
kurtamohler committed Oct 15, 2024
1 parent d894358 commit 6009810
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
4 changes: 3 additions & 1 deletion test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -10332,7 +10332,9 @@ def test_dreamer_actor(self, device, imagination_horizon, discount_loss, td_est)
return
if td_est is not None:
loss_module.make_value_estimator(td_est)
loss_td, fake_data = loss_module(tensordict)
# NOTE: Input is reshaped because GRUCell (which is part of the
# RSSMPrior module in `mb_env`) expects input to be either 1D or 2D
loss_td, fake_data = loss_module(tensordict.reshape(-1))
assert not fake_data.requires_grad
assert fake_data.shape == torch.Size([tensordict.numel(), imagination_horizon])
if discount_loss:
Expand Down
1 change: 0 additions & 1 deletion torchrl/objectives/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,6 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:

def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]:
tensordict = tensordict.select("state", self.tensor_keys.belief).detach()
tensordict = tensordict.reshape(-1)

with timeit("actor_loss/time-rollout"), hold_out_net(
self.model_based_env
Expand Down

0 comments on commit 6009810

Please sign in to comment.