Skip to content

Commit

Permalink
Implement TACR
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Feb 15, 2025
1 parent 9251cde commit ffed37c
Show file tree
Hide file tree
Showing 7 changed files with 465 additions and 0 deletions.
1 change: 1 addition & 0 deletions d3rlpy/algos/transformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .base import *
from .decision_transformer import *
from .inputs import *
from .tacr import *
169 changes: 169 additions & 0 deletions d3rlpy/algos/transformer/tacr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import dataclasses

from ...base import DeviceArg, register_learnable
from ...constants import ActionSpace, PositionEncodingType
from ...models import EncoderFactory, make_encoder_field
from ...models.builders import (
create_continuous_decision_transformer,
create_continuous_q_function,
)
from ...models.q_functions import QFunctionFactory, make_q_func_field
from ...optimizers import OptimizerFactory, make_optimizer_field
from ...types import Shape
from .base import TransformerAlgoBase, TransformerConfig
from .torch.tacr_impl import (
TACRImpl,
TACRModules,
)

__all__ = [
"TACRConfig",
"TACR",
]


@dataclasses.dataclass()
class TACRConfig(TransformerConfig):
"""Config of Decision Transformer.
Decision Transformer solves decision-making problems as a sequence modeling
problem.
References:
* `Chen at el., Decision Transformer: Reinforcement Learning via
Sequence Modeling. <https://arxiv.org/abs/2106.01345>`_
Args:
observation_scaler (d3rlpy.preprocessing.ObservationScaler):
Observation preprocessor.
action_scaler (d3rlpy.preprocessing.ActionScaler): Action preprocessor.
reward_scaler (d3rlpy.preprocessing.RewardScaler): Reward preprocessor.
context_size (int): Prior sequence length.
max_timestep (int): Maximum environmental timestep.
batch_size (int): Mini-batch size.
learning_rate (float): Learning rate.
encoder_factory (d3rlpy.models.encoders.EncoderFactory):
Encoder factory.
optim_factory (d3rlpy.optimizers.OptimizerFactory):
Optimizer factory.
num_heads (int): Number of attention heads.
num_layers (int): Number of attention blocks.
attn_dropout (float): Dropout probability for attentions.
resid_dropout (float): Dropout probability for residual connection.
embed_dropout (float): Dropout probability for embeddings.
activation_type (str): Type of activation function.
position_encoding_type (d3rlpy.PositionEncodingType):
Type of positional encoding (``SIMPLE`` or ``GLOBAL``).
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
"""

batch_size: int = 64
learning_rate: float = 1e-4
critic_learning_rate: float = 1e-4
encoder_factory: EncoderFactory = make_encoder_field()
critic_encoder_factory: EncoderFactory = make_encoder_field()
optim_factory: OptimizerFactory = make_optimizer_field()
critic_optim_factory: OptimizerFactory = make_optimizer_field()
q_func_factory: QFunctionFactory = make_q_func_field()
num_heads: int = 1
num_layers: int = 3
attn_dropout: float = 0.1
resid_dropout: float = 0.1
embed_dropout: float = 0.1
activation_type: str = "relu"
position_encoding_type: PositionEncodingType = PositionEncodingType.SIMPLE
n_critics: int = 2
alpha: float = 2.5
tau: float = 0.005
target_smoothing_sigma: float = 0.2
target_smoothing_clip: float = 0.5
compile_graph: bool = False

def create(
self, device: DeviceArg = False, enable_ddp: bool = False
) -> "TACR":
return TACR(self, device, enable_ddp)

@staticmethod
def get_type() -> str:
return "tacr"


class TACR(TransformerAlgoBase[TACRImpl, TACRConfig]):
def inner_create_impl(
self, observation_shape: Shape, action_size: int
) -> None:
transformer = create_continuous_decision_transformer(
observation_shape=observation_shape,
action_size=action_size,
encoder_factory=self._config.encoder_factory,
num_heads=self._config.num_heads,
max_timestep=self._config.max_timestep,
num_layers=self._config.num_layers,
context_size=self._config.context_size,
attn_dropout=self._config.attn_dropout,
resid_dropout=self._config.resid_dropout,
embed_dropout=self._config.embed_dropout,
activation_type=self._config.activation_type,
position_encoding_type=self._config.position_encoding_type,
device=self._device,
enable_ddp=self._enable_ddp,
)
optim = self._config.optim_factory.create(
transformer.named_modules(),
lr=self._config.learning_rate,
compiled=self.compiled,
)

