diff --git a/rlhfblender/routes/data.py b/rlhfblender/routes/data.py index 355f87a..4e406ed 100644 --- a/rlhfblender/routes/data.py +++ b/rlhfblender/routes/data.py @@ -219,13 +219,15 @@ async def get_single_step_details(request: SingleStepDetailRequest): reward = episode_benchmark_data["rewards"][request.step] info = episode_benchmark_data["infos"][request.step] - return { - "action_distribution": action_distribution.tolist(), - "action": action.item() if np.isscalar(action) else action.tolist(), - "reward": reward.item(), - "info": convert_to_serializable(info), - "action_space": action_space, - } + return convert_to_serializable( + { + "action_distribution": action_distribution, + "action": action, + "reward": reward, + "info": info, + "action_space": action_space, + } + ) @router.post("/get_actions_for_episode", response_model=list[int | float | list[float]], tags=["DATA"])