Skip to content

Commit

Permalink
Rename fields
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Feb 16, 2025
1 parent f115e78 commit 916e030
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 14 deletions.
14 changes: 7 additions & 7 deletions d3rlpy/algos/transformer/tacr.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@ class TACRConfig(TransformerConfig):
"""

batch_size: int = 64
learning_rate: float = 1e-4
actor_learning_rate: float = 1e-4
critic_learning_rate: float = 1e-4
encoder_factory: EncoderFactory = make_encoder_field()
actor_encoder_factory: EncoderFactory = make_encoder_field()
critic_encoder_factory: EncoderFactory = make_encoder_field()
optim_factory: OptimizerFactory = make_optimizer_field()
actor_optim_factory: OptimizerFactory = make_optimizer_field()
critic_optim_factory: OptimizerFactory = make_optimizer_field()
q_func_factory: QFunctionFactory = make_q_func_field()
num_heads: int = 1
Expand Down Expand Up @@ -96,7 +96,7 @@ def inner_create_impl(
transformer = create_continuous_decision_transformer(
observation_shape=observation_shape,
action_size=action_size,
encoder_factory=self._config.encoder_factory,
encoder_factory=self._config.actor_encoder_factory,
num_heads=self._config.num_heads,
max_timestep=self._config.max_timestep,
num_layers=self._config.num_layers,
Expand All @@ -109,9 +109,9 @@ def inner_create_impl(
device=self._device,
enable_ddp=self._enable_ddp,
)
optim = self._config.optim_factory.create(
optim = self._config.actor_optim_factory.create(
transformer.named_modules(),
lr=self._config.learning_rate,
lr=self._config.actor_learning_rate,
compiled=self.compiled,
)

Expand Down Expand Up @@ -141,7 +141,7 @@ def inner_create_impl(

modules = TACRModules(
transformer=transformer,
optim=optim,
actor_optim=optim,
q_funcs=q_funcs,
targ_q_funcs=targ_q_funcs,
critic_optim=critic_optim,
Expand Down
6 changes: 3 additions & 3 deletions d3rlpy/algos/transformer/torch/tacr_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
@dataclasses.dataclass(frozen=True)
class TACRModules(Modules):
transformer: ContinuousDecisionTransformer
optim: OptimizerWrapper
actor_optim: OptimizerWrapper
q_funcs: nn.ModuleList
targ_q_funcs: nn.ModuleList
critic_optim: OptimizerWrapper
Expand Down Expand Up @@ -90,7 +90,7 @@ def inner_predict(self, inpt: TorchTransformerInput) -> torch.Tensor:
def compute_actor_grad(
self, batch: TorchTrajectoryMiniBatch
) -> torch.Tensor:
self._modules.optim.zero_grad()
self._modules.actor_optim.zero_grad()
loss = self.compute_actor_loss(batch)
loss.backward()
return loss
Expand All @@ -111,7 +111,7 @@ def inner_update(
metrics = {}

actor_loss = self._compute_actor_grad(batch)
self._modules.optim.step()
self._modules.actor_optim.step()
metrics.update({"actor_loss": float(actor_loss.cpu().detach())})

critic_loss = self._compute_critic_grad(batch)
Expand Down
6 changes: 3 additions & 3 deletions reproductions/offline/tacr.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ def main() -> None:

tacr = d3rlpy.algos.TACRConfig(
batch_size=64,
learning_rate=1e-4,
optim_factory=d3rlpy.optimizers.AdamWFactory(
actor_learning_rate=1e-4,
actor_optim_factory=d3rlpy.optimizers.AdamWFactory(
weight_decay=1e-4,
clip_grad_norm=0.25,
lr_scheduler_factory=d3rlpy.optimizers.WarmupSchedulerFactory(
warmup_steps=10000
),
),
encoder_factory=d3rlpy.models.VectorEncoderFactory(
actor_encoder_factory=d3rlpy.models.VectorEncoderFactory(
[128],
exclude_last_activation=True,
),
Expand Down
2 changes: 1 addition & 1 deletion tests/algos/transformer/test_tacr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_tacr(observation_shape: Shape, scalers: Optional[str]) -> None:
scalers, observation_shape
)
config = TACRConfig(
encoder_factory=DummyEncoderFactory(),
actor_encoder_factory=DummyEncoderFactory(),
critic_encoder_factory=DummyEncoderFactory(),
observation_scaler=observation_scaler,
action_scaler=action_scaler,
Expand Down

0 comments on commit 916e030

Please sign in to comment.