Skip to content

Commit

Permalink
Merge pull request #11 from hubverse-org/elr/rel_metrics
Browse files Browse the repository at this point in the history
support length 1 arrays and relative metrics
  • Loading branch information
elray1 authored Jan 9, 2025
2 parents 6540a38 + 8a3632a commit a7cff83
Show file tree
Hide file tree
Showing 23 changed files with 1,034 additions and 85 deletions.
21 changes: 21 additions & 0 deletions R/config.R
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,27 @@ validate_config_targets <- function(webevals_config, task_groups, task_id_names)
)
}

# check that relative_metrics is a subset of metrics
extra_relative_metrics <- setdiff(
target$relative_metrics,
target$metrics
)
if (length(extra_relative_metrics) > 0) {
raise_config_error(
c(
cli::format_inline(
"Requested relative metrics for metrics that were not requested ",
"for {.arg target_id} {.val {target_id}}."
),
"i" = cli::format_inline("Requested metric{?s}: {.val {target$metrics}}."),
"x" = cli::format_inline(
"Relative metric{?s} not found in the requested metrics: ",
"{.val {extra_relative_metrics}}."
)
)
)
}

# check that disaggregate_by are task id variable names
extra_disaggregate_by <- setdiff(
target$disaggregate_by,
Expand Down
45 changes: 39 additions & 6 deletions R/generate_eval_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ generate_target_eval_data <- function(hub_path,
target) {
target_id <- target$target_id
metrics <- target$metrics
# if relative_metrics and baseline are not provided, the are NULL
relative_metrics <- target$relative_metrics
baseline <- target$baseline
# adding `NULL` at the beginning will calculate overall scores
disaggregate_by <- c(list(NULL), as.list(target$disaggregate_by))
eval_windows <- config$eval_windows
Expand All @@ -50,6 +53,8 @@ generate_target_eval_data <- function(hub_path,
model_out_tbl = model_out_tbl,
oracle_output = oracle_output,
metric_name_to_output_type = metric_name_to_output_type,
relative_metrics = relative_metrics,
baseline = baseline,
target_id = target_id,
window_name = eval_window$window_name,
by = by,
Expand All @@ -70,18 +75,22 @@ generate_target_eval_data <- function(hub_path,
#' out_path/target_id/window_name/by/scores.csv
#' @noRd
get_and_save_scores <- function(model_out_tbl, oracle_output, metric_name_to_output_type,
relative_metrics, baseline,
target_id, window_name, by,
out_path) {
# Iterate over the output types and calculate scores for each
scores <- purrr::map(
unique(metric_name_to_output_type$output_type),
~ hubEvals::score_model_out(
model_out_tbl = model_out_tbl |> dplyr::filter(output_type == !!.x),
~ get_scores_for_output_type(
model_out_tbl = model_out_tbl,
oracle_output = oracle_output,
metrics = metric_name_to_output_type$metric[
metric_name_to_output_type$output_type == .x
],
by = c("model_id", by)
metric_name_to_output_type = metric_name_to_output_type,
relative_metrics = relative_metrics,
baseline = baseline,
target_id = target_id,
window_name = window_name,
by = by,
output_type = .x
)
) |>
purrr::reduce(dplyr::left_join, by = c("model_id", by))
Expand All @@ -97,3 +106,27 @@ get_and_save_scores <- function(model_out_tbl, oracle_output, metric_name_to_out
file = file.path(target_window_by_out_path, "scores.csv"),
row.names = FALSE)
}


#' Get scores for a target in a given evaluation window for a specific output type.
get_scores_for_output_type <- function(model_out_tbl, oracle_output, metric_name_to_output_type,
relative_metrics, baseline,
target_id, window_name, by,
output_type) {
metrics <- metric_name_to_output_type$metric[
metric_name_to_output_type$output_type == output_type
]
if (!is.null(relative_metrics)) {
relative_metrics <- relative_metrics[relative_metrics %in% metrics]
}
scores <- hubEvals::score_model_out(
model_out_tbl = model_out_tbl |> dplyr::filter(.data[["output_type"]] == !!output_type),
oracle_output = oracle_output,
metrics = metrics,
relative_metrics = relative_metrics,
baseline = baseline,
by = c("model_id", by)
)

return(scores)
}
25 changes: 19 additions & 6 deletions inst/schema/config_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,37 @@
},
"metrics": {
"description": "Names of metrics to compute for this target. These should be names of metrics supported by hubEvals::score_model_out.",
"type": "array",
"type": ["string", "array"],
"items": {
"type": "string"
},
"minItems": 1
},
"relative_metrics": {
"description": "Optional names of metrics for which to compute pairwise relative skill for this target. These should be a subset of the metrics for the target.",
"type": ["string", "array"],
"items": {
"type": "string"
},
"minItems": 0
},
"baseline": {
"description": "Name of the model to use as a baseline for relative skill metrics for this target. Required if relative_metrics is provided.",
"type": "string",
"minItems": 0
},
"disaggregate_by": {
"description": "Optional list of task id columns to disaggregate by. Aggregated scores for each model will always be computed.",
"type": "array",
"type": ["string", "array"],
"items": {
"type": "string"
}
}
},
"required": [
"target_id",
"metrics"
]
"required": ["target_id", "metrics"],
"dependentRequired": {
"relative_metrics": ["baseline"]
}
}
},
"eval_windows": {
Expand Down
Loading

0 comments on commit a7cff83

Please sign in to comment.