diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index f28fac8e675..992abea64e0 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -217,7 +217,9 @@ def compile_rssms(module): with torch.autocast( device_type=device.type, dtype=torch.bfloat16 ) if use_autocast else contextlib.nullcontext(): - actor_loss_td, sampled_tensordict = actor_loss(sampled_tensordict) + actor_loss_td, sampled_tensordict = actor_loss( + sampled_tensordict.reshape(-1) + ) actor_opt.zero_grad() if use_autocast: diff --git a/test/test_cost.py b/test/test_cost.py index 3530fff825d..6d8de531d49 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -10332,7 +10332,7 @@ def test_dreamer_actor(self, device, imagination_horizon, discount_loss, td_est) return if td_est is not None: loss_module.make_value_estimator(td_est) - loss_td, fake_data = loss_module(tensordict) + loss_td, fake_data = loss_module(tensordict.reshape(-1)) assert not fake_data.requires_grad assert fake_data.shape == torch.Size([tensordict.numel(), imagination_horizon]) if discount_loss: diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index 30f6dd10699..73df58b7e56 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -271,7 +271,6 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]: tensordict = tensordict.select("state", self.tensor_keys.belief).detach() - tensordict = tensordict.reshape(-1) with timeit("actor_loss/time-rollout"), hold_out_net( self.model_based_env