Skip to content

Error with modeltime_refit() using resample spec #11

Open
@spsanderson

Description

data_tbl.xlsx

I am running a script where I have created a cross validation that is being passed to modeltime_refit I do believe this may be an underlying issue with tune but am posting here because I am using the modeltime_refit function. Data is attached.

> sessionInfo()
R version 4.0.3 (2020-10-10)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 19041)

Matrix products: default

Random number generation:
 RNG:     L'Ecuyer-CMRG 
 Normal:  Inversion 
 Sample:  Rejection 
 
locale:
[1] LC_COLLATE=English_United States.1252  LC_CTYPE=English_United States.1252   
[3] LC_MONETARY=English_United States.1252 LC_NUMERIC=C                          
[5] LC_TIME=English_United States.1252    

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] forecast_8.14              xgboost_1.4.1.1            vctrs_0.3.8               
 [4] rlang_0.4.11               modeltime.h2o_0.1.1        h2o_3.32.1.3              
 [7] modeltime.ensemble_0.4.1   modeltime.resample_0.2.0   tidyquant_1.0.3           
[10] quantmod_0.4.18            TTR_0.24.2                 PerformanceAnalytics_2.0.4
[13] xts_0.12.1                 zoo_1.8-9                  janitor_2.1.0             
[16] DBI_1.1.1                  odbc_1.3.2                 timetk_2.6.1              
[19] lubridate_1.7.10           forcats_0.5.1              stringr_1.4.0             
[22] readr_1.4.0                tidyverse_1.3.1            modeltime_0.6.0           
[25] yardstick_0.0.8            workflowsets_0.0.2         workflows_0.2.2           
[28] tune_0.1.5                 tidyr_1.1.3                tibble_3.1.2              
[31] rsample_0.1.0              recipes_0.1.16             purrr_0.3.4               
[34] parsnip_0.1.6              modeldata_0.1.0            infer_0.5.4               
[37] ggplot2_3.3.3              dplyr_1.0.6                dials_0.0.9               
[40] scales_1.1.1               broom_0.7.6                tidymodels_0.1.3          
[43] pacman_0.5.1              

loaded via a namespace (and not attached):
  [1] readxl_1.3.1         backports_1.2.1      plyr_1.8.6           lazyeval_0.2.2      
  [5] splines_4.0.3        crosstalk_1.1.1      listenv_0.8.0        inline_0.3.19       
  [9] digest_0.6.27        foreach_1.5.1        htmltools_0.5.1.1    earth_5.3.0         
 [13] fansi_0.5.0          magrittr_2.0.1       xlsx_0.6.5           globals_0.14.0      
 [17] modelr_0.1.8         gower_0.2.2          matrixStats_0.59.0   RcppParallel_5.1.4  
 [21] hardhat_0.1.5        prettyunits_1.1.1    tseries_0.10-48      colorspace_2.0-1    
 [25] blob_1.2.1           rvest_1.0.0          haven_2.4.1          callr_3.7.0         
 [29] crayon_1.4.1         RCurl_1.98-1.3       jsonlite_1.7.2       progressr_0.7.0     
 [33] survival_3.2-11      iterators_1.0.13     glue_1.4.2           gtable_0.3.0        
 [37] ipred_0.9-11         V8_3.4.2             pkgbuild_1.2.0       rstan_2.21.2        
 [41] Quandl_2.10.0        Rcpp_1.0.6           plotrix_3.8-1        viridisLite_0.4.0   
 [45] GPfit_1.0-8          bit_4.0.4            Formula_1.2-4        stats4_4.0.3        
 [49] lava_1.6.9           StanHeaders_2.21.0-7 prodlim_2019.11.13   htmlwidgets_1.5.3   
 [53] httr_1.4.2           ellipsis_0.3.2       rJava_1.0-4          loo_2.4.1           
 [57] pkgconfig_2.0.3      farver_2.1.0         nnet_7.3-16          dbplyr_2.1.1        
 [61] utf8_1.2.1           tidyselect_1.1.1     labeling_0.4.2       DiceDesign_1.9      
 [65] reactR_0.4.4         TeachingDemos_2.12   munsell_0.5.0        cellranger_1.1.0    
 [69] tools_4.0.3          cli_2.5.0            generics_0.1.0       yaml_2.2.1          
 [73] processx_3.5.2       bit64_4.0.5          fs_1.5.0             nlme_3.1-152        
 [77] future_1.21.0        reactable_0.2.3      tictoc_1.0.1         xml2_1.3.2          
 [81] LICHospitalR_0.2.0   compiler_4.0.3       rstudioapi_0.13      plotly_4.9.3        
 [85] curl_4.3.1           reprex_2.0.0         lhs_1.1.1            stringi_1.6.2       
 [89] plotmo_3.6.0         ps_1.6.0             lattice_0.20-44      Matrix_1.3-4        
 [93] urca_1.3-0           pillar_1.6.1         lifecycle_1.0.0      furrr_0.2.2         
 [97] lmtest_0.9-38        data.table_1.14.0    bitops_1.0-7         R6_2.5.0            
