diff --git a/.Rbuildignore b/.Rbuildignore index e57306d..8ca8d33 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -15,3 +15,4 @@ ^\.Rdata$ ^\.httr-oauth$ ^\.secrets$ +^.vscode diff --git a/.gitignore b/.gitignore index 440e2e6..27440c7 100644 --- a/.gitignore +++ b/.gitignore @@ -56,3 +56,5 @@ docs .Rdata .secrets .quarto + +.vscode/ diff --git a/DESCRIPTION b/DESCRIPTION index c711112..a832dd0 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -51,7 +51,7 @@ Remotes: hubverse-org/hubExamples, hubverse-org/hubUtils Roxygen: list(markdown = TRUE) -RoxygenNote: 7.3.1 +RoxygenNote: 7.3.2 URL: https://hubverse-org.github.io/hubEvals/ Depends: R (>= 2.10) diff --git a/R/score_model_out.R b/R/score_model_out.R index a2c8525..0769ed0 100644 --- a/R/score_model_out.R +++ b/R/score_model_out.R @@ -1,9 +1,13 @@ -#' Score model output predictions with a single `output_type` against observed data +#' Score model output predictions +#' +#' Scores model outputs with a single `output_type` against observed data. #' #' @param model_out_tbl Model output tibble with predictions #' @param target_observations Observed 'ground truth' data to be compared to #' predictions -#' @param metrics Optional character vector of scoring metrics to compute. See details for more. +#' @param metrics Character vector of scoring metrics to compute. If `NULL` +#' (the default), appropriate metrics are chosen automatically. See details +#' for more. #' @param summarize Boolean indicator of whether summaries of forecast scores #' should be computed. Defaults to `TRUE`. #' @param by Character vector naming columns to summarize by. For example, @@ -13,38 +17,40 @@ #' vector of levels for pmf forecasts, in increasing order of the levels. For #' all other output types, this is ignored. #' -#' @details If `metrics` is `NULL` (the default), this function chooses -#' appropriate metrics based on the `output_type` contained in the `model_out_tbl`: -#' \itemize{ -#' \item For `output_type == "quantile"`, we use the default metrics provided by -#' `scoringutils::metrics_quantile()`: `r names(scoringutils::metrics_quantile())` -#' \item For `output_type == "pmf"` and `output_type_id_order` is `NULL` (indicating -#' that the predicted variable is a nominal variable), we use the default metric -#' provided by `scoringutils::metrics_nominal()`, -#' `r names(scoringutils::metrics_nominal())` -#' \item For `output_type == "median"`, we use "ae_point" -#' \item For `output_type == "mean"`, we use "se_point" -#' } -#' -#' Alternatively, a character vector of scoring metrics can be provided. In this -#' case, the following options are supported: -#' - `output_type == "median"` and `output_type == "mean"`: -#' - "ae_point": absolute error of a point prediction (generally recommended for the median) -#' - "se_point": squared error of a point prediction (generally recommended for the mean) -#' - `output_type == "quantile"`: -#' - "ae_median": absolute error of the predictive median (i.e., the quantile at probability level 0.5) -#' - "wis": weighted interval score (WIS) of a collection of quantile predictions -#' - "overprediction": The component of WIS measuring the extent to which -#' predictions fell above the observation. -#' - "underprediction": The component of WIS measuring the extent to which -#' predictions fell below the observation. -#' - "dispersion": The component of WIS measuring the dispersion of forecast -#' distributions. -#' - "interval_coverage_XX": interval coverage at the "XX" level. For example, +#' @details +#' Default metrics are provided by the `scoringutils` package. You can select +#' metrics by passing in a character vector of metric names to the `metrics` +#' argument. +#' +#' The following metrics can be selected (all are used by default) for the +#' different `output_type`s: +#' +#' **Quantile forecasts:** (`output_type == "quantile"`) +#' `r exclude <- c("interval_coverage_50", "interval_coverage_90")` +#' `r metrics <- scoringutils::get_metrics(scoringutils::example_quantile, exclude = exclude)` +#' `r paste("- ", names(metrics), collapse = "\n")` +#' - "interval_coverage_XX": interval coverage at the "XX" level. For example, #' "interval_coverage_95" is the 95% interval coverage rate, which would be calculated #' based on quantiles at the probability levels 0.025 and 0.975. -#' - `output_type == "pmf"`: -#' - "log_score": log score +#' +#' See [scoringutils::get_metrics.forecast_quantile] for details. +#' +#' **Nominal forecasts:** (`output_type == "pmf"` and `output_type_id_order` is `NULL`) +#' +#' `r paste("- ", names(scoringutils::get_metrics(scoringutils::example_nominal)), collapse = "\n")` +#' +#' (scoring for ordinal forecasts will be added in the future). +#' +#' See [scoringutils::get_metrics.forecast_nominal] for details. +#' +#' **Median forecasts:** (`output_type == "median"`) +#' +#' - ae_point: absolute error of the point forecast (recommended for the median, see Gneiting (2011)) +#' +#' See [scoringutils::get_metrics.forecast_point] for details. +#' +#' **Mean forecasts:** (`output_type == "mean"`) +#' - `se_point`: squared error of the point forecast (recommended for the mean, see Gneiting (2011)) #' #' @examplesIf requireNamespace("hubExamples", quietly = TRUE) #' # compute WIS and interval coverage rates at 80% and 90% levels based on @@ -72,7 +78,11 @@ #' ) #' head(pmf_scores) #' -#' @return forecast_quantile +#' @return A data.table with scores +#' +#' @references +#' Gneiting, Tilmann. 2011. "Making and Evaluating Point Forecasts." Journal of the +#' American Statistical Association 106 (494): 746–62. . #' #' @export score_model_out <- function(model_out_tbl, target_observations, metrics = NULL, @@ -82,9 +92,6 @@ score_model_out <- function(model_out_tbl, target_observations, metrics = NULL, # also, retrieve that output_type output_type <- validate_output_type(model_out_tbl) - # get/validate the scoring metrics - metrics <- get_metrics(metrics, output_type, output_type_id_order) - # assemble data for scoringutils su_data <- switch(output_type, quantile = transform_quantile_model_out(model_out_tbl, target_observations), @@ -94,16 +101,15 @@ score_model_out <- function(model_out_tbl, target_observations, metrics = NULL, NULL # default, should not happen because of the validation above ) + # get/validate the scoring metrics + metrics <- get_metrics(forecast = su_data, output_type = output_type, select = metrics) + # compute scores scores <- scoringutils::score(su_data, metrics) # switch back to hubverse naming conventions for model name scores <- dplyr::rename(scores, model_id = "model") - # if present, drop predicted and observed columns - drop_cols <- c("predicted", "observed") - scores <- scores[!colnames(scores) %in% drop_cols] - # if requested, summarize scores if (summarize) { scores <- scoringutils::summarize_scores(scores = scores, by = by) @@ -113,129 +119,56 @@ score_model_out <- function(model_out_tbl, target_observations, metrics = NULL, } -#' Get metrics if user didn't specify anything; otherwise, process -#' and validate user inputs -#' -#' @inheritParams score_model_out -#' -#' @return a list of metric functions as required by scoringutils::score() -#' -#' @noRd -get_metrics <- function(metrics, output_type, output_type_id_order) { - if (is.null(metrics)) { - return(get_metrics_default(output_type, output_type_id_order)) - } else if (is.character(metrics)) { - return(get_metrics_character(metrics, output_type)) - } else { - cli::cli_abort( - "{.arg metrics} must be either `NULL` or a character vector of supported metrics." - ) - } -} - - -#' Default metrics if user didn't specify anything +#' Get scoring metrics #' +#' @param forecast A scoringutils `forecast` object (see +#' [scoringutils::as_forecast()] for details). #' @inheritParams score_model_out #' #' @return a list of metric functions as required by scoringutils::score() #' #' @noRd -get_metrics_default <- function(output_type, output_type_id_order) { - metrics <- switch(output_type, - quantile = scoringutils::metrics_quantile(), - pmf = scoringutils::metrics_nominal(), - mean = scoringutils::metrics_point(select = "se_point"), - median = scoringutils::metrics_point(select = "ae_point"), - NULL # default - ) - if (is.null(metrics)) { - # we have already validated `output_type`, so this case should not be - # triggered; this case is just double checking in case we add something new - # later, to ensure we update this function. - supported_types <- c("mean", "median", "pmf", "quantile") # nolint object_use_linter - cli::cli_abort( - "Provided `model_out_tbl` contains `output_type` {.val {output_type}}; - hubEvals currently only supports the following types: - {.val {supported_types}}" - ) - } +get_metrics <- function(forecast, output_type, select = NULL) { + forecast_type <- class(forecast)[1] - return(metrics) -} - - -#' Convert character vector of metrics to list of functions -#' -#' @inheritParams score_model_out -#' -#' @return a list of metric functions as required by scoringutils::score() -#' -#' @noRd -get_metrics_character <- function(metrics, output_type) { - if (output_type == "quantile") { + # process quantile metrics separately to allow better selection of interval + # coverage metrics + if (forecast_type == "forecast_quantile") { # split into metrics for interval coverage and others - interval_metric_inds <- grepl(pattern = "^interval_coverage_[[:digit:]][[:digit:]]$", metrics) - interval_metrics <- metrics[interval_metric_inds] - other_metrics <- metrics[!interval_metric_inds] + interval_metric_inds <- grepl(pattern = "^interval_coverage_", select) + interval_metrics <- select[interval_metric_inds] + other_metrics <- select[!interval_metric_inds] - # validate metrics - valid_metrics <- c("ae_median", "wis", "overprediction", "underprediction", "dispersion") - invalid_metrics <- other_metrics[!other_metrics %in% valid_metrics] - error_if_invalid_metrics( - valid_metrics = c(valid_metrics, "interval_coverage_XY"), - invalid_metrics = invalid_metrics, - output_type = output_type, - comment = c("i" = "NOTE: `XY` denotes the coverage level, e.g. {.val interval_coverage_95}.") - ) + other_metric_fns <- scoringutils::get_metrics(forecast, select = other_metrics) - # assemble metric functions + # assemble interval coverage functions interval_metric_fns <- lapply( interval_metrics, function(metric) { - level <- as.integer(substr(metric, 19, 20)) + level_str <- substr(metric, 19, nchar(metric)) + level <- suppressWarnings(as.numeric(level_str)) + if (is.na(level) || level <= 0 || level >= 100) { + cli::cli_abort(c( + "Invalid interval coverage level: {level_str}", + "i" = "must be a number between 0 and 100 (exclusive)" + )) + } return(purrr::partial(scoringutils::interval_coverage, interval_range = level)) } ) names(interval_metric_fns) <- interval_metrics - other_metric_fns <- scoringutils::metrics_quantile(select = other_metrics) - - metric_fns <- c(other_metric_fns, interval_metric_fns)[metrics] - metrics <- metric_fns - } else if (output_type == "pmf") { - valid_metrics <- c("log_score") - invalid_metrics <- metrics[!metrics %in% valid_metrics] - error_if_invalid_metrics(valid_metrics, invalid_metrics, output_type) - - metrics <- scoringutils::metrics_nominal(select = metrics) - } else if (output_type %in% c("median", "mean")) { - valid_metrics <- c("ae_point", "se_point") - invalid_metrics <- metrics[!metrics %in% valid_metrics] - error_if_invalid_metrics(valid_metrics, invalid_metrics, output_type) - - metrics <- scoringutils::metrics_point(select = metrics) - } else { - # we have already validated `output_type`, so this case should not be - # triggered; this case is just double checking in case we add something new - # later, to ensure we update this function. - error_if_invalid_output_type(output_type) + metric_fns <- c(other_metric_fns, interval_metric_fns) + return(metric_fns) } - return(metrics) -} - - -error_if_invalid_metrics <- function(valid_metrics, invalid_metrics, output_type, comment = NULL) { - n <- length(invalid_metrics) - if (n > 0) { - cli::cli_abort( - c( - "`metrics` had {n} unsupported metric{?s} for - {.arg output_type} {.val {output_type}}: {.strong {.val {invalid_metrics}}}; - supported metrics include {.val {valid_metrics}}.", - comment - ) - ) + # leave validation of user selection to scoringutils + metric_fns <- scoringutils::get_metrics(forecast, select = select) + if (output_type == "mean") { + metric_fns <- scoringutils::select_metrics(metric_fns, "se_point") + } else if (output_type == "median") { + metric_fns <- scoringutils::select_metrics(metric_fns, "ae_point") } + + return(metric_fns) } diff --git a/man/score_model_out.Rd b/man/score_model_out.Rd index 8c42a92..70b1f60 100644 --- a/man/score_model_out.Rd +++ b/man/score_model_out.Rd @@ -2,7 +2,7 @@ % Please edit documentation in R/score_model_out.R \name{score_model_out} \alias{score_model_out} -\title{Score model output predictions with a single \code{output_type} against observed data} +\title{Score model output predictions} \usage{ score_model_out( model_out_tbl, @@ -19,7 +19,9 @@ score_model_out( \item{target_observations}{Observed 'ground truth' data to be compared to predictions} -\item{metrics}{Optional character vector of scoring metrics to compute. See details for more.} +\item{metrics}{Character vector of scoring metrics to compute. If \code{NULL} +(the default), appropriate metrics are chosen automatically. See details +for more.} \item{summarize}{Boolean indicator of whether summaries of forecast scores should be computed. Defaults to \code{TRUE}.} @@ -33,51 +35,54 @@ vector of levels for pmf forecasts, in increasing order of the levels. For all other output types, this is ignored.} } \value{ -forecast_quantile +A data.table with scores } \description{ -Score model output predictions with a single \code{output_type} against observed data +Scores model outputs with a single \code{output_type} against observed data. } \details{ -If \code{metrics} is \code{NULL} (the default), this function chooses -appropriate metrics based on the \code{output_type} contained in the \code{model_out_tbl}: -\itemize{ -\item For \code{output_type == "quantile"}, we use the default metrics provided by -\code{scoringutils::metrics_quantile()}: wis, overprediction, underprediction, dispersion, bias, interval_coverage_50, interval_coverage_90, interval_coverage_deviation, ae_median -\item For \code{output_type == "pmf"} and \code{output_type_id_order} is \code{NULL} (indicating -that the predicted variable is a nominal variable), we use the default metric -provided by \code{scoringutils::metrics_nominal()}, -log_score -\item For \code{output_type == "median"}, we use "ae_point" -\item For \code{output_type == "mean"}, we use "se_point" -} +Default metrics are provided by the \code{scoringutils} package. You can select +metrics by passing in a character vector of metric names to the \code{metrics} +argument. -Alternatively, a character vector of scoring metrics can be provided. In this -case, the following options are supported: -\itemize{ -\item \code{output_type == "median"} and \code{output_type == "mean"}: -\itemize{ -\item "ae_point": absolute error of a point prediction (generally recommended for the median) -\item "se_point": squared error of a point prediction (generally recommended for the mean) -} -\item \code{output_type == "quantile"}: +The following metrics can be selected (all are used by default) for the +different \code{output_type}s: + +\strong{Quantile forecasts:} (\code{output_type == "quantile"}) \itemize{ -\item "ae_median": absolute error of the predictive median (i.e., the quantile at probability level 0.5) -\item "wis": weighted interval score (WIS) of a collection of quantile predictions -\item "overprediction": The component of WIS measuring the extent to which -predictions fell above the observation. -\item "underprediction": The component of WIS measuring the extent to which -predictions fell below the observation. -\item "dispersion": The component of WIS measuring the dispersion of forecast -distributions. +\item wis +\item overprediction +\item underprediction +\item dispersion +\item bias +\item interval_coverage_deviation +\item ae_median \item "interval_coverage_XX": interval coverage at the "XX" level. For example, "interval_coverage_95" is the 95\% interval coverage rate, which would be calculated based on quantiles at the probability levels 0.025 and 0.975. } -\item \code{output_type == "pmf"}: + +See \link[scoringutils:get_metrics.forecast_quantile]{scoringutils::get_metrics.forecast_quantile} for details. + +\strong{Nominal forecasts:} (\code{output_type == "pmf"} and \code{output_type_id_order} is \code{NULL}) \itemize{ -\item "log_score": log score +\item log_score } + +(scoring for ordinal forecasts will be added in the future). + +See \link[scoringutils:get_metrics.forecast_nominal]{scoringutils::get_metrics.forecast_nominal} for details. + +\strong{Median forecasts:} (\code{output_type == "median"}) +\itemize{ +\item ae_point: absolute error of the point forecast (recommended for the median, see Gneiting (2011)) +} + +See \link[scoringutils:get_metrics.forecast_point]{scoringutils::get_metrics.forecast_point} for details. + +\strong{Mean forecasts:} (\code{output_type == "mean"}) +\itemize{ +\item \code{se_point}: squared error of the point forecast (recommended for the mean, see Gneiting (2011)) } } \examples{ @@ -108,3 +113,7 @@ pmf_scores <- score_model_out( head(pmf_scores) \dontshow{\}) # examplesIf} } +\references{ +Making and Evaluating Point Forecasts, Gneiting, Tilmann, 2011, +Journal of the American Statistical Association. +} diff --git a/tests/testthat/test-score_model_out.R b/tests/testthat/test-score_model_out.R index 8987ee9..1c497f7 100644 --- a/tests/testthat/test-score_model_out.R +++ b/tests/testthat/test-score_model_out.R @@ -107,7 +107,6 @@ test_that("score_model_out succeeds with valid inputs: mean output_type, charact c("model_id", "location") ))) |> dplyr::summarize( - ae_point = mean(.data[["ae"]]), se_point = mean(.data[["se"]]), .groups = "drop" ) @@ -380,7 +379,7 @@ test_that("score_model_out errors when model_out_tbl has multiple output_types", }) -test_that("score_model_out errors when invalid interval levels are requested", { +test_that("score_model_out works with all kinds of interval levels are requested", { # Forecast data from HubExamples: load(test_path("testdata/forecast_outputs.rda")) # sets forecast_outputs load(test_path("testdata/forecast_target_observations.rda")) # sets forecast_target_observations @@ -389,27 +388,36 @@ test_that("score_model_out errors when invalid interval levels are requested", { score_model_out( model_out_tbl = forecast_outputs |> dplyr::filter(.data[["output_type"]] == "quantile"), target_observations = forecast_target_observations, - metrics = "interval_level_5" + metrics = "interval_coverage_5d2a" ), - regexp = "unsupported metric" + regexp = "must be a number between 0 and 100" ) - expect_error( + expect_warning( score_model_out( model_out_tbl = forecast_outputs |> dplyr::filter(.data[["output_type"]] == "quantile"), target_observations = forecast_target_observations, - metrics = "interval_level_100" + metrics = "interval_coverage_55" ), - regexp = "unsupported metric" + "To compute the interval coverage for an interval range of" #scoringutils warning ) expect_error( score_model_out( model_out_tbl = forecast_outputs |> dplyr::filter(.data[["output_type"]] == "quantile"), target_observations = forecast_target_observations, - metrics = "interval_level_XY" + metrics = "interval_coverage_100" + ), + regexp = "must be a number between 0 and 100" + ) + + expect_warning( + score_model_out( + model_out_tbl = forecast_outputs |> dplyr::filter(.data[["output_type"]] == "quantile"), + target_observations = forecast_target_observations, + metrics = "interval_coverage_5.3" ), - regexp = "unsupported metric" + "To compute the interval coverage for an interval range of" #scoringutils warning ) }) @@ -425,7 +433,7 @@ test_that("score_model_out errors when invalid metrics are requested", { target_observations = forecast_target_observations, metrics = "log_score" ), - regexp = "unsupported metric" + regexp = "has additional elements" ) expect_error( @@ -434,17 +442,29 @@ test_that("score_model_out errors when invalid metrics are requested", { target_observations = forecast_target_observations, metrics = list(5, 6, "asdf") ), - regexp = "`metrics` must be either `NULL` or a character vector of supported metrics." + regexp = + "^Assertion on 'c\\(select, exclude\\)' failed: Must be of type 'character' \\(or 'NULL'\\), not 'list'\\.$" + ) + + expect_error( + score_model_out( + model_out_tbl = forecast_outputs |> dplyr::filter(.data[["output_type"]] == "quantile"), + target_observations = forecast_target_observations, + metrics = c("asdfinterval_coverage_90") + ), + regexp = + "has additional elements" ) expect_error( score_model_out( model_out_tbl = forecast_outputs |> dplyr::filter(.data[["output_type"]] == "mean"), target_observations = forecast_target_observations, - metrics = scoringutils::metrics_point(), + metrics = scoringutils::get_metrics(scoringutils::example_point), by = c("model_id", "location") ), - regexp = "`metrics` must be either `NULL` or a character vector of supported metrics." + regexp = + "^Assertion on 'c\\(select, exclude\\)' failed: Must be of type 'character' \\(or 'NULL'\\), not 'list'\\.$" ) }) diff --git a/tests/testthat/test-transform_point_model_out.R b/tests/testthat/test-transform_point_model_out.R index 503c448..1bc6aa2 100644 --- a/tests/testthat/test-transform_point_model_out.R +++ b/tests/testthat/test-transform_point_model_out.R @@ -140,6 +140,9 @@ test_that("hubExamples data set is transformed correctly", { reference_date = as.Date(reference_date, "%Y-%m-%d"), target_end_date = as.Date(target_end_date, "%Y-%m-%d") ) - class(exp_forecast) <- c("forecast_point", "forecast", "data.table", "data.frame") - expect_equal(act_forecast, exp_forecast) + expect_s3_class( + act_forecast, + c("forecast_point", "forecast", "data.table", "data.frame") + ) + expect_equal(as.data.frame(act_forecast), as.data.frame(exp_forecast)) }) diff --git a/tests/testthat/test-transform_quantile_model_out.R b/tests/testthat/test-transform_quantile_model_out.R index bdc01a0..0e9f14b 100644 --- a/tests/testthat/test-transform_quantile_model_out.R +++ b/tests/testthat/test-transform_quantile_model_out.R @@ -89,6 +89,9 @@ test_that("hubExamples data set is transformed correctly", { reference_date = as.Date(reference_date, "%Y-%m-%d"), target_end_date = as.Date(target_end_date, "%Y-%m-%d") ) - class(exp_forecast) <- c("forecast", "forecast_quantile", "data.table", "data.frame") - expect_equal(act_forecast, exp_forecast, ignore_attr = "class") + expect_s3_class( + act_forecast, + c("forecast_quantile", "forecast", "data.table", "data.frame") + ) + expect_equal(as.data.frame(act_forecast), as.data.frame(exp_forecast)) })