Skip to content

Commit

Permalink
Merge branch 'master' into functional_components
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Feb 16, 2025
2 parents 3f9a2aa + aa4c402 commit 7086972
Show file tree
Hide file tree
Showing 15 changed files with 681 additions and 29 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ $ docker run -it --gpus all --name d3rlpy takuseno/d3rlpy:latest bash
| [Calibrated Q-Learning (Cal-QL)](https://arxiv.org/abs/2303.05479) | :no_entry: | :white_check_mark: |
| [ReBRAC](https://arxiv.org/abs/2305.09836) | :no_entry: | :white_check_mark: |
| [Decision Transformer](https://arxiv.org/abs/2106.01345) | :white_check_mark: | :white_check_mark: |
| [Q-learning Decision Transformer (QDT)](https://arxiv.org/abs/2209.03993) | :construction: | :white_check_mark: |
| [Q-learning Decision Transformer (QDT)](https://arxiv.org/abs/2209.03993) | :no_entry: | :white_check_mark: |
| [Transformer Actor-Critic with Regularization (TACR)](https://www.ifaamas.org/Proceedings/aamas2023/pdfs/p2815.pdf) | :no_entry: | :white_check_mark: |
| [Gato](https://arxiv.org/abs/2205.06175) | :construction: | :construction: |

## Supported Q functions
Expand Down
1 change: 1 addition & 0 deletions d3rlpy/algos/transformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .base import *
from .decision_transformer import *
from .inputs import *
from .tacr import *
181 changes: 181 additions & 0 deletions d3rlpy/algos/transformer/tacr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import dataclasses

from ...base import DeviceArg, register_learnable
from ...constants import ActionSpace, PositionEncodingType
from ...models import EncoderFactory, make_encoder_field
from ...models.builders import (
create_continuous_decision_transformer,
create_continuous_q_function,
)
from ...models.q_functions import QFunctionFactory, make_q_func_field
from ...optimizers import OptimizerFactory, make_optimizer_field
from ...types import Shape
from .base import TransformerAlgoBase, TransformerConfig
from .torch.tacr_impl import (
TACRImpl,
TACRModules,
)

__all__ = [
"TACRConfig",
"TACR",
]


@dataclasses.dataclass()
class TACRConfig(TransformerConfig):
"""Config of Transformer Actor-Critic with Regularization.
Decision Transformer-based actor-critic algorithm. The actor is modeled as
Decision Transformer and additionally trained with a critic model. The
extended actor-critic part is implemented as TD3+BC.
References:
* `Lee at el., Transformer Actor-Critic with Regularization: Automated
Stock Trading using Reinforcement Learning.
<https://www.ifaamas.org/Proceedings/aamas2023/pdfs/p2815.pdf>`_
Args:
observation_scaler (d3rlpy.preprocessing.ObservationScaler):
Observation preprocessor.
action_scaler (d3rlpy.preprocessing.ActionScaler): Action preprocessor.
reward_scaler (d3rlpy.preprocessing.RewardScaler): Reward preprocessor.
context_size (int): Prior sequence length.
max_timestep (int): Maximum environmental timestep.
batch_size (int): Mini-batch size.
actor_learning_rate (float): Learning rate for actor.
critic_learning_rate (float): Learning rate for critic.
actor_encoder_factory (d3rlpy.models.encoders.EncoderFactory):
Encoder factory for actor.
critic_encoder_factory (d3rlpy.models.encoders.EncoderFactory):
Encoder factory for critic.
actor_optim_factory (d3rlpy.optimizers.OptimizerFactory):
Optimizer factory for actor.
critic_optim_factory (d3rlpy.optimizers.OptimizerFactory):
Optimizer factory for critic.
num_heads (int): Number of attention heads.
num_layers (int): Number of attention blocks.
attn_dropout (float): Dropout probability for attentions.
resid_dropout (float): Dropout probability for residual connection.
embed_dropout (float): Dropout probability for embeddings.
activation_type (str): Type of activation function.
position_encoding_type (d3rlpy.PositionEncodingType):
Type of positional encoding (``SIMPLE`` or ``GLOBAL``).
n_critics (int): Number of critics.
alpha (float): Weight of Q-value actor loss.
tau (float): Target network synchronization coefficiency.
target_smoothing_sigma (float): Standard deviation for target noise.
target_smoothing_clip (float): Clipping range for target noise.
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
"""

batch_size: int = 64
actor_learning_rate: float = 1e-4
critic_learning_rate: float = 1e-4
actor_encoder_factory: EncoderFactory = make_encoder_field()
critic_encoder_factory: EncoderFactory = make_encoder_field()
actor_optim_factory: OptimizerFactory = make_optimizer_field()
critic_optim_factory: OptimizerFactory = make_optimizer_field()
q_func_factory: QFunctionFactory = make_q_func_field()
num_heads: int = 1
num_layers: int = 3
attn_dropout: float = 0.1
resid_dropout: float = 0.1
embed_dropout: float = 0.1
activation_type: str = "relu"
position_encoding_type: PositionEncodingType = PositionEncodingType.SIMPLE
n_critics: int = 2
alpha: float = 2.5
tau: float = 0.005
target_smoothing_sigma: float = 0.2
target_smoothing_clip: float = 0.5
compile_graph: bool = False

def create(
self, device: DeviceArg = False, enable_ddp: bool = False
) -> "TACR":
return TACR(self, device, enable_ddp)

@staticmethod
def get_type() -> str:
return "tacr"


class TACR(TransformerAlgoBase[TACRImpl, TACRConfig]):
def inner_create_impl(
self, observation_shape: Shape, action_size: int
) -> None:
transformer = create_continuous_decision_transformer(
observation_shape=observation_shape,
action_size=action_size,
encoder_factory=self._config.actor_encoder_factory,
num_heads=self._config.num_heads,
max_timestep=self._config.max_timestep,
num_layers=self._config.num_layers,
context_size=self._config.context_size,
attn_dropout=self._config.attn_dropout,
resid_dropout=self._config.resid_dropout,
embed_dropout=self._config.embed_dropout,
activation_type=self._config.activation_type,
position_encoding_type=self._config.position_encoding_type,
device=self._device,
enable_ddp=self._enable_ddp,
)
optim = self._config.actor_optim_factory.create(
transformer.named_modules(),
lr=self._config.actor_learning_rate,
compiled=self.compiled,
)

q_funcs, q_func_forwarder = create_continuous_q_function(
observation_shape,
action_size,
self._config.critic_encoder_factory,
self._config.q_func_factory,
n_ensembles=self._config.n_critics,
device=self._device,
enable_ddp=self._enable_ddp,
)
targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function(
observation_shape,
action_size,
self._config.critic_encoder_factory,
self._config.q_func_factory,
n_ensembles=self._config.n_critics,
device=self._device,
enable_ddp=self._enable_ddp,
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.named_modules(),
lr=self._config.critic_learning_rate,
compiled=self.compiled,
)

modules = TACRModules(
transformer=transformer,
actor_optim=optim,
q_funcs=q_funcs,
targ_q_funcs=targ_q_funcs,
critic_optim=critic_optim,
)

self._impl = TACRImpl(
observation_shape=observation_shape,
action_size=action_size,
modules=modules,
q_func_forwarder=q_func_forwarder,
targ_q_func_forwarder=targ_q_func_forwarder,
alpha=self._config.alpha,
gamma=self._config.gamma,
tau=self._config.tau,
target_smoothing_sigma=self._config.target_smoothing_sigma,
target_smoothing_clip=self._config.target_smoothing_clip,
device=self._device,
compiled=self.compiled,
)

def get_action_type(self) -> ActionSpace:
return ActionSpace.CONTINUOUS


register_learnable(TACRConfig)
1 change: 1 addition & 0 deletions d3rlpy/algos/transformer/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .decision_transformer_impl import *
from .tacr_impl import *
16 changes: 13 additions & 3 deletions d3rlpy/algos/transformer/torch/decision_transformer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ def __init__(
def inner_predict(self, inpt: TorchTransformerInput) -> torch.Tensor:
# (1, T, A)
action = self._modules.transformer(
inpt.observations, inpt.actions, inpt.returns_to_go, inpt.timesteps
inpt.observations,
inpt.actions,
inpt.returns_to_go,
inpt.timesteps,
1 - inpt.masks,
)
# (1, T, A) -> (A,)
return action[0][-1]
Expand All @@ -81,10 +85,11 @@ def compute_loss(self, batch: TorchTrajectoryMiniBatch) -> torch.Tensor:
batch.actions,
batch.returns_to_go,
batch.timesteps,
1 - batch.masks,
)
# (B, T, A) -> (B, T)
loss = ((action - batch.actions) ** 2).sum(dim=-1)
return loss.mean()
return (loss.reshape(-1) * batch.masks.reshape(-1)).mean()


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -132,7 +137,11 @@ def __init__(
def inner_predict(self, inpt: TorchTransformerInput) -> torch.Tensor:
# (1, T, A)
_, logits = self._modules.transformer(
inpt.observations, inpt.actions, inpt.returns_to_go, inpt.timesteps
inpt.observations,
inpt.actions,
inpt.returns_to_go,
inpt.timesteps,
1 - inpt.masks,
)
# (1, T, A) -> (A,)
return logits[0][-1]
Expand Down Expand Up @@ -177,6 +186,7 @@ def compute_loss(self, batch: TorchTrajectoryMiniBatch) -> torch.Tensor:
batch.actions,
batch.returns_to_go,
batch.timesteps,
1 - batch.masks,
)
loss = F.cross_entropy(
logits.view(-1, self._action_size),
Expand Down
Loading

0 comments on commit 7086972

Please sign in to comment.