Skip to content

Commit b43cd7d

Browse files
authored
Climatological backtest (#190)
1 parent cfee369 commit b43cd7d

9 files changed

+226
-50
lines changed

R/aux_data_utils.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ drop_non_seasons <- function(epi_data, min_window = 12) {
304304
(forecast_date - time_value < as.difftime(min_window, units = "weeks")),
305305
season != "2020/21",
306306
# season != "2021/22", # keeping this because whitening otherwise gets really bad with the single season of data
307-
(season != "2019/20") | (time_value < "2020-03-01"),
307+
(season != "2019/20"),
308308
season != "2008/09"
309309
)
310310
}

R/forecasters/climatological_model.R

+30-15
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
#' @param epi_data expected to have columns time_value, geo_value, season, value,
33
climatological_model <- function(epi_data, ahead, window_size = 3,
44
recent_window = 3, quantile_method = c("baseR", "epipredict"),
5-
quant_type = 8, geo_agg = FALSE) {
5+
quant_type = 8, geo_agg = FALSE,
6+
floor_value = 0, pop_scale = FALSE, include_forecast_date = TRUE) {
67
quantile_method <- arg_match(quantile_method)
78
forecast_date <- attributes(epi_data)$metadata$as_of
89
forecast_week <- epiweek(forecast_date)
@@ -20,17 +21,23 @@ climatological_model <- function(epi_data, ahead, window_size = 3,
2021
# drop weird years
2122
filtered %<>% filter((season != "2020/21") & (season != "2021/22"))
2223
# keep data either within the window, or within the past window weeks
23-
filtered %<>% filter(
24-
(abs(forecast_week + ahead - epiweek) <= window_size) |
25-
(last_date_data - time_value <= recent_window * 7)
26-
)
27-
28-
if (geo_agg) {
24+
if (include_forecast_date) {
25+
filtered %<>% filter(
26+
(abs(forecast_week + ahead - epiweek) <= window_size) |
27+
(last_date_data - time_value <= recent_window * 7)
28+
)
29+
} else {
30+
filtered %<>% filter(
31+
(abs(forecast_week + ahead - epiweek) <= window_size)
32+
)
33+
}
34+
# filtered %>% ggplot(aes(x = epiweek, y = value, color = source)) + geom_point() + facet_wrap(~geo_value); epi_data %>% autoplot(value, .facet_by = "geo_value", color = "source")
35+
if (geo_agg && pop_scale) {
2936
filtered %<>%
3037
add_pop_and_density() %>%
3138
mutate(value = value / population * 1e5) %>%
32-
select(geo_value, epiweek, epiyear, season, season_week, value, population)
33-
} else {
39+
select(any_of(c("geo_value", "epiweek", "epiyear", "season", "season_week", "value", "population")))
40+
} else if (!geo_agg) {
3441
filtered %<>%
3542
group_by(geo_value)
3643
}
@@ -56,18 +63,26 @@ climatological_model <- function(epi_data, ahead, window_size = 3,
5663
summarize(.dist_quantile = dist_quantiles(value, quantile), .groups = "keep") %>%
5764
reframe(tibble(quantile = covidhub_probs(), value = quantile(.dist_quantile, p = covidhub_probs())[[1]]))
5865
}
59-
naive_preds %<>% mutate(value = pmax(0, value))
66+
naive_preds %<>% mutate(value = pmax(floor_value, value))
6067
if (geo_agg) {
6168
naive_preds %<>%
6269
expand_grid(
63-
filtered %>% distinct(geo_value, population)
64-
) %>%
65-
mutate(value = value * population / 1e5) %>%
66-
select(-population) %>%
70+
filtered %>% distinct(geo_value)
71+
)
72+
if (pop_scale) {
73+
naive_preds %<>%
74+
left_join(
75+
filtered %>%
76+
distinct(geo_value, population)
77+
) %>%
78+
mutate(value = value * population / 1e5)
79+
}
80+
naive_preds %<>%
81+
select(-any_of("population")) %>%
6782
select(geo_value, forecast_date, target_end_date, quantile, value) %>%
6883
arrange(geo_value, forecast_date, target_end_date)
6984
}
7085
naive_preds %>%
71-
mutate(value = pmax(0, value)) %>%
86+
mutate(value = pmax(floor_value, value)) %>%
7287
ungroup()
7388
}

R/forecasters/ensemble_linear_climate.R

-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
#' forecasts and averages them on a per-quantile basis. By default the average
55
#' used is the median, but it can accept any vectorized function.
66
#'
7-
#' @param epi_data The data for fitting. Currently unused, but matches interface
8-
#' of other forecasters.
97
#' @param forecasts A tibble of quantile forecasts to aggregate. They should
108
#' be tibbles with columns `(geo_value, forecast_date, target_end_date,
119
#' quantile, value)`, preferably in that order.

R/forecasters/forecaster_baseline_linear.R

+17-7
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
#' epi_data is expected to have: geo_value, time_value, and value columns.
2-
forecaster_baseline_linear <- function(epi_data, ahead, log = FALSE, sort = FALSE, residual_tail = 0.85, residual_center = 0.085, no_intercept = FALSE) {
2+
forecaster_baseline_linear <- function(epi_data, ahead, log = FALSE, sort = FALSE, residual_tail = 0.85, residual_center = 0.085, no_intercept = TRUE, floor_value = 0, population_scale = TRUE) {
33
epi_data <- validate_epi_data(epi_data)
44
forecast_date <- attributes(epi_data)$metadata$as_of
55
population_data <- get_population_data() %>%
66
rename(geo_value = state_id) %>%
77
distinct(geo_value, population)
8+
if (population_scale) {
89
df_processed <- epi_data %>%
910
left_join(population_data, by = "geo_value") %>%
1011
mutate(value = value / population * 10**5)
12+
} else {
13+
df_processed <- epi_data
14+
}
1115

1216
if (log) {
1317
df_processed <- df_processed %>% mutate(value = log(value))
@@ -112,21 +116,27 @@ forecaster_baseline_linear <- function(epi_data, ahead, log = FALSE, sort = FALS
112116
pivot_quantiles_longer(dist) %>%
113117
rename(quantile_levels = dist_quantile_level, values = dist_value) %>%
114118
select(-value) %>%
115-
left_join(population_data, by = "geo_value") %>%
116119
rename(quantile = quantile_levels) %>%
117120
{
118121
if (log) {
119122
(.) %>% mutate(values = exp(values))
120123
} else {
121124
.
122125
}
123-
} %>%
124-
mutate(
125-
value = values * population / 10**5,
126+
}
127+
if (population_scale) {
128+
quantile_forecast %<>%
129+
left_join(population_data, by = "geo_value") %>%
130+
mutate(
131+
value = values * population / 10**5
132+
) %>%
133+
select(-population)
134+
}
135+
quantile_forecast %<>% mutate(
126136
target_end_date = reference_date + ahead * 7,
127137
forecast_date = forecast_date,
128138
) %>%
129-
select(-model, -values, -population, -season_week) %>%
130-
mutate(value = pmax(0, value))
139+
select(-model, -values, -season_week) %>%
140+
mutate(value = pmax(floor_value, value))
131141
quantile_forecast
132142
}

R/forecasters/forecaster_climatological.R

+89-15
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
1+
#' @params model_used the model used. "climate" means just climatological_model, "climate_linear" means the weighted ensemble with a linear model, "climatological_forecaster" means using the model from epipredict
2+
#'
13
climate_linear_ensembled <- function(epi_data,
24
outcome,
35
extra_sources = "",
46
ahead = 7,
57
trainer = parsnip::linear_reg(),
68
quantile_levels = covidhub_probs(),
9+
model_used = "climate_linear",
710
filter_source = "",
811
filter_agg_level = "",
912
scale_method = c("quantile", "std", "none"),
1013
center_method = c("median", "mean", "none"),
1114
nonlin_method = c("quart_root", "none"),
12-
drop_non_seasons = FALSE,
15+
quantiles_by_geo = TRUE,
16+
drop_non_season = FALSE,
1317
residual_tail = 0.99,
1418
residual_center = 0.35,
1519
...) {
@@ -45,31 +49,101 @@ climate_linear_ensembled <- function(epi_data,
4549
by = c("epiweek", "epiyear")
4650
)
4751
}
48-
if (drop_non_seasons) {
49-
season_data <- epi_data %>% drop_non_seasons()
52+
if (drop_non_season) {
53+
season_data <- epi_data %>%
54+
drop_non_seasons() %>%
55+
filter(season != "2021/22")
5056
} else {
5157
season_data <- epi_data
5258
}
5359
learned_params <- calculate_whitening_params(season_data, outcome, scale_method, center_method, nonlin_method)
54-
epi_data %<>% data_whitening(outcome, learned_params, nonlin_method)
55-
epi_data <- epi_data %>%
60+
season_data %<>% data_whitening(outcome, learned_params, nonlin_method)
61+
# epi_data %>% drop_non_seasons() %>% ggplot(aes(x = time_value, y = hhs, color = source)) + geom_line() + facet_wrap(~geo_value)
62+
season_data <- season_data %>%
5663
select(geo_value, source, time_value, season, value = !!outcome) %>%
5764
mutate(epiweek = epiweek(time_value))
58-
pred_climate <- climatological_model(epi_data, ahead) %>% mutate(forecaster = "climate")
59-
pred_geo_climate <- climatological_model(epi_data, ahead, geo_agg = FALSE) %>% mutate(forecaster = "climate_geo")
60-
pred_linear <- forecaster_baseline_linear(epi_data, ahead, residual_tail = residual_tail, residual_center = residual_center) %>% mutate(forecaster = "linear")
61-
pred <- bind_rows(pred_climate, pred_linear, pred_geo_climate) %>%
62-
ensemble_climate_linear((args_list$aheads[[1]]) / 7) %>%
63-
ungroup()
65+
66+
# either climate or climate linear needs the climate prediction
67+
if (model_used == "climate" || model_used == "climate_linear") {
68+
pred_climate <- climatological_model(season_data, ahead, geo_agg = quantiles_by_geo, floor_value = min(season_data$value, na.rm = TRUE), pop_scale = FALSE) %>% mutate(forecaster = "climate")
69+
pred <- pred_climate %>% select(-forecaster)
70+
}
71+
72+
# either linear or climate linear needs the linear prediction
73+
if (model_used == "linear" || model_used == "climate_linear") {
74+
pred_linear <- forecaster_baseline_linear(
75+
season_data %>% filter(source %in% c("nhsn", "none")),
76+
ahead,
77+
residual_tail = residual_tail,
78+
residual_center = residual_center,
79+
no_intercept = TRUE,
80+
floor_value = min(season_data$value, na.rm = TRUE, population_scale = FALSE)
81+
) %>%
82+
mutate(forecaster = "linear")
83+
pred <- pred_linear %>% select(-forecaster)
84+
}
85+
86+
if (model_used == "climate_linear") {
87+
pred <- bind_rows(pred_climate, pred_linear) %>%
88+
ensemble_climate_linear((args_list$aheads[[1]]) / 7) %>%
89+
ungroup()
90+
} else if (model_used == "climatological_forecaster") {
91+
# forecast all aheads at the same time
92+
if (ahead == args_list$aheads[[1]][[1]] / 7) {
93+
if (quantiles_by_geo) {
94+
quantile_key <- "geo_value"
95+
} else {
96+
quantile_key <- character(0)
97+
}
98+
clim_res <- climatological_forecaster(
99+
season_data,
100+
"value",
101+
args_list = climate_args_list(
102+
nonneg = (scale_method == "none"),
103+
time_type = "epiweek",
104+
quantile_levels = quantile_levels,
105+
forecast_horizon = args_list$aheads[[1]] / 7,
106+
quantile_by_key = quantile_key
107+
)
108+
)
109+
## clim_res$predictions
110+
pred <- clim_res$predictions %>%
111+
filter(source %in% c("nhsn", "none")) %>%
112+
pivot_quantiles_longer(.pred_distn) %>%
113+
select(geo_value, forecast_date, target_end_date = target_date, value = .pred_distn_value, quantile = .pred_distn_quantile_level) %>%
114+
mutate(target_end_date = ceiling_date(target_end_date, unit = "weeks", week_start = 6))
115+
} else {
116+
# we're fitting everything all at once in the first ahead for the
117+
# climatological_forecaster, so just return a null result for the other
118+
# aheads
119+
null_result <- tibble(
120+
geo_value = character(),
121+
forecast_date = lubridate::Date(),
122+
target_end_date = lubridate::Date(),
123+
quantile = numeric(),
124+
value = numeric()
125+
)
126+
return(null_result)
127+
}
128+
}
64129
# undo whitening
130+
if (adding_source) {
131+
pred %<>%
132+
rename({{ outcome }} := value) %>%
133+
mutate(source = "none")
134+
} else {
135+
pred %<>%
136+
rename({{ outcome }} := value) %>%
137+
mutate(source = "nhsn")
138+
}
65139
pred_final <- pred %>%
66-
rename({{ outcome }} := value) %>%
67-
mutate(source = "nhsn") %>%
68-
data_coloring(outcome, learned_params, join_cols = key_colnames(epi_data, exclude = "time_value"), nonlin_method = nonlin_method) %>%
140+
data_coloring(outcome, learned_params, join_cols = key_colnames(season_data, exclude = "time_value"), nonlin_method = nonlin_method) %>%
69141
rename(value = {{ outcome }}) %>%
70142
mutate(value = pmax(0, value)) %>%
71143
select(-source)
72144
# move dates to appropriate markers
73-
pred_final <- pred_final %>% mutate(target_end_date = target_end_date - 3)
145+
pred_final <- pred_final %>%
146+
mutate(target_end_date = target_end_date - 3) %>%
147+
sort_by_quantile()
74148
return(pred_final)
75149
}

R/scoring.R

+7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
# Scoring and Evaluation Functions
22

33
evaluate_predictions <- function(forecasts, truth_data) {
4+
# make sure the quantiles are in ascending order
5+
forecasts <- forecasts %>%
6+
arrange(model, geo_value, target_end_date, forecast_date, quantile) %>%
7+
group_by(model, geo_value, target_end_date, forecast_date) %>%
8+
mutate(prediction = sort(prediction)) %>%
9+
ungroup()
10+
411
checkmate::assert_data_frame(forecasts)
512
checkmate::assert_data_frame(truth_data)
613
checkmate::assert_names(

R/targets/covid_forecaster_config.R

+39
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,45 @@ get_covid_forecaster_params <- function() {
113113
c("climatological"),
114114
c("climatological", "window")
115115
)
116+
),
117+
climate_linear = bind_rows(
118+
expand_grid(
119+
forecaster = "climate_linear_ensembled",
120+
scale_method = "quantile",
121+
center_method = "median",
122+
nonlin_method = c("quart_root", "none"),
123+
model_used = c("climate_linear", "climate", "climatological_forecaster"),
124+
filter_agg_level = "state",
125+
drop_non_seasons = c(TRUE, FALSE),
126+
quantiles_by_geo = c(TRUE, FALSE),
127+
aheads = list(g_aheads),
128+
residual_tail = 0.70,
129+
residual_center = 0.127
130+
),
131+
expand_grid(
132+
forecaster = "climate_linear_ensembled",
133+
scale_method = "none",
134+
center_method = "none",
135+
nonlin_method = c("quart_root", "none"),
136+
model_used = c("climate_linear", "climate", "climatological_forecaster"),
137+
filter_agg_level = "state",
138+
drop_non_seasons = c(TRUE, FALSE),
139+
quantiles_by_geo = c(TRUE, FALSE),
140+
aheads = list(g_aheads),
141+
residual_tail = 0.97,
142+
residual_center = 0.097
143+
),
144+
expand_grid(
145+
forecaster = "climate_linear_ensembled",
146+
scale_method = "none",
147+
center_method = "none",
148+
nonlin_method = "none",
149+
model_used = "linear",
150+
filter_agg_level = "state",
151+
aheads = list(g_aheads),
152+
residual_tail = 0.97,
153+
residual_center = 0.097
154+
),
116155
)
117156
) %>%
118157
map(function(x) {

0 commit comments

Comments
 (0)