Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

filter task groups on target rather than getting indexes with target #7

Merged
merged 2 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions R/config.R
Original file line number Diff line number Diff line change
Expand Up @@ -106,32 +106,27 @@ validate_config_vs_hub_tasks <- function(hub_path, webevals_config) {
#'
#' @noRd
validate_config_targets <- function(webevals_config, task_groups, task_id_names) {
target_ids_by_task_group <- get_target_ids_by_task_group(task_groups)

for (target in webevals_config$targets) {
target_id <- target$target_id

# get task groups for this target
task_group_idxs_w_target <- get_task_group_idxs_w_target(target_id, target_ids_by_task_group)
task_groups_w_target <- filter_task_groups_to_target(task_groups, target_id)

# check that target_id in the webevals config appears in the hub tasks
if (length(task_group_idxs_w_target) == 0) {
if (length(task_groups_w_target) == 0) {
raise_config_error(
cli::format_inline("Target id {.val {target_id}} not found in any task group.")
)
}

# check that metrics are valid for the available output types
output_types_for_target <- purrr::map(
task_group_idxs_w_target,
function(i) names(task_groups[[i]][["output_type"]])
task_groups_w_target,
function(task_group) names(task_group[["output_type"]])
) |>
unlist() |>
unique()
task_group_idx <- task_group_idxs_w_target[[1]]
target_type <- task_groups[[task_group_idx]]$target_metadata[[
which(target_ids_by_task_group[[task_group_idx]] == target_id)
]]$target_type
target_type <- task_groups_w_target[[1]]$target_metadata[[1]]$target_type
target_is_ordinal <- target_type == "ordinal"

metric_name_to_output_type <- get_metric_name_to_output_type(
Expand Down
14 changes: 5 additions & 9 deletions R/load_model_out_in_window.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,11 @@ load_model_out_in_window <- function(hub_path, target_id, eval_window) {
hub_tasks_config <- hubUtils::read_config(hub_path, config = "tasks")
round_ids <- hubUtils::get_round_ids(hub_tasks_config)
task_groups <- hubUtils::get_round_model_tasks(hub_tasks_config, round_ids[1])
target_ids_by_task_group <- get_target_ids_by_task_group(task_groups)
task_group_idxs_w_target <- get_task_group_idxs_w_target(target_id, target_ids_by_task_group)

target_meta <- purrr::keep(
task_groups[[task_group_idxs_w_target[[1]]]]$target_metadata,
function(x) x$target_id == target_id
)
target_task_id_var_name <- names(target_meta[[1]]$target_keys)
target_task_id_value <- target_meta[[1]]$target_keys[[target_task_id_var_name]]
task_groups_w_target <- filter_task_groups_to_target(task_groups, target_id)

target_meta <- task_groups_w_target[[1]]$target_metadata[[1]]
target_task_id_var_name <- names(target_meta$target_keys)
target_task_id_value <- target_meta$target_keys[[target_task_id_var_name]]

conn <- conn |>
dplyr::filter(!!rlang::sym(target_task_id_var_name) == target_task_id_value)
Expand Down
42 changes: 12 additions & 30 deletions R/utils-hub_tasks_config.R
Original file line number Diff line number Diff line change
@@ -1,39 +1,21 @@
#' For each task group, get the target_id entries from its target_metadata
#'
#' @return a list with one entry for each task group,
#' where each entry is a character vector of target_ids for that group
#' Filter task groups from a hub's tasks config to those that contain a target_id.
#' Additionally, subset the target_metadata to just the entry for the target_id.
#'
#' @noRd
get_target_ids_by_task_group <- function(task_groups) {
result <- purrr::map(
filter_task_groups_to_target <- function(task_groups, target_id) {
# For each task group, subset the target_metadata to just the entry for the target_id
# If the target_id is not in the task group, the target_metadata will be empty
task_groups <- purrr::map(
task_groups,
function(task_group) {
purrr::map_chr(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a much better second iteration! I would even go so far as to say that you could move this anonymous function into a named function (maybe trim_to_target_id).

task_group[["target_metadata"]],
function(target) target[["target_id"]]
)
task_group$target_metadata <- purrr::keep(task_group$target_metadata,
~ .x$target_id == target_id)
return(task_group)
}
Comment on lines +11 to +13
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of returning task groups with empty target metadata, it might be better to return an empty list() or NULL. This way, you can use task_groups[lengths(task_groups) > 0] to filter the output later instead of using another purrr command.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestions! (I didn't know the lengths function before)

I'm going to merge this as is so that I can keep moving on other pieces

)

return(result)
}


#' Get an integer vector with the indices of elements of
#' target_ids_by_task_group that contain the target_id
#' @noRd
get_task_group_idxs_w_target <- function(target_id, target_ids_by_task_group) {
result <- purrr::map2(
seq_along(target_ids_by_task_group), target_ids_by_task_group,
function(i, target_ids) {
if (target_id %in% target_ids) {
return(i)
}
return(NULL)
}
) |>
purrr::compact() |>
unlist()
# Remove task groups that don't contain the target_id
task_groups <- purrr::keep(task_groups, ~ length(.x$target_metadata) > 0)

return(result)
return(task_groups)
}
63 changes: 26 additions & 37 deletions tests/testthat/test-utils-hub_tasks_config.R
Original file line number Diff line number Diff line change
@@ -1,74 +1,63 @@
test_that(
"get_target_ids_by_task_group works",
"filter_task_groups_to_target works",
{
task_groups <- list(
list(
group_number = 1,
target_metadata = list(
list(target_id = "target_id_1"),
list(target_id = "target_id_2")
)
),
list(
group_number = 2,
target_metadata = list(
list(target_id = "target_id_3"),
list(target_id = "target_id_4"),
list(target_id = "target_id_5")
)
),
list(
group_number = 3,
target_metadata = list(
list(target_id = "target_id_4")
)
)
)

target_ids_by_task_group <- get_target_ids_by_task_group(task_groups)

expect_equal(
target_ids_by_task_group,
filter_task_groups_to_target(task_groups, "target_id_1"),
list(
c("target_id_1", "target_id_2"),
c("target_id_3", "target_id_4", "target_id_5"),
"target_id_4"
list(
group_number = 1,
target_metadata = list(
list(target_id = "target_id_1")
)
)
)
)
}
)

test_that(
"get_task_group_idxs_w_target works",
{
task_groups <- list(
list(
target_metadata = list(
list(target_id = "target_id_1"),
list(target_id = "target_id_2")
)
),
list(
target_metadata = list(
list(target_id = "target_id_3"),
list(target_id = "target_id_4"),
list(target_id = "target_id_5")
)
),
expect_equal(
filter_task_groups_to_target(task_groups, "target_id_4"),
list(
target_metadata = list(
list(target_id = "target_id_4")
list(
group_number = 2,
target_metadata = list(
list(target_id = "target_id_4")
)
),
list(
group_number = 3,
target_metadata = list(
list(target_id = "target_id_4")
)
)
)
)

target_ids_by_task_group <- get_target_ids_by_task_group(task_groups)

expect_equal(
get_task_group_idxs_w_target("target_id_1", target_ids_by_task_group),
1L
)

expect_equal(
get_task_group_idxs_w_target("target_id_4", target_ids_by_task_group),
c(2L, 3L)
filter_task_groups_to_target(task_groups, "NOT A REAL TARGET ID"),
list()
)
}
)
Loading