Skip to content

Commit 37d78ff

Browse files
committed
wip: start test project
1 parent 5171d91 commit 37d78ff

File tree

3 files changed

+116
-278
lines changed

3 files changed

+116
-278
lines changed

tests/testthat/test-forecasters-data.R renamed to scripts/test_proj.R

+106-36
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,116 @@
1-
suppressPackageStartupMessages(source(here::here("R", "load_all.R")))
1+
# Test project to test our forecasters on synthetic data.
2+
suppressPackageStartupMessages(source("R/load_all.R"))
23

3-
testthat::skip("Optional, long-running tests skipped.")
4+
# ================================ GLOBALS =================================
5+
# Variables prefixed with 'g_' are globals needed by the targets pipeline (they
6+
# need to persist during the actual targets run, since the commands are frozen
7+
# as expressions).
8+
9+
# Setup targets config.
10+
set_targets_config()
11+
g_aheads <- -1:3
12+
g_submission_directory <- Sys.getenv("FLU_SUBMISSION_DIRECTORY", "cache")
13+
g_insufficient_data_geos <- c("as", "mp", "vi", "gu")
14+
g_excluded_geos <- c("as", "gu", "mh")
15+
g_time_value_adjust <- 3
16+
g_fetch_args <- epidatr::fetch_args_list(return_empty = FALSE, timeout_seconds = 400)
17+
g_disease <- "flu"
18+
g_external_object_name <- glue::glue("exploration/2024-2025_{g_disease}_hosp_forecasts.parquet")
19+
# needed for windowed_seasonal
20+
g_very_latent_locations <- list(list(
21+
c("source"),
22+
c("flusurv", "ILI+")
23+
))
24+
# Date to cut the truth data off at, so we don't have too much of the past for
25+
# plotting.
26+
g_truth_data_date <- "2023-09-01"
27+
# Whether we're running in backtest mode.
28+
# If TRUE, we don't run the report notebook, which is (a) slow and (b) should be
29+
# preserved as an ASOF snapshot of our production results for that week.
30+
# If TRUE, we run a scoring notebook, which scores the historical forecasts
31+
# against the truth data and compares them to the ensemble.
32+
# If FALSE, we run the weekly report notebook.
33+
g_backtest_mode <- as.logical(Sys.getenv("BACKTEST_MODE", FALSE))
34+
if (!g_backtest_mode) {
35+
# This is the as_of for the forecast. If run on our typical schedule, it's
36+
# today, which is a Wednesday. Sometimes, if we're doing a delayed forecast,
37+
# it's a Thursday. It's used for stamping the data and for determining the
38+
# appropriate as_of when creating the forecast.
39+
g_forecast_generation_dates <- Sys.Date()
40+
# Usually, the forecast_date is the same as the generation date, but you can
41+
# override this. It should be a Wednesday.
42+
g_forecast_dates <- round_date(g_forecast_generation_dates, "weeks", week_start = 3)
43+
} else {
44+
g_forecast_generation_dates <- c(as.Date(c("2024-11-22", "2024-11-27", "2024-12-04", "2024-12-11", "2024-12-18", "2024-12-26", "2025-01-02")), seq.Date(as.Date("2025-01-08"), Sys.Date(), by = 7L))
45+
g_forecast_dates <- seq.Date(as.Date("2024-11-20"), Sys.Date(), by = 7L)
46+
}
47+
48+
# TODO: Forecaster definitions. We should have a representative from each forecaster.
49+
g_linear <- function(epi_data, ahead, extra_data, ...) {
50+
epi_data %>%
51+
filter(source == "nhsn") %>%
52+
forecaster_baseline_linear(
53+
ahead, ...,
54+
residual_tail = 0.99,
55+
residual_center = 0.35,
56+
no_intercept = TRUE
57+
)
58+
}
59+
g_climate_base <- function(epi_data, ahead, extra_data, ...) {
60+
epi_data %>%
61+
filter(source == "nhsn") %>%
62+
climatological_model(ahead, ...)
63+
}
64+
g_climate_geo_agged <- function(epi_data, ahead, extra_data, ...) {
65+
epi_data %>%
66+
filter(source == "nhsn") %>%
67+
climatological_model(ahead, ..., geo_agg = TRUE)
68+
}
69+
g_windowed_seasonal <- function(epi_data, ahead, extra_data, ...) {
70+
scaled_pop_seasonal(
71+
epi_data,
72+
outcome = "value",
73+
ahead = ahead * 7,
74+
...,
75+
trainer = epipredict::quantile_reg(),
76+
seasonal_method = "window",
77+
pop_scaling = FALSE,
78+
lags = c(0, 7),
79+
keys_to_ignore = g_very_latent_locations
80+
) %>%
81+
mutate(target_end_date = target_end_date + 3)
82+
}
83+
g_windowed_seasonal_extra_sources <- function(epi_data, ahead, extra_data, ...) {
84+
fcst <-
85+
epi_data %>%
86+
left_join(extra_data, by = join_by(geo_value, time_value)) %>%
87+
scaled_pop_seasonal(
88+
outcome = "value",
89+
ahead = ahead * 7,
90+
extra_sources = "nssp",
91+
...,
92+
seasonal_method = "window",
93+
trainer = epipredict::quantile_reg(),
94+
drop_non_seasons = TRUE,
95+
pop_scaling = FALSE,
96+
lags = list(c(0, 7), c(0, 7)),
97+
keys_to_ignore = g_very_latent_locations
98+
) %>%
99+
select(-source) %>%
100+
mutate(target_end_date = target_end_date + 3) %>%
101+
fcst
102+
}
4103

5-
# A list of forecasters to be tested. Add here to test new forecasters.
6104
forecasters <- tibble::tribble(
7105
~forecaster, ~forecaster_args, ~forecaster_args_names, ~fc_name, ~outcome, ~extra_sources, ~ahead,
8106
scaled_pop, list(TRUE), list("pop_scaling"), "scaled_pop", "a", "", 1,
9107
scaled_pop, list(FALSE), list("pop_scaling"), "scaled_pop", "a", "", 1,
10108
flatline_fc, list(), list(), "flatline_fc", "a", "", 1,
11109
smoothed_scaled, list(list(c(0, 7, 14), c(0)), 14, 7), list("lags", "sd_width", "sd_mean_width"), "smoothed_scaled", "a", "", 1,
12110
)
13-
# Which forecasters expect the data to be non-identical?
14-
expects_nonequal <- c("scaled_pop", "smoothed_scaled")
15-
16-
#' A wrapper for a common call to slide a forecaster over a dataset.
17-
#'
18-
#' @param dataset The dataset to be used for the forecast.
19-
#' @param ii The row of the forecasters table to be used.
20-
#' @param outcome The name of the target column in the dataset.
21-
#' @param extra_sources Any extra columns used for prediction that aren't
22-
#' default.
23-
#' @param expect_linreg_warnings Whether to expect and then suppress warnings
24-
#' from linear_reg.
25-
#'
26-
#' Notes:
27-
#' - n_training_pad is set to avoid warnings from the trainer.
28-
#' - linear_reg doesn't like exactly equal data when training and throws a
29-
#' warning. wrapperfun is used to suppress that.
30-
default_slide_forecaster <- function(dataset, ii, expect_linreg_warnings = TRUE) {
31-
if (any(forecasters$fc_name[[ii]] %in% expects_nonequal) && expect_linreg_warnings) {
32-
wrapperfun <- function(x) {
33-
suppressWarnings(expect_warning(x, regexp = "prediction from rank-deficient fit"))
34-
}
35-
} else {
36-
wrapperfun <- identity
37-
}
38-
args <- forecasters %>%
39-
select(-fc_name) %>%
40-
slice(ii) %>%
41-
purrr::transpose() %>%
42-
pluck(1)
43-
wrapperfun(res <- inject(slide_forecaster(epi_archive = dataset, n_training_pad = 30, !!!args)))
44-
return(res)
45-
}
111+
112+
113+
### Datasets TODO: Convert to targets?
46114

47115
# Some arbitrary magic numbers used to generate data.
48116
synth_mean <- 25
@@ -76,10 +144,12 @@ different_constants <- rbind(
76144
) %>%
77145
arrange(version, time_value) %>%
78146
epiprocess::as_epi_archive()
147+
79148
different_constants_truth <- different_constants$DT %>%
80149
tibble() %>%
81150
rename("true_value" = "a", "target_end_date" = "time_value") %>%
82151
select(-version)
152+
83153
for (ii in seq_len(nrow(forecasters))) {
84154
test_that(paste(
85155
forecasters$fc_name[[ii]],

test_proj/.gitignore

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# CAUTION: do not edit this file by hand!
2+
# _targets/objects/ may have large data files,
3+
# and _targets/meta/process may have sensitive information.
4+
# It is good pratice to either commit nothing from _targets/,
5+
# or if your data is not too sensitive,
6+
# commit only _targets/meta/meta.
7+
*
8+
!.gitignore
9+
!meta
10+
meta/*

0 commit comments

Comments
 (0)