From 7ad94a300dca44b1a4c0c4eec190ccc7d4f511d5 Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Fri, 14 Feb 2025 16:06:40 +0100 Subject: [PATCH] styling --- R/descriptive.R | 215 ++++++++++++++++++++++++++++----------------- R/get-metadata.R | 46 ++++++---- R/import-data.R | 34 ++++--- R/metadata-utils.R | 131 ++++++++++++++++----------- R/model-wis.R | 18 ++-- R/prep-data.R | 53 +++++++---- R/score.R | 1 - 7 files changed, 307 insertions(+), 191 deletions(-) diff --git a/R/descriptive.R b/R/descriptive.R index 93374ab..03fcf8a 100644 --- a/R/descriptive.R +++ b/R/descriptive.R @@ -20,34 +20,41 @@ table_confint <- function(scores, group_var = NULL) { total_forecasts <- nrow(scores) total_models <- n_distinct(scores$Model) if (!is.null(group_var)) { - scores <- scores |> - group_by(.data[[group_var]]) + scores <- scores |> + group_by(.data[[group_var]]) } table <- scores |> - summarise(n_forecasts = n(), - p_forecasts = round(n() / total_forecasts * 100, 1), - n_models = n_distinct(Model), - p_models = round(n_models / total_models * 100, 1), - mean = mean(wis, na.rm = TRUE), - lower = t.test(wis)$conf.int[1], - upper = t.test(wis)$conf.int[2], - median = median(wis, na.rm = TRUE), - lq = quantile(wis, 0.25, na.rm = TRUE), - uq = quantile(wis, 0.75, na.rm = TRUE), - se = sd(wis, na.rm = TRUE) / sqrt(sum(!is.na(wis))) - ) |> - mutate(across(c(mean, lower, upper, - median, lq, uq), ~ round(., 2)), - Models = paste0(n_models, " (", p_models, "%)"), - Forecasts = paste0(n_forecasts, " (", p_forecasts, "%)"), - "Mean WIS (95% CI)" = paste0(mean, " (", - lower, "-", upper, ")"), - "Median WIS (IQR)" = paste0(median, " (", lq, "-", uq, ")")) + summarise( + n_forecasts = n(), + p_forecasts = round(n() / total_forecasts * 100, 1), + n_models = n_distinct(Model), + p_models = round(n_models / total_models * 100, 1), + mean = mean(wis, na.rm = TRUE), + lower = t.test(wis)$conf.int[1], + upper = t.test(wis)$conf.int[2], + median = median(wis, na.rm = TRUE), + lq = quantile(wis, 0.25, na.rm = TRUE), + uq = quantile(wis, 0.75, na.rm = TRUE), + se = sd(wis, na.rm = TRUE) / sqrt(sum(!is.na(wis))) + ) |> + mutate( + across(c( + mean, lower, upper, + median, lq, uq + ), ~ round(., 2)), + Models = paste0(n_models, " (", p_models, "%)"), + Forecasts = paste0(n_forecasts, " (", p_forecasts, "%)"), + "Mean WIS (95% CI)" = paste0( + mean, " (", + lower, "-", upper, ")" + ), + "Median WIS (IQR)" = paste0(median, " (", lq, "-", uq, ")") + ) if (!is.null(group_var)) { - table <- table |> - rename("Variable" = all_of(group_var)) |> - mutate(group = group_var) + table <- table |> + rename("Variable" = all_of(group_var)) |> + mutate(group = group_var) } return(table) } @@ -60,8 +67,10 @@ create_raw_table1 <- function(scores, targets) { horizon <- table_confint(scores, "Horizon") |> filter(!is.na(Variable)) trend <- table_confint(scores, "Trend") - bind_rows(overall, method, targets, - horizon, trend) + bind_rows( + overall, method, targets, + horizon, trend + ) } print_table1 <- function(scores) { @@ -75,7 +84,8 @@ print_table1 <- function(scores) { colnames(table)[!(colnames(table) %in% c("Variable", "group"))] <- paste( colnames(table)[!(colnames(table) %in% c("Variable", "group"))], - outcome, sep = "_" + outcome, + sep = "_" ) return(table) }) @@ -118,16 +128,18 @@ print_table1 <- function(scores) { col.names = str_remove(colnames(table1), "_.*$"), align = c("l", rep("r", ncol(table1) - 1)) ) |> - pack_rows(index = c(" " = 1, - "Method" = 5, - "Number of country targets" = 2, - "Week ahead horizon" = 4, - "3-week trend in incidence" = 3)) |> + pack_rows(index = c( + " " = 1, + "Method" = 5, + "Number of country targets" = 2, + "Week ahead horizon" = 4, + "3-week trend in incidence" = 3 + )) |> add_header_above(headers_to_add) } # Plot over time by explanatory variable ---------------------------------- -plot_over_time <- function(scores, ensemble, add_plot, show_uncertainty = TRUE){ +plot_over_time <- function(scores, ensemble, add_plot, show_uncertainty = TRUE) { quantiles <- c(0.25, 0.5, 0.75) plot_over_time_target <- scores |> @@ -136,26 +148,36 @@ plot_over_time <- function(scores, ensemble, add_plot, show_uncertainty = TRUE){ reframe( n = n(), value = quantile(wis, quantiles, na.rm = TRUE), - quantile = paste0("q", quantiles)) |> + quantile = paste0("q", quantiles) + ) |> pivot_wider(names_from = quantile) |> # Plot - ggplot(aes(x = target_end_date, - col = CountryTargets, - fill = CountryTargets)) + + ggplot(aes( + x = target_end_date, + col = CountryTargets, + fill = CountryTargets + )) + geom_line(aes(y = q0.5), alpha = 0.5) if (show_uncertainty) { plot_over_time_target <- plot_over_time_target + geom_ribbon(aes(ymin = q0.25, ymax = q0.75), - alpha = 0.1, col = NA) + alpha = 0.1, col = NA + ) } plot_over_time_target <- plot_over_time_target + facet_wrap(~outcome_target, scales = "free_y") + scale_x_date(date_labels = "%b %Y") + - scale_fill_manual(values = c("Single-country" = "#e7298a", - "Multi-country" = "#e6ab02"), - aesthetics = c("col", "fill")) + - labs(x = NULL, y = "median WIS (log scale)", - fill = NULL, col = NULL) + + scale_fill_manual( + values = c( + "Single-country" = "#e7298a", + "Multi-country" = "#e6ab02" + ), + aesthetics = c("col", "fill") + ) + + labs( + x = NULL, y = "median WIS (log scale)", + fill = NULL, col = NULL + ) + theme( legend.position = "bottom", strip.background = element_blank() @@ -167,7 +189,8 @@ plot_over_time <- function(scores, ensemble, add_plot, show_uncertainty = TRUE){ reframe( n = n(), value = quantile(wis, quantiles, na.rm = TRUE), - quantile = paste0("q", quantiles)) |> + quantile = paste0("q", quantiles) + ) |> pivot_wider(names_from = quantile) |> # Plot ggplot(aes(x = target_end_date, col = Method, fill = Method)) + @@ -175,15 +198,20 @@ plot_over_time <- function(scores, ensemble, add_plot, show_uncertainty = TRUE){ if (show_uncertainty) { plot_over_time_method <- plot_over_time_method + geom_ribbon(aes(ymin = q0.25, ymax = q0.75), - alpha = 0.1, col = NA) + alpha = 0.1, col = NA + ) } plot_over_time_method <- plot_over_time_method + facet_wrap(~outcome_target, scales = "free_y") + scale_x_date(date_labels = "%b %Y") + - scale_fill_brewer(aesthetics = c("col", "fill"), - type = "qual", palette = 2) + - labs(x = NULL, y = "median WIS (log scale)", - fill = NULL, col = NULL) + + scale_fill_brewer( + aesthetics = c("col", "fill"), + type = "qual", palette = 2 + ) + + labs( + x = NULL, y = "median WIS (log scale)", + fill = NULL, col = NULL + ) + theme( legend.position = "bottom", strip.background = element_blank() @@ -208,8 +236,10 @@ plot_over_time <- function(scores, ensemble, add_plot, show_uncertainty = TRUE){ plot_density <- function(scores) { plot_conditional_density <- function(scores, group_var) { scores |> - ggplot(aes(x = log_wis, - col = .data[[group_var]])) + + ggplot(aes( + x = log_wis, + col = .data[[group_var]] + )) + geom_density() + facet_wrap(~outcome_target, scales = "free") + labs(x = "Log of the weighted interval score") + @@ -222,33 +252,39 @@ plot_density <- function(scores) { method <- plot_conditional_density(scores, "Method") targets <- plot_conditional_density(scores, "CountryTargets") - #affiliated <- plot_conditional_density(scores, "CountryTargetAffiliated") + # affiliated <- plot_conditional_density(scores, "CountryTargetAffiliated") plot_density_patchwork <- method + targets + - #affiliated + + # affiliated + plot_layout(ncol = 1) + plot_annotation(tag_levels = "A") return(plot_density_patchwork) } # Ridge plot by model -------------------- -plot_ridges <- function(scores){ +plot_ridges <- function(scores) { scores |> group_by(Model) |> - mutate(median_score = median(wis, na.rm = TRUE), + mutate( + median_score = median(wis, na.rm = TRUE), lq = quantile(wis, 0.25, na.rm = TRUE), - uq = quantile(wis, 0.75, na.rm = TRUE)) |> + uq = quantile(wis, 0.75, na.rm = TRUE) + ) |> ungroup() |> mutate(Model = fct_reorder(Model, median_score)) |> filter(wis >= lq & wis <= uq) |> # Plot ggplot(aes(x = wis, y = Model, fill = stat(x))) + - geom_density_ridges_gradient(scale = 1.5, - rel_min_height = 0.01, - quantile_lines = TRUE, quantiles = 2) + - scale_fill_viridis_c(name = "Interval score", - option = "C", direction = -1) + + geom_density_ridges_gradient( + scale = 1.5, + rel_min_height = 0.01, + quantile_lines = TRUE, quantiles = 2 + ) + + scale_fill_viridis_c( + name = "Interval score", + option = "C", direction = -1 + ) + theme_ridges() + labs(x = "Interval score IQR", y = "Model") + theme(legend.position = "none") @@ -263,16 +299,21 @@ table_targets <- function(scores) { summarise(target_count = n(), .groups = "drop") |> ungroup() |> group_by(Model, outcome_target) |> - summarise(CountryTargets = all(target_count <= 2), - min_targets = min(target_count), - max_targets = max(target_count), - mean = mean(target_count), - median = median(target_count), - consistent = min_targets==max_targets) |> + summarise( + CountryTargets = all(target_count <= 2), + min_targets = min(target_count), + max_targets = max(target_count), + mean = mean(target_count), + median = median(target_count), + consistent = min_targets == max_targets + ) |> mutate(CountryTargets = factor(CountryTargets, - levels = c(TRUE, FALSE), - labels = c("Single-country", - "Multi-country"))) + levels = c(TRUE, FALSE), + labels = c( + "Single-country", + "Multi-country" + ) + )) return(table_targets) } @@ -297,7 +338,8 @@ table_metadata <- function(scores) { pivot_wider( names_from = "outcome_target", values_from = "Forecasts", - values_fill = "") |> + values_fill = "" + ) |> rename("Country Targets" = CountryTargets) |> arrange(Model) return(metadata_table) @@ -311,19 +353,28 @@ plot_linerange <- function(group_var) { reframe( n = n(), value = quantile(wis, quantiles), - quantile = paste0("q", quantiles)) |> + quantile = paste0("q", quantiles) + ) |> pivot_wider(names_from = quantile) |> # Plot ggplot(aes(y = .data[[group_var]], col = Horizon, fill = Horizon)) + - geom_point(aes(x = q0.5), alpha = 0.8, - position = position_dodge(width = 1)) + - geom_linerange(aes(xmin = q0.25, xmax = q0.75), linewidth = 4, - alpha = 0.5, position = position_dodge(width = 1)) + - geom_linerange(aes(xmin = q0.01, xmax = q0.99), linewidth = 4, - alpha = 0.2, position = position_dodge(width = 1)) + - labs(y = NULL, x = NULL, - col = "Week ahead forecast Horizon", - fill = "Week ahead forecast Horizon") + + geom_point(aes(x = q0.5), + alpha = 0.8, + position = position_dodge(width = 1) + ) + + geom_linerange(aes(xmin = q0.25, xmax = q0.75), + linewidth = 4, + alpha = 0.5, position = position_dodge(width = 1) + ) + + geom_linerange(aes(xmin = q0.01, xmax = q0.99), + linewidth = 4, + alpha = 0.2, position = position_dodge(width = 1) + ) + + labs( + y = NULL, x = NULL, + col = "Week ahead forecast Horizon", + fill = "Week ahead forecast Horizon" + ) + scale_color_viridis_d(direction = 1) + theme(legend.position = "bottom") return(plot) @@ -386,7 +437,7 @@ trends_plot <- function(scores) { geom_line() + scale_colour_brewer(palette = "Set2", na.value = "grey") + theme(legend.position = "bottom") + - facet_wrap(~ location, scales = "free_y") + + facet_wrap(~location, scales = "free_y") + theme(axis.text.x = element_text(angle = 45, vjust = 1, hjust = 1)) + xlab("") return(p) diff --git a/R/get-metadata.R b/R/get-metadata.R index 5eac3c8..8d1dcdb 100644 --- a/R/get-metadata.R +++ b/R/get-metadata.R @@ -4,31 +4,45 @@ library(lubridate) source(here("R", "metadata-utils.R")) # set repo paths -repos <- c("euro" = "covid19-forecast-hub-europe/covid19-forecast-hub-europe", - "us" = "reichlab/covid19-forecast-hub") +repos <- c( + "euro" = "covid19-forecast-hub-europe/covid19-forecast-hub-europe", + "us" = "reichlab/covid19-forecast-hub" +) # date range for submissions min_date <- "2021-03-07" max_date <- as.Date("2023-03-10") - weeks(4) # get euro metadata -eu_submissions <- get_model_submissions(repo = repos[["euro"]], - min_date = min_date, - max_date = max_date) -eu_metadata_raw <- get_metadata_raw(repo = repos[["euro"]], - model_abbr = unique(eu_submissions$model)) -write_metadata(metadata_raw = eu_metadata_raw, - sheet_name = "euro-metadata-raw", write_local = FALSE) +eu_submissions <- get_model_submissions( + repo = repos[["euro"]], + min_date = min_date, + max_date = max_date +) +eu_metadata_raw <- get_metadata_raw( + repo = repos[["euro"]], + model_abbr = unique(eu_submissions$model) +) +write_metadata( + metadata_raw = eu_metadata_raw, + sheet_name = "euro-metadata-raw", write_local = FALSE +) # get US metadata -us_submissions <- get_model_submissions(repo = repos[["us"]], - min_date = min_date, - max_date = max_date) -us_metadata_raw <- get_metadata_raw(repo = repos[["us"]], - model_abbr = unique(us_submissions$model)) +us_submissions <- get_model_submissions( + repo = repos[["us"]], + min_date = min_date, + max_date = max_date +) +us_metadata_raw <- get_metadata_raw( + repo = repos[["us"]], + model_abbr = unique(us_submissions$model) +) us_metadata_raw <- us_metadata_raw |> mutate(euro_overlap = model_abbr %in% unique(eu_submissions$model)) -write_metadata(metadata_raw = us_metadata_raw, - sheet_name = "us-metadata-raw", write_local = FALSE) +write_metadata( + metadata_raw = us_metadata_raw, + sheet_name = "us-metadata-raw", write_local = FALSE +) diff --git a/R/import-data.R b/R/import-data.R index b474bb7..6154e3a 100644 --- a/R/import-data.R +++ b/R/import-data.R @@ -16,11 +16,15 @@ walk(c("case", "death"), \(data_type) { "covid19-forecast-hub-europe/main/data-truth/JHU/", file_name )) |> # aggregate to weekly incidence - mutate(year = epiyear(date), - week = epiweek(date)) |> + mutate( + year = epiyear(date), + week = epiweek(date) + ) |> group_by(location, location_name, year, week) |> - summarise(target_end_date = max(date), - observed = sum(value, na.rm = TRUE)) |> + summarise( + target_end_date = max(date), + observed = sum(value, na.rm = TRUE) + ) |> ungroup() |> select(-year, -week) @@ -32,17 +36,23 @@ walk(c("case", "death"), \(data_type) { # Add "trend" as change in 3-week moving average obs <- obs |> - mutate(ma = zoo::rollmean(observed, - align = "right", k = 3, fill = NA), - trend = ma / lag(ma, n=1), - trend = as.factor(ifelse(is.nan(trend), "Stable", - ifelse(trend >= 1.05, "Increasing", - ifelse(trend <= 0.95, "Decreasing", - "Stable"))))) + mutate( + ma = zoo::rollmean(observed, + align = "right", k = 3, fill = NA + ), + trend = ma / lag(ma, n = 1), + trend = as.factor(ifelse(is.nan(trend), "Stable", + ifelse(trend >= 1.05, "Increasing", + ifelse(trend <= 0.95, "Decreasing", + "Stable" + ) + ) + )) + ) obs <- obs |> select(location, target_end_date, observed, trend) write_csv(obs, here("data", paste0("observed-", data_type, ".csv"))) -} +}) # Population data --------------------------------------------------------- pop <- read_csv(paste0( diff --git a/R/metadata-utils.R b/R/metadata-utils.R index 84d9f3c..094ccc8 100644 --- a/R/metadata-utils.R +++ b/R/metadata-utils.R @@ -15,8 +15,10 @@ library(googlesheets4) # repo: name of github repo (format: "org/repo") get_model_names <- function(repo) { # Get all filepaths in top level data-processed directory - data_processed <- gh(paste0("/repos/", repo, - "/contents/data-processed/?recursive=1")) + data_processed <- gh(paste0( + "/repos/", repo, + "/contents/data-processed/?recursive=1" + )) # Get names of all models model_abbr <- transpose(data_processed) model_abbr <- unlist(model_abbr[["name"]]) @@ -30,32 +32,45 @@ get_model_names <- function(repo) { # model_abbr: optional vector of model names to get content for # min_date, max_date: set to NULL to include all dates get_model_submissions <- function(repo = "covid19-forecast-hub-europe/covid19-forecast-hub-europe", - model_abbr = NULL, - min_date = NULL, - max_date = NULL) { + model_abbr = NULL, + min_date = NULL, + max_date = NULL) { if (is.null(model_abbr)) { model_abbr <- get_model_names(repo) } # Get contents of individual models' directories - model_content <- map(model_abbr, - ~ gh(paste0("/repos/", repo, - "/contents/data-processed/", - .x, - "?recursive=1"))) + model_content <- map( + model_abbr, + ~ gh(paste0( + "/repos/", repo, + "/contents/data-processed/", + .x, + "?recursive=1" + )) + ) # Get forecasts within each model's submissions folder model_content <- unlist(model_content) model_dates <- model_content[names(model_content) == "name"] model_submissions <- tibble("date" = model_dates) |> - mutate(model = substr(date, 12, nchar(date)-4), - date = floor_date(as.Date(substr(date, 1, 10)), - unit="week")) + mutate( + model = substr(date, 12, nchar(date) - 4), + date = floor_date(as.Date(substr(date, 1, 10)), + unit = "week" + ) + ) # Inclusion by date - if (is.null(min_date)) {min_date <- min(model_submissions$date)} - if (is.null(max_date)) {min_date <- max(model_submissions$date)} + if (is.null(min_date)) { + min_date <- min(model_submissions$date) + } + if (is.null(max_date)) { + min_date <- max(model_submissions$date) + } model_submissions <- model_submissions |> - filter(between(date, - as.Date(min_date), - as.Date(max_date))) + filter(between( + date, + as.Date(min_date), + as.Date(max_date) + )) return(model_submissions) } @@ -68,31 +83,41 @@ get_model_submissions <- function(repo = "covid19-forecast-hub-europe/covid19-fo get_metadata_raw <- function(model_abbr, repo) { # - read metadata files from github # -- try a single filepath first to check which formatting style in use - try_read <- try(read_yaml(paste0("https://raw.githubusercontent.com/", - repo, "/main/model-metadata/", - model_abbr[1], - ".yml"))) + try_read <- try(read_yaml(paste0( + "https://raw.githubusercontent.com/", + repo, "/main/model-metadata/", + model_abbr[1], + ".yml" + ))) # (US hub uses old formatting style: txt on master) if ("try-error" %in% class(try_read)) { - metadata_raw <- map_dfr(model_abbr, - ~ read_yaml( - paste0("https://raw.githubusercontent.com/", - repo, "/master", - "/data-processed/", .x, - "/metadata-", .x, ".txt"))) - } - else # (euro hub uses yml on main/model-metadata) + metadata_raw <- map_dfr( + model_abbr, + ~ read_yaml( + paste0( + "https://raw.githubusercontent.com/", + repo, "/master", + "/data-processed/", .x, + "/metadata-", .x, ".txt" + ) + ) + ) + } else # (euro hub uses yml on main/model-metadata) { - metadata_raw <- tryCatch(map_dfr(model_abbr, - ~ read_yaml(paste0( - "https://raw.githubusercontent.com/", - repo, "/main/model-metadata/", - .x, ".yml")))) + metadata_raw <- tryCatch(map_dfr( + model_abbr, + ~ read_yaml(paste0( + "https://raw.githubusercontent.com/", + repo, "/main/model-metadata/", + .x, ".yml" + )) + )) } # add name of hub metadata_raw <- mutate(metadata_raw, - hub = repo) + hub = repo + ) return(metadata_raw) } @@ -100,7 +125,7 @@ get_metadata_raw <- function(model_abbr, repo) { # get_contributors() # args # metadata_raw: raw metadata from get_metadata_raw() -get_contributors <- function(metadata_raw){ +get_contributors <- function(metadata_raw) { ctb <- metadata_raw |> select(model_abbr, model_contributors, team_funding) |> unnest_wider(col = c("model_contributors")) @@ -123,16 +148,16 @@ write_metadata <- function(metadata_raw, distinct() # Write to google sheet - gs4_auth() - cat("Writing to google sheet") - sheet_url <- "https://docs.google.com/spreadsheets/d/1XgXLYBCpdtjztJFhWDJz6G7A_Uw92dnnr-WFqGAFGn4/edit#gid=0" - write_sheet(metadata, ss = sheet_url, sheet = sheet_name) + gs4_auth() + cat("Writing to google sheet") + sheet_url <- "https://docs.google.com/spreadsheets/d/1XgXLYBCpdtjztJFhWDJz6G7A_Uw92dnnr-WFqGAFGn4/edit#gid=0" + write_sheet(metadata, ss = sheet_url, sheet = sheet_name) - # optionally save local copy - if(write_local) { - cat("Writing csv to local /data") - write_csv(metadata, here("data", paste0("metadata-raw", ".csv"))) - } + # optionally save local copy + if (write_local) { + cat("Writing csv to local /data") + write_csv(metadata, here("data", paste0("metadata-raw", ".csv"))) + } } # Fetch metadata from google sheet --------------------------------------- @@ -142,13 +167,13 @@ write_metadata <- function(metadata_raw, # write_local: optionally write to local csv file get_metadata_processed <- function(sheet_name = "model-classification", write_local = TRUE) { - sheet_url <- "https://docs.google.com/spreadsheets/d/1XgXLYBCpdtjztJFhWDJz6G7A_Uw92dnnr-WFqGAFGn4/edit#gid=0" - metadata <- read_sheet(sheet_url, sheet = sheet_name) |> - select(model, KS, RB, SF, JM) + sheet_url <- "https://docs.google.com/spreadsheets/d/1XgXLYBCpdtjztJFhWDJz6G7A_Uw92dnnr-WFqGAFGn4/edit#gid=0" + metadata <- read_sheet(sheet_url, sheet = sheet_name) |> + select(model, KS, RB, SF, JM) - if (write_local) { - cat("Writing csv to local /data") - write_csv(metadata, here("data", paste0(sheet_name, ".csv"))) - } + if (write_local) { + cat("Writing csv to local /data") + write_csv(metadata, here("data", paste0(sheet_name, ".csv"))) + } return(metadata) } diff --git a/R/model-wis.R b/R/model-wis.R index e442b7f..f0c1993 100644 --- a/R/model-wis.R +++ b/R/model-wis.R @@ -18,8 +18,10 @@ m.data <- data |> filter(!grepl("EuroCOVIDhub-", Model)) |> mutate(location = factor(location)) |> group_by(location) |> - mutate(time = as.numeric(forecast_date - min(forecast_date)) / 7, - Horizon = as.numeric(Horizon)) |> + mutate( + time = as.numeric(forecast_date - min(forecast_date)) / 7, + Horizon = as.numeric(Horizon) + ) |> ungroup() # --- Model --- @@ -48,11 +50,11 @@ m.formula <- log_wis ~ m.fits <- outcomes |> set_names() |> map(\(outcome) { - bam( - formula = m.formula, - data = m.data |> filter(outcome_target == outcome), - family = gaussian() - ) -}) + bam( + formula = m.formula, + data = m.data |> filter(outcome_target == outcome), + family = gaussian() + ) + }) saveRDS(m.fits, here("output", "fits.rds")) diff --git a/R/prep-data.R b/R/prep-data.R index 9cc46b9..b06129e 100644 --- a/R/prep-data.R +++ b/R/prep-data.R @@ -7,7 +7,8 @@ library("readr") classify_models <- function(file = here("data", "model-classification.csv")) { methods <- read_csv(file) |> pivot_longer( - -model, names_to = "classifier", values_to = "classification" + -model, + names_to = "classifier", values_to = "classification" ) |> filter(!(is.na(classification) | classification == "#N/A")) |> group_by(model) |> @@ -51,9 +52,12 @@ prep_data <- function(scoring_scale = "log") { group_by(model) |> summarise(CountryTargets = all(target_count == 1), .groups = "drop") |> mutate(CountryTargets = factor(CountryTargets, - levels = c(TRUE, FALSE), - labels = c("Single-country", - "Multi-country"))) + levels = c(TRUE, FALSE), + labels = c( + "Single-country", + "Multi-country" + ) + )) # Method type methods <- classify_models() |> @@ -73,38 +77,48 @@ prep_data <- function(scoring_scale = "log") { left_join(country_targets, by = "model") |> left_join(methods, by = "model") |> rename(Model = model, Horizon = horizon) |> - mutate(Model = as.factor(Model), - outcome_target = paste0(str_to_title(outcome_target), "s"), - Horizon = ordered(Horizon, - levels = 1:4, labels = 1:4), - log_wis = log(wis + 0.01)) |> + mutate( + Model = as.factor(Model), + outcome_target = paste0(str_to_title(outcome_target), "s"), + Horizon = ordered(Horizon, + levels = 1:4, labels = 1:4 + ), + log_wis = log(wis + 0.01) + ) |> filter(!is.na(Horizon)) ## horizon not in 1:4 return(data) } # Prediction data ------------------------------------------------------ get_forecasts <- function(data_type = "death") { - forecasts <- arrow::read_parquet(here("data", - "covid19-forecast-hub-europe.parquet")) |> + forecasts <- arrow::read_parquet(here( + "data", + "covid19-forecast-hub-europe.parquet" + )) |> filter(grepl(data_type, target)) forecasts <- forecasts |> - separate(target, into = c("horizon", "target_variable"), - sep = " wk ahead ") |> + separate(target, + into = c("horizon", "target_variable"), + sep = " wk ahead " + ) |> # set forecast date to corresponding submission date mutate( horizon = as.numeric(horizon), - forecast_date = target_end_date - weeks(horizon) + days(1)) |> + forecast_date = target_end_date - weeks(horizon) + days(1) + ) |> rename(prediction = value) |> - select(location, forecast_date, - horizon, target_end_date, - model, quantile, prediction) + select( + location, forecast_date, + horizon, target_end_date, + model, quantile, prediction + ) # Exclusions # dates should be between start of hub and until end of JHU data forecasts <- forecasts |> filter(forecast_date >= as.Date("2021-03-07") & - target_end_date <= as.Date("2023-03-10")) + target_end_date <= as.Date("2023-03-10")) # only keep forecasts up to 4 weeks ahead forecasts <- filter(forecasts, horizon <= 4) @@ -114,7 +128,8 @@ get_forecasts <- function(data_type = "death") { summarise(q = length(unique(quantile))) |> filter(q < 23) forecasts <- anti_join(forecasts, rm_quantiles, - by = c("model", "forecast_date", "location")) + by = c("model", "forecast_date", "location") + ) forecasts <- filter(forecasts, !is.na(quantile)) # remove "median" # remove duplicates diff --git a/R/score.R b/R/score.R index b69b93a..74996bc 100644 --- a/R/score.R +++ b/R/score.R @@ -5,7 +5,6 @@ library(scoringutils) source(here("R", "import-data.R")) walk(c("case", "death"), \(target) { - # Get forecasts & observations ----- # Get forecasts (note this is slow) forecasts_raw <- get_forecasts(data_type = target)