Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] New API stack: (Multi)RLModule overhaul vol 03 (Introduce generic _forward to further simplify the user experience). #47889

Open
wants to merge 6 commits into
base: master
Choose a base branch
from

Conversation

sven1977
Copy link
Contributor

@sven1977 sven1977 commented Oct 3, 2024

New API stack: (Multi)RLModule overhaul vol 03 (Introduce generic _forward to further simplify the user experience).

  • Adds a generic _forward method to be used by RLModule subclasses (by default, all _forward_[inference|exploration|train] call this)
  • Users can still override _forward_[inference|exploration|train] to individualize behavior for the different algo phases.

Why are these changes needed?

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Copy link
Collaborator

@simonsays1980 simonsays1980 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Some nits in the docstrings.

action_dist_class_exploration = (
self.module[module_id].unwrapped().get_exploration_action_dist_cls()
)
action_dist_class_train = module.get_train_action_dist_cls()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need to use unwrapped in case DDP is used?

@@ -91,12 +89,14 @@ def possibly_masked_mean(data_):

# Compute a value function loss.
if config.use_critic:
value_fn_out = fwd_out[Columns.VF_PREDS]
value_fn_out = module.compute_values(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this gives again problems in the DDP case. I remember similar problems with CQL and SAC when not running everything in forward_train, but I guess the problem was that forward_train was run multiple times. So, my guess: works here.

encoder_outs = self.encoder(batch)
output[Columns.FEATURES] = encoder_outs[ENCODER_OUT][CRITIC]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imo features is a misleading term here as features are usually the inputs to a neural network or model in general. embeddings might fit better.

batch: Dict[str, Any],
features: Optional[Any] = None,
) -> TensorType:
if features is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not using features in batch and instead passing it in as an extra argument?

batch: The batch of multi-agent data (i.e. mapping from module ids to
individual modules' batches).
def items(self) -> ItemsView[ModuleID, RLModule]:
"""Returns a keys view over the module IDs in this MultiRLModule."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"keys" -> "items"

) -> Union[Dict[str, Any], Dict[ModuleID, Dict[str, Any]]]:
"""Runs the forward_exploration pass.
def values(self) -> ValuesView[ModuleID]:
"""Returns a keys view over the module IDs in this MultiRLModule."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"keys" -> "values"

By default, RLlib assumes that the module is non-recurrent if the initial
state is an empty dict and recurrent otherwise.
This behavior can be overridden by implementing this method.
Note that RLlib's distribution classes all implement the `Distribution`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! This makes it clear why!

values = self._values(features).squeeze(-1)
# Same logic as _forward, but also return features to be used by value function
# branch during training.
features, state_outs = self._compute_features_and_state_outs(batch)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As before, in my very own opinion I think "features" is a misleading name as it is usually used for the inputs of a neural network.

…odule_do_over_bc_default_module_03_common_forward
…odule_do_over_bc_default_module_03_common_forward
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants