|
1 |
| -suppressPackageStartupMessages(source(here::here("R", "load_all.R"))) |
| 1 | +# Test project to test our forecasters on synthetic data. |
| 2 | +suppressPackageStartupMessages(source("R/load_all.R")) |
2 | 3 |
|
3 |
| -testthat::skip("Optional, long-running tests skipped.") |
| 4 | +# ================================ GLOBALS ================================= |
| 5 | +# Variables prefixed with 'g_' are globals needed by the targets pipeline (they |
| 6 | +# need to persist during the actual targets run, since the commands are frozen |
| 7 | +# as expressions). |
| 8 | + |
| 9 | +# Setup targets config. |
| 10 | +set_targets_config() |
| 11 | +g_aheads <- -1:3 |
| 12 | +g_submission_directory <- Sys.getenv("FLU_SUBMISSION_DIRECTORY", "cache") |
| 13 | +g_insufficient_data_geos <- c("as", "mp", "vi", "gu") |
| 14 | +g_excluded_geos <- c("as", "gu", "mh") |
| 15 | +g_time_value_adjust <- 3 |
| 16 | +g_fetch_args <- epidatr::fetch_args_list(return_empty = FALSE, timeout_seconds = 400) |
| 17 | +g_disease <- "flu" |
| 18 | +g_external_object_name <- glue::glue("exploration/2024-2025_{g_disease}_hosp_forecasts.parquet") |
| 19 | +# needed for windowed_seasonal |
| 20 | +g_very_latent_locations <- list(list( |
| 21 | + c("source"), |
| 22 | + c("flusurv", "ILI+") |
| 23 | +)) |
| 24 | +# Date to cut the truth data off at, so we don't have too much of the past for |
| 25 | +# plotting. |
| 26 | +g_truth_data_date <- "2023-09-01" |
| 27 | +# Whether we're running in backtest mode. |
| 28 | +# If TRUE, we don't run the report notebook, which is (a) slow and (b) should be |
| 29 | +# preserved as an ASOF snapshot of our production results for that week. |
| 30 | +# If TRUE, we run a scoring notebook, which scores the historical forecasts |
| 31 | +# against the truth data and compares them to the ensemble. |
| 32 | +# If FALSE, we run the weekly report notebook. |
| 33 | +g_backtest_mode <- as.logical(Sys.getenv("BACKTEST_MODE", FALSE)) |
| 34 | +if (!g_backtest_mode) { |
| 35 | + # This is the as_of for the forecast. If run on our typical schedule, it's |
| 36 | + # today, which is a Wednesday. Sometimes, if we're doing a delayed forecast, |
| 37 | + # it's a Thursday. It's used for stamping the data and for determining the |
| 38 | + # appropriate as_of when creating the forecast. |
| 39 | + g_forecast_generation_dates <- Sys.Date() |
| 40 | + # Usually, the forecast_date is the same as the generation date, but you can |
| 41 | + # override this. It should be a Wednesday. |
| 42 | + g_forecast_dates <- round_date(g_forecast_generation_dates, "weeks", week_start = 3) |
| 43 | +} else { |
| 44 | + g_forecast_generation_dates <- c(as.Date(c("2024-11-22", "2024-11-27", "2024-12-04", "2024-12-11", "2024-12-18", "2024-12-26", "2025-01-02")), seq.Date(as.Date("2025-01-08"), Sys.Date(), by = 7L)) |
| 45 | + g_forecast_dates <- seq.Date(as.Date("2024-11-20"), Sys.Date(), by = 7L) |
| 46 | +} |
| 47 | + |
| 48 | +# TODO: Forecaster definitions. We should have a representative from each forecaster. |
| 49 | +g_linear <- function(epi_data, ahead, extra_data, ...) { |
| 50 | + epi_data %>% |
| 51 | + filter(source == "nhsn") %>% |
| 52 | + forecaster_baseline_linear( |
| 53 | + ahead, ..., |
| 54 | + residual_tail = 0.99, |
| 55 | + residual_center = 0.35, |
| 56 | + no_intercept = TRUE |
| 57 | + ) |
| 58 | +} |
| 59 | +g_climate_base <- function(epi_data, ahead, extra_data, ...) { |
| 60 | + epi_data %>% |
| 61 | + filter(source == "nhsn") %>% |
| 62 | + climatological_model(ahead, ...) |
| 63 | +} |
| 64 | +g_climate_geo_agged <- function(epi_data, ahead, extra_data, ...) { |
| 65 | + epi_data %>% |
| 66 | + filter(source == "nhsn") %>% |
| 67 | + climatological_model(ahead, ..., geo_agg = TRUE) |
| 68 | +} |
| 69 | +g_windowed_seasonal <- function(epi_data, ahead, extra_data, ...) { |
| 70 | + scaled_pop_seasonal( |
| 71 | + epi_data, |
| 72 | + outcome = "value", |
| 73 | + ahead = ahead * 7, |
| 74 | + ..., |
| 75 | + trainer = epipredict::quantile_reg(), |
| 76 | + seasonal_method = "window", |
| 77 | + pop_scaling = FALSE, |
| 78 | + lags = c(0, 7), |
| 79 | + keys_to_ignore = g_very_latent_locations |
| 80 | + ) %>% |
| 81 | + mutate(target_end_date = target_end_date + 3) |
| 82 | +} |
| 83 | +g_windowed_seasonal_extra_sources <- function(epi_data, ahead, extra_data, ...) { |
| 84 | + fcst <- |
| 85 | + epi_data %>% |
| 86 | + left_join(extra_data, by = join_by(geo_value, time_value)) %>% |
| 87 | + scaled_pop_seasonal( |
| 88 | + outcome = "value", |
| 89 | + ahead = ahead * 7, |
| 90 | + extra_sources = "nssp", |
| 91 | + ..., |
| 92 | + seasonal_method = "window", |
| 93 | + trainer = epipredict::quantile_reg(), |
| 94 | + drop_non_seasons = TRUE, |
| 95 | + pop_scaling = FALSE, |
| 96 | + lags = list(c(0, 7), c(0, 7)), |
| 97 | + keys_to_ignore = g_very_latent_locations |
| 98 | + ) %>% |
| 99 | + select(-source) %>% |
| 100 | + mutate(target_end_date = target_end_date + 3) %>% |
| 101 | + fcst |
| 102 | +} |
4 | 103 |
|
5 |
| -# A list of forecasters to be tested. Add here to test new forecasters. |
6 | 104 | forecasters <- tibble::tribble(
|
7 | 105 | ~forecaster, ~forecaster_args, ~forecaster_args_names, ~fc_name, ~outcome, ~extra_sources, ~ahead,
|
8 | 106 | scaled_pop, list(TRUE), list("pop_scaling"), "scaled_pop", "a", "", 1,
|
9 | 107 | scaled_pop, list(FALSE), list("pop_scaling"), "scaled_pop", "a", "", 1,
|
10 | 108 | flatline_fc, list(), list(), "flatline_fc", "a", "", 1,
|
11 | 109 | smoothed_scaled, list(list(c(0, 7, 14), c(0)), 14, 7), list("lags", "sd_width", "sd_mean_width"), "smoothed_scaled", "a", "", 1,
|
12 | 110 | )
|
13 |
| -# Which forecasters expect the data to be non-identical? |
14 |
| -expects_nonequal <- c("scaled_pop", "smoothed_scaled") |
15 |
| - |
16 |
| -#' A wrapper for a common call to slide a forecaster over a dataset. |
17 |
| -#' |
18 |
| -#' @param dataset The dataset to be used for the forecast. |
19 |
| -#' @param ii The row of the forecasters table to be used. |
20 |
| -#' @param outcome The name of the target column in the dataset. |
21 |
| -#' @param extra_sources Any extra columns used for prediction that aren't |
22 |
| -#' default. |
23 |
| -#' @param expect_linreg_warnings Whether to expect and then suppress warnings |
24 |
| -#' from linear_reg. |
25 |
| -#' |
26 |
| -#' Notes: |
27 |
| -#' - n_training_pad is set to avoid warnings from the trainer. |
28 |
| -#' - linear_reg doesn't like exactly equal data when training and throws a |
29 |
| -#' warning. wrapperfun is used to suppress that. |
30 |
| -default_slide_forecaster <- function(dataset, ii, expect_linreg_warnings = TRUE) { |
31 |
| - if (any(forecasters$fc_name[[ii]] %in% expects_nonequal) && expect_linreg_warnings) { |
32 |
| - wrapperfun <- function(x) { |
33 |
| - suppressWarnings(expect_warning(x, regexp = "prediction from rank-deficient fit")) |
34 |
| - } |
35 |
| - } else { |
36 |
| - wrapperfun <- identity |
37 |
| - } |
38 |
| - args <- forecasters %>% |
39 |
| - select(-fc_name) %>% |
40 |
| - slice(ii) %>% |
41 |
| - purrr::transpose() %>% |
42 |
| - pluck(1) |
43 |
| - wrapperfun(res <- inject(slide_forecaster(epi_archive = dataset, n_training_pad = 30, !!!args))) |
44 |
| - return(res) |
45 |
| -} |
| 111 | + |
| 112 | + |
| 113 | +### Datasets TODO: Convert to targets? |
46 | 114 |
|
47 | 115 | # Some arbitrary magic numbers used to generate data.
|
48 | 116 | synth_mean <- 25
|
@@ -76,10 +144,12 @@ different_constants <- rbind(
|
76 | 144 | ) %>%
|
77 | 145 | arrange(version, time_value) %>%
|
78 | 146 | epiprocess::as_epi_archive()
|
| 147 | + |
79 | 148 | different_constants_truth <- different_constants$DT %>%
|
80 | 149 | tibble() %>%
|
81 | 150 | rename("true_value" = "a", "target_end_date" = "time_value") %>%
|
82 | 151 | select(-version)
|
| 152 | + |
83 | 153 | for (ii in seq_len(nrow(forecasters))) {
|
84 | 154 | test_that(paste(
|
85 | 155 | forecasters$fc_name[[ii]],
|
|
0 commit comments