Skip to content

Commit

Permalink
reorganisation/tidy
Browse files Browse the repository at this point in the history
  • Loading branch information
sbfnk committed Feb 14, 2025
1 parent 3ce35cc commit bae2c62
Show file tree
Hide file tree
Showing 60 changed files with 301 additions and 42,367 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/render-report.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:

- name: Compile the report
run: |
rmarkdown::render("output/results.Rmd")
rmarkdown::render("report/results.Rmd")
shell: Rscript {0}

- name: Create Pull Request
Expand All @@ -46,5 +46,5 @@ jobs:
branch: "render-report-${{ github.run_number }}"
labels: "documentation"
add-paths: |
output/results.pdf
report/results.pdf
token: ${{ secrets.GITHUB_TOKEN }}
4 changes: 2 additions & 2 deletions .github/workflows/test-report.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ jobs:

- name: Compile the report
run: |
rmarkdown::render("output/results.Rmd")
rmarkdown::render("report/results.Rmd")
shell: Rscript {0}

- name: Upload pdf as an artifact
uses: actions/upload-artifact@v4
with:
name: results
path: output/results.pdf
path: report/results.pdf
131 changes: 10 additions & 121 deletions R/import-data.R
Original file line number Diff line number Diff line change
@@ -1,13 +1,3 @@
# Functions to import and save data for predictions and observations
# Examples:
# forecasts <- get_forecasts()
# obs <- get_observed()
# forecasts <- left_join(forecasts, obs,
# by = c("location", "target_end_date"))
# anomalies <- get_anomalies()
# forecasts <- anti_join(forecasts, anomalies,
# by = c("target_end_date", "location"))

library(here)
library(dplyr)
library(readr)
Expand All @@ -16,57 +6,10 @@ library(arrow)
library(tidyr)
library(ggplot2)
library(stringr)
theme_set(theme_minimal())

# Prediction data ------------------------------------------------------
get_forecasts <- function(data_type = "death") {
forecasts <- arrow::read_parquet(here("data",
"covid19-forecast-hub-europe.parquet")) |>
filter(grepl(data_type, target))

forecasts <- forecasts |>
separate(target, into = c("horizon", "target_variable"),
sep = " wk ahead ") |>
# set forecast date to corresponding submission date
mutate(
horizon = as.numeric(horizon),
forecast_date = target_end_date - weeks(horizon) + days(1)) |>
rename(prediction = value) |>
select(location, forecast_date,
horizon, target_end_date,
model, quantile, prediction)

# Exclusions
# dates should be between start of hub and until end of JHU data
forecasts <- forecasts |>
filter(forecast_date >= as.Date("2021-03-07") &
target_end_date <= as.Date("2023-03-10"))
# only keep forecasts up to 4 weeks ahead
forecasts <- filter(forecasts, horizon <= 4)

# only include predictions from models with all quantiles
rm_quantiles <- forecasts |>
group_by(model, forecast_date, location) |>
summarise(q = length(unique(quantile))) |>
filter(q < 23)
forecasts <- anti_join(forecasts, rm_quantiles,
by = c("model", "forecast_date", "location"))
forecasts <- filter(forecasts, !is.na(quantile)) # remove "median"

# remove duplicates
forecasts <- forecasts |>
group_by_all() |>
mutate(duplicate = row_number()) |>
ungroup() |>
filter(duplicate == 1) |>
select(-duplicate)

return(forecasts)
}
library(purrr)

# Observed data ---------------------------------------------------------
# Get raw values
get_observed <- function(data_type = "death") {
walk(c("case", "death"), \(data_type) {
file_name <- paste0("truth_JHU-Incident%20", str_to_title(data_type), "s.csv")
obs <- read_csv(paste0(
"https://raw.githubusercontent.com/covid19-forecast-hub-europe/",
Expand Down Expand Up @@ -98,67 +41,13 @@ get_observed <- function(data_type = "death") {
"Stable")))))
obs <- obs |>
select(location, target_end_date, observed, trend)
return(obs)
}

# Observed data ---------------------------------------------------------
# Get raw values
get_pop <- function() {
pop <- read_csv(paste0(
"https://raw.githubusercontent.com/european-modelling-hubs/",
"covid19-forecast-hub-europe/main/data-locations/locations_eu.csv"
), show_col_types = FALSE) |>
select(location, population)
return(pop)
}

