Skip to content

Commit 3e57d42

Browse files
committed
Update more prob functions
1 parent a2a7276 commit 3e57d42

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

R/prob_thief_wrapper.R

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,15 @@ prob_thief_wrapper <-
6464
if (num_missing_obs >= 1) {
6565
truth_data <- truth_data |>
6666
dplyr::add_row(target_end_date = most_recent_date + 1:num_missing_obs) |>
67+
tidyr::complete(!!!rlang::syms(c("target_end_date", "location"))) |>
6768
tidyr::fill(!!!rlang::syms(names(truth_data)))
6869
warning(paste(num_missing_obs, "missing truth data observations will be imputed using last available value."))
6970
}
7071

72+
remainder_observation <- as.numeric(end_date - start_date + 1) %% frequency
73+
truth_dates_desc <- sort(unique(truth_data$target_end_date), decreasing = TRUE)
74+
temp_res <- as.integer(as.Date(truth_dates_desc[1]) - as.Date(truth_dates_desc[2]))
75+
start_date <- start_date + (remainder_observation) * temp_res
7176
thief_aggregation <- truth_data |>
7277
aggregate_thief_df(ts_col, start_date, end_date, fips_code,
7378
aggregate_levels, NULL, frequency, transform.4root) |>

R/transform_matrix_to_hub_df.R

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
transform_matrix_to_hub_df <-
3636
function(fc_matrix, forecast_date, fips_code, target_name, n_samples = NULL,
3737
quantile_levels = c(0.025, 0.25, 0.5, 0.75, 0.975), h_ahead = 56,
38-
keep_bottommost_only = TRUE) {
38+
temp_res = 1, keep_bottommost_only = TRUE) {
3939
if (!is.null(n_samples)) {
4040
if (n_samples > nrow(fc_matrix)) {
4141
stop("Requested number of samples cannot exceed the number that have been provided")
@@ -67,18 +67,27 @@ transform_matrix_to_hub_df <-
6767
) |>
6868
dplyr::mutate(
6969
forecast_date = forecast_date,
70+
level = .data[["k"]],
7071
location = fips_code,
7172
horizon = as.numeric(.data[["h"]]) * .data[["k"]],
72-
temporal_resolution = "daily",
73+
temporal_resolution = dplyr::case_when(
74+
temp_res == 1 ~ "daily",
75+
temp_res == 7 ~ "weekly",
76+
temp_res == 30 ~ "monthly",
77+
temp_res == 365 ~ "yearly",
78+
.default = as.character(temp_res)
79+
),
7380
target = target_name,
74-
target_end_date = forecast_date + .data[["horizon"]],
81+
target_end_date = forecast_date + .data[["horizon"]] * temp_res,
7582
type = ifelse(as.numeric(.data[["Var1"]] < 1), "quantile", "sample"),
7683
quantile = as.numeric(.data[["Var1"]]),
7784
value = ifelse(.data[["Freq"]] < 0, 0, .data[["Freq"]]),
7885
)
7986

8087
if (keep_bottommost_only) {
81-
hub_df <- dplyr::filter(hub_df, .data[["k"]] == min(hub_df$k))
88+
hub_df <- hub_df |>
89+
dplyr::filter(.data[["k"]] == min(hub_df$k)) |>
90+
dplyr::select(-"level")
8291
}
8392

8493
hub_df |>

0 commit comments

Comments
 (0)