From 103eb203179311a342422fef496e02ba73f0aadf Mon Sep 17 00:00:00 2001 From: sven1977 Date: Thu, 3 Oct 2024 17:39:49 +0200 Subject: [PATCH 1/7] wip Signed-off-by: sven1977 --- .../algorithms/ppo/torch/ppo_torch_learner.py | 2 +- .../ppo/torch/ppo_torch_rl_module.py | 85 +++--- rllib/core/columns.py | 1 + rllib/core/learner/learner.py | 4 +- .../core/rl_module/apis/value_function_api.py | 16 +- rllib/core/rl_module/multi_rl_module.py | 268 +++++++++--------- rllib/core/rl_module/rl_module.py | 242 +++++++++------- rllib/core/rl_module/torch/torch_rl_module.py | 49 +++- .../classes/intrinsic_curiosity_model_rlm.py | 8 +- .../rl_modules/classes/lstm_containing_rlm.py | 68 ++--- .../rl_modules/classes/tiny_atari_cnn_rlm.py | 35 +-- rllib/utils/annotations.py | 6 +- 12 files changed, 426 insertions(+), 358 deletions(-) diff --git a/rllib/algorithms/ppo/torch/ppo_torch_learner.py b/rllib/algorithms/ppo/torch/ppo_torch_learner.py index f866165e22435..bcfa64813739f 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_learner.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_learner.py @@ -96,7 +96,7 @@ def possibly_masked_mean(data_): vf_loss_clipped = torch.clamp(vf_loss, 0, config.vf_clip_param) mean_vf_loss = possibly_masked_mean(vf_loss_clipped) mean_vf_unclipped_loss = possibly_masked_mean(vf_loss) - # Ignore the value function. + # Ignore the value function -> Set all to 0.0. else: z = torch.tensor(0.0, device=surrogate_loss.device) value_fn_out = mean_vf_unclipped_loss = vf_loss_clipped = mean_vf_loss = z diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index 2f3283cf1ca0c..ead91ca819662 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Optional from ray.rllib.algorithms.ppo.ppo_rl_module import PPORLModule from ray.rllib.core.columns import Columns @@ -17,63 +17,64 @@ class PPOTorchRLModule(TorchRLModule, PPORLModule): framework: str = "torch" @override(RLModule) - def _forward_inference(self, batch: Dict[str, Any]) -> Dict[str, Any]: - output = {} + def _forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: + """Default forward pass (used for inference and exploration).""" # Encoder forward pass. encoder_outs = self.encoder(batch) + # Stateful encoder? + state_out = None if Columns.STATE_OUT in encoder_outs: - output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT] + state_out = encoder_outs[Columns.STATE_OUT] # Pi head. - output[Columns.ACTION_DIST_INPUTS] = self.pi(encoder_outs[ENCODER_OUT][ACTOR]) + logits = self.pi(encoder_outs[ENCODER_OUT][ACTOR]) - return output + return { + Columns.ACTION_DIST_INPUTS: logits + } | {Columns.STATE_OUT: state_out} if state_out else {} @override(RLModule) - def _forward_exploration(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: - return self._forward_inference(batch) + def _forward_train(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: + """Train forward pass (keep features for possible shared value function call).""" - @override(RLModule) - def _forward_train(self, batch: Dict[str, Any]) -> Dict[str, Any]: - if self.config.inference_only: - raise RuntimeError( - "Trying to train a module that is not a learner module. Set the " - "flag `inference_only=False` when building the module." - ) - output = {} - - # Shared encoder. + # Encoder forward pass. encoder_outs = self.encoder(batch) + features = encoder_outs[ENCODER_OUT][CRITIC] + # Stateful encoder? + state_out = None if Columns.STATE_OUT in encoder_outs: - output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT] + state_out = encoder_outs[Columns.STATE_OUT] - # Value head. - vf_out = self.vf(encoder_outs[ENCODER_OUT][CRITIC]) - # Squeeze out last dim (value function node). - output[Columns.VF_PREDS] = vf_out.squeeze(-1) - - # Policy head. - action_logits = self.pi(encoder_outs[ENCODER_OUT][ACTOR]) - output[Columns.ACTION_DIST_INPUTS] = action_logits + # Pi head. + logits = self.pi(encoder_outs[ENCODER_OUT][ACTOR]) - return output + return { + Columns.ACTION_DIST_INPUTS: logits, + Columns.FEATURES: features, + } | {Columns.STATE_OUT: state_out} if state_out else {} @override(ValueFunctionAPI) - def compute_values(self, batch: Dict[str, Any]) -> TensorType: - # Separate vf-encoder. - if hasattr(self.encoder, "critic_encoder"): - batch_ = batch - if self.is_stateful(): - # The recurrent encoders expect a `(state_in, h)` key in the - # input dict while the key returned is `(state_in, critic, h)`. - batch_ = batch.copy() - batch_[Columns.STATE_IN] = batch[Columns.STATE_IN][CRITIC] - encoder_outs = self.encoder.critic_encoder(batch_)[ENCODER_OUT] - # Shared encoder. - else: - encoder_outs = self.encoder(batch)[ENCODER_OUT][CRITIC] + def compute_values( + self, + batch: Dict[str, Any], + features: Optional[Any] = None, + ) -> TensorType: + if features is None: + # Separate vf-encoder. + if hasattr(self.encoder, "critic_encoder"): + batch_ = batch + if self.is_stateful(): + # The recurrent encoders expect a `(state_in, h)` key in the + # input dict while the key returned is `(state_in, critic, h)`. + batch_ = batch.copy() + batch_[Columns.STATE_IN] = batch[Columns.STATE_IN][CRITIC] + features = self.encoder.critic_encoder(batch_)[ENCODER_OUT] + # Shared encoder. + else: + features = self.encoder(batch)[ENCODER_OUT][CRITIC] + # Value head. - vf_out = self.vf(encoder_outs) + vf_out = self.vf(features) # Squeeze out last dimension (single node value head). return vf_out.squeeze(-1) diff --git a/rllib/core/columns.py b/rllib/core/columns.py index 0944d521e2c17..073f37a73d840 100644 --- a/rllib/core/columns.py +++ b/rllib/core/columns.py @@ -44,6 +44,7 @@ class Columns: # Common extra RLModule output keys. STATE_IN = "state_in" STATE_OUT = "state_out" + FEATURES = "features" ACTION_DIST_INPUTS = "action_dist_inputs" ACTION_PROB = "action_prob" ACTION_LOGP = "action_logp" diff --git a/rllib/core/learner/learner.py b/rllib/core/learner/learner.py index dcb088ac74507..f0db18a79295a 100644 --- a/rllib/core/learner/learner.py +++ b/rllib/core/learner/learner.py @@ -691,7 +691,7 @@ def filter_param_dict_for_optimizer( def get_param_ref(self, param: Param) -> Hashable: """Returns a hashable reference to a trainable parameter. - This should be overriden in framework specific specialization. For example in + This should be overridden in framework specific specialization. For example in torch it will return the parameter itself, while in tf it returns the .ref() of the variable. The purpose is to retrieve a unique reference to the parameters. @@ -706,7 +706,7 @@ def get_param_ref(self, param: Param) -> Hashable: def get_parameters(self, module: RLModule) -> Sequence[Param]: """Returns the list of parameters of a module. - This should be overriden in framework specific learner. For example in torch it + This should be overridden in framework specific learner. For example in torch it will return .parameters(), while in tf it returns .trainable_variables. Args: diff --git a/rllib/core/rl_module/apis/value_function_api.py b/rllib/core/rl_module/apis/value_function_api.py index 06f0afccc19e8..595969b646c5a 100644 --- a/rllib/core/rl_module/apis/value_function_api.py +++ b/rllib/core/rl_module/apis/value_function_api.py @@ -1,5 +1,5 @@ import abc -from typing import Any, Dict +from typing import Any, Dict, Optional from ray.rllib.utils.typing import TensorType @@ -7,14 +7,24 @@ class ValueFunctionAPI(abc.ABC): """An API to be implemented by RLModules for handling value function-based learning. - RLModules implementing this API must override the `compute_values` method.""" + RLModules implementing this API must override the `compute_values` method. + """ @abc.abstractmethod - def compute_values(self, batch: Dict[str, Any]) -> TensorType: + def compute_values( + self, + batch: Dict[str, Any], + features: Optional[Any] = None, + ) -> TensorType: """Computes the value estimates given `batch`. Args: batch: The batch to compute value function estimates for. + features: Optional features already computed from the `batch` (by another + forward pass through the model's encoder (or other feature computing + subcomponent). For example, the caller of thie method should provide + `fetuares` - if available - to avoid duplicate passes through a shared + encoder. Returns: A tensor of shape (B,) or (B, T) (in case the input `batch` has a diff --git a/rllib/core/rl_module/multi_rl_module.py b/rllib/core/rl_module/multi_rl_module.py index a4b0deedce1e8..fe5b7cb594901 100644 --- a/rllib/core/rl_module/multi_rl_module.py +++ b/rllib/core/rl_module/multi_rl_module.py @@ -20,7 +20,6 @@ from ray.rllib.core import COMPONENT_MULTI_RL_MODULE_SPEC from ray.rllib.core.models.specs.typing import SpecType from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec -from ray.rllib.policy.sample_batch import MultiAgentBatch from ray.rllib.utils import force_list from ray.rllib.utils.annotations import ( ExperimentalAPI, @@ -74,7 +73,7 @@ def __init__(self, config: Optional["MultiRLModuleConfig"] = None) -> None: def setup(self): """Sets up the underlying RLModules.""" self._rl_modules = {} - self.__check_module_configs(self.config.modules) + self._check_module_configs(self.config.modules) # Make sure all individual RLModules have the same framework OR framework=None. framework = None for module_id, module_spec in self.config.modules.items(): @@ -85,6 +84,87 @@ def setup(self): assert self._rl_modules[module_id].framework in [None, framework] self.framework = framework + @override(RLModule) + def _forward( + self, + batch: Dict[ModuleID, Any], + **kwargs, + ) -> Dict[ModuleID, Dict[str, Any]]: + """Generic forward pass method, used in all phases of training and evaluation. + + If you need a more nuanced distinction between forward passes in the different + phases of training and evaluation, override the following methods insted: + For distinct action computation logic w/o exploration, override the + `self._forward_inference()` method. + For distinct action computation logic with exploration, override the + `self._forward_exploration()` method. + For distinct forward pass logic before loss computation, override the + `self._forward_train()` method. + + Args: + batch: The input batch, a dict mapping from ModuleID to individual modules' + batches. + **kwargs: Additional keyword arguments. + + Returns: + The output of the forward pass. + """ + return { + mid: self._rl_modules[mid]._forward(batch[mid], **kwargs) + for mid in batch.keys() if mid in self + } + + @override(RLModule) + def _forward_inference( + self, batch: Dict[str, Any], **kwargs + ) -> Union[Dict[str, Any], Dict[ModuleID, Dict[str, Any]]]: + """Forward-pass used for action computation without exploration behavior. + + Override this method only, if you need specific behavior for non-exploratory + action computation behavior. If you have only one generic behavior for all + phases of training and evaluation, override `self._forward()` instead. + + By default, this calls the generic `self._forward()` method. + """ + return { + mid: self._rl_modules[mid]._forward_inference(batch[mid], **kwargs) + for mid in batch.keys() if mid in self + } + + @override(RLModule) + def _forward_exploration( + self, batch: Dict[str, Any], **kwargs + ) -> Union[Dict[str, Any], Dict[ModuleID, Dict[str, Any]]]: + """Forward-pass used for action computation with exploration behavior. + + Override this method only, if you need specific behavior for exploratory + action computation behavior. If you have only one generic behavior for all + phases of training and evaluation, override `self._forward()` instead. + + By default, this calls the generic `self._forward()` method. + """ + return { + mid: self._rl_modules[mid]._forward_exploration(batch[mid], **kwargs) + for mid in batch.keys() if mid in self + } + + @override(RLModule) + def _forward_train( + self, batch: Dict[str, Any], **kwargs + ) -> Union[Dict[str, Any], Dict[ModuleID, Dict[str, Any]]]: + """Forward-pass used before the loss computation (training). + + Override this method only, if you need specific behavior and outputs for your + loss computations. If you have only one generic behavior for all + phases of training and evaluation, override `self._forward()` instead. + + By default, this calls the generic `self._forward()` method. + """ + return { + mid: self._rl_modules[mid]._forward_train(batch[mid], **kwargs) + for mid in batch.keys() if mid in self + } + @OverrideToImplementCustomLogic @override(RLModule) def get_initial_state(self) -> Any: @@ -105,50 +185,6 @@ def is_stateful(self) -> bool: ) return bool(any(sa_init_state for sa_init_state in initial_state.values())) - @classmethod - def __check_module_configs(cls, module_configs: Dict[ModuleID, Any]): - """Checks the module configs for validity. - - The module_configs be a mapping from module_ids to RLModuleSpec - objects. - - Args: - module_configs: The module configs to check. - - Raises: - ValueError: If the module configs are invalid. - """ - for module_id, module_spec in module_configs.items(): - if not isinstance(module_spec, RLModuleSpec): - raise ValueError(f"Module {module_id} is not a RLModuleSpec object.") - - def items(self) -> ItemsView[ModuleID, RLModule]: - """Returns a keys view over the module IDs in this MultiRLModule.""" - return self._rl_modules.items() - - def keys(self) -> KeysView[ModuleID]: - """Returns a keys view over the module IDs in this MultiRLModule.""" - return self._rl_modules.keys() - - def values(self) -> ValuesView[ModuleID]: - """Returns a keys view over the module IDs in this MultiRLModule.""" - return self._rl_modules.values() - - def __len__(self) -> int: - """Returns the number of RLModules within this MultiRLModule.""" - return len(self._rl_modules) - - @override(RLModule) - def as_multi_rl_module(self) -> "MultiRLModule": - """Returns self in order to match `RLModule.as_multi_rl_module()` behavior. - - This method is overridden to avoid double wrapping. - - Returns: - The instance itself. - """ - return self - def add_module( self, module_id: ModuleID, @@ -271,70 +307,24 @@ def get( return default return self._rl_modules[module_id] - @override(RLModule) - def output_specs_train(self) -> SpecType: - return [] - - @override(RLModule) - def output_specs_inference(self) -> SpecType: - return [] - - @override(RLModule) - def output_specs_exploration(self) -> SpecType: - return [] - - @override(RLModule) - def _default_input_specs(self) -> SpecType: - """MultiRLModule should not check the input specs. - - The underlying single-agent RLModules will check the input specs. - """ - return [] - - @override(RLModule) - def _forward_train( - self, batch: MultiAgentBatch, **kwargs - ) -> Union[Dict[str, Any], Dict[ModuleID, Dict[str, Any]]]: - """Runs the forward_train pass. - - Args: - batch: The batch of multi-agent data (i.e. mapping from module ids to - individual modules' batches). - - Returns: - The output of the forward_train pass the specified modules. - """ - return self._run_forward_pass("forward_train", batch, **kwargs) - - @override(RLModule) - def _forward_inference( - self, batch: MultiAgentBatch, **kwargs - ) -> Union[Dict[str, Any], Dict[ModuleID, Dict[str, Any]]]: - """Runs the forward_inference pass. - - Args: - batch: The batch of multi-agent data (i.e. mapping from module ids to - individual modules' batches). + def items(self) -> ItemsView[ModuleID, RLModule]: + """Returns a keys view over the module IDs in this MultiRLModule.""" + return self._rl_modules.items() - Returns: - The output of the forward_inference pass the specified modules. - """ - return self._run_forward_pass("forward_inference", batch, **kwargs) + def keys(self) -> KeysView[ModuleID]: + """Returns a keys view over the module IDs in this MultiRLModule.""" + return self._rl_modules.keys() - @override(RLModule) - def _forward_exploration( - self, batch: MultiAgentBatch, **kwargs - ) -> Union[Dict[str, Any], Dict[ModuleID, Dict[str, Any]]]: - """Runs the forward_exploration pass. + def values(self) -> ValuesView[ModuleID]: + """Returns a keys view over the module IDs in this MultiRLModule.""" + return self._rl_modules.values() - Args: - batch: The batch of multi-agent data (i.e. mapping from module ids to - individual modules' batches). + def __len__(self) -> int: + """Returns the number of RLModules within this MultiRLModule.""" + return len(self._rl_modules) - Returns: - The output of the forward_exploration pass the specified modules. - """ - return self._run_forward_pass("forward_exploration", batch, **kwargs) + def __repr__(self) -> str: + return f"MARL({pprint.pformat(self._rl_modules)})" @override(RLModule) def get_state( @@ -409,39 +399,53 @@ def set_state(self, state: StateDict) -> None: def get_checkpointable_components(self) -> List[Tuple[str, Checkpointable]]: return list(self._rl_modules.items()) - def __repr__(self) -> str: - return f"MARL({pprint.pformat(self._rl_modules)})" + @override(RLModule) + def output_specs_train(self) -> SpecType: + return [] - def _run_forward_pass( - self, - forward_fn_name: str, - batch: Dict[ModuleID, Any], - **kwargs, - ) -> Dict[ModuleID, Dict[ModuleID, Any]]: - """This is a helper method that runs the forward pass for the given module. + @override(RLModule) + def output_specs_inference(self) -> SpecType: + return [] - It uses forward_fn_name to get the forward pass method from the RLModule - (e.g. forward_train vs. forward_exploration) and runs it on the given batch. + @override(RLModule) + def output_specs_exploration(self) -> SpecType: + return [] - Args: - forward_fn_name: The name of the forward pass method to run. - batch: The batch of multi-agent data (i.e. mapping from module ids to - SampleBaches). - **kwargs: Additional keyword arguments to pass to the forward function. + @override(RLModule) + def _default_input_specs(self) -> SpecType: + """MultiRLModule should not check the input specs. + + The underlying single-agent RLModules will check the input specs. + """ + return [] + + @override(RLModule) + def as_multi_rl_module(self) -> "MultiRLModule": + """Returns self in order to match `RLModule.as_multi_rl_module()` behavior. + + This method is overridden to avoid double wrapping. Returns: - The output of the forward pass the specified modules. The output is a - mapping from module ID to the output of the forward pass. + The instance itself. """ + return self - outputs = {} - for module_id in batch.keys(): - self._check_module_exists(module_id) - rl_module = self._rl_modules[module_id] - forward_fn = getattr(rl_module, forward_fn_name) - outputs[module_id] = forward_fn(batch[module_id], **kwargs) + @classmethod + def _check_module_configs(cls, module_configs: Dict[ModuleID, Any]): + """Checks the module configs for validity. + + The module_configs be a mapping from module_ids to RLModuleSpec + objects. + + Args: + module_configs: The module configs to check. - return outputs + Raises: + ValueError: If the module configs are invalid. + """ + for module_id, module_spec in module_configs.items(): + if not isinstance(module_spec, RLModuleSpec): + raise ValueError(f"Module {module_id} is not a RLModuleSpec object.") def _check_module_exists(self, module_id: ModuleID) -> None: if module_id not in self._rl_modules: @@ -457,7 +461,7 @@ class MultiRLModuleSpec: """A utility spec class to make it constructing MultiRLModules easier. Users can extend this class to modify the behavior of base class. For example to - share neural networks across the modules, the build method can be overriden to + share neural networks across the modules, the build method can be overridden to create the shared module first and then pass it to custom module classes that would then use it as a shared module. diff --git a/rllib/core/rl_module/rl_module.py b/rllib/core/rl_module/rl_module.py index 88debf2204c71..91d246dbb9250 100644 --- a/rllib/core/rl_module/rl_module.py +++ b/rllib/core/rl_module/rl_module.py @@ -4,13 +4,6 @@ import gymnasium as gym -if TYPE_CHECKING: - from ray.rllib.core.rl_module.multi_rl_module import ( - MultiRLModule, - MultiRLModuleSpec, - ) - from ray.rllib.core.models.catalog import Catalog - from ray.rllib.core import DEFAULT_MODULE_ID from ray.rllib.core.columns import Columns from ray.rllib.core.models.specs.typing import SpecType @@ -34,9 +27,16 @@ serialize_type, deserialize_type, ) -from ray.rllib.utils.typing import SampleBatchType, StateDict +from ray.rllib.utils.typing import StateDict from ray.util.annotations import PublicAPI +if TYPE_CHECKING: + from ray.rllib.core.rl_module.multi_rl_module import ( + MultiRLModule, + MultiRLModuleSpec, + ) + from ray.rllib.core.models.catalog import Catalog + @PublicAPI(stability="alpha") @dataclass @@ -476,21 +476,6 @@ def setup(self): """ return None - @OverrideToImplementCustomLogic - def get_train_action_dist_cls(self) -> Type[Distribution]: - """Returns the action distribution class for this RLModule used for training. - - This class is used to get the correct action distribution class to be used by - the training components. In case that no action distribution class is needed, - this method can return None. - - Note that RLlib's distribution classes all implement the `Distribution` - interface. This requires two special methods: `Distribution.from_logits()` and - `Distribution.to_deterministic()`. See the documentation of the - :py:class:`~ray.rllib.models.distributions.Distribution` class for more details. - """ - raise NotImplementedError - @OverrideToImplementCustomLogic def get_exploration_action_dist_cls(self) -> Type[Distribution]: """Returns the action distribution class for this RLModule used for exploration. @@ -522,77 +507,49 @@ def get_inference_action_dist_cls(self) -> Type[Distribution]: raise NotImplementedError @OverrideToImplementCustomLogic - def get_initial_state(self) -> Any: - """Returns the initial state of the RLModule. - - This can be used for recurrent models. - """ - return {} + def get_train_action_dist_cls(self) -> Type[Distribution]: + """Returns the action distribution class for this RLModule used for training. - @OverrideToImplementCustomLogic - def is_stateful(self) -> bool: - """Returns False if the initial state is an empty dict (or None). + This class is used to get the correct action distribution class to be used by + the training components. In case that no action distribution class is needed, + this method can return None. - By default, RLlib assumes that the module is non-recurrent if the initial - state is an empty dict and recurrent otherwise. - This behavior can be overridden by implementing this method. + Note that RLlib's distribution classes all implement the `Distribution` + interface. This requires two special methods: `Distribution.from_logits()` and + `Distribution.to_deterministic()`. See the documentation of the + :py:class:`~ray.rllib.models.distributions.Distribution` class for more details. """ - initial_state = self.get_initial_state() - assert isinstance(initial_state, dict), ( - "The initial state of an RLModule must be a dict, but is " - f"{type(initial_state)} instead." - ) - return bool(initial_state) - - @OverrideToImplementCustomLogic_CallToSuperRecommended - def output_specs_inference(self) -> SpecType: - """Returns the output specs of the `forward_inference()` method. + raise NotImplementedError - Override this method to customize the output specs of the inference call. - The default implementation requires the `forward_inference()` method to return - a dict that has `action_dist` key and its value is an instance of - `Distribution`. - """ - return [Columns.ACTION_DIST_INPUTS] + @OverrideToImplementCustomLogic + @abc.abstractmethod + def _forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: + """Generic forward pass method, used in all phases of training and evaluation. + + If you need a more nuanced distinction between forward passes in the different + phases of training and evaluation, override the following methods insted: + For distinct action computation logic w/o exploration, override the + `self._forward_inference()` method. + For distinct action computation logic with exploration, override the + `self._forward_exploration()` method. + For distinct forward pass logic before loss computation, override the + `self._forward_train()` method. - @OverrideToImplementCustomLogic_CallToSuperRecommended - def output_specs_exploration(self) -> SpecType: - """Returns the output specs of the `forward_exploration()` method. + Args: + batch: The input batch. + **kwargs: Additional keyword arguments. - Override this method to customize the output specs of the exploration call. - The default implementation requires the `forward_exploration()` method to return - a dict that has `action_dist` key and its value is an instance of - `Distribution`. + Returns: + The output of the forward pass. """ - return [Columns.ACTION_DIST_INPUTS] - - def output_specs_train(self) -> SpecType: - """Returns the output specs of the forward_train method.""" - return {} - - def input_specs_inference(self) -> SpecType: - """Returns the input specs of the forward_inference method.""" - return self._default_input_specs() - - def input_specs_exploration(self) -> SpecType: - """Returns the input specs of the forward_exploration method.""" - return self._default_input_specs() - - def input_specs_train(self) -> SpecType: - """Returns the input specs of the forward_train method.""" - return self._default_input_specs() - - def _default_input_specs(self) -> SpecType: - """Returns the default input specs.""" - return [Columns.OBS] @check_input_specs("_input_specs_inference") @check_output_specs("_output_specs_inference") - def forward_inference(self, batch: SampleBatchType, **kwargs) -> Dict[str, Any]: - """Forward-pass during evaluation, called from the sampler. + def forward_inference(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: + """DO NOT OVERRIDE! Forward-pass during evaluation, called from the sampler. - This method should not be overriden to implement a custom forward inference - method. Instead, override the _forward_inference method. + This method should not be overridden. Override the `self._forward_inference()` + method instead. Args: batch: The input batch. This input batch should comply with @@ -605,17 +562,25 @@ def forward_inference(self, batch: SampleBatchType, **kwargs) -> Dict[str, Any]: """ return self._forward_inference(batch, **kwargs) - @abc.abstractmethod + @OverrideToImplementCustomLogic def _forward_inference(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: - """Forward-pass during evaluation. See forward_inference for details.""" + """Forward-pass used for action computation without exploration behavior. + + Override this method only, if you need specific behavior for non-exploratory + action computation behavior. If you have only one generic behavior for all + phases of training and evaluation, override `self._forward()` instead. + + By default, this calls the generic `self._forward()` method. + """ + return self._forward(batch, **kwargs) @check_input_specs("_input_specs_exploration") @check_output_specs("_output_specs_exploration") - def forward_exploration(self, batch: SampleBatchType, **kwargs) -> Dict[str, Any]: - """Forward-pass during exploration, called from the sampler. + def forward_exploration(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: + """DO NOT OVERRIDE! Forward-pass during exploration, called from the sampler. - This method should not be overriden to implement a custom forward exploration - method. Instead, override the _forward_exploration method. + This method should not be overridden. Override the `self._forward_exploration()` + method instead. Args: batch: The input batch. This input batch should comply with @@ -628,15 +593,25 @@ def forward_exploration(self, batch: SampleBatchType, **kwargs) -> Dict[str, Any """ return self._forward_exploration(batch, **kwargs) - @abc.abstractmethod + @OverrideToImplementCustomLogic def _forward_exploration(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: - """Forward-pass during exploration. See forward_exploration for details.""" + """Forward-pass used for action computation with exploration behavior. + + Override this method only, if you need specific behavior for exploratory + action computation behavior. If you have only one generic behavior for all + phases of training and evaluation, override `self._forward()` instead. + + By default, this calls the generic `self._forward()` method. + """ + return self._forward(batch, **kwargs) @check_input_specs("_input_specs_train") @check_output_specs("_output_specs_train") - def forward_train(self, batch: SampleBatchType, **kwargs) -> Dict[str, Any]: - """Forward-pass during training called from the learner. This method should - not be overriden. Instead, override the _forward_train method. + def forward_train(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: + """DO NOT OVERRIDE! Forward-pass during training called from the learner. + + This method should not be overridden. Override the `self._forward_train()` + method instead. Args: batch: The input batch. This input batch should comply with @@ -655,9 +630,42 @@ def forward_train(self, batch: SampleBatchType, **kwargs) -> Dict[str, Any]: ) return self._forward_train(batch, **kwargs) - @abc.abstractmethod + @OverrideToImplementCustomLogic def _forward_train(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: - """Forward-pass during training. See forward_train for details.""" + """Forward-pass used before the loss computation (training). + + Override this method only, if you need specific behavior and outputs for your + loss computations. If you have only one generic behavior for all + phases of training and evaluation, override `self._forward()` instead. + + By default, this calls the generic `self._forward()` method. + """ + return self._forward(batch, **kwargs) + + @OverrideToImplementCustomLogic + def get_initial_state(self) -> Any: + """Returns the initial state of the RLModule, in case this is a stateful module. + + Returns: + A tensor or any nested struct of tensors, representing an initial state for + this (stateful) RLModule. + """ + return {} + + @OverrideToImplementCustomLogic + def is_stateful(self) -> bool: + """By default, returns False if the initial state is an empty dict (or None). + + By default, RLlib assumes that the module is non-recurrent, if the initial + state is an empty dict and recurrent otherwise. + This behavior can be customized by overriding this method. + """ + initial_state = self.get_initial_state() + assert isinstance(initial_state, dict), ( + "The initial state of an RLModule must be a dict, but is " + f"{type(initial_state)} instead." + ) + return bool(initial_state) @OverrideToImplementCustomLogic @override(Checkpointable) @@ -701,6 +709,48 @@ def get_ctor_args_and_kwargs(self): {}, # **kwargs ) + @OverrideToImplementCustomLogic_CallToSuperRecommended + def output_specs_inference(self) -> SpecType: + """Returns the output specs of the `forward_inference()` method. + + Override this method to customize the output specs of the inference call. + The default implementation requires the `forward_inference()` method to return + a dict that has `action_dist` key and its value is an instance of + `Distribution`. + """ + return [Columns.ACTION_DIST_INPUTS] + + @OverrideToImplementCustomLogic_CallToSuperRecommended + def output_specs_exploration(self) -> SpecType: + """Returns the output specs of the `forward_exploration()` method. + + Override this method to customize the output specs of the exploration call. + The default implementation requires the `forward_exploration()` method to return + a dict that has `action_dist` key and its value is an instance of + `Distribution`. + """ + return [Columns.ACTION_DIST_INPUTS] + + def output_specs_train(self) -> SpecType: + """Returns the output specs of the forward_train method.""" + return {} + + def input_specs_inference(self) -> SpecType: + """Returns the input specs of the forward_inference method.""" + return self._default_input_specs() + + def input_specs_exploration(self) -> SpecType: + """Returns the input specs of the forward_exploration method.""" + return self._default_input_specs() + + def input_specs_train(self) -> SpecType: + """Returns the input specs of the forward_train method.""" + return self._default_input_specs() + + def _default_input_specs(self) -> SpecType: + """Returns the default input specs.""" + return [Columns.OBS] + def as_multi_rl_module(self) -> "MultiRLModule": """Returns a multi-agent wrapper around this module.""" from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule diff --git a/rllib/core/rl_module/torch/torch_rl_module.py b/rllib/core/rl_module/torch/torch_rl_module.py index db33cf9f9e0e0..8cc315baad627 100644 --- a/rllib/core/rl_module/torch/torch_rl_module.py +++ b/rllib/core/rl_module/torch/torch_rl_module.py @@ -68,15 +68,6 @@ def __init__(self, *args, **kwargs) -> None: if target is not None: del target - @override(nn.Module) - def forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: - """forward pass of the module. - - This is aliased to forward_train because Torch DDP requires a forward method to - be implemented for backpropagation to work. - """ - return self.forward_train(batch, **kwargs) - def compile(self, compile_config: TorchCompileConfig): """Compile the forward methods of this module. @@ -88,6 +79,20 @@ def compile(self, compile_config: TorchCompileConfig): """ return compile_wrapper(self, compile_config) + @OverrideToImplementCustomLogic + def _forward_inference(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: + # By default, calls the generic `_forward()` method, but with a no-grad context + # for performance reasons. + with torch.no_grad(): + return self._forward(batch, **kwargs) + + @OverrideToImplementCustomLogic + def _forward_exploration(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: + # By default, calls the generic `_forward()` method, but with a no-grad context + # for performance reasons. + with torch.no_grad(): + return self._forward(batch, **kwargs) + @OverrideToImplementCustomLogic @override(RLModule) def get_state( @@ -156,6 +161,24 @@ def get_exploration_action_dist_cls(self) -> Type[TorchDistribution]: def get_train_action_dist_cls(self) -> Type[TorchDistribution]: return self.get_inference_action_dist_cls() + @override(nn.Module) + def forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: + """DO NOT OVERRIDE! + + This is aliased to `self.forward_train` because Torch DDP requires a forward + method to be implemented for backpropagation to work. + + Instead, override: + `_forward()` to define a generic forward pass for all phases (exploration, + inference, training) + `_forward_inference()` to define the forward pass for action inference in + deployment/production (no exploration). + `_forward_exploration()` to define the forward pass for action inference during + training sample collection (w/ exploration behavior). + `_forward_train()` to define the forward pass prior to loss computation. + """ + return self.forward_train(batch, **kwargs) + class TorchDDPRLModule(RLModule, nn.parallel.DistributedDataParallel): def __init__(self, *args, **kwargs) -> None: @@ -187,8 +210,8 @@ def is_stateful(self) -> bool: return self.unwrapped().is_stateful() @override(RLModule) - def _forward_train(self, *args, **kwargs): - return self(*args, **kwargs) + def _forward(self, *args, **kwargs): + return self.unwrapped()._forward(*args, **kwargs) @override(RLModule) def _forward_inference(self, *args, **kwargs) -> Dict[str, Any]: @@ -198,6 +221,10 @@ def _forward_inference(self, *args, **kwargs) -> Dict[str, Any]: def _forward_exploration(self, *args, **kwargs) -> Dict[str, Any]: return self.unwrapped()._forward_exploration(*args, **kwargs) + @override(RLModule) + def _forward_train(self, *args, **kwargs): + return self(*args, **kwargs) + @override(RLModule) def get_state(self, *args, **kwargs): return self.unwrapped().get_state(*args, **kwargs) diff --git a/rllib/examples/rl_modules/classes/intrinsic_curiosity_model_rlm.py b/rllib/examples/rl_modules/classes/intrinsic_curiosity_model_rlm.py index c03f61d820281..ed1efbc1fc177 100644 --- a/rllib/examples/rl_modules/classes/intrinsic_curiosity_model_rlm.py +++ b/rllib/examples/rl_modules/classes/intrinsic_curiosity_model_rlm.py @@ -238,12 +238,8 @@ def compute_self_supervised_loss( # Inference and exploration not supported (this is a world-model that should only # be used for training). @override(TorchRLModule) - def _forward_inference(self, batch, **kwargs): + def _forward(self, batch, **kwargs): raise NotImplementedError( "`IntrinsicCuriosityModel` should only be used for training! " - "Use `forward_train()` instead." + "Only calls to `forward_train()` supported." ) - - @override(TorchRLModule) - def _forward_exploration(self, batch, **kwargs): - return self._forward_inference(batch) diff --git a/rllib/examples/rl_modules/classes/lstm_containing_rlm.py b/rllib/examples/rl_modules/classes/lstm_containing_rlm.py index 87363c267a7a8..d856d607c6a48 100644 --- a/rllib/examples/rl_modules/classes/lstm_containing_rlm.py +++ b/rllib/examples/rl_modules/classes/lstm_containing_rlm.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Optional import numpy as np @@ -77,7 +77,7 @@ def setup(self): # Get the LSTM cell size from our RLModuleConfig's (self.config) # `model_config_dict` property: self._lstm_cell_size = self.config.model_config_dict.get("lstm_cell_size", 256) - self._lstm = nn.LSTM(in_size, self._lstm_cell_size, batch_first=False) + self._lstm = nn.LSTM(in_size, self._lstm_cell_size, batch_first=True) in_size = self._lstm_cell_size # Build a sequential stack. @@ -94,7 +94,7 @@ def setup(self): self._fc_net = nn.Sequential(*layers) # Logits layer (no bias, no activation). - self._logits = nn.Linear(in_size, self.config.action_space.n) + self._pi_head = nn.Linear(in_size, self.config.action_space.n) # Single-node value layer. self._values = nn.Linear(in_size, 1) @@ -106,70 +106,48 @@ def get_initial_state(self) -> Any: } @override(TorchRLModule) - def _forward_inference(self, batch, **kwargs): + def _forward(self, batch, **kwargs): # Compute the basic 1D feature tensor (inputs to policy- and value-heads). - _, state_out, logits = self._compute_features_state_out_and_logits(batch) + features, state_outs = self._compute_features_and_state_outs(batch) + logits = self._pi_head(features) # Return logits as ACTION_DIST_INPUTS (categorical distribution). # Note that the default `GetActions` connector piece (in the EnvRunner) will # take care of argmax-"sampling" from the logits to yield the inference (greedy) # action. return { - Columns.STATE_OUT: state_out, Columns.ACTION_DIST_INPUTS: logits, + Columns.STATE_OUT: state_outs, } - @override(TorchRLModule) - def _forward_exploration(self, batch, **kwargs): - # Exact same as `_forward_inference`. - # Note that the default `GetActions` connector piece (in the EnvRunner) will - # take care of stochastic sampling from the Categorical defined by the logits - # to yield the exploration action. - return self._forward_inference(batch, **kwargs) - @override(TorchRLModule) def _forward_train(self, batch, **kwargs): - # Compute the basic 1D feature tensor (inputs to policy- and value-heads). - features, state_out, logits = self._compute_features_state_out_and_logits(batch) - # Besides the action logits, we also have to return value predictions here - # (to be used inside the loss function). - values = self._values(features).squeeze(-1) + # Same logic as _forward, but also return features to be used by value function + # branch during training. + features, state_outs = self._compute_features_and_state_outs(batch) + logits = self._pi_head(features) return { - Columns.STATE_OUT: state_out, Columns.ACTION_DIST_INPUTS: logits, - Columns.VF_PREDS: values, + Columns.STATE_OUT: state_outs, + Columns.FEATURES: features, } # We implement this RLModule as a ValueFunctionAPI RLModule, so it can be used # by value-based methods like PPO or IMPALA. @override(ValueFunctionAPI) - def compute_values(self, batch: Dict[str, Any]) -> TensorType: - obs = batch[Columns.OBS] - state_in = batch[Columns.STATE_IN] - h, c = state_in["h"], state_in["c"] - # Unsqueeze the layer dim (we only have 1 LSTM layer. - features, _ = self._lstm( - obs.permute(1, 0, 2), # we have to permute, b/c our LSTM is time-major - (h.unsqueeze(0), c.unsqueeze(0)), - ) - # Make batch-major again. - features = features.permute(1, 0, 2) - # Push through our FC net. - features = self._fc_net(features) - return self._values(features).squeeze(-1) + def compute_values(self, batch: Dict[str, Any], features: Optional[Any] = None) -> TensorType: + if features is None: + features, _ = self._compute_features_and_state_outs(batch) + values = self._values(features).squeeze(-1) + return values - def _compute_features_state_out_and_logits(self, batch): + def _compute_features_and_state_outs(self, batch): obs = batch[Columns.OBS] state_in = batch[Columns.STATE_IN] h, c = state_in["h"], state_in["c"] - # Unsqueeze the layer dim (we only have 1 LSTM layer. - features, (h, c) = self._lstm( - obs.permute(1, 0, 2), # we have to permute, b/c our LSTM is time-major - (h.unsqueeze(0), c.unsqueeze(0)), - ) - # Make batch-major again. - features = features.permute(1, 0, 2) + # Unsqueeze the layer dim (we only have 1 LSTM layer). + features, (h, c) = self._lstm(obs, (h.unsqueeze(0), c.unsqueeze(0))) # Push through our FC net. features = self._fc_net(features) - logits = self._logits(features) - return features, {"h": h.squeeze(0), "c": c.squeeze(0)}, logits + # Squeeze the layer dim (we only have 1 LSTM layer). + return features, {"h": h.squeeze(0), "c": c.squeeze(0)} diff --git a/rllib/examples/rl_modules/classes/tiny_atari_cnn_rlm.py b/rllib/examples/rl_modules/classes/tiny_atari_cnn_rlm.py index 22acf3939e8f8..fd9f1cefa8caf 100644 --- a/rllib/examples/rl_modules/classes/tiny_atari_cnn_rlm.py +++ b/rllib/examples/rl_modules/classes/tiny_atari_cnn_rlm.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Optional from ray.rllib.core.columns import Columns from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI @@ -49,7 +49,6 @@ class TinyAtariCNN(TorchRLModule, ValueFunctionAPI): num_all_params = sum(int(np.prod(p.size())) for p in my_net.parameters()) print(f"num params = {num_all_params}") - """ @override(TorchRLModule) @@ -122,35 +121,37 @@ def setup(self): normc_initializer(0.01)(self._values.weight) @override(TorchRLModule) - def _forward_inference(self, batch, **kwargs): + def _forward(self, batch, **kwargs): # Compute the basic 1D feature tensor (inputs to policy- and value-heads). _, logits = self._compute_features_and_logits(batch) - # Return logits as ACTION_DIST_INPUTS (categorical distribution). - return {Columns.ACTION_DIST_INPUTS: logits} - - @override(TorchRLModule) - def _forward_exploration(self, batch, **kwargs): - return self._forward_inference(batch, **kwargs) + # Return features and logits as ACTION_DIST_INPUTS (categorical distribution). + return { + Columns.ACTION_DIST_INPUTS: logits, + } @override(TorchRLModule) def _forward_train(self, batch, **kwargs): # Compute the basic 1D feature tensor (inputs to policy- and value-heads). features, logits = self._compute_features_and_logits(batch) - # Besides the action logits, we also have to return value predictions here - # (to be used inside the loss function). - values = self._values(features).squeeze(-1) + # Return features and logits as ACTION_DIST_INPUTS (categorical distribution). return { Columns.ACTION_DIST_INPUTS: logits, - Columns.VF_PREDS: values, + Columns.FEATURES: features, } # We implement this RLModule as a ValueFunctionAPI RLModule, so it can be used # by value-based methods like PPO or IMPALA. @override(ValueFunctionAPI) - def compute_values(self, batch: Dict[str, Any]) -> TensorType: - obs = batch[Columns.OBS] - features = self._base_cnn_stack(obs.permute(0, 3, 1, 2)) - features = torch.squeeze(features, dim=[-1, -2]) + def compute_values( + self, + batch: Dict[str, Any], + features: Optional[Any] = None, + ) -> TensorType: + # Features not provided -> We need to compute them first. + if features is None: + obs = batch[Columns.OBS] + features = self._base_cnn_stack(obs.permute(0, 3, 1, 2)) + features = torch.squeeze(features, dim=[-1, -2]) return self._values(features).squeeze(-1) def _compute_features_and_logits(self, batch): diff --git a/rllib/utils/annotations.py b/rllib/utils/annotations.py index d06a45dcb49bd..6824412b354f1 100644 --- a/rllib/utils/annotations.py +++ b/rllib/utils/annotations.py @@ -171,7 +171,7 @@ def loss(self, ...): ... """ - obj.__is_overriden__ = False + obj.__is_overridden__ = False return obj @@ -196,7 +196,7 @@ def setup(self, config): super().setup(config) # ... or here (after having called super()'s setup method. """ - obj.__is_overriden__ = False + obj.__is_overridden__ = False return obj @@ -206,7 +206,7 @@ def is_overridden(obj): Note, this only works for API calls decorated with OverrideToImplementCustomLogic or OverrideToImplementCustomLogic_CallToSuperRecommended. """ - return getattr(obj, "__is_overriden__", True) + return getattr(obj, "__is_overridden__", True) # Backward compatibility. From 2d9936495013ce3ccdc5dc4a1410039a1691d653 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Thu, 3 Oct 2024 19:06:14 +0200 Subject: [PATCH 2/7] wip Signed-off-by: sven1977 --- .../algorithms/ppo/torch/ppo_torch_learner.py | 14 ++++---- .../ppo/torch/ppo_torch_rl_module.py | 34 ++++++------------- rllib/core/rl_module/multi_rl_module.py | 12 ++++--- rllib/core/rl_module/torch/torch_rl_module.py | 8 ++--- .../rl_modules/classes/lstm_containing_rlm.py | 4 ++- 5 files changed, 32 insertions(+), 40 deletions(-) diff --git a/rllib/algorithms/ppo/torch/ppo_torch_learner.py b/rllib/algorithms/ppo/torch/ppo_torch_learner.py index bcfa64813739f..66e398bb4cef5 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_learner.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_learner.py @@ -40,6 +40,8 @@ def compute_loss_for_module( batch: Dict[str, Any], fwd_out: Dict[str, TensorType], ) -> TensorType: + module = self.module[module_id].unwrapped() + # Possibly apply masking to some sub loss terms and to the total loss term # at the end. Masking could be used for RNN-based model (zero padded `batch`) # and for PPO's batched value function (and bootstrap value) computations, @@ -55,12 +57,8 @@ def possibly_masked_mean(data_): else: possibly_masked_mean = torch.mean - action_dist_class_train = ( - self.module[module_id].unwrapped().get_train_action_dist_cls() - ) - action_dist_class_exploration = ( - self.module[module_id].unwrapped().get_exploration_action_dist_cls() - ) + action_dist_class_train = module.get_train_action_dist_cls() + action_dist_class_exploration = module.get_exploration_action_dist_cls() curr_action_dist = action_dist_class_train.from_logits( fwd_out[Columns.ACTION_DIST_INPUTS] @@ -91,7 +89,9 @@ def possibly_masked_mean(data_): # Compute a value function loss. if config.use_critic: - value_fn_out = fwd_out[Columns.VF_PREDS] + value_fn_out = module.compute_values( + batch, features=fwd_out.get(Columns.FEATURES) + ) vf_loss = torch.pow(value_fn_out - batch[Postprocessing.VALUE_TARGETS], 2.0) vf_loss_clipped = torch.clamp(vf_loss, 0, config.vf_clip_param) mean_vf_loss = possibly_masked_mean(vf_loss_clipped) diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index ead91ca819662..1cc99ccc2a17f 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -19,40 +19,26 @@ class PPOTorchRLModule(TorchRLModule, PPORLModule): @override(RLModule) def _forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: """Default forward pass (used for inference and exploration).""" - + output = {} # Encoder forward pass. encoder_outs = self.encoder(batch) # Stateful encoder? - state_out = None if Columns.STATE_OUT in encoder_outs: - state_out = encoder_outs[Columns.STATE_OUT] - + output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT] # Pi head. - logits = self.pi(encoder_outs[ENCODER_OUT][ACTOR]) - - return { - Columns.ACTION_DIST_INPUTS: logits - } | {Columns.STATE_OUT: state_out} if state_out else {} + output[Columns.ACTION_DIST_INPUTS] = self.pi(encoder_outs[ENCODER_OUT][ACTOR]) + return output @override(RLModule) def _forward_train(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: - """Train forward pass (keep features for possible shared value function call).""" - - # Encoder forward pass. + """Train forward pass (keep features for possible shared value func. call).""" + output = {} encoder_outs = self.encoder(batch) - features = encoder_outs[ENCODER_OUT][CRITIC] - # Stateful encoder? - state_out = None + output[Columns.FEATURES] = encoder_outs[ENCODER_OUT][CRITIC] if Columns.STATE_OUT in encoder_outs: - state_out = encoder_outs[Columns.STATE_OUT] - - # Pi head. - logits = self.pi(encoder_outs[ENCODER_OUT][ACTOR]) - - return { - Columns.ACTION_DIST_INPUTS: logits, - Columns.FEATURES: features, - } | {Columns.STATE_OUT: state_out} if state_out else {} + output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT] + output[Columns.ACTION_DIST_INPUTS] = self.pi(encoder_outs[ENCODER_OUT][ACTOR]) + return output @override(ValueFunctionAPI) def compute_values( diff --git a/rllib/core/rl_module/multi_rl_module.py b/rllib/core/rl_module/multi_rl_module.py index fe5b7cb594901..5587e0772cb26 100644 --- a/rllib/core/rl_module/multi_rl_module.py +++ b/rllib/core/rl_module/multi_rl_module.py @@ -111,7 +111,8 @@ def _forward( """ return { mid: self._rl_modules[mid]._forward(batch[mid], **kwargs) - for mid in batch.keys() if mid in self + for mid in batch.keys() + if mid in self } @override(RLModule) @@ -128,7 +129,8 @@ def _forward_inference( """ return { mid: self._rl_modules[mid]._forward_inference(batch[mid], **kwargs) - for mid in batch.keys() if mid in self + for mid in batch.keys() + if mid in self } @override(RLModule) @@ -145,7 +147,8 @@ def _forward_exploration( """ return { mid: self._rl_modules[mid]._forward_exploration(batch[mid], **kwargs) - for mid in batch.keys() if mid in self + for mid in batch.keys() + if mid in self } @override(RLModule) @@ -162,7 +165,8 @@ def _forward_train( """ return { mid: self._rl_modules[mid]._forward_train(batch[mid], **kwargs) - for mid in batch.keys() if mid in self + for mid in batch.keys() + if mid in self } @OverrideToImplementCustomLogic diff --git a/rllib/core/rl_module/torch/torch_rl_module.py b/rllib/core/rl_module/torch/torch_rl_module.py index 8cc315baad627..536631db96a84 100644 --- a/rllib/core/rl_module/torch/torch_rl_module.py +++ b/rllib/core/rl_module/torch/torch_rl_module.py @@ -188,8 +188,8 @@ def __init__(self, *args, **kwargs) -> None: self.config = self.unwrapped().config @override(RLModule) - def get_train_action_dist_cls(self, *args, **kwargs) -> Type[TorchDistribution]: - return self.unwrapped().get_train_action_dist_cls(*args, **kwargs) + def get_inference_action_dist_cls(self, *args, **kwargs) -> Type[TorchDistribution]: + return self.unwrapped().get_inference_action_dist_cls(*args, **kwargs) @override(RLModule) def get_exploration_action_dist_cls( @@ -198,8 +198,8 @@ def get_exploration_action_dist_cls( return self.unwrapped().get_exploration_action_dist_cls(*args, **kwargs) @override(RLModule) - def get_inference_action_dist_cls(self, *args, **kwargs) -> Type[TorchDistribution]: - return self.unwrapped().get_inference_action_dist_cls(*args, **kwargs) + def get_train_action_dist_cls(self, *args, **kwargs) -> Type[TorchDistribution]: + return self.unwrapped().get_train_action_dist_cls(*args, **kwargs) @override(RLModule) def get_initial_state(self) -> Any: diff --git a/rllib/examples/rl_modules/classes/lstm_containing_rlm.py b/rllib/examples/rl_modules/classes/lstm_containing_rlm.py index d856d607c6a48..0116145f9c6cf 100644 --- a/rllib/examples/rl_modules/classes/lstm_containing_rlm.py +++ b/rllib/examples/rl_modules/classes/lstm_containing_rlm.py @@ -135,7 +135,9 @@ def _forward_train(self, batch, **kwargs): # We implement this RLModule as a ValueFunctionAPI RLModule, so it can be used # by value-based methods like PPO or IMPALA. @override(ValueFunctionAPI) - def compute_values(self, batch: Dict[str, Any], features: Optional[Any] = None) -> TensorType: + def compute_values( + self, batch: Dict[str, Any], features: Optional[Any] = None + ) -> TensorType: if features is None: features, _ = self._compute_features_and_state_outs(batch) values = self._values(features).squeeze(-1) From d85d95e5f9481e920785b39d6b87262b85cb804a Mon Sep 17 00:00:00 2001 From: sven1977 Date: Thu, 3 Oct 2024 22:24:34 +0200 Subject: [PATCH 3/7] wip Signed-off-by: sven1977 --- rllib/algorithms/ppo/torch/ppo_torch_learner.py | 2 +- rllib/core/rl_module/rl_module.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/rllib/algorithms/ppo/torch/ppo_torch_learner.py b/rllib/algorithms/ppo/torch/ppo_torch_learner.py index 66e398bb4cef5..ee4dbc952d25a 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_learner.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_learner.py @@ -45,7 +45,7 @@ def compute_loss_for_module( # Possibly apply masking to some sub loss terms and to the total loss term # at the end. Masking could be used for RNN-based model (zero padded `batch`) # and for PPO's batched value function (and bootstrap value) computations, - # for which we add an additional (artificial) timestep to each episode to + # for which we add an (artificial) timestep to each episode to # simplify the actual computation. if Columns.LOSS_MASK in batch: mask = batch[Columns.LOSS_MASK] diff --git a/rllib/core/rl_module/rl_module.py b/rllib/core/rl_module/rl_module.py index 91d246dbb9250..32c09f977a887 100644 --- a/rllib/core/rl_module/rl_module.py +++ b/rllib/core/rl_module/rl_module.py @@ -522,7 +522,6 @@ def get_train_action_dist_cls(self) -> Type[Distribution]: raise NotImplementedError @OverrideToImplementCustomLogic - @abc.abstractmethod def _forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: """Generic forward pass method, used in all phases of training and evaluation. @@ -542,6 +541,7 @@ def _forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: Returns: The output of the forward pass. """ + return {} @check_input_specs("_input_specs_inference") @check_output_specs("_output_specs_inference") From 9196ca296318f4c9e96f35b8327ce27f96388447 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Thu, 3 Oct 2024 22:26:49 +0200 Subject: [PATCH 4/7] wip Signed-off-by: sven1977 --- rllib/core/rl_module/multi_rl_module.py | 2 +- rllib/core/rl_module/rl_module.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/rllib/core/rl_module/multi_rl_module.py b/rllib/core/rl_module/multi_rl_module.py index 5587e0772cb26..38109867f4ef8 100644 --- a/rllib/core/rl_module/multi_rl_module.py +++ b/rllib/core/rl_module/multi_rl_module.py @@ -93,7 +93,7 @@ def _forward( """Generic forward pass method, used in all phases of training and evaluation. If you need a more nuanced distinction between forward passes in the different - phases of training and evaluation, override the following methods insted: + phases of training and evaluation, override the following methods instead: For distinct action computation logic w/o exploration, override the `self._forward_inference()` method. For distinct action computation logic with exploration, override the diff --git a/rllib/core/rl_module/rl_module.py b/rllib/core/rl_module/rl_module.py index 32c09f977a887..1ddd33471a298 100644 --- a/rllib/core/rl_module/rl_module.py +++ b/rllib/core/rl_module/rl_module.py @@ -526,7 +526,7 @@ def _forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: """Generic forward pass method, used in all phases of training and evaluation. If you need a more nuanced distinction between forward passes in the different - phases of training and evaluation, override the following methods insted: + phases of training and evaluation, override the following methods instead: For distinct action computation logic w/o exploration, override the `self._forward_inference()` method. For distinct action computation logic with exploration, override the From 1a081194b5fe6fa16a04a93f1a8658c49351f759 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Sat, 5 Oct 2024 14:15:03 +0200 Subject: [PATCH 5/7] wip Signed-off-by: sven1977 --- rllib/algorithms/appo/torch/appo_torch_learner.py | 11 ++++++----- .../algorithms/impala/torch/impala_torch_learner.py | 12 ++++++------ 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/rllib/algorithms/appo/torch/appo_torch_learner.py b/rllib/algorithms/appo/torch/appo_torch_learner.py index bce58cd55c3ed..8e6c85c020bcc 100644 --- a/rllib/algorithms/appo/torch/appo_torch_learner.py +++ b/rllib/algorithms/appo/torch/appo_torch_learner.py @@ -14,7 +14,7 @@ ) from ray.rllib.core.columns import Columns from ray.rllib.core.learner.learner import POLICY_LOSS_KEY, VF_LOSS_KEY, ENTROPY_KEY -from ray.rllib.core.rl_module.apis.target_network_api import TargetNetworkAPI +from ray.rllib.core.rl_module.apis import TargetNetworkAPI, ValueFunctionAPI from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.numpy import convert_to_numpy @@ -35,6 +35,10 @@ def compute_loss_for_module( batch: Dict, fwd_out: Dict[str, TensorType], ) -> TensorType: + module = self.module[module_id].unwrapped() + assert isinstance(module, TargetNetworkAPI) + assert isinstance(module, ValueFunctionAPI) + # TODO (sven): Now that we do the +1ts trick to be less vulnerable about # bootstrap values at the end of rollouts in the new stack, we might make # this a more flexible, configurable parameter for users, e.g. @@ -51,10 +55,7 @@ def compute_loss_for_module( ) size_loss_mask = torch.sum(loss_mask) - module = self.module[module_id].unwrapped() - assert isinstance(module, TargetNetworkAPI) - - values = fwd_out[Columns.VF_PREDS] + values = module.compute_values(batch, features=fwd_out.get(Columns.FEATURES)) action_dist_cls_train = module.get_train_action_dist_cls() target_policy_dist = action_dist_cls_train.from_logits( diff --git a/rllib/algorithms/impala/torch/impala_torch_learner.py b/rllib/algorithms/impala/torch/impala_torch_learner.py index c98c0947be059..792911a9ae359 100644 --- a/rllib/algorithms/impala/torch/impala_torch_learner.py +++ b/rllib/algorithms/impala/torch/impala_torch_learner.py @@ -28,6 +28,8 @@ def compute_loss_for_module( batch: Dict, fwd_out: Dict[str, TensorType], ) -> TensorType: + module = self.module[module_id].unwrapped() + # TODO (sven): Now that we do the +1ts trick to be less vulnerable about # bootstrap values at the end of rollouts in the new stack, we might make # this a more flexible, configurable parameter for users, e.g. @@ -46,17 +48,15 @@ def compute_loss_for_module( # Behavior actions logp and target actions logp. behaviour_actions_logp = batch[Columns.ACTION_LOGP] - target_policy_dist = ( - self.module[module_id] - .unwrapped() - .get_train_action_dist_cls() - .from_logits(fwd_out[Columns.ACTION_DIST_INPUTS]) + target_policy_dist = module.get_train_action_dist_cls().from_logits( + fwd_out[Columns.ACTION_DIST_INPUTS] ) target_actions_logp = target_policy_dist.logp(batch[Columns.ACTIONS]) # Values and bootstrap values. + values = module.compute_values(batch, features=fwd_out.get(Columns.FEATURES)) values_time_major = make_time_major( - fwd_out[Columns.VF_PREDS], + values, trajectory_len=rollout_frag_or_episode_len, recurrent_seq_len=recurrent_seq_len, ) From 717f9e85dc9e0b67a380eec72c2365e5c52028f1 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Sat, 5 Oct 2024 14:29:10 +0200 Subject: [PATCH 6/7] wip Signed-off-by: sven1977 --- .../appo/torch/appo_torch_learner.py | 4 ++- .../impala/torch/impala_torch_learner.py | 4 ++- .../marwil/torch/marwil_torch_rl_module.py | 31 +++++++++++-------- rllib/algorithms/ppo/tf/ppo_tf_rl_module.py | 2 +- .../algorithms/ppo/torch/ppo_torch_learner.py | 2 +- .../ppo/torch/ppo_torch_rl_module.py | 12 +++---- rllib/core/columns.py | 2 +- .../core/rl_module/apis/value_function_api.py | 12 +++---- .../offline_rl/train_w_bc_finetune_w_ppo.py | 9 +++--- .../rl_modules/classes/action_masking_rlm.py | 4 +-- .../classes/autoregressive_actions_rlm.py | 19 +++--------- .../rl_modules/classes/lstm_containing_rlm.py | 26 ++++++++-------- .../rl_modules/classes/modelv2_to_rlm.py | 4 +-- .../rl_modules/classes/tiny_atari_cnn_rlm.py | 24 +++++++------- 14 files changed, 78 insertions(+), 77 deletions(-) diff --git a/rllib/algorithms/appo/torch/appo_torch_learner.py b/rllib/algorithms/appo/torch/appo_torch_learner.py index 8e6c85c020bcc..d53815989e09e 100644 --- a/rllib/algorithms/appo/torch/appo_torch_learner.py +++ b/rllib/algorithms/appo/torch/appo_torch_learner.py @@ -55,7 +55,9 @@ def compute_loss_for_module( ) size_loss_mask = torch.sum(loss_mask) - values = module.compute_values(batch, features=fwd_out.get(Columns.FEATURES)) + values = module.compute_values( + batch, embeddings=fwd_out.get(Columns.EMBEDDINGS) + ) action_dist_cls_train = module.get_train_action_dist_cls() target_policy_dist = action_dist_cls_train.from_logits( diff --git a/rllib/algorithms/impala/torch/impala_torch_learner.py b/rllib/algorithms/impala/torch/impala_torch_learner.py index 792911a9ae359..256e3b48fb79f 100644 --- a/rllib/algorithms/impala/torch/impala_torch_learner.py +++ b/rllib/algorithms/impala/torch/impala_torch_learner.py @@ -54,7 +54,9 @@ def compute_loss_for_module( target_actions_logp = target_policy_dist.logp(batch[Columns.ACTIONS]) # Values and bootstrap values. - values = module.compute_values(batch, features=fwd_out.get(Columns.FEATURES)) + values = module.compute_values( + batch, embeddings=fwd_out.get(Columns.EMBEDDINGS) + ) values_time_major = make_time_major( values, trajectory_len=rollout_frag_or_episode_len, diff --git a/rllib/algorithms/marwil/torch/marwil_torch_rl_module.py b/rllib/algorithms/marwil/torch/marwil_torch_rl_module.py index fe774e9041f4f..56097d0bdcb7b 100644 --- a/rllib/algorithms/marwil/torch/marwil_torch_rl_module.py +++ b/rllib/algorithms/marwil/torch/marwil_torch_rl_module.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Optional from ray.rllib.algorithms.marwil.marwil_rl_module import MARWILRLModule from ray.rllib.core.columns import Columns @@ -63,18 +63,23 @@ def _forward_train(self, batch: Dict[str, Any]) -> Dict[str, Any]: # (similar to IMPALA's v-trace architecture). This would also get rid of the # second Connector pass currently necessary. @override(ValueFunctionAPI) - def compute_values(self, batch: Dict[str, Any]) -> TensorType: - # Separate vf-encoder. - if hasattr(self.encoder, "critic_encoder"): - if self.is_stateful(): - # The recurrent encoders expect a `(state_in, h)` key in the - # input dict while the key returned is `(state_in, critic, h)`. - batch[Columns.STATE_IN] = batch[Columns.STATE_IN][CRITIC] - encoder_outs = self.encoder.critic_encoder(batch)[ENCODER_OUT] - # Shared encoder. - else: - encoder_outs = self.encoder(batch)[ENCODER_OUT][CRITIC] + def compute_values( + self, + batch: Dict[str, Any], + embeddings: Optional[Any] = None, + ) -> TensorType: + if embeddings is None: + # Separate vf-encoder. + if hasattr(self.encoder, "critic_encoder"): + if self.is_stateful(): + # The recurrent encoders expect a `(state_in, h)` key in the + # input dict while the key returned is `(state_in, critic, h)`. + batch[Columns.STATE_IN] = batch[Columns.STATE_IN][CRITIC] + embeddings = self.encoder.critic_encoder(batch)[ENCODER_OUT] + # Shared encoder. + else: + embeddings = self.encoder(batch)[ENCODER_OUT][CRITIC] # Value head. - vf_out = self.vf(encoder_outs) + vf_out = self.vf(embeddings) # Squeeze out last dimension (single node value head). return vf_out.squeeze(-1) diff --git a/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py b/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py index e242e2f892298..021b68e505dec 100644 --- a/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py +++ b/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py @@ -58,7 +58,7 @@ def _forward_train(self, batch: Dict): return output @override(ValueFunctionAPI) - def compute_values(self, batch: Dict[str, Any]) -> TensorType: + def compute_values(self, batch: Dict[str, Any], embeddings=None) -> TensorType: infos = batch.pop(Columns.INFOS, None) batch = tree.map_structure(lambda s: tf.convert_to_tensor(s), batch) if infos is not None: diff --git a/rllib/algorithms/ppo/torch/ppo_torch_learner.py b/rllib/algorithms/ppo/torch/ppo_torch_learner.py index ee4dbc952d25a..8cff87e4fc2a7 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_learner.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_learner.py @@ -90,7 +90,7 @@ def possibly_masked_mean(data_): # Compute a value function loss. if config.use_critic: value_fn_out = module.compute_values( - batch, features=fwd_out.get(Columns.FEATURES) + batch, embeddings=fwd_out.get(Columns.EMBEDDINGS) ) vf_loss = torch.pow(value_fn_out - batch[Postprocessing.VALUE_TARGETS], 2.0) vf_loss_clipped = torch.clamp(vf_loss, 0, config.vf_clip_param) diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index 1cc99ccc2a17f..86df016326a10 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -34,7 +34,7 @@ def _forward_train(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: """Train forward pass (keep features for possible shared value func. call).""" output = {} encoder_outs = self.encoder(batch) - output[Columns.FEATURES] = encoder_outs[ENCODER_OUT][CRITIC] + output[Columns.EMBEDDINGS] = encoder_outs[ENCODER_OUT][CRITIC] if Columns.STATE_OUT in encoder_outs: output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT] output[Columns.ACTION_DIST_INPUTS] = self.pi(encoder_outs[ENCODER_OUT][ACTOR]) @@ -44,9 +44,9 @@ def _forward_train(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: def compute_values( self, batch: Dict[str, Any], - features: Optional[Any] = None, + embeddings: Optional[Any] = None, ) -> TensorType: - if features is None: + if embeddings is None: # Separate vf-encoder. if hasattr(self.encoder, "critic_encoder"): batch_ = batch @@ -55,12 +55,12 @@ def compute_values( # input dict while the key returned is `(state_in, critic, h)`. batch_ = batch.copy() batch_[Columns.STATE_IN] = batch[Columns.STATE_IN][CRITIC] - features = self.encoder.critic_encoder(batch_)[ENCODER_OUT] + embeddings = self.encoder.critic_encoder(batch_)[ENCODER_OUT] # Shared encoder. else: - features = self.encoder(batch)[ENCODER_OUT][CRITIC] + embeddings = self.encoder(batch)[ENCODER_OUT][CRITIC] # Value head. - vf_out = self.vf(features) + vf_out = self.vf(embeddings) # Squeeze out last dimension (single node value head). return vf_out.squeeze(-1) diff --git a/rllib/core/columns.py b/rllib/core/columns.py index 073f37a73d840..2fc722f5c7240 100644 --- a/rllib/core/columns.py +++ b/rllib/core/columns.py @@ -44,7 +44,7 @@ class Columns: # Common extra RLModule output keys. STATE_IN = "state_in" STATE_OUT = "state_out" - FEATURES = "features" + EMBEDDINGS = "embeddings" ACTION_DIST_INPUTS = "action_dist_inputs" ACTION_PROB = "action_prob" ACTION_LOGP = "action_logp" diff --git a/rllib/core/rl_module/apis/value_function_api.py b/rllib/core/rl_module/apis/value_function_api.py index 595969b646c5a..43280228badb0 100644 --- a/rllib/core/rl_module/apis/value_function_api.py +++ b/rllib/core/rl_module/apis/value_function_api.py @@ -14,17 +14,17 @@ class ValueFunctionAPI(abc.ABC): def compute_values( self, batch: Dict[str, Any], - features: Optional[Any] = None, + embeddings: Optional[Any] = None, ) -> TensorType: """Computes the value estimates given `batch`. Args: batch: The batch to compute value function estimates for. - features: Optional features already computed from the `batch` (by another - forward pass through the model's encoder (or other feature computing - subcomponent). For example, the caller of thie method should provide - `fetuares` - if available - to avoid duplicate passes through a shared - encoder. + embeddings: Optional embeddings already computed from the `batch` (by + another forward pass through the model's encoder (or other subcomponent + that computes an embedding). For example, the caller of thie method + should provide `embeddings` - if available - to avoid duplicate passes + through a shared encoder. Returns: A tensor of shape (B,) or (B, T) (in case the input `batch` has a diff --git a/rllib/examples/offline_rl/train_w_bc_finetune_w_ppo.py b/rllib/examples/offline_rl/train_w_bc_finetune_w_ppo.py index e18512dbf47e3..348dfb2af1427 100644 --- a/rllib/examples/offline_rl/train_w_bc_finetune_w_ppo.py +++ b/rllib/examples/offline_rl/train_w_bc_finetune_w_ppo.py @@ -176,11 +176,12 @@ def _forward_train(self, batch, **kwargs): } @override(ValueFunctionAPI) - def compute_values(self, batch): - # Compute features ... - features = self._encoder(batch)[ENCODER_OUT] + def compute_values(self, batch, embeddings=None): + # Compute embeddings ... + if embeddings is None: + embeddings = self._encoder(batch)[ENCODER_OUT] # then values using our value head. - return self._vf(features).squeeze(-1) + return self._vf(embeddings).squeeze(-1) if __name__ == "__main__": diff --git a/rllib/examples/rl_modules/classes/action_masking_rlm.py b/rllib/examples/rl_modules/classes/action_masking_rlm.py index e948b8c1a1efa..2a71b66fa1094 100644 --- a/rllib/examples/rl_modules/classes/action_masking_rlm.py +++ b/rllib/examples/rl_modules/classes/action_masking_rlm.py @@ -99,14 +99,14 @@ def _forward_train( return self._mask_action_logits(outs, batch["action_mask"]) @override(ValueFunctionAPI) - def compute_values(self, batch: Dict[str, TensorType]): + def compute_values(self, batch: Dict[str, TensorType], embeddings=None): # Preprocess the batch to extract the `observations` to `Columns.OBS`. action_mask, batch = self._preprocess_batch(batch) # NOTE: Because we manipulate the batch we need to add the `action_mask` # to the batch to access them in `_forward_train`. batch["action_mask"] = action_mask # Call the super's method to compute values for GAE. - return super().compute_values(batch) + return super().compute_values(batch, embeddings) def _preprocess_batch( self, batch: Dict[str, TensorType], **kwargs diff --git a/rllib/examples/rl_modules/classes/autoregressive_actions_rlm.py b/rllib/examples/rl_modules/classes/autoregressive_actions_rlm.py index d0ff7650a166f..bbfcb69821514 100644 --- a/rllib/examples/rl_modules/classes/autoregressive_actions_rlm.py +++ b/rllib/examples/rl_modules/classes/autoregressive_actions_rlm.py @@ -212,10 +212,8 @@ def pi( @override(TorchRLModule) def _forward_inference(self, batch: Dict[str, TensorType]) -> Dict[str, TensorType]: - # Encoder forward pass. encoder_out = self.encoder(batch) - # Policy head forward pass. return self.pi(encoder_out[ENCODER_OUT], inference=True) @@ -225,21 +223,16 @@ def _forward_exploration( ) -> Dict[str, TensorType]: # Encoder forward pass. encoder_out = self.encoder(batch) - # Policy head forward pass. return self.pi(encoder_out[ENCODER_OUT], inference=False) @override(TorchRLModule) def _forward_train(self, batch: Dict[str, TensorType]) -> Dict[str, TensorType]: - outs = {} - # Encoder forward pass. encoder_out = self.encoder(batch) - # Policy head forward pass. outs.update(self.pi(encoder_out[ENCODER_OUT])) - # Value function head forward pass. vf_out = self.vf(encoder_out[ENCODER_OUT]) outs[Columns.VF_PREDS] = vf_out.squeeze(-1) @@ -247,13 +240,11 @@ def _forward_train(self, batch: Dict[str, TensorType]) -> Dict[str, TensorType]: return outs @override(ValueFunctionAPI) - def compute_values(self, batch: Dict[str, TensorType]): - - # Encoder forward pass. - encoder_outs = self.encoder(batch)[ENCODER_OUT] - + def compute_values(self, batch: Dict[str, TensorType], embeddings=None): + # Encoder forward pass to get `embeddings`, if necessary. + if embeddings is None: + embeddings = self.encoder(batch)[ENCODER_OUT] # Value head forward pass. - vf_out = self.vf(encoder_outs) - + vf_out = self.vf(embeddings) # Squeeze out last dimension (single node value head). return vf_out.squeeze(-1) diff --git a/rllib/examples/rl_modules/classes/lstm_containing_rlm.py b/rllib/examples/rl_modules/classes/lstm_containing_rlm.py index 0116145f9c6cf..a28a675b8aabf 100644 --- a/rllib/examples/rl_modules/classes/lstm_containing_rlm.py +++ b/rllib/examples/rl_modules/classes/lstm_containing_rlm.py @@ -108,8 +108,8 @@ def get_initial_state(self) -> Any: @override(TorchRLModule) def _forward(self, batch, **kwargs): # Compute the basic 1D feature tensor (inputs to policy- and value-heads). - features, state_outs = self._compute_features_and_state_outs(batch) - logits = self._pi_head(features) + embeddings, state_outs = self._compute_embeddings_and_state_outs(batch) + logits = self._pi_head(embeddings) # Return logits as ACTION_DIST_INPUTS (categorical distribution). # Note that the default `GetActions` connector piece (in the EnvRunner) will @@ -124,32 +124,32 @@ def _forward(self, batch, **kwargs): def _forward_train(self, batch, **kwargs): # Same logic as _forward, but also return features to be used by value function # branch during training. - features, state_outs = self._compute_features_and_state_outs(batch) - logits = self._pi_head(features) + embeddings, state_outs = self._compute_features_and_state_outs(batch) + logits = self._pi_head(embeddings) return { Columns.ACTION_DIST_INPUTS: logits, Columns.STATE_OUT: state_outs, - Columns.FEATURES: features, + Columns.EMBEDDINGS: embeddings, } # We implement this RLModule as a ValueFunctionAPI RLModule, so it can be used # by value-based methods like PPO or IMPALA. @override(ValueFunctionAPI) def compute_values( - self, batch: Dict[str, Any], features: Optional[Any] = None + self, batch: Dict[str, Any], embeddings: Optional[Any] = None ) -> TensorType: - if features is None: - features, _ = self._compute_features_and_state_outs(batch) - values = self._values(features).squeeze(-1) + if embeddings is None: + embeddings, _ = self._compute_embeddings_and_state_outs(batch) + values = self._values(embeddings).squeeze(-1) return values - def _compute_features_and_state_outs(self, batch): + def _compute_embeddings_and_state_outs(self, batch): obs = batch[Columns.OBS] state_in = batch[Columns.STATE_IN] h, c = state_in["h"], state_in["c"] # Unsqueeze the layer dim (we only have 1 LSTM layer). - features, (h, c) = self._lstm(obs, (h.unsqueeze(0), c.unsqueeze(0))) + embeddings, (h, c) = self._lstm(obs, (h.unsqueeze(0), c.unsqueeze(0))) # Push through our FC net. - features = self._fc_net(features) + embeddings = self._fc_net(embeddings) # Squeeze the layer dim (we only have 1 LSTM layer). - return features, {"h": h.squeeze(0), "c": c.squeeze(0)} + return embeddings, {"h": h.squeeze(0), "c": c.squeeze(0)} diff --git a/rllib/examples/rl_modules/classes/modelv2_to_rlm.py b/rllib/examples/rl_modules/classes/modelv2_to_rlm.py index bf8e4731ceef5..0fa166a610a77 100644 --- a/rllib/examples/rl_modules/classes/modelv2_to_rlm.py +++ b/rllib/examples/rl_modules/classes/modelv2_to_rlm.py @@ -1,5 +1,5 @@ import pathlib -from typing import Any, Dict +from typing import Any, Dict, Optional import tree from ray.rllib.core import Columns, DEFAULT_POLICY_ID @@ -181,7 +181,7 @@ def _forward_pass(self, batch, inference=True): return output @override(ValueFunctionAPI) - def compute_values(self, batch: Dict[str, Any]): + def compute_values(self, batch: Dict[str, Any], embeddings: Optional[Any] = None): self._forward_pass(batch, inference=False) v_preds = self._model_v2.value_function() if Columns.STATE_IN in batch and Columns.SEQ_LENS in batch: diff --git a/rllib/examples/rl_modules/classes/tiny_atari_cnn_rlm.py b/rllib/examples/rl_modules/classes/tiny_atari_cnn_rlm.py index fd9f1cefa8caf..317b3e3c8c091 100644 --- a/rllib/examples/rl_modules/classes/tiny_atari_cnn_rlm.py +++ b/rllib/examples/rl_modules/classes/tiny_atari_cnn_rlm.py @@ -123,7 +123,7 @@ def setup(self): @override(TorchRLModule) def _forward(self, batch, **kwargs): # Compute the basic 1D feature tensor (inputs to policy- and value-heads). - _, logits = self._compute_features_and_logits(batch) + _, logits = self._compute_embeddings_and_logits(batch) # Return features and logits as ACTION_DIST_INPUTS (categorical distribution). return { Columns.ACTION_DIST_INPUTS: logits, @@ -132,11 +132,11 @@ def _forward(self, batch, **kwargs): @override(TorchRLModule) def _forward_train(self, batch, **kwargs): # Compute the basic 1D feature tensor (inputs to policy- and value-heads). - features, logits = self._compute_features_and_logits(batch) + embeddings, logits = self._compute_embeddings_and_logits(batch) # Return features and logits as ACTION_DIST_INPUTS (categorical distribution). return { Columns.ACTION_DIST_INPUTS: logits, - Columns.FEATURES: features, + Columns.EMBEDDINGS: embeddings, } # We implement this RLModule as a ValueFunctionAPI RLModule, so it can be used @@ -145,20 +145,20 @@ def _forward_train(self, batch, **kwargs): def compute_values( self, batch: Dict[str, Any], - features: Optional[Any] = None, + embeddings: Optional[Any] = None, ) -> TensorType: # Features not provided -> We need to compute them first. - if features is None: + if embeddings is None: obs = batch[Columns.OBS] - features = self._base_cnn_stack(obs.permute(0, 3, 1, 2)) - features = torch.squeeze(features, dim=[-1, -2]) - return self._values(features).squeeze(-1) + embeddings = self._base_cnn_stack(obs.permute(0, 3, 1, 2)) + embeddings = torch.squeeze(embeddings, dim=[-1, -2]) + return self._values(embeddings).squeeze(-1) - def _compute_features_and_logits(self, batch): + def _compute_embeddings_and_logits(self, batch): obs = batch[Columns.OBS].permute(0, 3, 1, 2) - features = self._base_cnn_stack(obs) - logits = self._logits(features) + embeddings = self._base_cnn_stack(obs) + logits = self._logits(embeddings) return ( - torch.squeeze(features, dim=[-1, -2]), + torch.squeeze(embeddings, dim=[-1, -2]), torch.squeeze(logits, dim=[-1, -2]), ) From 8a009a9032b2a985b177c156953c10afe5352e9e Mon Sep 17 00:00:00 2001 From: sven1977 Date: Sat, 5 Oct 2024 14:34:36 +0200 Subject: [PATCH 7/7] wip Signed-off-by: sven1977 --- rllib/core/rl_module/multi_rl_module.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/rllib/core/rl_module/multi_rl_module.py b/rllib/core/rl_module/multi_rl_module.py index 38109867f4ef8..a9f8a7a606f84 100644 --- a/rllib/core/rl_module/multi_rl_module.py +++ b/rllib/core/rl_module/multi_rl_module.py @@ -312,15 +312,15 @@ def get( return self._rl_modules[module_id] def items(self) -> ItemsView[ModuleID, RLModule]: - """Returns a keys view over the module IDs in this MultiRLModule.""" + """Returns an ItemsView over the module IDs in this MultiRLModule.""" return self._rl_modules.items() def keys(self) -> KeysView[ModuleID]: - """Returns a keys view over the module IDs in this MultiRLModule.""" + """Returns a KeysView over the module IDs in this MultiRLModule.""" return self._rl_modules.keys() def values(self) -> ValuesView[ModuleID]: - """Returns a keys view over the module IDs in this MultiRLModule.""" + """Returns a ValuesView over the module IDs in this MultiRLModule.""" return self._rl_modules.values() def __len__(self) -> int: