|
| 1 | +#' @params model_used the model used. "climate" means just climatological_model, "climate_linear" means the weighted ensemble with a linear model, "climatological_forecaster" means using the model from epipredict |
| 2 | +#' |
1 | 3 | climate_linear_ensembled <- function(epi_data,
|
2 | 4 | outcome,
|
3 | 5 | extra_sources = "",
|
4 | 6 | ahead = 7,
|
5 | 7 | trainer = parsnip::linear_reg(),
|
6 | 8 | quantile_levels = covidhub_probs(),
|
| 9 | + model_used = "climate_linear", |
7 | 10 | filter_source = "",
|
8 | 11 | filter_agg_level = "",
|
9 | 12 | scale_method = c("quantile", "std", "none"),
|
10 | 13 | center_method = c("median", "mean", "none"),
|
11 | 14 | nonlin_method = c("quart_root", "none"),
|
12 |
| - drop_non_seasons = FALSE, |
| 15 | + quantiles_by_geo = TRUE, |
| 16 | + drop_non_season = FALSE, |
13 | 17 | residual_tail = 0.99,
|
14 | 18 | residual_center = 0.35,
|
15 | 19 | ...) {
|
@@ -45,31 +49,101 @@ climate_linear_ensembled <- function(epi_data,
|
45 | 49 | by = c("epiweek", "epiyear")
|
46 | 50 | )
|
47 | 51 | }
|
48 |
| - if (drop_non_seasons) { |
49 |
| - season_data <- epi_data %>% drop_non_seasons() |
| 52 | + if (drop_non_season) { |
| 53 | + season_data <- epi_data %>% |
| 54 | + drop_non_seasons() %>% |
| 55 | + filter(season != "2021/22") |
50 | 56 | } else {
|
51 | 57 | season_data <- epi_data
|
52 | 58 | }
|
53 | 59 | learned_params <- calculate_whitening_params(season_data, outcome, scale_method, center_method, nonlin_method)
|
54 |
| - epi_data %<>% data_whitening(outcome, learned_params, nonlin_method) |
55 |
| - epi_data <- epi_data %>% |
| 60 | + season_data %<>% data_whitening(outcome, learned_params, nonlin_method) |
| 61 | + # epi_data %>% drop_non_seasons() %>% ggplot(aes(x = time_value, y = hhs, color = source)) + geom_line() + facet_wrap(~geo_value) |
| 62 | + season_data <- season_data %>% |
56 | 63 | select(geo_value, source, time_value, season, value = !!outcome) %>%
|
57 | 64 | mutate(epiweek = epiweek(time_value))
|
58 |
| - pred_climate <- climatological_model(epi_data, ahead) %>% mutate(forecaster = "climate") |
59 |
| - pred_geo_climate <- climatological_model(epi_data, ahead, geo_agg = FALSE) %>% mutate(forecaster = "climate_geo") |
60 |
| - pred_linear <- forecaster_baseline_linear(epi_data, ahead, residual_tail = residual_tail, residual_center = residual_center) %>% mutate(forecaster = "linear") |
61 |
| - pred <- bind_rows(pred_climate, pred_linear, pred_geo_climate) %>% |
62 |
| - ensemble_climate_linear((args_list$aheads[[1]]) / 7) %>% |
63 |
| - ungroup() |
| 65 | + |
| 66 | + # either climate or climate linear needs the climate prediction |
| 67 | + if (model_used == "climate" || model_used == "climate_linear") { |
| 68 | + pred_climate <- climatological_model(season_data, ahead, geo_agg = quantiles_by_geo, floor_value = min(season_data$value, na.rm = TRUE), pop_scale = FALSE) %>% mutate(forecaster = "climate") |
| 69 | + pred <- pred_climate %>% select(-forecaster) |
| 70 | + } |
| 71 | + |
| 72 | + # either linear or climate linear needs the linear prediction |
| 73 | + if (model_used == "linear" || model_used == "climate_linear") { |
| 74 | + pred_linear <- forecaster_baseline_linear( |
| 75 | + season_data %>% filter(source %in% c("nhsn", "none")), |
| 76 | + ahead, |
| 77 | + residual_tail = residual_tail, |
| 78 | + residual_center = residual_center, |
| 79 | + no_intercept = TRUE, |
| 80 | + floor_value = min(season_data$value, na.rm = TRUE, population_scale = FALSE) |
| 81 | + ) %>% |
| 82 | + mutate(forecaster = "linear") |
| 83 | + pred <- pred_linear %>% select(-forecaster) |
| 84 | + } |
| 85 | + |
| 86 | + if (model_used == "climate_linear") { |
| 87 | + pred <- bind_rows(pred_climate, pred_linear) %>% |
| 88 | + ensemble_climate_linear((args_list$aheads[[1]]) / 7) %>% |
| 89 | + ungroup() |
| 90 | + } else if (model_used == "climatological_forecaster") { |
| 91 | + # forecast all aheads at the same time |
| 92 | + if (ahead == args_list$aheads[[1]][[1]] / 7) { |
| 93 | + if (quantiles_by_geo) { |
| 94 | + quantile_key <- "geo_value" |
| 95 | + } else { |
| 96 | + quantile_key <- character(0) |
| 97 | + } |
| 98 | + clim_res <- climatological_forecaster( |
| 99 | + season_data, |
| 100 | + "value", |
| 101 | + args_list = climate_args_list( |
| 102 | + nonneg = (scale_method == "none"), |
| 103 | + time_type = "epiweek", |
| 104 | + quantile_levels = quantile_levels, |
| 105 | + forecast_horizon = args_list$aheads[[1]] / 7, |
| 106 | + quantile_by_key = quantile_key |
| 107 | + ) |
| 108 | + ) |
| 109 | + ## clim_res$predictions |
| 110 | + pred <- clim_res$predictions %>% |
| 111 | + filter(source %in% c("nhsn", "none")) %>% |
| 112 | + pivot_quantiles_longer(.pred_distn) %>% |
| 113 | + select(geo_value, forecast_date, target_end_date = target_date, value = .pred_distn_value, quantile = .pred_distn_quantile_level) %>% |
| 114 | + mutate(target_end_date = ceiling_date(target_end_date, unit = "weeks", week_start = 6)) |
| 115 | + } else { |
| 116 | + # we're fitting everything all at once in the first ahead for the |
| 117 | + # climatological_forecaster, so just return a null result for the other |
| 118 | + # aheads |
| 119 | + null_result <- tibble( |
| 120 | + geo_value = character(), |
| 121 | + forecast_date = lubridate::Date(), |
| 122 | + target_end_date = lubridate::Date(), |
| 123 | + quantile = numeric(), |
| 124 | + value = numeric() |
| 125 | + ) |
| 126 | + return(null_result) |
| 127 | + } |
| 128 | + } |
64 | 129 | # undo whitening
|
| 130 | + if (adding_source) { |
| 131 | + pred %<>% |
| 132 | + rename({{ outcome }} := value) %>% |
| 133 | + mutate(source = "none") |
| 134 | + } else { |
| 135 | + pred %<>% |
| 136 | + rename({{ outcome }} := value) %>% |
| 137 | + mutate(source = "nhsn") |
| 138 | + } |
65 | 139 | pred_final <- pred %>%
|
66 |
| - rename({{ outcome }} := value) %>% |
67 |
| - mutate(source = "nhsn") %>% |
68 |
| - data_coloring(outcome, learned_params, join_cols = key_colnames(epi_data, exclude = "time_value"), nonlin_method = nonlin_method) %>% |
| 140 | + data_coloring(outcome, learned_params, join_cols = key_colnames(season_data, exclude = "time_value"), nonlin_method = nonlin_method) %>% |
69 | 141 | rename(value = {{ outcome }}) %>%
|
70 | 142 | mutate(value = pmax(0, value)) %>%
|
71 | 143 | select(-source)
|
72 | 144 | # move dates to appropriate markers
|
73 |
| - pred_final <- pred_final %>% mutate(target_end_date = target_end_date - 3) |
| 145 | + pred_final <- pred_final %>% |
| 146 | + mutate(target_end_date = target_end_date - 3) %>% |
| 147 | + sort_by_quantile() |
74 | 148 | return(pred_final)
|
75 | 149 | }
|
0 commit comments