Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Avoid reshape(-1) for inputs to objectives modules #2494

Merged
merged 3 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 22 additions & 34 deletions torchrl/objectives/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,32 +514,21 @@ def out_keys(self, values):

@dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
shape = None
if tensordict.ndimension() > 1:
shape = tensordict.shape
tensordict_reshape = tensordict.reshape(-1)
else:
tensordict_reshape = tensordict

q_loss, metadata = self.q_loss(tensordict_reshape)
cql_loss, cql_metadata = self.cql_loss(tensordict_reshape)
q_loss, metadata = self.q_loss(tensordict)
cql_loss, cql_metadata = self.cql_loss(tensordict)
if self.with_lagrange:
alpha_prime_loss, alpha_prime_metadata = self.alpha_prime_loss(
tensordict_reshape
)
alpha_prime_loss, alpha_prime_metadata = self.alpha_prime_loss(tensordict)
metadata.update(alpha_prime_metadata)
loss_actor_bc, bc_metadata = self.actor_bc_loss(tensordict_reshape)
loss_actor, actor_metadata = self.actor_loss(tensordict_reshape)
loss_actor_bc, bc_metadata = self.actor_bc_loss(tensordict)
loss_actor, actor_metadata = self.actor_loss(tensordict)
loss_alpha, alpha_metadata = self.alpha_loss(actor_metadata)
metadata.update(bc_metadata)
metadata.update(cql_metadata)
metadata.update(actor_metadata)
metadata.update(alpha_metadata)
tensordict_reshape.set(
tensordict.set(
self.tensor_keys.priority, metadata.pop("td_error").detach().max(0).values
)
if shape:
tensordict.update(tensordict_reshape.view(shape))
out = {
"loss_actor": loss_actor,
"loss_actor_bc": loss_actor_bc,
Expand Down Expand Up @@ -682,7 +671,9 @@ def _get_value_v(self, tensordict, _alpha, actor_params, qval_params):
)
# take max over actions
state_action_value = state_action_value.reshape(
self.num_qvalue_nets, tensordict.shape[0], self.num_random, -1
torch.Size(
[self.num_qvalue_nets, *tensordict.shape, self.num_random, -1]
)
).max(-2)[0]
# take min over qvalue nets
next_state_value = state_action_value.min(0)[0]
Expand Down Expand Up @@ -739,14 +730,13 @@ def cql_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]:
"This could be caused by calling cql_loss method before q_loss method."
)

random_actions_tensor = (
torch.FloatTensor(
tensordict.shape[0] * self.num_random,
random_actions_tensor = pred_q1.new_empty(
(
*tensordict.shape[:-1],
tensordict.shape[-1] * self.num_random,
tensordict[self.tensor_keys.action].shape[-1],
)
.uniform_(-1, 1)
.to(tensordict.device)
)
).uniform_(-1, 1)
curr_actions_td, curr_log_pis = self._get_policy_actions(
tensordict.copy(),
self.actor_network_params,
Expand Down Expand Up @@ -833,31 +823,31 @@ def filter_and_repeat(name, x):
q_new[0] - new_log_pis.detach().unsqueeze(-1),
q_curr[0] - curr_log_pis.detach().unsqueeze(-1),
],
1,
-1,
)
cat_q2 = torch.cat(
[
q_random[1] - random_density,
q_new[1] - new_log_pis.detach().unsqueeze(-1),
q_curr[1] - curr_log_pis.detach().unsqueeze(-1),
],
1,
-1,
)

min_qf1_loss = (
torch.logsumexp(cat_q1 / self.temperature, dim=1)
torch.logsumexp(cat_q1 / self.temperature, dim=-1)
* self.min_q_weight
* self.temperature
)
min_qf2_loss = (
torch.logsumexp(cat_q2 / self.temperature, dim=1)
torch.logsumexp(cat_q2 / self.temperature, dim=-1)
* self.min_q_weight
* self.temperature
)

# Subtract the log likelihood of data
cql_q1_loss = min_qf1_loss - pred_q1 * self.min_q_weight
cql_q2_loss = min_qf2_loss - pred_q2 * self.min_q_weight
cql_q1_loss = min_qf1_loss.flatten() - pred_q1 * self.min_q_weight
cql_q2_loss = min_qf2_loss.flatten() - pred_q2 * self.min_q_weight

