Skip to content

Commit cefa40c

Browse files
authored
feat: use veterans admissions data as exogenous variable (#189)
* feat: add va data exogenous forecasters to flu and covid explore * repo: add explore Make targets
1 parent b43cd7d commit cefa40c

19 files changed

+504
-78
lines changed

Makefile

+11-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,17 @@ prod-flu:
2020
export TAR_RUN_PROJECT=flu_hosp_prod; \
2121
Rscript scripts/run.R
2222

23-
prod: prod-covid prod-flu update_site netlify
23+
prod: prod-covid prod-flu update-site netlify
24+
25+
explore-covid:
26+
export TAR_RUN_PROJECT=covid_hosp_explore; \
27+
Rscript scripts/run.R
28+
29+
explore-flu:
30+
export TAR_RUN_PROJECT=flu_hosp_explore; \
31+
Rscript scripts/run.R
32+
33+
explore: explore-covid explore-flu update-site netlify
2434

2535
submit-covid:
2636
cd ../covid19-forecast-hub; \

R/aux_data_utils.R

+41-12
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,7 @@ add_pop_and_density <-
9090
)
9191
}
9292
if (!("agg_level" %in% names(original_dataset))) {
93-
original_dataset %<>%
94-
mutate(agg_level = ifelse(grepl("[0-9]{2}", geo_value), "hhs_region", ifelse(("us" == geo_value) | ("usa" == geo_value), "nation", "state")))
93+
original_dataset %<>% add_agg_level()
9594
}
9695
original_dataset %>%
9796
mutate(year = year(time_value)) %>%
@@ -106,6 +105,15 @@ add_pop_and_density <-
106105
fill(population, density)
107106
}
108107

108+
add_agg_level <- function(data) {
109+
data %>%
110+
mutate(agg_level = case_when(
111+
grepl("[0-9]{2}", geo_value) ~ "hhs_region",
112+
geo_value %in% c("us", "usa") ~ "nation",
113+
.default = "state"
114+
))
115+
}
116+
109117
gen_pop_and_density_data <-
110118
function(apportion_filename = here::here("aux_data", "flusion_data", "apportionment.csv"),
111119
state_code_filename = here::here("aux_data", "flusion_data", "state_codes_table.csv"),
@@ -188,13 +196,15 @@ gen_pop_and_density_data <-
188196
daily_to_weekly <- function(epi_df, agg_method = c("sum", "mean"), keys = "geo_value", values = c("value")) {
189197
agg_method <- arg_match(agg_method)
190198
epi_df %>%
199+
arrange(across(all_of(c(keys, "time_value")))) %>%
191200
mutate(epiweek = epiweek(time_value), year = epiyear(time_value)) %>%
192201
group_by(across(any_of(c(keys, "epiweek", "year")))) %>%
193202
summarize(
194203
across(all_of(values), ~ sum(.x, na.rm = TRUE)),
195204
time_value = floor_date(max(time_value), "weeks", week_start = 7) + 3,
196205
.groups = "drop"
197206
) %>%
207+
arrange(across(all_of(c(keys, "time_value")))) %>%
198208
select(-epiweek, -year)
199209
}
200210

@@ -336,9 +346,7 @@ add_hhs_region_sum <- function(archive_data_raw, hhs_region_table) {
336346
archive_data_raw %<>%
337347
filter(agg_level != "state") %>%
338348
mutate(hhs_region = hhs) %>%
339-
bind_rows(
340-
hhs_region_agg_state
341-
)
349+
bind_rows(hhs_region_agg_state)
342350
if (need_agg_level) {
343351
archive_data_raw %<>% select(-agg_level)
344352
}
@@ -401,11 +409,32 @@ get_health_data <- function(as_of, disease = c("covid", "flu")) {
401409
# Get something sort of compatible with that by summing to national with
402410
# na.omit = TRUE. As otherwise we have some NAs from probably territories
403411
# propagated to US level.
404-
bind_rows(
405-
(.) %>%
406-
group_by(time_value) %>%
407-
summarize(geo_value = "us", hhs = sum(hhs, na.rm = TRUE))
408-
)
412+
append_us_aggregate("hhs")
413+
}
414+
415+
#' Append a national aggregate to a dataframe
416+
#'
417+
#' Computes national values by summing all the values per group_keys.
418+
#' Removes pre-existing national values.
419+
#'
420+
#' @param df A dataframe with a `geo_value` column.
421+
#' @param cols A character vector of column names to aggregate.
422+
#' @param group_keys A character vector of column names to group by.
423+
#' @return A dataframe with a `geo_value` column.
424+
append_us_aggregate <- function(df, cols = NULL, group_keys = c("time_value")) {
425+
if (!(is.data.frame(df))) {
426+
cli::cli_abort("df must be a data.frame", call = rlang::caller_env())
427+
}
428+
national_col_names <- c("us", "usa", "national", "nation", "US", "USA")
429+
df1 <- df %>% filter(geo_value %nin% national_col_names)
430+
if (is.null(cols)) {
431+
df2 <- df1 %>%
432+
summarize(geo_value = "us", across(where(is.numeric), ~ sum(.x, na.rm = TRUE)), .by = all_of(group_keys))
433+
} else {
434+
df2 <- df1 %>%
435+
summarize(geo_value = "us", across(all_of(cols), ~ sum(.x, na.rm = TRUE)), .by = all_of(group_keys))
436+
}
437+
bind_rows(df1, df2)
409438
}
410439

411440
calculate_burden_adjustment <- function(flusurv_latest) {
@@ -718,7 +747,7 @@ up_to_date_nssp_state_archive <- function(disease = c("covid", "influenza")) {
718747
wait_seconds = 1,
719748
fn = pub_covidcast,
720749
source = "nssp",
721-
signal = glue::glue("pct_ed_visits_{disease}"),
750+
signals = glue::glue("pct_ed_visits_{disease}"),
722751
time_type = "week",
723752
geo_type = "state",
724753
geo_values = "*",
@@ -727,7 +756,7 @@ up_to_date_nssp_state_archive <- function(disease = c("covid", "influenza")) {
727756
nssp_state %>%
728757
select(geo_value, time_value, issue, nssp = value) %>%
729758
as_epi_archive(compactify = TRUE) %>%
730-
`$`("DT") %>%
759+
extract2("DT") %>%
731760
# End of week to midweek correction.
732761
mutate(time_value = time_value + 3) %>%
733762
as_epi_archive(compactify = TRUE)

R/forecasters/data_transforms.R

+2-5
Original file line numberDiff line numberDiff line change
@@ -104,17 +104,14 @@ rolling_sd <- function(epi_data, sd_width = 29L, mean_width = NULL, cols_to_sd =
104104
#' @importFrom tidyr drop_na
105105
#' @importFrom epiprocess as_epi_df
106106
#' @export
107-
clear_lastminute_nas <- function(epi_data, outcome, extra_sources) {
107+
clear_lastminute_nas <- function(epi_data, cols) {
108108
meta_data <- attr(epi_data, "metadata")
109-
if (extra_sources == c("")) {
110-
extra_sources <- character(0L)
111-
}
112109
as_of <- attributes(epi_data)$metadata$as_of
113110
other_keys <- attributes(epi_data)$metadata$other_keys %||% character()
114111
epi_data %>% na.omit()
115112
# make sure at least one column is not NA
116113
epi_data %<>%
117-
filter(if_any(c(!!outcome, !!extra_sources), ~ !is.na(.x))) %>%
114+
filter(if_any(c(!!cols), ~ !is.na(.x))) %>%
118115
as_epi_df(as_of = as_of, other_keys = other_keys)
119116
attr(epi_data, "metadata") <- meta_data
120117
return(epi_data)

R/forecasters/data_validation.R

+15-2
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@ sanitize_args_predictors_trainer <- function(epi_data,
3030
if (!is.null(trainer) && !epipredict:::is_regression(trainer)) {
3131
cli::cli_abort("{trainer} must be a `{parsnip}` model of mode 'regression'.")
3232
} else if (inherits(trainer, "rand_forest") && trainer$engine == "grf_quantiles") {
33-
trainer %<>%
34-
set_engine("grf_quantiles", quantiles = args_list$quantile_levels)
33+
trainer %<>% set_engine("grf_quantiles", quantiles = args_list$quantile_levels)
3534
} else if (inherits(trainer, "quantile_reg")) {
3635
# add all quantile_levels to the trainer and update args list
3736
quantile_levels <- sort(epipredict:::compare_quantile_args(
@@ -97,3 +96,17 @@ filter_extraneous <- function(epi_data, filter_source, filter_agg_level) {
9796
}
9897
return(epi_data)
9998
}
99+
100+
#' Unwrap an argument if it's a list of length 1
101+
#'
102+
#' Many of our arguments to the forecasters come as lists not because we expect
103+
#' them that way, but as a byproduct of tibble and expand_grid.
104+
unwrap_argument <- function(arg, default_trigger = "", default = character(0L)) {
105+
if (is.list(arg) && length(arg) == 1) {
106+
arg <- arg[[1]]
107+
}
108+
if (identical(arg, default_trigger)) {
109+
return(default)
110+
}
111+
return(arg)
112+
}

R/forecasters/forecaster_climatological.R

+2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ climate_linear_ensembled <- function(epi_data,
2222
nonlin_method <- arg_match(nonlin_method)
2323

2424
epi_data <- validate_epi_data(epi_data)
25+
extra_sources <- unwrap_argument(extra_sources)
26+
trainer <- unwrap_argument(trainer)
2527

2628
args_list <- list(...)
2729
ahead <- as.integer(ahead / 7)

R/forecasters/forecaster_flatline.R

+4-1
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,13 @@ flatline_fc <- function(epi_data,
1818
filter_agg_level = "",
1919
...) {
2020
epi_data <- validate_epi_data(epi_data)
21+
extra_sources <- unwrap_argument(extra_sources)
22+
trainer <- unwrap_argument(trainer)
23+
2124
# perform any preprocessing not supported by epipredict
2225
epi_data %<>% filter_extraneous(filter_source, filter_agg_level)
2326
# this is a temp fix until a real fix gets put into epipredict
24-
epi_data <- clear_lastminute_nas(epi_data, outcome, extra_sources)
27+
epi_data <- clear_lastminute_nas(epi_data, cols = c(outcome, extra_sources))
2528
# one that every forecaster will need to handle: how to manage max(time_value)
2629
# that's older than the `as_of` date
2730
c(epi_data, ahead) %<-% extend_ahead(epi_data, ahead)

R/forecasters/forecaster_flusion.R

+2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ flusion <- function(epi_data,
2424
derivative_estimator <- arg_match(derivative_estimator)
2525

2626
epi_data <- validate_epi_data(epi_data)
27+
extra_sources <- unwrap_argument(extra_sources)
28+
trainer <- unwrap_argument(trainer)
2729

2830
# perform any preprocessing not supported by epipredict
2931
args_input <- list(...)

R/forecasters/forecaster_no_recent_outcome.R

+4-2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ no_recent_outcome <- function(epi_data,
2424
week_method <- arg_match(week_method)
2525

2626
epi_data <- validate_epi_data(epi_data)
27+
extra_sources <- unwrap_argument(extra_sources)
28+
trainer <- unwrap_argument(trainer)
2729

2830
# this is for the case where there are multiple sources in the same column
2931
epi_data %<>% filter_extraneous(filter_source, filter_agg_level)
@@ -62,10 +64,10 @@ no_recent_outcome <- function(epi_data,
6264
args_input[["quantile_levels"]] <- quantile_levels
6365
args_list <- do.call(default_args_list, args_input)
6466
# if you want to hardcode particular predictors in a particular forecaster
65-
if (identical(extra_sources[[1]], "")) {
67+
if (identical(extra_sources, character(0L))) {
6668
predictors <- character()
6769
} else {
68-
predictors <- extra_sources[[1]]
70+
predictors <- extra_sources
6971
}
7072
c(args_list, tmp_pred, trainer) %<-% sanitize_args_predictors_trainer(epi_data, outcome, predictors, trainer, args_list)
7173

R/forecasters/forecaster_scaled_pop.R

+3-1
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ scaled_pop <- function(epi_data,
6464
nonlin_method <- arg_match(nonlin_method)
6565

6666
epi_data <- validate_epi_data(epi_data)
67+
extra_sources <- unwrap_argument(extra_sources)
68+
trainer <- unwrap_argument(trainer)
6769

6870
# perform any preprocessing not supported by epipredict
6971
#
@@ -94,7 +96,7 @@ scaled_pop <- function(epi_data,
9496
args_input[["nonneg"]] <- scale_method == "none"
9597
args_list <- inject(default_args_list(!!!args_input))
9698
# if you want to hardcode particular predictors in a particular forecaster
97-
predictors <- c(outcome, extra_sources[[1]])
99+
predictors <- c(outcome, extra_sources)
98100
c(args_list, predictors, trainer) %<-% sanitize_args_predictors_trainer(epi_data, outcome, predictors, trainer, args_list)
99101
# end of the copypasta
100102
# finally, any other pre-processing (e.g. smoothing) that isn't performed by

R/forecasters/forecaster_scaled_pop_seasonal.R

+4-6
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,9 @@ scaled_pop_seasonal <- function(epi_data,
5858
nonlin_method <- arg_match(nonlin_method)
5959

6060
epi_data <- validate_epi_data(epi_data)
61+
extra_sources <- unwrap_argument(extra_sources)
62+
trainer <- unwrap_argument(trainer)
6163

62-
# TODO: handle this when creating param grid?
63-
if (typeof(seasonal_method) == "list") {
64-
seasonal_method <- seasonal_method[[1]]
65-
}
6664
if (all(seasonal_method == c("none", "flu", "covid", "indicator", "window", "climatological"))) {
6765
seasonal_method <- "none"
6866
}
@@ -71,7 +69,7 @@ scaled_pop_seasonal <- function(epi_data,
7169
# this is for the case where there are multiple sources in the same column
7270
epi_data %<>% filter_extraneous(filter_source, filter_agg_level)
7371
# this is a temp fix until a real fix gets put into epipredict
74-
epi_data <- clear_lastminute_nas(epi_data, outcome, extra_sources)
72+
epi_data <- clear_lastminute_nas(epi_data, cols = c(outcome, extra_sources))
7573
# this next part is basically unavoidable boilerplate you'll want to copy
7674
args_input <- list(...)
7775
# edge case where there is no data or less data than the lags; eventually epipredict will handle this
@@ -100,7 +98,7 @@ scaled_pop_seasonal <- function(epi_data,
10098
args_input[["seasonal_forward_window"]] <- seasonal_forward_window + ahead
10199
args_list <- inject(default_args_list(!!!args_input))
102100
# if you want to hardcode particular predictors in a particular forecaster
103-
predictors <- c(outcome, extra_sources[[1]])
101+
predictors <- c(outcome, extra_sources)
104102
c(args_list, predictors, trainer) %<-% sanitize_args_predictors_trainer(epi_data, outcome, predictors, trainer, args_list)
105103

106104
if ("season_week" %nin% names(epi_data)) {

R/forecasters/forecaster_smoothed_scaled.R

+4-4
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,15 @@ smoothed_scaled <- function(epi_data,
7373
nonlin_method <- arg_match(nonlin_method)
7474

7575
epi_data <- validate_epi_data(epi_data)
76+
extra_sources <- unwrap_argument(extra_sources)
77+
trainer <- unwrap_argument(trainer)
7678

7779
# perform any preprocessing not supported by epipredict
7880
#
7981
# this is for the case where there are multiple sources in the same column
8082
epi_data %<>% filter_extraneous(filter_source, filter_agg_level)
8183
# this is a temp fix until a real fix gets put into epipredict
82-
epi_data <- clear_lastminute_nas(epi_data, outcome, extra_sources)
84+
epi_data <- clear_lastminute_nas(epi_data, cols = c(outcome, extra_sources))
8385
# see latency_adjusting for other examples
8486
args_input <- list(...)
8587
# edge case where there is no data or less data than the lags; eventually epipredict will handle this
@@ -106,14 +108,12 @@ smoothed_scaled <- function(epi_data,
106108
args_list <- inject(default_args_list(!!!args_input))
107109
# `extra_sources` sets which variables beyond the outcome are lagged and used as predictors
108110
# any which are modified by `rolling_mean` or `rolling_sd` have their original values dropped later
109-
predictors <- c(outcome, extra_sources[[1]])
110-
predictors <- predictors[predictors != ""]
111+
predictors <- c(outcome, extra_sources)
111112
# end of the copypasta
112113
# finally, any other pre-processing (e.g. smoothing) that isn't performed by
113114
# epipredict
114115

115116

116-
117117
#######################
118118
# robust whitening
119119
#######################

0 commit comments

Comments
 (0)