@@ -233,64 +233,66 @@ generate_stacking_results <- function(object,
233233 grid = 6 ,
234234 control = control_grid()) {
235235
236- # 1. Fit Resamples ----
237- # - This is now performed separately with modeltime_fit_resamples()
236+ if (control $ verbose ) tictoc :: tic()
238237
239238 # 2. Wrangle Predictions ----
240239 predictions_tbl <- modeltime.resample :: unnest_modeltime_resamples(object )
241240
242- # Target Variable is the name in the data
241+ # Target variable name comes right after .model_desc (new tune) or .row (old tune)
243242 if (utils :: packageVersion(" tune" ) > = " 1.3.0.9006" ) {
244243 target_text <- predictions_tbl %> %
245244 modeltime.resample :: get_target_text_from_resamples(column_before_target = " .model_desc" )
246245 } else {
247246 target_text <- predictions_tbl %> %
248247 modeltime.resample :: get_target_text_from_resamples(column_before_target = " .row" )
249248 }
250- target_var <- rlang :: sym(target_text )
249+ target_var <- rlang :: sym(target_text )
251250
251+ # Keep resample id so keys are unique across slices
252252 predictions_tbl <- predictions_tbl %> %
253- dplyr :: select(.row_id , .model_id , .pred , !! target_var )
253+ dplyr :: select(.resample_id , . row_id , .model_id , .pred , !! target_var )
254254
255- # * Actuals By Row ID ----
255+ # Defuse any list-column predictions (can arise when duplicates exist pre-pivot)
256+ if (is.list(predictions_tbl $ .pred )) {
257+ predictions_tbl <- predictions_tbl %> %
258+ dplyr :: mutate(.pred = purrr :: map_dbl(.pred , ~ if (length(.x )) as.numeric(.x )[1 ] else NA_real_ ))
259+ }
260+
261+ # * Actuals: one row per resample + row id
256262 actuals_by_rowid_tbl <- predictions_tbl %> %
257- dplyr :: filter(.model_id %in% unique(.model_id )[1 ]) %> %
258- dplyr :: select(.row_id , !! target_var )
263+ dplyr :: distinct(.resample_id , .row_id , !! target_var )
259264
260- # * Get Predictions by Row ID ----
265+ # * Predictions wide: id by resample + row id; columns per model
261266 predictions_by_rowid_tbl <- predictions_tbl %> %
262- dplyr :: select(.row_id , .model_id , .pred ) %> %
267+ dplyr :: select(.resample_id , . row_id , .model_id , .pred ) %> %
263268 dplyr :: mutate(.model_id = stringr :: str_c(" .model_id_" , .model_id )) %> %
264269 tidyr :: pivot_wider(
265- names_from = .model_id ,
270+ id_cols = c(.resample_id , .row_id ),
271+ names_from = .model_id ,
266272 values_from = .pred
267273 )
268274
269- # * Join Actuals & Predictions ----
275+ # * Join Actuals & Predictions
270276 data_prepared_tbl <- actuals_by_rowid_tbl %> %
271- dplyr :: left_join(predictions_by_rowid_tbl , by = " . row_id" )
277+ dplyr :: left_join(predictions_by_rowid_tbl , by = c( " .resample_id " , " . row_id" ) )
272278
273279 # 3. Build Model ----
274-
275280 form <- stats :: formula(stringr :: str_glue(" {target_text} ~ ." ))
276281
277282 recipe_spec <- recipes :: recipe(
278283 formula = form ,
279- data = data_prepared_tbl %> % dplyr :: select(- .row_id )
284+ data = data_prepared_tbl %> % dplyr :: select(- .resample_id , - . row_id )
280285 )
281286
282287 wflw_spec <- workflows :: workflow() %> %
283288 workflows :: add_model(model_spec ) %> %
284289 workflows :: add_recipe(recipe_spec )
285290
286- # **** Split Paths (Tuned vs Non-Tuned) **** ----
291+ # Tuned vs non-tuned paths
292+ tune_args_tbl <- wflw_spec %> % tune :: tune_args()
293+ tuning_required <- nrow(tune_args_tbl ) > 0
287294
288- tune_args_tbl <- wflw_spec %> % tune :: tune_args()
289- tuning_required <- nrow(tune_args_tbl ) > 0
290-
291- # 4A. Tune Model ----
292295 if (tuning_required ) {
293-
294296 if (control $ verbose ) {
295297 print(cli :: rule(" Tuning Model Specification" , width = 65 ))
296298 cli :: cli_alert_info(stringr :: str_glue(" Performing {kfolds}-Fold Cross Validation." ))
@@ -326,37 +328,26 @@ generate_stacking_results <- function(object,
326328 }
327329
328330 final_model <- wflw_spec %> %
329- tune :: finalize_workflow(
330- best_params_tbl
331- ) %> %
331+ tune :: finalize_workflow(best_params_tbl ) %> %
332332 generics :: fit(data_prepared_tbl )
333333
334- }
335-
336- # 4B. No Tuning -----
337- if (! tuning_required ) {
334+ } else {
338335
339336 if (control $ verbose ) {
340337 print(cli :: rule(" Fitting Non-Tunable Model Specification" , width = 65 ))
341- cli :: cli_alert_info(stringr :: str_glue( " Fitting model spec to submodel cross-validation predictions." ) )
338+ cli :: cli_alert_info(" Fitting model spec to submodel cross-validation predictions." )
342339 cli :: cat_line()
343340 }
344341
345342 best_params_tbl <- NULL
346343
347344 final_model <- wflw_spec %> %
348345 generics :: fit(data_prepared_tbl )
349-
350346 }
351347
352-
353-
354348 # 5. Fit Best Model ----
355-
356349 pred_tbl <- data_prepared_tbl %> %
357- dplyr :: bind_cols(
358- stats :: predict(final_model , data_prepared_tbl )
359- )
350+ dplyr :: bind_cols(stats :: predict(final_model , data_prepared_tbl ))
360351
361352 cv_comparison_tbl <- pred_tbl %> %
362353 dplyr :: rename(.model_id_ensemble = .pred ) %> %
@@ -366,15 +357,15 @@ generate_stacking_results <- function(object,
366357 values_to = " .preds"
367358 ) %> %
368359 dplyr :: group_by(.model_id ) %> %
369- dplyr :: summarise(rmse = yardstick :: rmse_vec(!! target_var , .preds ), .groups = " drop" ) %> %
360+ dplyr :: summarise(rmse = yardstick :: rmse_vec(!! target_var , .preds ), .groups = " drop" ) %> %
370361 dplyr :: mutate(.model_id = stringr :: str_remove(.model_id , " .model_id_" )) %> %
371362 dplyr :: left_join(
372363 object %> %
373364 dplyr :: select(.model_id , .model_desc ) %> %
374365 dplyr :: mutate(.model_id = as.character(.model_id )),
375366 by = " .model_id"
376367 ) %> %
377- dplyr :: mutate(.model_desc = ifelse (is.na(.model_desc ), " ENSEMBLE (MODEL SPEC)" , .model_desc ))
368+ dplyr :: mutate(.model_desc = dplyr :: if_else (is.na(.model_desc ), " ENSEMBLE (MODEL SPEC)" , .model_desc ))
378369
379370 if (control $ verbose ) {
380371 cli :: cli_alert_info(" Prediction Error Comparison:" )
@@ -383,27 +374,22 @@ generate_stacking_results <- function(object,
383374 }
384375
385376 if (control $ verbose ) print(cli :: rule(" Final Model" , width = 65 ))
386-
387377 if (control $ verbose ) {
388-
389378 cli :: cat_line()
390379 cli :: cli_alert_info(" Model Workflow:" )
391380 print(final_model )
392381 cli :: cat_line()
393382 }
394383
395- # Return ----
396- ret <- list (
384+ list (
397385 fit = final_model ,
398386 fit_params = best_params_tbl ,
399387 prediction_tbl = pred_tbl ,
400388 prediction_error_tbl = cv_comparison_tbl
401389 )
402-
403- return (ret )
404-
405390}
406391
407392
408393
409394
395+
0 commit comments