Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] New API stack: (Multi)RLModule overhaul vol 03 (Introduce generic _forward to further simplify the user experience). #47889

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions rllib/algorithms/ppo/torch/ppo_torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def compute_loss_for_module(
batch: Dict[str, Any],
fwd_out: Dict[str, TensorType],
) -> TensorType:
module = self.module[module_id].unwrapped()

# Possibly apply masking to some sub loss terms and to the total loss term
# at the end. Masking could be used for RNN-based model (zero padded `batch`)
# and for PPO's batched value function (and bootstrap value) computations,
Expand All @@ -55,12 +57,8 @@ def possibly_masked_mean(data_):
else:
possibly_masked_mean = torch.mean

action_dist_class_train = (
self.module[module_id].unwrapped().get_train_action_dist_cls()
)
action_dist_class_exploration = (
self.module[module_id].unwrapped().get_exploration_action_dist_cls()
)
action_dist_class_train = module.get_train_action_dist_cls()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need to use unwrapped in case DDP is used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great question! DDP already wraps this method to use the unwrapped underlying RLModule, so this is ok here.

action_dist_class_exploration = module.get_exploration_action_dist_cls()

curr_action_dist = action_dist_class_train.from_logits(
fwd_out[Columns.ACTION_DIST_INPUTS]
Expand Down Expand Up @@ -91,12 +89,14 @@ def possibly_masked_mean(data_):

# Compute a value function loss.
if config.use_critic:
value_fn_out = fwd_out[Columns.VF_PREDS]
value_fn_out = module.compute_values(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this gives again problems in the DDP case. I remember similar problems with CQL and SAC when not running everything in forward_train, but I guess the problem was that forward_train was run multiple times. So, my guess: works here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good point, I think you are right. Let's see what the tests say ...

batch, features=fwd_out.get(Columns.FEATURES)
)
vf_loss = torch.pow(value_fn_out - batch[Postprocessing.VALUE_TARGETS], 2.0)
vf_loss_clipped = torch.clamp(vf_loss, 0, config.vf_clip_param)
mean_vf_loss = possibly_masked_mean(vf_loss_clipped)
mean_vf_unclipped_loss = possibly_masked_mean(vf_loss)
# Ignore the value function.
# Ignore the value function -> Set all to 0.0.
else:
z = torch.tensor(0.0, device=surrogate_loss.device)
value_fn_out = mean_vf_unclipped_loss = vf_loss_clipped = mean_vf_loss = z
Expand Down
69 changes: 28 additions & 41 deletions rllib/algorithms/ppo/torch/ppo_torch_rl_module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Dict, Optional

from ray.rllib.algorithms.ppo.ppo_rl_module import PPORLModule
from ray.rllib.core.columns import Columns
Expand All @@ -17,63 +17,50 @@ class PPOTorchRLModule(TorchRLModule, PPORLModule):
framework: str = "torch"

@override(RLModule)
def _forward_inference(self, batch: Dict[str, Any]) -> Dict[str, Any]:
def _forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
"""Default forward pass (used for inference and exploration)."""
output = {}

# Encoder forward pass.
encoder_outs = self.encoder(batch)
# Stateful encoder?
if Columns.STATE_OUT in encoder_outs:
output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT]

# Pi head.
output[Columns.ACTION_DIST_INPUTS] = self.pi(encoder_outs[ENCODER_OUT][ACTOR])

return output

