diff --git a/d3rlpy/algos/transformer/tacr.py b/d3rlpy/algos/transformer/tacr.py index 1d604165..fef176f8 100644 --- a/d3rlpy/algos/transformer/tacr.py +++ b/d3rlpy/algos/transformer/tacr.py @@ -24,14 +24,16 @@ @dataclasses.dataclass() class TACRConfig(TransformerConfig): - """Config of Decision Transformer. + """Config of Transformer Actor-Critic with Regularization. - Decision Transformer solves decision-making problems as a sequence modeling - problem. + Decision Transformer-based actor-critic algorithm. The actor is modeled as + Decision Transformer and additionally trained with a critic model. The + extended actor-critic part is implemented as TD3+BC. References: - * `Chen at el., Decision Transformer: Reinforcement Learning via - Sequence Modeling. `_ + * `Lee at el., Transformer Actor-Critic with Regularization: Automated + Stock Trading using Reinforcement Learning. + `_ Args: observation_scaler (d3rlpy.preprocessing.ObservationScaler): @@ -41,11 +43,16 @@ class TACRConfig(TransformerConfig): context_size (int): Prior sequence length. max_timestep (int): Maximum environmental timestep. batch_size (int): Mini-batch size. - learning_rate (float): Learning rate. - encoder_factory (d3rlpy.models.encoders.EncoderFactory): - Encoder factory. - optim_factory (d3rlpy.optimizers.OptimizerFactory): - Optimizer factory. + actor_learning_rate (float): Learning rate for actor. + critic_learning_rate (float): Learning rate for critic. + actor_encoder_factory (d3rlpy.models.encoders.EncoderFactory): + Encoder factory for actor. + critic_encoder_factory (d3rlpy.models.encoders.EncoderFactory): + Encoder factory for critic. + actor_optim_factory (d3rlpy.optimizers.OptimizerFactory): + Optimizer factory for actor. + critic_optim_factory (d3rlpy.optimizers.OptimizerFactory): + Optimizer factory for critic. num_heads (int): Number of attention heads. num_layers (int): Number of attention blocks. attn_dropout (float): Dropout probability for attentions. @@ -54,6 +61,11 @@ class TACRConfig(TransformerConfig): activation_type (str): Type of activation function. position_encoding_type (d3rlpy.PositionEncodingType): Type of positional encoding (``SIMPLE`` or ``GLOBAL``). + n_critics (int): Number of critics. + alpha (float): Weight of Q-value actor loss. + tau (float): Target network synchronization coefficiency. + target_smoothing_sigma (float): Standard deviation for target noise. + target_smoothing_clip (float): Clipping range for target noise. compile_graph (bool): Flag to enable JIT compilation and CUDAGraph. """