-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] New API stack: (Multi)RLModule overhaul vol 03 (Introduce generic _forward
to further simplify the user experience).
#47889
base: master
Are you sure you want to change the base?
Changes from 2 commits
103eb20
2d99364
d85d95e
9196ca2
5d32009
26e14e2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,6 +40,8 @@ def compute_loss_for_module( | |
batch: Dict[str, Any], | ||
fwd_out: Dict[str, TensorType], | ||
) -> TensorType: | ||
module = self.module[module_id].unwrapped() | ||
|
||
# Possibly apply masking to some sub loss terms and to the total loss term | ||
# at the end. Masking could be used for RNN-based model (zero padded `batch`) | ||
# and for PPO's batched value function (and bootstrap value) computations, | ||
|
@@ -55,12 +57,8 @@ def possibly_masked_mean(data_): | |
else: | ||
possibly_masked_mean = torch.mean | ||
|
||
action_dist_class_train = ( | ||
self.module[module_id].unwrapped().get_train_action_dist_cls() | ||
) | ||
action_dist_class_exploration = ( | ||
self.module[module_id].unwrapped().get_exploration_action_dist_cls() | ||
) | ||
action_dist_class_train = module.get_train_action_dist_cls() | ||
action_dist_class_exploration = module.get_exploration_action_dist_cls() | ||
|
||
curr_action_dist = action_dist_class_train.from_logits( | ||
fwd_out[Columns.ACTION_DIST_INPUTS] | ||
|
@@ -91,12 +89,14 @@ def possibly_masked_mean(data_): | |
|
||
# Compute a value function loss. | ||
if config.use_critic: | ||
value_fn_out = fwd_out[Columns.VF_PREDS] | ||
value_fn_out = module.compute_values( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if this gives again problems in the DDP case. I remember similar problems with CQL and SAC when not running everything in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, good point, I think you are right. Let's see what the tests say ... |
||
batch, features=fwd_out.get(Columns.FEATURES) | ||
) | ||
vf_loss = torch.pow(value_fn_out - batch[Postprocessing.VALUE_TARGETS], 2.0) | ||
vf_loss_clipped = torch.clamp(vf_loss, 0, config.vf_clip_param) | ||
mean_vf_loss = possibly_masked_mean(vf_loss_clipped) | ||
mean_vf_unclipped_loss = possibly_masked_mean(vf_loss) | ||
# Ignore the value function. | ||
# Ignore the value function -> Set all to 0.0. | ||
else: | ||
z = torch.tensor(0.0, device=surrogate_loss.device) | ||
value_fn_out = mean_vf_unclipped_loss = vf_loss_clipped = mean_vf_loss = z | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from typing import Any, Dict | ||
from typing import Any, Dict, Optional | ||
|
||
from ray.rllib.algorithms.ppo.ppo_rl_module import PPORLModule | ||
from ray.rllib.core.columns import Columns | ||
|
@@ -17,63 +17,50 @@ class PPOTorchRLModule(TorchRLModule, PPORLModule): | |
framework: str = "torch" | ||
|
||
@override(RLModule) | ||
def _forward_inference(self, batch: Dict[str, Any]) -> Dict[str, Any]: | ||
def _forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: | ||
"""Default forward pass (used for inference and exploration).""" | ||
output = {} | ||
|
||
# Encoder forward pass. | ||
encoder_outs = self.encoder(batch) | ||
# Stateful encoder? | ||
if Columns.STATE_OUT in encoder_outs: | ||
output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT] | ||
|
||
# Pi head. | ||
output[Columns.ACTION_DIST_INPUTS] = self.pi(encoder_outs[ENCODER_OUT][ACTOR]) | ||
|
||
return output | ||
|
||
@override(RLModule) | ||
def _forward_exploration(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: | ||
return self._forward_inference(batch) | ||
|
||
@override(RLModule) | ||
def _forward_train(self, batch: Dict[str, Any]) -> Dict[str, Any]: | ||
if self.config.inference_only: | ||
raise RuntimeError( | ||
"Trying to train a module that is not a learner module. Set the " | ||
"flag `inference_only=False` when building the module." | ||
) | ||
def _forward_train(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: | ||
"""Train forward pass (keep features for possible shared value func. call).""" | ||
output = {} | ||
|
||
# Shared encoder. | ||
encoder_outs = self.encoder(batch) | ||
output[Columns.FEATURES] = encoder_outs[ENCODER_OUT][CRITIC] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Imo |
||
if Columns.STATE_OUT in encoder_outs: | ||
output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT] | ||
|
||
# Value head. | ||
vf_out = self.vf(encoder_outs[ENCODER_OUT][CRITIC]) | ||
# Squeeze out last dim (value function node). | ||
output[Columns.VF_PREDS] = vf_out.squeeze(-1) | ||
|
||
# Policy head. | ||
action_logits = self.pi(encoder_outs[ENCODER_OUT][ACTOR]) | ||
output[Columns.ACTION_DIST_INPUTS] = action_logits | ||
|
||
output[Columns.ACTION_DIST_INPUTS] = self.pi(encoder_outs[ENCODER_OUT][ACTOR]) | ||
return output | ||
|
||
@override(ValueFunctionAPI) | ||
def compute_values(self, batch: Dict[str, Any]) -> TensorType: | ||
# Separate vf-encoder. | ||
if hasattr(self.encoder, "critic_encoder"): | ||
batch_ = batch | ||
if self.is_stateful(): | ||
# The recurrent encoders expect a `(state_in, h)` key in the | ||
# input dict while the key returned is `(state_in, critic, h)`. | ||
batch_ = batch.copy() | ||
batch_[Columns.STATE_IN] = batch[Columns.STATE_IN][CRITIC] | ||
encoder_outs = self.encoder.critic_encoder(batch_)[ENCODER_OUT] | ||
# Shared encoder. | ||
else: | ||
encoder_outs = self.encoder(batch)[ENCODER_OUT][CRITIC] | ||
def compute_values( | ||
self, | ||
batch: Dict[str, Any], | ||
features: Optional[Any] = None, | ||
) -> TensorType: | ||
if features is None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not using |
||
# Separate vf-encoder. | ||
if hasattr(self.encoder, "critic_encoder"): | ||
batch_ = batch | ||
if self.is_stateful(): | ||
# The recurrent encoders expect a `(state_in, h)` key in the | ||
# input dict while the key returned is `(state_in, critic, h)`. | ||
batch_ = batch.copy() | ||
batch_[Columns.STATE_IN] = batch[Columns.STATE_IN][CRITIC] | ||
features = self.encoder.critic_encoder(batch_)[ENCODER_OUT] | ||
# Shared encoder. | ||
else: | ||
features = self.encoder(batch)[ENCODER_OUT][CRITIC] | ||
|
||
# Value head. | ||
vf_out = self.vf(encoder_outs) | ||
vf_out = self.vf(features) | ||
# Squeeze out last dimension (single node value head). | ||
return vf_out.squeeze(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't we need to use
unwrapped
in case DDP is used?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great question! DDP already wraps this method to use the
unwrapped
underlying RLModule, so this is ok here.