diff --git a/d3rlpy/algos/transformer/torch/tacr_impl.py b/d3rlpy/algos/transformer/torch/tacr_impl.py index 5c43f7b7..fbb1504d 100644 --- a/d3rlpy/algos/transformer/torch/tacr_impl.py +++ b/d3rlpy/algos/transformer/torch/tacr_impl.py @@ -80,7 +80,7 @@ def inner_predict(self, inpt: TorchTransformerInput) -> torch.Tensor: action = self._modules.transformer( inpt.observations, inpt.actions, - inpt.returns_to_go, + inpt.rewards, inpt.timesteps, 1 - inpt.masks, ) @@ -128,7 +128,7 @@ def compute_actor_loss( action = self._modules.transformer( batch.observations, batch.actions, - batch.returns_to_go, + batch.rewards, batch.timesteps, 1 - batch.masks, ) @@ -168,7 +168,7 @@ def compute_target( action = self._modules.transformer( batch.observations, batch.actions, - batch.returns_to_go, + batch.rewards, batch.timesteps, 1 - batch.masks, )[:, :-1].reshape(-1, self._action_size)