Skip to content

Commit

Permalink
Merge pull request #7 from hubverse-org/elr/refactor_task_subsetting
Browse files Browse the repository at this point in the history
filter task groups on target rather than getting indexes with target
  • Loading branch information
elray1 authored Dec 23, 2024
2 parents 70a3b12 + f1c0609 commit 269a97e
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 86 deletions.
15 changes: 5 additions & 10 deletions R/config.R
Original file line number Diff line number Diff line change
Expand Up @@ -106,32 +106,27 @@ validate_config_vs_hub_tasks <- function(hub_path, webevals_config) {
#'
#' @noRd
validate_config_targets <- function(webevals_config, task_groups, task_id_names) {
target_ids_by_task_group <- get_target_ids_by_task_group(task_groups)

for (target in webevals_config$targets) {
target_id <- target$target_id

# get task groups for this target
task_group_idxs_w_target <- get_task_group_idxs_w_target(target_id, target_ids_by_task_group)
task_groups_w_target <- filter_task_groups_to_target(task_groups, target_id)

# check that target_id in the webevals config appears in the hub tasks
if (length(task_group_idxs_w_target) == 0) {
if (length(task_groups_w_target) == 0) {
raise_config_error(
cli::format_inline("Target id {.val {target_id}} not found in any task group.")
)
}

# check that metrics are valid for the available output types
output_types_for_target <- purrr::map(
task_group_idxs_w_target,
function(i) names(task_groups[[i]][["output_type"]])
task_groups_w_target,
function(task_group) names(task_group[["output_type"]])
) |>
unlist() |>
unique()
task_group_idx <- task_group_idxs_w_target[[1]]
target_type <- task_groups[[task_group_idx]]$target_metadata[[
which(target_ids_by_task_group[[task_group_idx]] == target_id)
]]$target_type
target_type <- task_groups_w_target[[1]]$target_metadata[[1]]$target_type
target_is_ordinal <- target_type == "ordinal"

metric_name_to_output_type <- get_metric_name_to_output_type(
Expand Down
14 changes: 5 additions & 9 deletions R/load_model_out_in_window.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,11 @@ load_model_out_in_window <- function(hub_path, target_id, eval_window) {
hub_tasks_config <- hubUtils::read_config(hub_path, config = "tasks")
round_ids <- hubUtils::get_round_ids(hub_tasks_config)
task_groups <- hubUtils::get_round_model_tasks(hub_tasks_config, round_ids[1])
target_ids_by_task_group <- get_target_ids_by_task_group(task_groups)
task_group_idxs_w_target <- get_task_group_idxs_w_target(target_id, target_ids_by_task_group)

target_meta <- purrr::keep(
task_groups[[task_group_idxs_w_target[[1]]]]$target_metadata,
function(x) x$target_id == target_id
)
target_task_id_var_name <- names(target_meta[[1]]$target_keys)
target_task_id_value <- target_meta[[1]]$target_keys[[target_task_id_var_name]]
task_groups_w_target <- filter_task_groups_to_target(task_groups, target_id)

target_meta <- task_groups_w_target[[1]]$target_metadata[[1]]
target_task_id_var_name <- names(target_meta$target_keys)
target_task_id_value <- target_meta$target_keys[[target_task_id_var_name]]

conn <- conn |>
dplyr::filter(!!rlang::sym(target_task_id_var_name) == target_task_id_value)
Expand Down
42 changes: 12 additions & 30 deletions R/utils-hub_tasks_config.R
Original file line number Diff line number Diff line change
@@ -1,39 +1,21 @@
#' For each task group, get the target_id entries from its target_metadata
#'
#' @return a list with one entry for each task group,
#' where each entry is a character vector of target_ids for that group
#' Filter task groups from a hub's tasks config to those that contain a target_id.
#' Additionally, subset the target_metadata to just the entry for the target_id.
#'
#' @noRd
get_target_ids_by_task_group <- function(task_groups) {
result <- purrr::map(
filter_task_groups_to_target <- function(task_groups, target_id) {
# For each task group, subset the target_metadata to just the entry for the target_id
# If the target_id is not in the task group, the target_metadata will be empty
task_groups <- purrr::map(
task_groups,
function(task_group) {
purrr::map_chr(
task_group[["target_metadata"]],
function(target) target[["target_id"]]
)
task_group$target_metadata <- purrr::keep(task_group$target_metadata,
~ .x$target_id == target_id)
return(task_group)
}
)

return(result)
}


#' Get an integer vector with the indices of elements of
#' target_ids_by_task_group that contain the target_id
#' @noRd
get_task_group_idxs_w_target <- function(target_id, target_ids_by_task_group) {
result <- purrr::map2(
seq_along(target_ids_by_task_group), target_ids_by_task_group,
function(i, target_ids) {
if (target_id %in% target_ids) {
return(i)
}
return(NULL)
}
) |>
purrr::compact() |>
unlist()
# Remove task groups that don't contain the target_id
task_groups <- purrr::keep(task_groups, ~ length(.x$target_metadata) > 0)

return(result)
return(task_groups)
}
63 changes: 26 additions & 37 deletions tests/testthat/test-utils-hub_tasks_config.R
Original file line number Diff line number Diff line change
@@ -1,74 +1,63 @@
test_that(
"get_target_ids_by_task_group works",
"filter_task_groups_to_target works",
{
task_groups <- list(
list(
group_number = 1,
target_metadata = list(
list(target_id = "target_id_1"),
list(target_id = "target_id_2")
)
),
list(
group_number = 2,
target_metadata = list(
list(target_id = "target_id_3"),
list(target_id = "target_id_4"),
list(target_id = "target_id_5")
)
),
list(
group_number = 3,
target_metadata = list(
list(target_id = "target_id_4")
)
)
)

target_ids_by_task_group <- get_target_ids_by_task_group(task_groups)

expect_equal(
target_ids_by_task_group,
filter_task_groups_to_target(task_groups, "target_id_1"),
list(
c("target_id_1", "target_id_2"),
c("target_id_3", "target_id_4", "target_id_5"),
"target_id_4"
list(
group_number = 1,
target_metadata = list(
list(target_id = "target_id_1")
)
)
)
)
}
)

test_that(
"get_task_group_idxs_w_target works",
{
task_groups <- list(
list(
target_metadata = list(
list(target_id = "target_id_1"),
list(target_id = "target_id_2")
)
),
list(
target_metadata = list(
list(target_id = "target_id_3"),
list(target_id = "target_id_4"),
list(target_id = "target_id_5")
)
),
expect_equal(
filter_task_groups_to_target(task_groups, "target_id_4"),
list(
target_metadata = list(
list(target_id = "target_id_4")
list(
group_number = 2,
target_metadata = list(
list(target_id = "target_id_4")
)
),
list(
group_number = 3,
target_metadata = list(
list(target_id = "target_id_4")
)
)
)
)

target_ids_by_task_group <- get_target_ids_by_task_group(task_groups)

expect_equal(
get_task_group_idxs_w_target("target_id_1", target_ids_by_task_group),
1L
)

expect_equal(
get_task_group_idxs_w_target("target_id_4", target_ids_by_task_group),
c(2L, 3L)
filter_task_groups_to_target(task_groups, "NOT A REAL TARGET ID"),
list()
)
}
)

0 comments on commit 269a97e

Please sign in to comment.