Skip to content

Commit

Permalink
Update TACR comment
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Feb 16, 2025
1 parent 916e030 commit c417972
Showing 1 changed file with 22 additions and 10 deletions.
32 changes: 22 additions & 10 deletions d3rlpy/algos/transformer/tacr.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,16 @@

@dataclasses.dataclass()
class TACRConfig(TransformerConfig):
"""Config of Decision Transformer.
"""Config of Transformer Actor-Critic with Regularization.
Decision Transformer solves decision-making problems as a sequence modeling
problem.
Decision Transformer-based actor-critic algorithm. The actor is modeled as
Decision Transformer and additionally trained with a critic model. The
extended actor-critic part is implemented as TD3+BC.
References:
* `Chen at el., Decision Transformer: Reinforcement Learning via
Sequence Modeling. <https://arxiv.org/abs/2106.01345>`_
* `Lee at el., Transformer Actor-Critic with Regularization: Automated
Stock Trading using Reinforcement Learning.
<https://www.ifaamas.org/Proceedings/aamas2023/pdfs/p2815.pdf>`_
Args:
observation_scaler (d3rlpy.preprocessing.ObservationScaler):
Expand All @@ -41,11 +43,16 @@ class TACRConfig(TransformerConfig):
context_size (int): Prior sequence length.
max_timestep (int): Maximum environmental timestep.
batch_size (int): Mini-batch size.
learning_rate (float): Learning rate.
encoder_factory (d3rlpy.models.encoders.EncoderFactory):
Encoder factory.
optim_factory (d3rlpy.optimizers.OptimizerFactory):
Optimizer factory.
actor_learning_rate (float): Learning rate for actor.
critic_learning_rate (float): Learning rate for critic.
actor_encoder_factory (d3rlpy.models.encoders.EncoderFactory):
Encoder factory for actor.
critic_encoder_factory (d3rlpy.models.encoders.EncoderFactory):
Encoder factory for critic.
actor_optim_factory (d3rlpy.optimizers.OptimizerFactory):
Optimizer factory for actor.
critic_optim_factory (d3rlpy.optimizers.OptimizerFactory):
Optimizer factory for critic.
num_heads (int): Number of attention heads.
num_layers (int): Number of attention blocks.
attn_dropout (float): Dropout probability for attentions.
Expand All @@ -54,6 +61,11 @@ class TACRConfig(TransformerConfig):
activation_type (str): Type of activation function.
position_encoding_type (d3rlpy.PositionEncodingType):
Type of positional encoding (``SIMPLE`` or ``GLOBAL``).
n_critics (int): Number of critics.
alpha (float): Weight of Q-value actor loss.
tau (float): Target network synchronization coefficiency.
target_smoothing_sigma (float): Standard deviation for target noise.
target_smoothing_clip (float): Clipping range for target noise.
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
"""

Expand Down

0 comments on commit c417972

Please sign in to comment.