From 3f9a2aa02fede0eb7096476eb0c34e5ed22793f4 Mon Sep 17 00:00:00 2001 From: takuseno Date: Wed, 12 Feb 2025 10:00:46 +0900 Subject: [PATCH] Fix mypy check --- reproductions/offline/qdt.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/reproductions/offline/qdt.py b/reproductions/offline/qdt.py index 544d5997..9c9d421d 100644 --- a/reproductions/offline/qdt.py +++ b/reproductions/offline/qdt.py @@ -8,6 +8,7 @@ import d3rlpy from d3rlpy.algos import CQL, IQL +from d3rlpy.algos.qlearning.torch.iql_impl import IQLModules from d3rlpy.dataset import InfiniteBuffer, ReplayBuffer from d3rlpy.types import NDArray @@ -232,9 +233,10 @@ def fit_iql( # workaround for learning scheduler iql.build_with_dataset(dataset) assert iql.impl + assert isinstance(iql.impl.modules, IQLModules) scheduler = CosineAnnealingLR( - iql.impl._modules.actor_optim, # pylint: disable=protected-access - 500000, + optimizer=iql.impl.modules.actor_optim.optim, + T_max=500000, ) def callback(algo: d3rlpy.algos.IQL, epoch: int, total_step: int) -> None: