-
Notifications
You must be signed in to change notification settings - Fork 245
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
465 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,3 +2,4 @@ | |
from .base import * | ||
from .decision_transformer import * | ||
from .inputs import * | ||
from .tacr import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
import dataclasses | ||
|
||
from ...base import DeviceArg, register_learnable | ||
from ...constants import ActionSpace, PositionEncodingType | ||
from ...models import EncoderFactory, make_encoder_field | ||
from ...models.builders import ( | ||
create_continuous_decision_transformer, | ||
create_continuous_q_function, | ||
) | ||
from ...models.q_functions import QFunctionFactory, make_q_func_field | ||
from ...optimizers import OptimizerFactory, make_optimizer_field | ||
from ...types import Shape | ||
from .base import TransformerAlgoBase, TransformerConfig | ||
from .torch.tacr_impl import ( | ||
TACRImpl, | ||
TACRModules, | ||
) | ||
|
||
__all__ = [ | ||
"TACRConfig", | ||
"TACR", | ||
] | ||
|
||
|
||
@dataclasses.dataclass() | ||
class TACRConfig(TransformerConfig): | ||
"""Config of Decision Transformer. | ||
Decision Transformer solves decision-making problems as a sequence modeling | ||
problem. | ||
References: | ||
* `Chen at el., Decision Transformer: Reinforcement Learning via | ||
Sequence Modeling. <https://arxiv.org/abs/2106.01345>`_ | ||
Args: | ||
observation_scaler (d3rlpy.preprocessing.ObservationScaler): | ||
Observation preprocessor. | ||
action_scaler (d3rlpy.preprocessing.ActionScaler): Action preprocessor. | ||
reward_scaler (d3rlpy.preprocessing.RewardScaler): Reward preprocessor. | ||
context_size (int): Prior sequence length. | ||
max_timestep (int): Maximum environmental timestep. | ||
batch_size (int): Mini-batch size. | ||
learning_rate (float): Learning rate. | ||
encoder_factory (d3rlpy.models.encoders.EncoderFactory): | ||
Encoder factory. | ||
optim_factory (d3rlpy.optimizers.OptimizerFactory): | ||
Optimizer factory. | ||
num_heads (int): Number of attention heads. | ||
num_layers (int): Number of attention blocks. | ||
attn_dropout (float): Dropout probability for attentions. | ||
resid_dropout (float): Dropout probability for residual connection. | ||
embed_dropout (float): Dropout probability for embeddings. | ||
activation_type (str): Type of activation function. | ||
position_encoding_type (d3rlpy.PositionEncodingType): | ||
Type of positional encoding (``SIMPLE`` or ``GLOBAL``). | ||
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph. | ||
""" | ||
|
||
batch_size: int = 64 | ||
learning_rate: float = 1e-4 | ||
critic_learning_rate: float = 1e-4 | ||
encoder_factory: EncoderFactory = make_encoder_field() | ||
critic_encoder_factory: EncoderFactory = make_encoder_field() | ||
optim_factory: OptimizerFactory = make_optimizer_field() | ||
critic_optim_factory: OptimizerFactory = make_optimizer_field() | ||
q_func_factory: QFunctionFactory = make_q_func_field() | ||
num_heads: int = 1 | ||
num_layers: int = 3 | ||
attn_dropout: float = 0.1 | ||
resid_dropout: float = 0.1 | ||
embed_dropout: float = 0.1 | ||
activation_type: str = "relu" | ||
position_encoding_type: PositionEncodingType = PositionEncodingType.SIMPLE | ||
n_critics: int = 2 | ||
alpha: float = 2.5 | ||
tau: float = 0.005 | ||
target_smoothing_sigma: float = 0.2 | ||
target_smoothing_clip: float = 0.5 | ||
compile_graph: bool = False | ||
|
||
def create( | ||
self, device: DeviceArg = False, enable_ddp: bool = False | ||
) -> "TACR": | ||
return TACR(self, device, enable_ddp) | ||
|
||
@staticmethod | ||
def get_type() -> str: | ||
return "tacr" | ||
|
||
|
||
class TACR(TransformerAlgoBase[TACRImpl, TACRConfig]): | ||
def inner_create_impl( | ||
self, observation_shape: Shape, action_size: int | ||
) -> None: | ||
transformer = create_continuous_decision_transformer( | ||
observation_shape=observation_shape, | ||
action_size=action_size, | ||
encoder_factory=self._config.encoder_factory, | ||
num_heads=self._config.num_heads, | ||
max_timestep=self._config.max_timestep, | ||
num_layers=self._config.num_layers, | ||
context_size=self._config.context_size, | ||
attn_dropout=self._config.attn_dropout, | ||
resid_dropout=self._config.resid_dropout, | ||
embed_dropout=self._config.embed_dropout, | ||
activation_type=self._config.activation_type, | ||
position_encoding_type=self._config.position_encoding_type, | ||
device=self._device, | ||
enable_ddp=self._enable_ddp, | ||
) | ||
optim = self._config.optim_factory.create( | ||
transformer.named_modules(), | ||
lr=self._config.learning_rate, | ||
compiled=self.compiled, | ||
) | ||
|
||
q_funcs, q_func_forwarder = create_continuous_q_function( | ||
observation_shape, | ||
action_size, | ||
self._config.critic_encoder_factory, | ||
self._config.q_func_factory, | ||
n_ensembles=self._config.n_critics, | ||
device=self._device, | ||
enable_ddp=self._enable_ddp, | ||
) | ||
targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function( | ||
observation_shape, | ||
action_size, | ||
self._config.critic_encoder_factory, | ||
self._config.q_func_factory, | ||
n_ensembles=self._config.n_critics, | ||
device=self._device, | ||
enable_ddp=self._enable_ddp, | ||
) | ||
critic_optim = self._config.optim_factory.create( | ||
q_funcs.named_modules(), | ||
lr=self._config.critic_learning_rate, | ||
compiled=self.compiled, | ||
) | ||
|
||
modules = TACRModules( | ||
transformer=transformer, | ||
optim=optim, | ||
q_funcs=q_funcs, | ||
targ_q_funcs=targ_q_funcs, | ||
critic_optim=critic_optim, | ||
) | ||
|
||
self._impl = TACRImpl( | ||
observation_shape=observation_shape, | ||
action_size=action_size, | ||
modules=modules, | ||
q_func_forwarder=q_func_forwarder, | ||
targ_q_func_forwarder=targ_q_func_forwarder, | ||
alpha=self._config.alpha, | ||
gamma=self._config.gamma, | ||
tau=self._config.tau, | ||
target_smoothing_sigma=self._config.target_smoothing_sigma, | ||
target_smoothing_clip=self._config.target_smoothing_clip, | ||
device=self._device, | ||
compiled=self.compiled, | ||
) | ||
|
||
def get_action_type(self) -> ActionSpace: | ||
return ActionSpace.CONTINUOUS | ||
|
||
|
||
register_learnable(TACRConfig) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from .decision_transformer_impl import * | ||
from .tacr_impl import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
import dataclasses | ||
from typing import Callable | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from ....models.torch import ( | ||
ContinuousDecisionTransformer, | ||
ContinuousEnsembleQFunctionForwarder, | ||
) | ||
from ....optimizers import OptimizerWrapper | ||
from ....torch_utility import ( | ||
CudaGraphWrapper, | ||
Modules, | ||
TorchMiniBatch, | ||
TorchTrajectoryMiniBatch, | ||
flatten_left_recursively, | ||
soft_sync, | ||
) | ||
from ....types import Shape | ||
from ..base import TransformerAlgoImplBase | ||
from ..inputs import TorchTransformerInput | ||
|
||
__all__ = [ | ||
"TACRImpl", | ||
"TACRModules", | ||
] | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class TACRModules(Modules): | ||
transformer: ContinuousDecisionTransformer | ||
optim: OptimizerWrapper | ||
q_funcs: nn.ModuleList | ||
targ_q_funcs: nn.ModuleList | ||
critic_optim: OptimizerWrapper | ||
|
||
|
||
class TACRImpl(TransformerAlgoImplBase): | ||
_modules: TACRModules | ||
_compute_actor_grad: Callable[[TorchTrajectoryMiniBatch], torch.Tensor] | ||
_compute_critic_grad: Callable[[TorchTrajectoryMiniBatch], torch.Tensor] | ||
|
||
def __init__( | ||
self, | ||
observation_shape: Shape, | ||
action_size: int, | ||
modules: Modules, | ||
q_func_forwarder: ContinuousEnsembleQFunctionForwarder, | ||
targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, | ||
alpha: float, | ||
gamma: float, | ||
tau: float, | ||
target_smoothing_sigma: float, | ||
target_smoothing_clip: float, | ||
compiled: bool, | ||
device: str, | ||
): | ||
super().__init__(observation_shape, action_size, modules, device) | ||
self._q_func_forwarder = q_func_forwarder | ||
self._targ_q_func_forwarder = targ_q_func_forwarder | ||
self._alpha = alpha | ||
self._gamma = gamma | ||
self._tau = tau | ||
self._target_smoothing_sigma = target_smoothing_sigma | ||
self._target_smoothing_clip = target_smoothing_clip | ||
self._compute_actor_grad = ( | ||
CudaGraphWrapper(self.compute_actor_grad) | ||
if compiled | ||
else self.compute_actor_grad | ||
) | ||
self._compute_critic_grad = ( | ||
CudaGraphWrapper(self.compute_critic_grad) | ||
if compiled | ||
else self.compute_critic_grad | ||
) | ||
|
||
def inner_predict(self, inpt: TorchTransformerInput) -> torch.Tensor: | ||
# (1, T, A) | ||
action = self._modules.transformer( | ||
inpt.observations, | ||
inpt.actions, | ||
inpt.returns_to_go, | ||
inpt.timesteps, | ||
1 - inpt.masks, | ||
) | ||
# (1, T, A) -> (A,) | ||
return action[0][-1] | ||
|
||
def compute_actor_grad( | ||
self, batch: TorchTrajectoryMiniBatch | ||
) -> torch.Tensor: | ||
self._modules.optim.zero_grad() | ||
loss = self.compute_actor_loss(batch) | ||
loss.backward() | ||
return loss | ||
|
||
def compute_critic_grad( | ||
self, batch: TorchTrajectoryMiniBatch | ||
) -> torch.Tensor: | ||
self._modules.critic_optim.zero_grad() | ||
transition_batch, masks = batch.to_transition_batch() | ||
q_tpn = self.compute_target(batch, transition_batch) | ||
loss = self.compute_critic_loss(transition_batch, q_tpn, masks) | ||
loss.backward() | ||
return loss | ||
|
||
def inner_update( | ||
self, batch: TorchTrajectoryMiniBatch, grad_step: int | ||
) -> dict[str, float]: | ||
metrics = {} | ||
|
||
actor_loss = self._compute_actor_grad(batch) | ||
self._modules.optim.step() | ||
metrics.update({"actor_loss": float(actor_loss.cpu().detach())}) | ||
|
||
critic_loss = self._compute_critic_grad(batch) | ||
self._modules.critic_optim.step() | ||
soft_sync(self._modules.targ_q_funcs, self._modules.q_funcs, self._tau) | ||
metrics.update({"critic_loss": float(critic_loss.cpu().detach())}) | ||
|
||
return metrics | ||
|
||
def compute_actor_loss( | ||
self, batch: TorchTrajectoryMiniBatch | ||
) -> torch.Tensor: | ||
# (B, T, A) | ||
action = self._modules.transformer( | ||
batch.observations, | ||
batch.actions, | ||
batch.returns_to_go, | ||
batch.timesteps, | ||
1 - batch.masks, | ||
) | ||
# (B * T , 1) | ||
q_values = self._q_func_forwarder.compute_expected_q( | ||
x=flatten_left_recursively(batch.observations, dim=1), | ||
action=action.view(-1, self._action_size), | ||
reduction="min", | ||
) | ||
lam = self._alpha / (q_values.abs().mean()).detach() | ||
q_loss = lam * -q_values | ||
# (B, T, A) -> (B, T) | ||
bc_loss = ((action - batch.actions) ** 2).sum(dim=-1) | ||
return ( | ||
batch.masks.view(-1) * (q_loss.view(-1) + bc_loss.view(-1)) | ||
).mean() | ||
|
||
def compute_critic_loss( | ||
self, batch: TorchMiniBatch, q_tpn: torch.Tensor, masks: torch.Tensor | ||
) -> torch.Tensor: | ||
loss = self._q_func_forwarder.compute_error( | ||
observations=batch.observations, | ||
actions=batch.actions, | ||
rewards=batch.rewards, | ||
target=q_tpn, | ||
terminals=batch.terminals, | ||
gamma=self._gamma, | ||
masks=masks, | ||
) | ||
return loss | ||
|
||
def compute_target( | ||
self, batch: TorchTrajectoryMiniBatch, transition_batch: TorchMiniBatch | ||
) -> torch.Tensor: | ||
with torch.no_grad(): | ||
# (B, T, A) -> (B * (T - 1), A) | ||
action = self._modules.transformer( | ||
batch.observations, | ||
batch.actions, | ||
batch.returns_to_go, | ||
batch.timesteps, | ||
1 - batch.masks, | ||
)[:, :-1].reshape(-1, self._action_size) | ||
# smoothing target | ||
noise = torch.randn(action.shape, device=batch.device) | ||
scaled_noise = self._target_smoothing_sigma * noise | ||
clipped_noise = scaled_noise.clamp( | ||
-self._target_smoothing_clip, self._target_smoothing_clip | ||
) | ||
smoothed_action = action + clipped_noise | ||
clipped_action = smoothed_action.clamp(-1.0, 1.0) | ||
return self._targ_q_func_forwarder.compute_target( | ||
transition_batch.next_observations, | ||
clipped_action, | ||
reduction="min", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.