@@ -163,9 +163,10 @@ def ModelPrediction(df_train, forecast_length: int, transformation_dict: dict,
163163
164164 return df_forecast
165165
166- ModelNames = ['ZeroesNaive' , 'LastValueNaive' , 'MedValueNaive' ,
167- 'GLM' , 'ETS' , 'ARIMA' , 'FBProphet' , 'RandomForestRolling' ]
168-
166+ ModelNames = ['ZeroesNaive' , 'LastValueNaive' , 'MedValueNaive' , 'GLS' ,
167+ 'GLM' , 'ETS' , 'ARIMA' , 'FBProphet' , 'RollingRegression' ,
168+ 'UnobservedComponents' , 'VARMAX' , 'VECM' , 'DynamicFactor' ]
169+ # ModelNames = ['RollingRegression']
169170def ModelMonster (model : str , parameters : dict = {}, frequency : str = 'infer' ,
170171 prediction_interval : float = 0.9 , holiday_country : str = 'US' ,
171172 startTimeStamps = None ,
@@ -188,9 +189,17 @@ def ModelMonster(model: str, parameters: dict = {}, frequency: str = 'infer',
188189 from autots .models .basics import MedValueNaive
189190 return MedValueNaive (frequency = frequency , prediction_interval = prediction_interval )
190191
192+ if model == 'GLS' :
193+ from autots .models .statsmodels import GLS
194+ return GLS (frequency = frequency , prediction_interval = prediction_interval )
195+
191196 if model == 'GLM' :
192197 from autots .models .statsmodels import GLM
193- return GLM (frequency = frequency , prediction_interval = prediction_interval )
198+ if parameters == {}:
199+ model = GLM (frequency = frequency , prediction_interval = prediction_interval , holiday_country = holiday_country , random_seed = random_seed , verbose = verbose )
200+ else :
201+ model = GLM (frequency = frequency , prediction_interval = prediction_interval , holiday_country = holiday_country , random_seed = random_seed , verbose = verbose , family = parameters ['family' ])
202+ return model
194203
195204 if model == 'ETS' :
196205 from autots .models .statsmodels import ETS
@@ -216,15 +225,56 @@ def ModelMonster(model: str, parameters: dict = {}, frequency: str = 'infer',
216225 model = FBProphet (frequency = frequency , prediction_interval = prediction_interval , holiday_country = holiday_country , holiday = parameters ['holiday' ], regression_type = parameters ['regression_type' ], random_seed = random_seed , verbose = verbose )
217226 return model
218227
219- if model == 'RandomForestRolling' :
220- from autots .models .sklearn import RandomForestRolling
228+ if model == 'RollingRegression' :
229+ from autots .models .sklearn import RollingRegression
230+ if parameters == {}:
231+ model = RollingRegression (frequency = frequency , prediction_interval = prediction_interval , holiday_country = holiday_country , random_seed = random_seed , verbose = verbose )
232+ else :
233+ model = RollingRegression (frequency = frequency , prediction_interval = prediction_interval , holiday_country = holiday_country , holiday = parameters ['holiday' ], regression_type = parameters ['regression_type' ], random_seed = random_seed , verbose = verbose ,
234+ regression_model = parameters ['regression_model' ], mean_rolling_periods = parameters ['mean_rolling_periods' ], std_rolling_periods = parameters ['std_rolling_periods' ])
235+ return model
236+
237+ if model == 'UnobservedComponents' :
238+ from autots .models .statsmodels import UnobservedComponents
239+ if parameters == {}:
240+ model = UnobservedComponents (frequency = frequency , prediction_interval = prediction_interval , holiday_country = holiday_country , random_seed = random_seed , verbose = verbose )
241+ else :
242+ model = UnobservedComponents (frequency = frequency , prediction_interval = prediction_interval , holiday_country = holiday_country ,
243+ regression_type = parameters ['regression_type' ], random_seed = random_seed , verbose = verbose ,
244+ level = parameters ['level' ], trend = parameters ['trend' ], cycle = parameters ['cycle' ],
245+ damped_cycle = parameters ['damped_cycle' ], irregular = parameters ['irregular' ],
246+ stochastic_trend = parameters ['stochastic_trend' ], stochastic_level = parameters ['stochastic_level' ],
247+ stochastic_cycle = parameters ['stochastic_cycle' ])
248+ return model
249+
250+ if model == 'DynamicFactor' :
251+ from autots .models .statsmodels import DynamicFactor
252+ if parameters == {}:
253+ model = DynamicFactor (frequency = frequency , prediction_interval = prediction_interval , holiday_country = holiday_country , random_seed = random_seed , verbose = verbose )
254+ else :
255+ model = DynamicFactor (frequency = frequency , prediction_interval = prediction_interval , holiday_country = holiday_country ,
256+ regression_type = parameters ['regression_type' ], random_seed = random_seed , verbose = verbose ,
257+ k_factors = parameters ['k_factors' ], factor_order = parameters ['factor_order' ])
258+ return model
259+
260+ if model == 'VECM' :
261+ from autots .models .statsmodels import VECM
221262 if parameters == {}:
222- model = RandomForestRolling (frequency = frequency , prediction_interval = prediction_interval , holiday_country = holiday_country , random_seed = random_seed , verbose = verbose )
263+ model = VECM (frequency = frequency , prediction_interval = prediction_interval , holiday_country = holiday_country , random_seed = random_seed , verbose = verbose )
223264 else :
224- model = RandomForestRolling (frequency = frequency , prediction_interval = prediction_interval , holiday_country = holiday_country , holiday = parameters ['holiday' ], regression_type = parameters ['regression_type' ], random_seed = random_seed , verbose = verbose ,
225- n_estimators = parameters ['n_estimators' ], min_samples_split = parameters ['min_samples_split' ], max_depth = parameters ['max_depth' ], mean_rolling_periods = parameters ['mean_rolling_periods' ], std_rolling_periods = parameters ['std_rolling_periods' ])
265+ model = VECM (frequency = frequency , prediction_interval = prediction_interval , holiday_country = holiday_country ,
266+ regression_type = parameters ['regression_type' ], random_seed = random_seed , verbose = verbose ,
267+ deterministic = parameters ['deterministic' ], k_ar_diff = parameters ['k_ar_diff' ])
226268 return model
227269
270+ if model == 'VARMAX' :
271+ from autots .models .statsmodels import VARMAX
272+ if parameters == {}:
273+ model = VARMAX (frequency = frequency , prediction_interval = prediction_interval , holiday_country = holiday_country , random_seed = random_seed , verbose = verbose )
274+ else :
275+ model = VARMAX (frequency = frequency , prediction_interval = prediction_interval , holiday_country = holiday_country , random_seed = random_seed , verbose = verbose ,
276+ order = parameters ['order' ], trend = parameters ['trend' ])
277+ return model
228278
229279 else :
230280 raise AttributeError ("Model String not found in ModelMonster" )
0 commit comments