[101] gridExtra_2.3        parallelly_1.25.0    codetools_0.2-18     MASS_7.3-54         
[105] assertthat_0.2.1     xlsxjars_0.6.1       withr_2.4.2          fracdiff_1.5-1      
[109] parallel_4.0.3       hms_1.1.0            quadprog_1.5-8       grid_4.0.3          
[113] rpart_4.1-15         timeDate_3043.102    class_7.3-19         snakecase_0.11.0    
[117] prophet_1.0          pROC_1.17.0.1    

Script Fails at this:

resample_tscv <- training(splits) %>%
  time_series_cv(
    date_var      = date_col
    , assess      = "12 months"
    , initial     = "24 months"
    , skip        = "3 months"
    , slice_limit = 1
  )

refit_tbl <- calibration_tbl %>%
  modeltime_refit(
    data        = data_tbl
    , resamples = resample_tscv
    , control   = control_resamples(verbose = TRUE)
  ) # Fail

The error message produced:

> refit_tbl <- calibration_tbl %>%
+   modeltime_refit(
+     data        = data_tbl
+     , resamples = resample_tscv
+     , control   = control_resamples(verbose = TRUE)
+   )
Error in if ((control$cores > 1) && control$allow_par) { : 
  missing value where TRUE/FALSE needed

Full script:

# Lib Load ----------------------------------------------------------------

if(!require(pacman)) install.packages("pacman")
pacman::p_load(
  "tidymodels",
  "modeltime",
  "tidyverse",
  "lubridate",
  "timetk",
  "odbc",
  "DBI",
  "janitor",
  "timetk",
  "tidyquant",
  "modeltime.ensemble",
  "modeltime.resample",
  "modeltime.h2o"
)

interactive <- TRUE

# Read Data ----
data_final_tbl <- # read in the excel file here

# Data Split --------------------------------------------------------------
data_final_tbl <- data_tbl %>%
  select(date_col, excess_days)

splits <- initial_time_split(
  data_final_tbl
  , prop = 0.8
  , cumulative = TRUE
)

# Features ----------------------------------------------------------------

recipe_base <- recipe(excess_days ~ ., data = training(splits)) %>%
  step_timeseries_signature(date_col)

recipe_final <- recipe_base %>%
  step_rm(matches("(iso$)|(xts$)|(hour)|(min)|(sec)|(am.pm)")) %>%
  step_normalize(contains("index.num"), date_col_year) %>%
  step_dummy(contains("lbl"), one_hot = TRUE) %>%
  step_fourier(date_col, period = 365/12, K = 2) %>%
  step_holiday_signature(date_col) %>%
  step_YeoJohnson(excess_days)

# Models ------------------------------------------------------------------

# Auto ARIMA --------------------------------------------------------------

model_spec_arima_no_boost <- arima_reg() %>%
  set_engine(engine = "auto_arima")

wflw_fit_arima_no_boost <- workflow() %>%
  add_recipe(recipe = recipe_final) %>%
  add_model(model_spec_arima_no_boost) %>%
  fit(training(splits))

# Boosted Auto ARIMA ------------------------------------------------------

model_spec_arima_boosted <- arima_boost(
    min_n = 2
    , learn_rate = 0.015
  ) %>%
  set_engine(engine = "auto_arima_xgboost")

wflw_fit_arima_boosted <- workflow() %>%
  add_recipe(recipe = recipe_final) %>%
  add_model(model_spec_arima_boosted) %>%
  fit(training(splits))


# ETS ---------------------------------------------------------------------

model_spec_ets <- exp_smoothing() %>%
  set_engine(engine = "ets") 

wflw_fit_ets <- workflow() %>%
  add_recipe(recipe = recipe_final) %>%
  add_model(model_spec_ets) %>%
  fit(training(splits))

model_spec_croston <- exp_smoothing() %>%
  set_engine(engine = "croston")

wflw_fit_croston <- workflow() %>%
  add_recipe(recipe = recipe_final) %>%
  add_model(model_spec_croston) %>%
  fit(training(splits))

model_spec_theta <- exp_smoothing() %>%
  set_engine(engine = "theta")

wflw_fit_theta <- workflow() %>%
  add_recipe(recipe = recipe_final) %>%
  add_model(model_spec_theta) %>%
  fit(training(splits))


# STLM ETS ----------------------------------------------------------------

model_spec_stlm_ets <- seasonal_reg() %>%
  set_engine("stlm_ets")

wflw_fit_stlm_ets <- workflow() %>%
  add_recipe(recipe = recipe_final) %>%
  add_model(model_spec_stlm_ets) %>%
  fit(training(splits))

model_spec_stlm_tbats <- seasonal_reg() %>%
  set_engine("tbats")

wflw_fit_stlm_tbats <- workflow() %>%
  add_recipe(recipe = recipe_final) %>%
  add_model(model_spec_stlm_tbats) %>%
  fit(training(splits))

model_spec_stlm_arima <- seasonal_reg() %>%
  set_engine("stlm_arima")

wflw_fit_stlm_arima <- workflow() %>%
  add_recipe(recipe = recipe_final) %>%
  add_model(model_spec_stlm_arima) %>%
  fit(training(splits))

# NNETAR ------------------------------------------------------------------

model_spec_nnetar <- nnetar_reg() %>%
  set_engine("nnetar")

wflw_fit_nnetar <- workflow() %>%
  add_recipe(recipe = recipe_final) %>%
  add_model(model_spec_nnetar) %>%
  fit(training(splits))

# Prophet -----------------------------------------------------------------

model_spec_prophet <- prophet_reg() %>%
  set_engine(engine = "prophet")

wflw_fit_prophet <- workflow() %>%
  add_recipe(recipe = recipe_final) %>%
  add_model(model_spec_prophet) %>%
  fit(training(splits))

model_spec_prophet_boost <- prophet_boost(learn_rate = 0.1) %>% 
  set_engine("prophet_xgboost") 

wflw_fit_prophet_boost <- workflow() %>%
  add_recipe(recipe = recipe_final) %>%
  add_model(model_spec_prophet_boost) %>%
  fit(training(splits))

# TSLM --------------------------------------------------------------------

model_spec_lm <- linear_reg() %>%
  set_engine("lm")

wflw_fit_lm <- workflow() %>%
  add_recipe(recipe = recipe_final) %>%
  add_model(model_spec_lm) %>%
  fit(training(splits))


# MARS --------------------------------------------------------------------

model_spec_mars <- mars(mode = "regression") %>%
  set_engine("earth")

wflw_fit_mars <- workflow() %>%
  add_recipe(recipe = recipe_final) %>%
  add_model(model_spec_mars) %>%
  fit(training(splits))

# Model Table -------------------------------------------------------------

models_tbl <- modeltime_table(
  #wflw_fit_arima_no_boost,
  wflw_fit_arima_boosted,
  wflw_fit_ets,
  wflw_fit_theta,
  wflw_fit_stlm_ets,
  wflw_fit_stlm_tbats,
  wflw_fit_nnetar,
  wflw_fit_prophet,
  wflw_fit_prophet_boost,
  wflw_fit_lm, 
  wflw_fit_mars
)

# Model Ensemble Table ----------------------------------------------------
resample_tscv <- training(splits) %>%
  time_series_cv(
    date_var      = date_col
    , assess      = "12 months"
    , initial     = "24 months"
    , skip        = "3 months"
    , slice_limit = 1
  )

# This fails from #
submodel_predictions <- models_tbl %>%
  modeltime_fit_resamples(
    resamples = resample_tscv
    , control = control_resamples(verbose = TRUE)
  )

ensemble_fit <- submodel_predictions %>%
  ensemble_model_spec(
    model_spec = linear_reg(
      penalty  = tune()
      , mixture = tune()
    ) %>%
      set_engine("glmnet")
    , kfold    = 5
    , grid     = 6
    , control  = control_grid(verbose = TRUE)
  )

fit_mean_ensemble <- models_tbl %>%
  ensemble_average(type = "mean")

fit_median_ensemble <- models_tbl %>%
  ensemble_average(type = "median")

# Model Table -------------------------------------------------------------

models_tbl <- modeltime_table(
  #wflw_fit_arima_no_boost,
  wflw_fit_arima_boosted,
  wflw_fit_ets,
  wflw_fit_theta,
  wflw_fit_stlm_ets,
  wflw_fit_stlm_tbats,
  wflw_fit_nnetar,
  wflw_fit_prophet,
  wflw_fit_prophet_boost,
  wflw_fit_lm, 
  wflw_fit_mars,
  fit_mean_ensemble,
  fit_median_ensemble
)

models_tbl

# Calibrate Model Testing -------------------------------------------------

calibration_tbl <- models_tbl %>%
  modeltime_calibrate(new_data = testing(splits))
calibration_tbl

# Testing Accuracy --------------------------------------------------------

calibration_tbl %>%
  modeltime_forecast(
    new_data    = testing(splits),
    actual_data = data_tbl
  ) %>%
  plot_modeltime_forecast(
    .legend_max_width   = 25,
    .interactive        = interactive,
    .conf_interval_show = FALSE
  )

calibration_tbl %>%
  modeltime_accuracy() %>%
  arrange(mae) %>%
  table_modeltime_accuracy(resizable = TRUE, bordered = TRUE)

# Refit to all Data -------------------------------------------------------
# **** Failure **** ----
refit_tbl <- calibration_tbl %>%
  modeltime_refit(
    data        = data_tbl
    , resamples = resample_tscv
    , control   = control_resamples(verbose = TRUE)
  )

top_two_models <- refit_tbl %>% 
  modeltime_accuracy() %>% 
  arrange(mae) %>% 
  head(2)

ensemble_models <- refit_tbl %>%
  filter(
    .model_desc %>% 
      str_to_lower() %>%
      str_detect("ensemble")
  ) %>%
  modeltime_accuracy()

model_choices <- rbind(top_two_models, ensemble_models)

refit_tbl %>%
  filter(.model_id %in% model_choices$.model_id) %>%
  modeltime_forecast(h = "1 year", actual_data = data_tbl) %>%
  plot_modeltime_forecast(
    .legend_max_width     = 25
    , .interactive        = FALSE
    , .conf_interval_show = FALSE
  )

# Misc --------------------------------------------------------------------
models_tbl %>%
  modeltime_calibrate(new_data = testing(splits)) %>%
  modeltime_residuals() %>%
  plot_modeltime_residuals()

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions