-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] New API stack: (Multi)RLModule overhaul vol 03 (Introduce generic _forward
to further simplify the user experience).
#47889
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Some nits in the docstrings.
action_dist_class_exploration = ( | ||
self.module[module_id].unwrapped().get_exploration_action_dist_cls() | ||
) | ||
action_dist_class_train = module.get_train_action_dist_cls() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't we need to use unwrapped
in case DDP is used?
@@ -91,12 +89,14 @@ def possibly_masked_mean(data_): | |||
|
|||
# Compute a value function loss. | |||
if config.use_critic: | |||
value_fn_out = fwd_out[Columns.VF_PREDS] | |||
value_fn_out = module.compute_values( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if this gives again problems in the DDP case. I remember similar problems with CQL and SAC when not running everything in forward_train
, but I guess the problem was that forward_train
was run multiple times. So, my guess: works here.
encoder_outs = self.encoder(batch) | ||
output[Columns.FEATURES] = encoder_outs[ENCODER_OUT][CRITIC] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Imo features
is a misleading term here as features
are usually the inputs to a neural network or model in general. embeddings
might fit better.
batch: Dict[str, Any], | ||
features: Optional[Any] = None, | ||
) -> TensorType: | ||
if features is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not using features
in batch
and instead passing it in as an extra argument?
batch: The batch of multi-agent data (i.e. mapping from module ids to | ||
individual modules' batches). | ||
def items(self) -> ItemsView[ModuleID, RLModule]: | ||
"""Returns a keys view over the module IDs in this MultiRLModule.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"keys" -> "items"
) -> Union[Dict[str, Any], Dict[ModuleID, Dict[str, Any]]]: | ||
"""Runs the forward_exploration pass. | ||
def values(self) -> ValuesView[ModuleID]: | ||
"""Returns a keys view over the module IDs in this MultiRLModule.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"keys" -> "values"
By default, RLlib assumes that the module is non-recurrent if the initial | ||
state is an empty dict and recurrent otherwise. | ||
This behavior can be overridden by implementing this method. | ||
Note that RLlib's distribution classes all implement the `Distribution` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice! This makes it clear why!
values = self._values(features).squeeze(-1) | ||
# Same logic as _forward, but also return features to be used by value function | ||
# branch during training. | ||
features, state_outs = self._compute_features_and_state_outs(batch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As before, in my very own opinion I think "features" is a misleading name as it is usually used for the inputs of a neural network.
…odule_do_over_bc_default_module_03_common_forward
…odule_do_over_bc_default_module_03_common_forward
New API stack: (Multi)RLModule overhaul vol 03 (Introduce generic
_forward
to further simplify the user experience)._forward
method to be used by RLModule subclasses (by default, all_forward_[inference|exploration|train]
call this)_forward_[inference|exploration|train]
to individualize behavior for the different algo phases.Why are these changes needed?
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.