# Plot observed data and trend classification
plot_observed <- function() {
obs <- import_observed()
obs |>
ggplot(aes(x = target_end_date, y = log(observed))) +
geom_point(col = trend) +
geom_line(alpha = 0.3) +
scale_x_date() +
labs(x = NULL, y = "Log observed", col = "Trend",
caption = "Trend (coloured points) of weekly change in 3-week moving average") +
theme(legend.position = "bottom", ) +
facet_wrap(facets = "location", ncol = 1,
strip.position = "left")

ggsave(filename = here("output/fig-trends.pdf"),
height = 50, width = 15, limitsize = FALSE)
}

# Anomalies
get_anomalies <- function() {
read_csv("https://raw.githubusercontent.com/covid19-forecast-hub-europe/covid19-forecast-hub-europe/5a2a8d48e018888f652e981c95de0bf05a838135/data-truth/anomalies/anomalies.csv") |>
filter(target_variable == "inc death") |>
select(-target_variable) |>
mutate(anomaly = TRUE)
write_csv(obs, here("data", paste0("observed-", data_type, ".csv")))
}

# Plot anomalies
plot_anomalies <- function() {
obs <- get_observed()
anomalies <- get_anomalies()

obs <- left_join(obs, anomalies) |>
group_by(location) |>
mutate(anomaly = replace_na(anomaly, FALSE))

obs |>
ggplot(aes(x = target_end_date,
y = log(observed),
col = anomaly)) +
geom_line() +
geom_point(size = 0.3) +
scale_x_date() +
labs(x = NULL, y = "Log observed") +
theme(legend.position = "bottom", ) +
facet_wrap(facets = "location", ncol = 1,
strip.position = "left")

ggsave(filename = here("output/fig-anomalies.pdf"),
height = 50, width = 15, limitsize = FALSE)
}
# Population data ---------------------------------------------------------
pop <- read_csv(paste0(
"https://raw.githubusercontent.com/european-modelling-hubs/",
"covid19-forecast-hub-europe/main/data-locations/locations_eu.csv"
), show_col_types = FALSE) |>
select(location, population)
write_csv(pop, here("data", paste0("populations.csv")))
75 changes: 75 additions & 0 deletions R/model-plots.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
library("purrr")
library("dplyr")
library("ggplot2")
library("patchwork")
library("gammit")
source(here("R", "prep-data.R"))
source(here("R", "descriptive.R"))

plot_models <- function(fits, scores, x_labels = TRUE) {
outcomes <- unique(scores$outcome_target)
classification <- classify_models() |>
rename(group = model)
targets <- table_targets(scores) |>
select(group = Model, CountryTargets) |>
distinct()
plots <- map(fits, function(fit) {
plot <- extract_ranef(fit) |>
filter(group_var == "Model") |>
left_join(classification) |>
left_join(targets) |>
mutate(group = sub(".*-", "", group)) |> ## remove institution identifier
select(-group_var) |>
arrange(-value) |>
mutate(group = factor(group, levels = unique(as.character(group)))) |>
ggplot(aes(x = group, col = classification, shape = CountryTargets)) +
geom_point(aes(y = value)) +
geom_linerange(aes(ymin = lower_2.5, ymax = upper_97.5)) +
geom_hline(yintercept = 0, lty = 2) +
labs(y = "Partial effect", x = "Model", colour = NULL, shape = NULL) +
scale_colour_brewer(type = "qual", palette = 2) +
theme(
legend.position = "bottom",
axis.text.x = element_text(angle = 45, vjust = 1, hjust = 1)
) +
coord_flip()
if (!x_labels) {
plot <- plot +
theme(
axis.text.y = element_blank(),
axis.ticks.y = element_blank()
)
}
return(plot)
})
## remove legends
if (length(plots) > 1) {
for (i in seq_len(length(plots) - 1)) {
plots[[i]] <- plots[[i]] + theme(legend.position = "none")
}
}
for (i in seq_along(plots)) {
plots[[i]] <- plots[[i]] + ggtitle(outcomes[i])
}
Reduce(`+`, plots) + plot_layout(ncol = 2)
}

