Skip to content

Commit

Permalink
fixed type check issues
Browse files Browse the repository at this point in the history
  • Loading branch information
takuyamagata committed Feb 9, 2025
1 parent 9958206 commit 25d2255
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions reproductions/offline/qdt.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,14 @@ def main() -> None:

# first fit Q-learning algorithm to the dataset
if args.model_file is not None:
q_algo = d3rlpy.load_learnable(args.model_file)
# load model and assert type
q_algo_loaded = d3rlpy.load_learnable(args.model_file)
if not isinstance(q_algo_loaded, (CQL, IQL)):
raise ValueError(
"The loaded model is not an instance of CQL or IQL."
)
# cast to the expected type
q_algo = q_algo_loaded
else:
if args.q_learning_type == "cql":
q_algo = fit_cql(dataset, env, args.seed, args.gpu, timestamp)
Expand All @@ -40,6 +47,9 @@ def main() -> None:

# relabel dataset RTGs with the learned value functions
print("Relabeling dataset with RTGs...")
if not isinstance(dataset._buffer, InfiniteBuffer):
raise ValueError("Dataset must be an InfiniteBuffer.")

relabel_dataset_rtg(
dataset._buffer, q_algo, args.context_size, seed=args.seed
)
Expand All @@ -49,6 +59,8 @@ def main() -> None:
dataset, env, args.context_size, args.seed, args.gpu, False, timestamp
)

return


""" --------------------------------------------------------------------
Aargument dataset
Expand All @@ -57,11 +69,11 @@ def main() -> None:

def relabel_dataset_rtg(
buffer: InfiniteBuffer,
q_algo: Union["CQL", "IQL"],
q_algo: Union[CQL, IQL],
k: int,
num_action_samples: int = 10,
seed: int = 0,
):
) -> None:
"""
Relabel RTG (reward-to-go) to the given dataset using the given Q-function.
Expand Down Expand Up @@ -109,6 +121,8 @@ def relabel_dataset_rtg(

prev_idx = idx

return


""" --------------------------------------------------------------------
Fit offline RL algorithms to the given dataset.
Expand Down

0 comments on commit 25d2255

Please sign in to comment.