Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions R/analysis-descriptive.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Aim: describe interval score in terms of model structure and country target type
# Load data:
# source(here("R", "process-data.R"))
# scores <- prep_data(scoring_scale = "log")
# scores <- process_data(scoring_scale = "log")
library(here)
library(dplyr)
library(purrr)
Expand Down Expand Up @@ -258,7 +258,7 @@ plot_ridges <- function(scores, target = "Deaths") {
# Table of targets by model -------------
table_targets <- function(scores) {
table_targets <- scores |>
select(Model, outcome_target, forecast_date, location) |>
select(Model, outcome_target, forecast_date, Location) |>
distinct() |>
group_by(Model, outcome_target, forecast_date) |>
summarise(target_count = n(), .groups = "drop") |>
Expand Down Expand Up @@ -312,11 +312,12 @@ table_metadata <- function(scores) {
# Data --------------------
data_plot <- function(scores, log = FALSE, all = FALSE) {
data <- scores |>
select(location, outcome_target, target_end_date, Incidence) |>
select(Location, outcome_target, target_end_date, Incidence) |>
distinct()
pop <- read_csv(here("data", "populations.csv"), show_col_types = FALSE)
pop <- read_csv(here("data", "populations.csv"), show_col_types = FALSE) |>
rename(Location = location)
data <- data |>
left_join(pop, by = join_by(location)) |>
left_join(pop, by = join_by(Location)) |>
mutate(
rel_inc = Incidence / population * 1e5,
log_inc = log(Incidence + 1)
Expand All @@ -331,11 +332,11 @@ data_plot <- function(scores, log = FALSE, all = FALSE) {
mutate(
rel_inc = Incidence / population * 1e5,
log_inc = log(Incidence + 1),
location = "Total"
Location = "Total"
)
var_name <- ifelse(log, "log_inc", "rel_inc")
plot <- ggplot(mapping = aes(
x = target_end_date, y = .data[[var_name]], group = location
x = target_end_date, y = .data[[var_name]], group = Location
))

if (all) {
Expand All @@ -360,14 +361,14 @@ data_plot <- function(scores, log = FALSE, all = FALSE) {

trends_plot <- function(scores) {
trends <- scores |>
select(location, target_end_date, Incidence, Trend) |>
select(Location, target_end_date, Incidence, Trend) |>
distinct()
p <- ggplot(trends, aes(x = target_end_date, y = Incidence)) +
geom_point(mapping = aes(colour = Trend), size = 1) +
geom_line() +
scale_colour_brewer(palette = "Set2", na.value = "grey") +
theme(legend.position = "bottom") +
facet_wrap(~location, scales = "free_y") +
facet_wrap(~Location, scales = "free_y") +
theme(axis.text.x = element_text(angle = 45, vjust = 1, hjust = 1)) +
xlab("")
return(p)
Expand Down
82 changes: 40 additions & 42 deletions R/analysis-model.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,16 @@
# Aim: use a GAMM to model the effects of model structure and country target type on WIS
# Model:
#
# Method: model method (mechanistic, statistical, etc.)
# CountryTargets: model predicts for single- vs multi-country
# Trend: epidemic trend (stable, increasing, decreasing)
# Location: location (random effect)
# VariantPhase: dominant variant phase (random effect)
# Horizon: forecast horizon (smooth, by model)
# Model: individual model (random effect)
#
# Response: WIS (log-transformed, Gaussian family with log link)

library(here)
library(dplyr)
library(readr)
Expand All @@ -11,44 +23,31 @@ source(here("R", "process-data.R"))
source(here("R", "analysis-descriptive.R"))

# --- Get data ---
data <- prep_data(scoring_scale = "log")
outcomes <- unique(data$outcome_target)
classification <- classify_models()
targets <- table_targets(data)

data <- process_data(scoring_scale = "log")
m.data <- data |>
filter(!grepl("EuroCOVIDhub-", Model)) |>
mutate(location = factor(location)) |>
group_by(location) |>
mutate(
time = as.numeric(forecast_date - min(forecast_date)) / 7,
Horizon = as.numeric(Horizon),
wis = wis + 1e-7
) |>
ungroup()
filter(!grepl("EuroCOVIDhub-", Model))
outcomes <- unique(data$outcome_target)

# --- Model formula ---
# Univariate for explanatory variables
m.formula_uni_type <- wis ~ s(Method, bs = "re")
m.formula_uni_tgt <- wis ~ s(CountryTargets, bs = "re")
m.formula_uni_model <- wis ~ s(Model, bs = "re")
# Univariate for each explanatory variable
m.formulas_uni <- list(
method = wis ~ s(Method, bs = "re"),
target = wis ~ s(CountryTargets, bs = "re"),
trend = wis ~ s(Trend, bs = "re"),
location = wis ~ s(Location, bs = "re"),
variant = wis ~ s(VariantPhase, bs = "re"),
horizon = wis ~ s(Horizon, by = Model, k = 3, bs = "sz"),
model = wis ~ s(Model, bs = "re")
)

# Full model
m.formula <- wis ~
# Method
# Full joint model
m.formula_joint <- wis ~
s(Method, bs = "re") +
# Number of target countries
s(CountryTargets, bs = "re") +
# -----------------------------
# Trend
s(Trend, bs = "re") +
# Location
s(location, bs = "re") +
# Week * location
s(time, by = location, k = 40) +
# Horizon
s(Horizon, k = 3, by = Model, bs = "sz") +
# Individual model
s(Location, bs = "re") +
s(VariantPhase, bs = "re") +
s(Horizon, by = Model, k = 3, bs = "sz") +
s(Model, bs = "re")

# --- Model fitting ---
Expand All @@ -69,23 +68,22 @@ m.fit <- function(outcomes, m.formula) {
}
# Fit
cat("--------fitting univariate models")
m.fits_uni_type <- m.fit(outcomes, m.formula_uni_type)
m.fits_uni_tgt <- m.fit(outcomes, m.formula_uni_tgt)
m.fits_uni_model <- m.fit(outcomes, m.formula_uni_model)
m.fits_uni <- map(m.formulas_uni, ~ m.fit(outcomes, .x))

cat("--------fitting joint model")
m.fits_joint <- m.fit(outcomes, m.formula)
cat("finished fitting")
m.fits_joint <- m.fit(outcomes, m.formula_joint)

# --- Output handling ---
# Extract estimates for random effects
random_effects_uni <- map_df(
c(m.fits_uni_type, m.fits_uni_tgt, m.fits_uni_model),
extract_ranef,
.id = "outcome_target") |>
random_effects_uni <- m.fits_uni[!grepl("horizon", names(m.fits_uni))] |>
map_depth(.depth = 2, ~ extract_ranef(.x)) |>
map(~ list_rbind(.x, names_to = "outcome_target")) |>
list_rbind() |>
mutate(model = "Unadjusted")

random_effects_joint <- map_df(m.fits_joint,
extract_ranef,
.id = "outcome_target") |>
extract_ranef,
.id = "outcome_target") |>
mutate(model = "Adjusted")

random_effects <- random_effects_joint |>
Expand Down
4 changes: 3 additions & 1 deletion R/plot-model-results.R
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ plot_models <- function(random_effects, scores, x_labels = TRUE,
}

plot_effects <- function(random_effects,
variables = c("Method", "CountryTargets")) {
variables = NULL) {
if(is.null(variables)){variables <- unique(random_effects$group_var)}

random_effects |>
filter(group_var %in% variables) |>
mutate(group = factor(group, levels = unique(as.character(rev(group)))),
Expand Down
41 changes: 24 additions & 17 deletions R/process-data.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
library("dplyr")
library("tidyr")
library("purrr")
library("readr")
library("lubridate")
library(here)
library(dplyr)
library(tidyr)
library(purrr)
library(readr)
library(lubridate)
source(here("R", "utils-variants.R"))

# Metadata ----------------------------------------------------------------
# Get classification of model types
Expand Down Expand Up @@ -32,20 +34,21 @@ classify_models <- function(file = here("data", "model-classification.csv")) {
return(methods)
}

# Scores data: add explanatory variables -----------------------------
# Get scores for all forecasts and add explanatory variables used:
# number of country targets, method classification, trend of observed incidence
prep_data <- function(scoring_scale = "log") {
# Prepare data for analysis -----------------------------
# Get scores for all forecasts; and add explanatory variables in a single dframe
process_data <- function(scoring_scale = "log") {
# Get raw interval scores ----------------------------------------
# scores data created in: R/process-score.r
scores_files <- list.files(here("data"), pattern = "scores-raw-.*\\.csv")
names(scores_files) <- sub("scores-raw-(.*)\\..*$", "\\1", scores_files)
# Get raw interval score
scores_raw <- scores_files |>
map(\(file) {
read_csv(here("data", file))
}) |>
bind_rows(.id = "outcome_target") |>
filter(scale == scoring_scale)

# Add variables of interest to scores dataframe ----------------------
# Target type
country_targets <- scores_raw |>
select(model, forecast_date, location) |>
Expand All @@ -66,7 +69,7 @@ prep_data <- function(scoring_scale = "log") {
methods <- classify_models() |>
select(model, Method = classification, agreement)

# Incidence level + trend (see: R/import-data.r)
# Incidence level + trend (observed data from: R/utils-data.r)
obs <- names(scores_files) |>
set_names() |>
map(~ read_csv(here("data", paste0("observed-", .x, ".csv")))) |>
Expand All @@ -76,19 +79,23 @@ prep_data <- function(scoring_scale = "log") {
rename(Incidence = observed) |>
select(target_end_date, location, outcome_target, Trend, Incidence)

# Variant phase
variant_phase <- classify_variant_phases()

# Combine all data -----------------------------------------------------
data <- scores_raw |>
left_join(obs, by = c("location", "target_end_date", "outcome_target")) |>
left_join(variant_phase, by = c("location", "target_end_date")) |>
left_join(country_targets, by = "model") |>
left_join(methods, by = "model") |>
rename(Model = model, Horizon = horizon) |>
# set to factors
rename(Model = model, Horizon = horizon, Location = location) |>
mutate(
Horizon = ifelse(!Horizon %in% 1:4, NA_integer_, Horizon),
Model = as.factor(Model),
Location = as.factor(Location),
outcome_target = paste0(str_to_title(outcome_target), "s"),
Horizon = ordered(Horizon,
levels = 1:4, labels = 1:4
),
log_wis = log(wis + 0.01)
) |>
wis = wis + 1e-7) |>
filter(!is.na(Horizon)) ## horizon not in 1:4
return(data)
}
Expand Down
Loading
Loading