q_funcs, q_func_forwarder = create_continuous_q_function(
observation_shape,
action_size,
self._config.critic_encoder_factory,
self._config.q_func_factory,
n_ensembles=self._config.n_critics,
device=self._device,
enable_ddp=self._enable_ddp,
)
targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function(
observation_shape,
action_size,
self._config.critic_encoder_factory,
self._config.q_func_factory,
n_ensembles=self._config.n_critics,
device=self._device,
enable_ddp=self._enable_ddp,
)
critic_optim = self._config.optim_factory.create(
q_funcs.named_modules(),
lr=self._config.critic_learning_rate,
compiled=self.compiled,
)

modules = TACRModules(
transformer=transformer,
optim=optim,
q_funcs=q_funcs,
targ_q_funcs=targ_q_funcs,
critic_optim=critic_optim,
)

self._impl = TACRImpl(
observation_shape=observation_shape,
action_size=action_size,
modules=modules,
q_func_forwarder=q_func_forwarder,
targ_q_func_forwarder=targ_q_func_forwarder,
alpha=self._config.alpha,
gamma=self._config.gamma,
tau=self._config.tau,
target_smoothing_sigma=self._config.target_smoothing_sigma,
target_smoothing_clip=self._config.target_smoothing_clip,
device=self._device,
compiled=self.compiled,
)

def get_action_type(self) -> ActionSpace:
return ActionSpace.CONTINUOUS


register_learnable(TACRConfig)
1 change: 1 addition & 0 deletions d3rlpy/algos/transformer/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .decision_transformer_impl import *
from .tacr_impl import *
187 changes: 187 additions & 0 deletions d3rlpy/algos/transformer/torch/tacr_impl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import dataclasses
from typing import Callable

import torch
from torch import nn

from ....models.torch import (
ContinuousDecisionTransformer,
ContinuousEnsembleQFunctionForwarder,
)
from ....optimizers import OptimizerWrapper
from ....torch_utility import (
CudaGraphWrapper,
Modules,
TorchMiniBatch,
TorchTrajectoryMiniBatch,
flatten_left_recursively,
soft_sync,
)
from ....types import Shape
from ..base import TransformerAlgoImplBase
from ..inputs import TorchTransformerInput

__all__ = [
"TACRImpl",
"TACRModules",
]


@dataclasses.dataclass(frozen=True)
class TACRModules(Modules):
transformer: ContinuousDecisionTransformer
optim: OptimizerWrapper
q_funcs: nn.ModuleList
targ_q_funcs: nn.ModuleList
critic_optim: OptimizerWrapper


class TACRImpl(TransformerAlgoImplBase):
_modules: TACRModules
_compute_actor_grad: Callable[[TorchTrajectoryMiniBatch], torch.Tensor]
_compute_critic_grad: Callable[[TorchTrajectoryMiniBatch], torch.Tensor]

def __init__(
self,
observation_shape: Shape,
action_size: int,
modules: Modules,
q_func_forwarder: ContinuousEnsembleQFunctionForwarder,
targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder,
alpha: float,
gamma: float,
tau: float,
target_smoothing_sigma: float,
target_smoothing_clip: float,
compiled: bool,
device: str,
):
super().__init__(observation_shape, action_size, modules, device)
self._q_func_forwarder = q_func_forwarder
self._targ_q_func_forwarder = targ_q_func_forwarder
self._alpha = alpha
self._gamma = gamma
self._tau = tau
self._target_smoothing_sigma = target_smoothing_sigma
self._target_smoothing_clip = target_smoothing_clip
self._compute_actor_grad = (
CudaGraphWrapper(self.compute_actor_grad)
if compiled
else self.compute_actor_grad
)
self._compute_critic_grad = (
CudaGraphWrapper(self.compute_critic_grad)
if compiled
else self.compute_critic_grad
)

def inner_predict(self, inpt: TorchTransformerInput) -> torch.Tensor:
# (1, T, A)
action = self._modules.transformer(
inpt.observations,
inpt.actions,
inpt.returns_to_go,
inpt.timesteps,
1 - inpt.masks,
)
# (1, T, A) -> (A,)
return action[0][-1]

