diff --git a/reproductions/offline/qdt.py b/reproductions/offline/qdt.py index d2f79a17..94910e94 100644 --- a/reproductions/offline/qdt.py +++ b/reproductions/offline/qdt.py @@ -31,7 +31,14 @@ def main() -> None: # first fit Q-learning algorithm to the dataset if args.model_file is not None: - q_algo = d3rlpy.load_learnable(args.model_file) + # load model and assert type + q_algo_loaded = d3rlpy.load_learnable(args.model_file) + if not isinstance(q_algo_loaded, (CQL, IQL)): + raise ValueError( + "The loaded model is not an instance of CQL or IQL." + ) + # cast to the expected type + q_algo = q_algo_loaded else: if args.q_learning_type == "cql": q_algo = fit_cql(dataset, env, args.seed, args.gpu, timestamp) @@ -40,6 +47,9 @@ def main() -> None: # relabel dataset RTGs with the learned value functions print("Relabeling dataset with RTGs...") + if not isinstance(dataset._buffer, InfiniteBuffer): + raise ValueError("Dataset must be an InfiniteBuffer.") + relabel_dataset_rtg( dataset._buffer, q_algo, args.context_size, seed=args.seed ) @@ -49,6 +59,8 @@ def main() -> None: dataset, env, args.context_size, args.seed, args.gpu, False, timestamp ) + return + """ -------------------------------------------------------------------- Aargument dataset @@ -57,11 +69,11 @@ def main() -> None: def relabel_dataset_rtg( buffer: InfiniteBuffer, - q_algo: Union["CQL", "IQL"], + q_algo: Union[CQL, IQL], k: int, num_action_samples: int = 10, seed: int = 0, -): +) -> None: """ Relabel RTG (reward-to-go) to the given dataset using the given Q-function. @@ -109,6 +121,8 @@ def relabel_dataset_rtg( prev_idx = idx + return + """ -------------------------------------------------------------------- Fit offline RL algorithms to the given dataset.