Skip to content

Commit

Permalink
Merge pull request #5 from hubverse-org/elr/load_model_out_in_window
Browse files Browse the repository at this point in the history
load model out in window
  • Loading branch information
elray1 authored Dec 20, 2024
2 parents b2d8313 + 9d03da5 commit 70a3b12
Show file tree
Hide file tree
Showing 27 changed files with 12,077 additions and 438,803 deletions.
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ License: MIT + file LICENSE
Imports:
cli,
dplyr (>= 1.1.0),
hubData,
hubUtils,
jsonlite,
jsonvalidate,
Expand All @@ -18,6 +19,7 @@ Imports:
scoringutils (>= 2.0.0.9000),
yaml
Remotes:
hubverse-org/hubData,
hubverse-org/hubUtils,
epiforecasts/scoringutils
Encoding: UTF-8
Expand Down
43 changes: 10 additions & 33 deletions R/config.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

#' Load and validate a webevals config file
#'
#' @param hub_path A path to the hub.
Expand All @@ -9,7 +8,7 @@
read_config <- function(hub_path, config_path) {
tryCatch(
{
config <- yaml::read_yaml(config_path)
config <- yaml::read_yaml(config_path, eval.expr = FALSE)
},
error = function(e) {
# This handler is used when an unrecoverable error is thrown while
Expand Down Expand Up @@ -88,23 +87,9 @@ validate_config_vs_hub_tasks <- function(hub_path, webevals_config) {
}

task_groups <- hub_tasks_config[["rounds"]][[1]][["model_tasks"]]
target_ids_by_task_group <- purrr::map(
task_groups,
function(task_group) {
purrr::map_chr(
task_group[["target_metadata"]],
function(target) target[["target_id"]]
)
}
)

# checks for targets
validate_config_targets(
webevals_config,
task_groups,
target_ids_by_task_group,
task_id_names
)
validate_config_targets(webevals_config, task_groups, task_id_names)

# checks for eval_windows
validate_config_eval_windows(webevals_config, hub_tasks_config)
Expand All @@ -120,39 +105,30 @@ validate_config_vs_hub_tasks <- function(hub_path, webevals_config) {
#' - disaggregate_by entries are task id variable names
#'
#' @noRd
validate_config_targets <- function(webevals_config, task_groups,
target_ids_by_task_group, task_id_names) {
validate_config_targets <- function(webevals_config, task_groups, task_id_names) {
target_ids_by_task_group <- get_target_ids_by_task_group(task_groups)

for (target in webevals_config$targets) {
target_id <- target$target_id

# get task groups for this target
task_group_idxs_with_target <- purrr::map2(
seq_along(target_ids_by_task_group), target_ids_by_task_group,
function(i, target_ids) {
if (target_id %in% target_ids) {
return(i)
}
return(NULL)
}
) |>
purrr::compact() |>
unlist()
task_group_idxs_w_target <- get_task_group_idxs_w_target(target_id, target_ids_by_task_group)

# check that target_id in the webevals config appears in the hub tasks
if (length(task_group_idxs_with_target) == 0) {
if (length(task_group_idxs_w_target) == 0) {
raise_config_error(
cli::format_inline("Target id {.val {target_id}} not found in any task group.")
)
}

# check that metrics are valid for the available output types
output_types_for_target <- purrr::map(
task_group_idxs_with_target,
task_group_idxs_w_target,
function(i) names(task_groups[[i]][["output_type"]])
) |>
unlist() |>
unique()
task_group_idx <- task_group_idxs_with_target[[1]]
task_group_idx <- task_group_idxs_w_target[[1]]
target_type <- task_groups[[task_group_idx]]$target_metadata[[
which(target_ids_by_task_group[[task_group_idx]] == target_id)
]]$target_type
Expand Down Expand Up @@ -334,6 +310,7 @@ get_standard_metrics <- function(output_type, is_ordinal) {
)
}


#' Raise an error related to the webevals config file
#' @noRd
raise_config_error <- function(msgs) {
Expand Down
57 changes: 57 additions & 0 deletions R/load_model_out_in_window.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#' Load model output data from a hub, filtering to a specified target and
#' evaluation window.
#'
#' @param hub_path A path to the hub.
#' @param target_id The target_id to filter to.
#' @param eval_window A list specifying the evaluation window, derived from the
#' eval_windows field of the predeval config.
#'
#' @return A data frame containing the model output data.
#' @noRd
load_model_out_in_window <- function(hub_path, target_id, eval_window) {
conn <- hubData::connect_hub(hub_path)

# filter to the requested target_id
hub_tasks_config <- hubUtils::read_config(hub_path, config = "tasks")
round_ids <- hubUtils::get_round_ids(hub_tasks_config)
task_groups <- hubUtils::get_round_model_tasks(hub_tasks_config, round_ids[1])
target_ids_by_task_group <- get_target_ids_by_task_group(task_groups)
task_group_idxs_w_target <- get_task_group_idxs_w_target(target_id, target_ids_by_task_group)

target_meta <- purrr::keep(
task_groups[[task_group_idxs_w_target[[1]]]]$target_metadata,
function(x) x$target_id == target_id
)
target_task_id_var_name <- names(target_meta[[1]]$target_keys)
target_task_id_value <- target_meta[[1]]$target_keys[[target_task_id_var_name]]

conn <- conn |>
dplyr::filter(!!rlang::sym(target_task_id_var_name) == target_task_id_value)

# if eval_window doesn't specify any subsetting by rounds, return the full data
no_limits <- identical(names(eval_window), "window_name")
if (no_limits) {
return(conn |> dplyr::collect())
}

# if eval_window specifies a minimum round id, filter to that
round_id_var_name <- hub_tasks_config[["rounds"]][[1]][["round_id"]]
if ("min_round_id" %in% names(eval_window)) {
conn <- conn |>
dplyr::filter(!!rlang::sym(round_id_var_name) >= eval_window$min_round_id)
}

# load the data
model_out_tbl <- conn |> dplyr::collect()

if ("n_last_round_ids" %in% names(eval_window)) {
# filter to the last n rounds
max_present_round_id <- max(model_out_tbl[[round_id_var_name]])
round_ids <- round_ids[round_ids <= max_present_round_id]
round_ids <- utils::tail(round_ids, eval_window$n_last_round_ids)
model_out_tbl <- model_out_tbl |>
dplyr::filter(!!rlang::sym(round_id_var_name) %in% round_ids)
}

return(model_out_tbl)
}
39 changes: 39 additions & 0 deletions R/utils-hub_tasks_config.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#' For each task group, get the target_id entries from its target_metadata
#'
#' @return a list with one entry for each task group,
#' where each entry is a character vector of target_ids for that group
#'
#' @noRd
get_target_ids_by_task_group <- function(task_groups) {
result <- purrr::map(
task_groups,
function(task_group) {
purrr::map_chr(
task_group[["target_metadata"]],
function(target) target[["target_id"]]
)
}
)

return(result)
}


#' Get an integer vector with the indices of elements of
#' target_ids_by_task_group that contain the target_id
#' @noRd
get_task_group_idxs_w_target <- function(target_id, target_ids_by_task_group) {
result <- purrr::map2(
seq_along(target_ids_by_task_group), target_ids_by_task_group,
function(i, target_ids) {
if (target_id %in% target_ids) {
return(i)
}
return(NULL)
}
) |>
purrr::compact() |>
unlist()

return(result)
}
8 changes: 0 additions & 8 deletions tests/testthat/_snaps/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

Code
read_config(hub_path, test_path("testdata", "test_configs", "config_valid.yaml"))
Message
i Updating superseded URL `Infectious-Disease-Modeling-hubs` to `hubverse-org`
Output
$targets
$targets[[1]]
Expand Down Expand Up @@ -235,8 +233,6 @@
Code
read_config(hub_path, test_path("testdata", "test_configs",
"config_valid_no_min_round_id.yaml"))
Message
i Updating superseded URL `Infectious-Disease-Modeling-hubs` to `hubverse-org`
Output
$targets
$targets[[1]]
Expand Down Expand Up @@ -454,8 +450,6 @@
Code
read_config(hub_path, test_path("testdata", "test_configs",
"config_valid_no_disaggregate_by.yaml"))
Message
i Updating superseded URL `Infectious-Disease-Modeling-hubs` to `hubverse-org`
Output
$targets
$targets[[1]]
Expand Down Expand Up @@ -681,8 +675,6 @@
Code
read_config(hub_path, test_path("testdata", "test_configs",
"config_valid_no_task_id_text.yaml"))
Message
i Updating superseded URL `Infectious-Disease-Modeling-hubs` to `hubverse-org`
Output
$targets
$targets[[1]]
Expand Down
13 changes: 13 additions & 0 deletions tests/testthat/helper-expect_df_equal_up_to_order.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#' Check that two data frames are equal up to row order
#'
#' @param df_act The actual data frame
#' @param df_exp The expected data frame
expect_df_equal_up_to_order <- function(df_act, df_exp) {
cols <- colnames(df_act)
testthat::expect_equal(cols, colnames(df_exp))
testthat::expect_equal(
dplyr::arrange(df_act, dplyr::across(dplyr::all_of(cols))),
dplyr::arrange(df_exp, dplyr::across(dplyr::all_of(cols))),
ignore_attr = FALSE
)
}
Loading

0 comments on commit 70a3b12

Please sign in to comment.