@override(RLModule)
def _forward_exploration(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
return self._forward_inference(batch)

@override(RLModule)
def _forward_train(self, batch: Dict[str, Any]) -> Dict[str, Any]:
if self.config.inference_only:
raise RuntimeError(
"Trying to train a module that is not a learner module. Set the "
"flag `inference_only=False` when building the module."
)
def _forward_train(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
"""Train forward pass (keep features for possible shared value func. call)."""
output = {}

# Shared encoder.
encoder_outs = self.encoder(batch)
output[Columns.FEATURES] = encoder_outs[ENCODER_OUT][CRITIC]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imo features is a misleading term here as features are usually the inputs to a neural network or model in general. embeddings might fit better.

if Columns.STATE_OUT in encoder_outs:
output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT]

# Value head.
vf_out = self.vf(encoder_outs[ENCODER_OUT][CRITIC])
# Squeeze out last dim (value function node).
output[Columns.VF_PREDS] = vf_out.squeeze(-1)

# Policy head.
action_logits = self.pi(encoder_outs[ENCODER_OUT][ACTOR])
output[Columns.ACTION_DIST_INPUTS] = action_logits

output[Columns.ACTION_DIST_INPUTS] = self.pi(encoder_outs[ENCODER_OUT][ACTOR])
return output

@override(ValueFunctionAPI)
def compute_values(self, batch: Dict[str, Any]) -> TensorType:
# Separate vf-encoder.
if hasattr(self.encoder, "critic_encoder"):
batch_ = batch
if self.is_stateful():
# The recurrent encoders expect a `(state_in, h)` key in the
# input dict while the key returned is `(state_in, critic, h)`.
batch_ = batch.copy()
batch_[Columns.STATE_IN] = batch[Columns.STATE_IN][CRITIC]
encoder_outs = self.encoder.critic_encoder(batch_)[ENCODER_OUT]
# Shared encoder.
else:
encoder_outs = self.encoder(batch)[ENCODER_OUT][CRITIC]
def compute_values(
self,
batch: Dict[str, Any],
features: Optional[Any] = None,
) -> TensorType:
if features is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not using features in batch and instead passing it in as an extra argument?

# Separate vf-encoder.
if hasattr(self.encoder, "critic_encoder"):
batch_ = batch
if self.is_stateful():
# The recurrent encoders expect a `(state_in, h)` key in the
# input dict while the key returned is `(state_in, critic, h)`.
batch_ = batch.copy()
batch_[Columns.STATE_IN] = batch[Columns.STATE_IN][CRITIC]
features = self.encoder.critic_encoder(batch_)[ENCODER_OUT]
# Shared encoder.
else:
features = self.encoder(batch)[ENCODER_OUT][CRITIC]

# Value head.
vf_out = self.vf(encoder_outs)
vf_out = self.vf(features)
# Squeeze out last dimension (single node value head).
return vf_out.squeeze(-1)
1 change: 1 addition & 0 deletions rllib/core/columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class Columns:
# Common extra RLModule output keys.
STATE_IN = "state_in"
STATE_OUT = "state_out"
FEATURES = "features"
ACTION_DIST_INPUTS = "action_dist_inputs"
ACTION_PROB = "action_prob"
ACTION_LOGP = "action_logp"
Expand Down
4 changes: 2 additions & 2 deletions rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,7 @@ def filter_param_dict_for_optimizer(
def get_param_ref(self, param: Param) -> Hashable:
"""Returns a hashable reference to a trainable parameter.

This should be overriden in framework specific specialization. For example in
This should be overridden in framework specific specialization. For example in
torch it will return the parameter itself, while in tf it returns the .ref() of
the variable. The purpose is to retrieve a unique reference to the parameters.

Expand All @@ -706,7 +706,7 @@ def get_param_ref(self, param: Param) -> Hashable:
def get_parameters(self, module: RLModule) -> Sequence[Param]:
"""Returns the list of parameters of a module.

This should be overriden in framework specific learner. For example in torch it
This should be overridden in framework specific learner. For example in torch it
will return .parameters(), while in tf it returns .trainable_variables.

Args:
Expand Down
16 changes: 13 additions & 3 deletions rllib/core/rl_module/apis/value_function_api.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,30 @@
import abc
from typing import Any, Dict
from typing import Any, Dict, Optional

from ray.rllib.utils.typing import TensorType


class ValueFunctionAPI(abc.ABC):
"""An API to be implemented by RLModules for handling value function-based learning.

RLModules implementing this API must override the `compute_values` method."""
RLModules implementing this API must override the `compute_values` method.
"""

@abc.abstractmethod
def compute_values(self, batch: Dict[str, Any]) -> TensorType:
def compute_values(
self,
batch: Dict[str, Any],
features: Optional[Any] = None,
) -> TensorType:
"""Computes the value estimates given `batch`.

Args:
batch: The batch to compute value function estimates for.
features: Optional features already computed from the `batch` (by another
forward pass through the model's encoder (or other feature computing
subcomponent). For example, the caller of thie method should provide
`fetuares` - if available - to avoid duplicate passes through a shared
encoder.

Returns:
A tensor of shape (B,) or (B, T) (in case the input `batch` has a
Expand Down
Loading