def compute_actor_grad(
self, batch: TorchTrajectoryMiniBatch
) -> torch.Tensor:
self._modules.optim.zero_grad()
loss = self.compute_actor_loss(batch)
loss.backward()
return loss

def compute_critic_grad(
self, batch: TorchTrajectoryMiniBatch
) -> torch.Tensor:
self._modules.critic_optim.zero_grad()
transition_batch, masks = batch.to_transition_batch()
q_tpn = self.compute_target(batch, transition_batch)
loss = self.compute_critic_loss(transition_batch, q_tpn, masks)
loss.backward()
return loss

def inner_update(
self, batch: TorchTrajectoryMiniBatch, grad_step: int
) -> dict[str, float]:
metrics = {}

actor_loss = self._compute_actor_grad(batch)
self._modules.optim.step()
metrics.update({"actor_loss": float(actor_loss.cpu().detach())})

critic_loss = self._compute_critic_grad(batch)
self._modules.critic_optim.step()
soft_sync(self._modules.targ_q_funcs, self._modules.q_funcs, self._tau)
metrics.update({"critic_loss": float(critic_loss.cpu().detach())})

return metrics

def compute_actor_loss(
self, batch: TorchTrajectoryMiniBatch
) -> torch.Tensor:
# (B, T, A)
action = self._modules.transformer(
batch.observations,
batch.actions,
batch.returns_to_go,
batch.timesteps,
1 - batch.masks,
)
# (B * T , 1)
q_values = self._q_func_forwarder.compute_expected_q(
x=flatten_left_recursively(batch.observations, dim=1),
action=action.view(-1, self._action_size),
reduction="min",
)
lam = self._alpha / (q_values.abs().mean()).detach()
q_loss = lam * -q_values
# (B, T, A) -> (B, T)
bc_loss = ((action - batch.actions) ** 2).sum(dim=-1)
return (
batch.masks.view(-1) * (q_loss.view(-1) + bc_loss.view(-1))
).mean()

def compute_critic_loss(
self, batch: TorchMiniBatch, q_tpn: torch.Tensor, masks: torch.Tensor
) -> torch.Tensor:
loss = self._q_func_forwarder.compute_error(
observations=batch.observations,
actions=batch.actions,
rewards=batch.rewards,
target=q_tpn,
terminals=batch.terminals,
gamma=self._gamma,
masks=masks,
)
return loss

def compute_target(
self, batch: TorchTrajectoryMiniBatch, transition_batch: TorchMiniBatch
) -> torch.Tensor:
with torch.no_grad():
# (B, T, A) -> (B * (T - 1), A)
action = self._modules.transformer(
batch.observations,
batch.actions,
batch.returns_to_go,
batch.timesteps,
1 - batch.masks,
)[:, :-1].reshape(-1, self._action_size)
# smoothing target
noise = torch.randn(action.shape, device=batch.device)
scaled_noise = self._target_smoothing_sigma * noise
clipped_noise = scaled_noise.clamp(
-self._target_smoothing_clip, self._target_smoothing_clip
)
smoothed_action = action + clipped_noise
clipped_action = smoothed_action.clamp(-1.0, 1.0)
return self._targ_q_func_forwarder.compute_target(
transition_batch.next_observations,
clipped_action,
reduction="min",
)
5 changes: 5 additions & 0 deletions d3rlpy/models/torch/q_functions/ensemble_q_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def compute_ensemble_q_function_error(
target: torch.Tensor,
terminals: torch.Tensor,
gamma: Union[float, torch.Tensor] = 0.99,
masks: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert target.ndim == 2
td_sum = torch.tensor(
Expand All @@ -102,6 +103,8 @@ def compute_ensemble_q_function_error(
gamma=gamma,
reduction="none",
)
if masks is not None:
loss = loss * masks
td_sum += loss.mean()
return td_sum

Expand Down Expand Up @@ -252,6 +255,7 @@ def compute_error(
target: torch.Tensor,
terminals: torch.Tensor,
gamma: Union[float, torch.Tensor] = 0.99,
masks: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return compute_ensemble_q_function_error(
forwarders=self._forwarders,
Expand All @@ -261,6 +265,7 @@ def compute_error(
target=target,
terminals=terminals,
gamma=gamma,
masks=masks,
)

def compute_target(
Expand Down
Loading

0 comments on commit ffed37c

Please sign in to comment.