# write cql losses in tensordict for alpha prime loss
tensordict.set(self.tensor_keys.cql_q1_loss, cql_q1_loss)
Expand Down Expand Up @@ -1080,9 +1070,9 @@ def __init__(
self.loss_function = loss_function
if action_space is None:
# infer from value net
try:
if hasattr(value_network, "action_space"):
action_space = value_network.spec
except AttributeError:
else:
# let's try with action_space then
try:
action_space = value_network.action_space
Expand Down Expand Up @@ -1205,8 +1195,6 @@ def value_loss(
with torch.no_grad():
td_error = (pred_val_index - target_value).pow(2)
td_error = td_error.unsqueeze(-1)
if tensordict.device is not None:
td_error = td_error.to(tensordict.device)

tensordict.set(
self.tensor_keys.priority,
Expand Down
15 changes: 3 additions & 12 deletions torchrl/objectives/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,23 +495,14 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
To see what keys are expected in the input tensordict and what keys are expected as output, check the
class's `"in_keys"` and `"out_keys"` attributes.
"""
shape = None
if tensordict.ndimension() > 1:
shape = tensordict.shape
tensordict_reshape = tensordict.reshape(-1)
else:
tensordict_reshape = tensordict

loss_qvalue, value_metadata = self.qvalue_loss(tensordict_reshape)
loss_actor, metadata_actor = self.actor_loss(tensordict_reshape)
loss_qvalue, value_metadata = self.qvalue_loss(tensordict)
loss_actor, metadata_actor = self.actor_loss(tensordict)
loss_alpha = self.alpha_loss(log_prob=metadata_actor["log_prob"])
tensordict_reshape.set(self.tensor_keys.priority, value_metadata["td_error"])
tensordict.set(self.tensor_keys.priority, value_metadata["td_error"])
if loss_actor.shape != loss_qvalue.shape:
raise RuntimeError(
f"Losses shape mismatch: {loss_actor.shape} and {loss_qvalue.shape}"
)
if shape:
tensordict.update(tensordict_reshape.view(shape))
entropy = -metadata_actor["log_prob"]
out = {
"loss_actor": loss_actor,
Expand Down
20 changes: 5 additions & 15 deletions torchrl/objectives/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,16 +373,9 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:

@dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
shape = None
if tensordict.ndimension() > 1:
shape = tensordict.shape
tensordict_reshape = tensordict.reshape(-1)
else:
tensordict_reshape = tensordict

loss_actor, metadata = self.actor_loss(tensordict_reshape)
loss_qvalue, metadata_qvalue = self.qvalue_loss(tensordict_reshape)
loss_value, metadata_value = self.value_loss(tensordict_reshape)
loss_actor, metadata = self.actor_loss(tensordict)
loss_qvalue, metadata_qvalue = self.qvalue_loss(tensordict)
loss_value, metadata_value = self.value_loss(tensordict)
metadata.update(metadata_qvalue)
metadata.update(metadata_value)

Expand All @@ -392,13 +385,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
raise RuntimeError(
f"Losses shape mismatch: {loss_actor.shape}, {loss_qvalue.shape} and {loss_value.shape}"
)
tensordict_reshape.set(
tensordict.set(
self.tensor_keys.priority, metadata.pop("td_error").detach().max(0).values
)
if shape:
tensordict.update(tensordict_reshape.view(shape))

entropy = -tensordict_reshape.get(self.tensor_keys.log_prob).detach()
entropy = -tensordict.get(self.tensor_keys.log_prob).detach()
out = {
"loss_actor": loss_actor,
"loss_qvalue": loss_qvalue,
Expand Down
34 changes: 8 additions & 26 deletions torchrl/objectives/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,30 +577,21 @@ def out_keys(self, values):

@dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
shape = None
if tensordict.ndimension() > 1:
shape = tensordict.shape
tensordict_reshape = tensordict.reshape(-1)
else:
tensordict_reshape = tensordict

if self._version == 1:
loss_qvalue, value_metadata = self._qvalue_v1_loss(tensordict_reshape)
loss_value, _ = self._value_loss(tensordict_reshape)
loss_qvalue, value_metadata = self._qvalue_v1_loss(tensordict)
loss_value, _ = self._value_loss(tensordict)
else:
loss_qvalue, value_metadata = self._qvalue_v2_loss(tensordict_reshape)
loss_qvalue, value_metadata = self._qvalue_v2_loss(tensordict)
loss_value = None
loss_actor, metadata_actor = self._actor_loss(tensordict_reshape)
loss_actor, metadata_actor = self._actor_loss(tensordict)
loss_alpha = self._alpha_loss(log_prob=metadata_actor["log_prob"])
tensordict_reshape.set(self.tensor_keys.priority, value_metadata["td_error"])
tensordict.set(self.tensor_keys.priority, value_metadata["td_error"])
if (loss_actor.shape != loss_qvalue.shape) or (
loss_value is not None and loss_actor.shape != loss_value.shape
):
raise RuntimeError(
f"Losses shape mismatch: {loss_actor.shape}, {loss_qvalue.shape} and {loss_value.shape}"
)
if shape:
tensordict.update(tensordict_reshape.view(shape))
entropy = -metadata_actor["log_prob"]
out = {
"loss_actor": loss_actor,
Expand Down Expand Up @@ -1158,26 +1149,17 @@ def in_keys(self, values):

@dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
shape = None
if tensordict.ndimension() > 1:
shape = tensordict.shape
tensordict_reshape = tensordict.reshape(-1)
else:
tensordict_reshape = tensordict

loss_value, metadata_value = self._value_loss(tensordict_reshape)
loss_actor, metadata_actor = self._actor_loss(tensordict_reshape)
loss_value, metadata_value = self._value_loss(tensordict)
loss_actor, metadata_actor = self._actor_loss(tensordict)
loss_alpha = self._alpha_loss(
log_prob=metadata_actor["log_prob"],
)

tensordict_reshape.set(self.tensor_keys.priority, metadata_value["td_error"])
tensordict.set(self.tensor_keys.priority, metadata_value["td_error"])
if loss_actor.shape != loss_value.shape:
raise RuntimeError(
f"Losses shape mismatch: {loss_actor.shape}, and {loss_value.shape}"
)
if shape:
tensordict.update(tensordict_reshape.view(shape))
entropy = -metadata_actor["log_prob"]
out = {
"loss_actor": loss_actor,
Expand Down
Loading