-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #5 from hubverse-org/elr/load_model_out_in_window
load model out in window
- Loading branch information
Showing
27 changed files
with
12,077 additions
and
438,803 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
#' Load model output data from a hub, filtering to a specified target and | ||
#' evaluation window. | ||
#' | ||
#' @param hub_path A path to the hub. | ||
#' @param target_id The target_id to filter to. | ||
#' @param eval_window A list specifying the evaluation window, derived from the | ||
#' eval_windows field of the predeval config. | ||
#' | ||
#' @return A data frame containing the model output data. | ||
#' @noRd | ||
load_model_out_in_window <- function(hub_path, target_id, eval_window) { | ||
conn <- hubData::connect_hub(hub_path) | ||
|
||
# filter to the requested target_id | ||
hub_tasks_config <- hubUtils::read_config(hub_path, config = "tasks") | ||
round_ids <- hubUtils::get_round_ids(hub_tasks_config) | ||
task_groups <- hubUtils::get_round_model_tasks(hub_tasks_config, round_ids[1]) | ||
target_ids_by_task_group <- get_target_ids_by_task_group(task_groups) | ||
task_group_idxs_w_target <- get_task_group_idxs_w_target(target_id, target_ids_by_task_group) | ||
|
||
target_meta <- purrr::keep( | ||
task_groups[[task_group_idxs_w_target[[1]]]]$target_metadata, | ||
function(x) x$target_id == target_id | ||
) | ||
target_task_id_var_name <- names(target_meta[[1]]$target_keys) | ||
target_task_id_value <- target_meta[[1]]$target_keys[[target_task_id_var_name]] | ||
|
||
conn <- conn |> | ||
dplyr::filter(!!rlang::sym(target_task_id_var_name) == target_task_id_value) | ||
|
||
# if eval_window doesn't specify any subsetting by rounds, return the full data | ||
no_limits <- identical(names(eval_window), "window_name") | ||
if (no_limits) { | ||
return(conn |> dplyr::collect()) | ||
} | ||
|
||
# if eval_window specifies a minimum round id, filter to that | ||
round_id_var_name <- hub_tasks_config[["rounds"]][[1]][["round_id"]] | ||
if ("min_round_id" %in% names(eval_window)) { | ||
conn <- conn |> | ||
dplyr::filter(!!rlang::sym(round_id_var_name) >= eval_window$min_round_id) | ||
} | ||
|
||
# load the data | ||
model_out_tbl <- conn |> dplyr::collect() | ||
|
||
if ("n_last_round_ids" %in% names(eval_window)) { | ||
# filter to the last n rounds | ||
max_present_round_id <- max(model_out_tbl[[round_id_var_name]]) | ||
round_ids <- round_ids[round_ids <= max_present_round_id] | ||
round_ids <- utils::tail(round_ids, eval_window$n_last_round_ids) | ||
model_out_tbl <- model_out_tbl |> | ||
dplyr::filter(!!rlang::sym(round_id_var_name) %in% round_ids) | ||
} | ||
|
||
return(model_out_tbl) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
#' For each task group, get the target_id entries from its target_metadata | ||
#' | ||
#' @return a list with one entry for each task group, | ||
#' where each entry is a character vector of target_ids for that group | ||
#' | ||
#' @noRd | ||
get_target_ids_by_task_group <- function(task_groups) { | ||
result <- purrr::map( | ||
task_groups, | ||
function(task_group) { | ||
purrr::map_chr( | ||
task_group[["target_metadata"]], | ||
function(target) target[["target_id"]] | ||
) | ||
} | ||
) | ||
|
||
return(result) | ||
} | ||
|
||
|
||
#' Get an integer vector with the indices of elements of | ||
#' target_ids_by_task_group that contain the target_id | ||
#' @noRd | ||
get_task_group_idxs_w_target <- function(target_id, target_ids_by_task_group) { | ||
result <- purrr::map2( | ||
seq_along(target_ids_by_task_group), target_ids_by_task_group, | ||
function(i, target_ids) { | ||
if (target_id %in% target_ids) { | ||
return(i) | ||
} | ||
return(NULL) | ||
} | ||
) |> | ||
purrr::compact() |> | ||
unlist() | ||
|
||
return(result) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
#' Check that two data frames are equal up to row order | ||
#' | ||
#' @param df_act The actual data frame | ||
#' @param df_exp The expected data frame | ||
expect_df_equal_up_to_order <- function(df_act, df_exp) { | ||
cols <- colnames(df_act) | ||
testthat::expect_equal(cols, colnames(df_exp)) | ||
testthat::expect_equal( | ||
dplyr::arrange(df_act, dplyr::across(dplyr::all_of(cols))), | ||
dplyr::arrange(df_exp, dplyr::across(dplyr::all_of(cols))), | ||
ignore_attr = FALSE | ||
) | ||
} |
Oops, something went wrong.