plot_effects <- function(fits, scores) {
map(fits, extract_ranef) |>
bind_rows(.id = "outcome_target") |>
filter(!(group_var %in% c("Model", "location"))) |>
mutate(group = factor(group, levels = unique(as.character(rev(group))))) |>
ggplot(aes(x = group, col = group_var)) +
geom_point(aes(y = value)) +
geom_linerange(aes(ymin = lower_2.5, ymax = upper_97.5)) +
geom_hline(yintercept = 0, lty = 2, alpha = 0.25) +
facet_wrap(~outcome_target, scales = "free_y") +
labs(y = "Partial effect", x = NULL, colour = NULL, shape = NULL) +
scale_colour_brewer(type = "qual", palette = "Set1") +
theme(
legend.position = "bottom",
strip.background = element_blank(),
axis.text.x = element_text(angle = 45, vjust = 1, hjust = 1)
) +
coord_flip()
}
75 changes: 2 additions & 73 deletions R/model-wis.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,8 @@ library(readr)
library(tidyr)
library(purrr)
library(mgcv)
library(gratia) # devtools::install_github('gavinsimpson/gratia')
library(broom)
library(ggplot2)
library(broom)
library(gammit)
theme_set(theme_classic())
source(here("R", "prep-data.R"))
source(here("R", "descriptive.R"))

# --- Get data ---
data <- prep_data(scoring_scale = "log")
Expand Down Expand Up @@ -60,70 +55,4 @@ m.fits <- outcomes |>
)
})

plot_models <- function(fits, scores, x_labels = TRUE) {
outcomes <- unique(scores$outcome_target)
classification <- classify_models() |>
rename(group = model)
targets <- table_targets(scores) |>
select(group = Model, CountryTargets) |>
distinct()
plots <- map(fits, function(fit) {
plot <- extract_ranef(fit) |>
filter(group_var == "Model") |>
left_join(classification) |>
left_join(targets) |>
mutate(group = sub(".*-", "", group)) |> ## remove institution identifier
select(-group_var) |>
arrange(-value) |>
mutate(group = factor(group, levels = unique(as.character(group)))) |>
ggplot(aes(x = group, col = classification, shape = CountryTargets)) +
geom_point(aes(y = value)) +
geom_linerange(aes(ymin = lower_2.5, ymax = upper_97.5)) +
geom_hline(yintercept = 0, lty = 2) +
labs(y = "Partial effect", x = "Model", colour = NULL, shape = NULL) +
scale_colour_brewer(type = "qual", palette = 2) +
theme(
legend.position = "bottom",
axis.text.x = element_text(angle = 45, vjust = 1, hjust = 1)
) +
coord_flip()
if (!x_labels) {
plot <- plot +
theme(
axis.text.y = element_blank(),
axis.ticks.y = element_blank()
)
}
return(plot)
})
## remove legends
if (length(plots) > 1) {
for (i in seq_len(length(plots) - 1)) {
plots[[i]] <- plots[[i]] + theme(legend.position = "none")
}
}
for (i in seq_along(plots)) {
plots[[i]] <- plots[[i]] + ggtitle(outcomes[i])
}
Reduce(`+`, plots) + plot_layout(ncol = 2)
}

plot_effects <- function(fits, scores) {
map(fits, extract_ranef) |>
bind_rows(.id = "outcome_target") |>
filter(!(group_var %in% c("Model", "location"))) |>
mutate(group = factor(group, levels = unique(as.character(rev(group))))) |>
ggplot(aes(x = group, col = group_var)) +
geom_point(aes(y = value)) +
geom_linerange(aes(ymin = lower_2.5, ymax = upper_97.5)) +
geom_hline(yintercept = 0, lty = 2, alpha = 0.25) +
facet_wrap(~outcome_target, scales = "free_y") +
labs(y = "Partial effect", x = NULL, colour = NULL, shape = NULL) +
scale_colour_brewer(type = "qual", palette = "Set1") +
theme(
legend.position = "bottom",
strip.background = element_blank(),
axis.text.x = element_text(angle = 45, vjust = 1, hjust = 1)
) +
coord_flip()
}
saveRDS(m.fits, here("output", "fits.rds"))
55 changes: 0 additions & 55 deletions R/natural-scale-scores.R

This file was deleted.

Loading

0 comments on commit bae2c62

Please sign in to comment.