1010import numpy as np
1111import optuna
1212import pandas as pd
13+ from pandas .tseries .frequencies import to_offset
1314from joblib import Parallel , delayed
1415from optuna .trial import TrialState
16+ from scipy .stats import linregress
1517from sktime .split import ExpandingWindowSplitter
1618from statsmodels .tsa .exponential_smoothing .ets import ETSModel
1719
@@ -63,6 +65,39 @@ def preprocess(self, data, series_id):
6365 )
6466 return df_encoded .set_index (self .spec .datetime_column .name )
6567
68+ def _auto_detect_ets_params (self , y : pd .Series , seasonal_period : int ) -> Dict [str , Any ]:
69+ """Detect trend, error type, and damping from data."""
70+ params = {}
71+
72+ # Detect trend via linear regression significance
73+ slope , _ , r_value , p_value , _ = linregress (range (len (y )), y .values )
74+ has_trend = p_value < 0.05
75+ params ["trend" ] = "add" if has_trend else None
76+ params ["damped_trend" ] = (has_trend and abs (r_value ) < 0.7 ) # weak trend → damp
77+
78+ # Detect additive vs multiplicative: does variance scale with level?
79+ if seasonal_period and len (y ) >= 2 * seasonal_period :
80+ segments = [y .iloc [i :i + seasonal_period ] for i in range (0 , len (y ) - seasonal_period , seasonal_period )]
81+ means = np .array ([s .mean () for s in segments ])
82+ stds = np .array ([s .std () for s in segments ])
83+ if means .std () > 0 :
84+ corr = np .corrcoef (means , stds )[0 , 1 ]
85+ params ["error" ] = "mul" if corr > 0.7 else "add"
86+ params ["seasonal" ] = params ["error" ] # match seasonal to error type
87+ else :
88+ params ["error" ] = "add"
89+ params ["seasonal" ] = "add"
90+ else :
91+ params ["error" ] = "add"
92+ params ["seasonal" ] = "add"
93+
94+ # Multiplicative requires strictly positive data
95+ if (y <= 0 ).any ():
96+ params ["error" ] = "add"
97+ params ["seasonal" ] = "add"
98+
99+ return params
100+
66101 def _train_model (self , i , series_id , df : pd .DataFrame , model_kwargs : Dict [str , Any ]):
67102 try :
68103 self .forecast_output .init_series_output (series_id = series_id , data_at_series = df )
@@ -73,12 +108,24 @@ def _train_model(self, i, series_id, df: pd.DataFrame, model_kwargs: Dict[str, A
73108
74109 Y = data_i [self .spec .target_column ]
75110 dates = data_i .index .values
111+ is_daily = False
112+ if freq is not None :
113+ try :
114+ is_daily = to_offset (freq ).name in ["D" , "B" ]
115+ except ValueError :
116+ pass
76117
118+ ets_model_args = self .spec .model_kwargs .get ("restrict_daily_series_improvement_flow" , True )
119+ sp , probable_sps = find_seasonal_period_from_dataset (Y )
77120 if model_kwargs ["seasonal" ] is None :
78121 model_kwargs ["seasonal" ] = "add"
79- if model_kwargs [ "seasonal_periods" ] is None :
80- sp , probable_sps = find_seasonal_period_from_dataset ( Y )
122+
123+ if ets_model_args and is_daily :
81124 model_kwargs ["seasonal_periods" ] = sp if sp > 1 else None
125+ else :
126+ auto_params = self ._auto_detect_ets_params (Y , sp )
127+ for k , v in auto_params .items ():
128+ model_kwargs [k ] = v
82129
83130 if self .loaded_models is not None and series_id in self .loaded_models :
84131 previous_res = self .loaded_models [series_id ].get ("model" )
@@ -92,14 +139,26 @@ def _train_model(self, i, series_id, df: pd.DataFrame, model_kwargs: Dict[str, A
92139 if self .perform_tuning :
93140 model_kwargs = self .run_tuning (Y , model_kwargs )
94141
95- use_seasonal = (model_kwargs ["seasonal" ] is not None and
96- model_kwargs ["seasonal_periods" ] is not None and
97- len (Y ) >= 2 * model_kwargs ["seasonal_periods" ]
98- )
142+ if (model_kwargs ["error" ] == "mul" or
143+ model_kwargs ["trend" ] == "mul" or
144+ model_kwargs ["seasonal" ] == "mul" ) and (Y <= 0 ).any ():
145+
146+ logger .info ("Multiplicative model incompatible with non-positive values. Switching to additive." )
147+ model_kwargs ["error" ] = "add"
148+ if model_kwargs ["trend" ] == "mul" :
149+ model_kwargs ["trend" ] = "add"
150+ if model_kwargs ["seasonal" ] == "mul" :
151+ model_kwargs ["seasonal" ] = "add"
152+
153+ use_seasonal = (
154+ model_kwargs ["seasonal" ] is not None and
155+ model_kwargs ["seasonal_periods" ] is not None and
156+ model_kwargs ["seasonal_periods" ] > 1 and
157+ len (Y ) >= max (2 * model_kwargs ["seasonal_periods" ], 10 )
158+ )
99159 if not use_seasonal :
100160 model_kwargs ["seasonal" ] = None
101161 model_kwargs ["seasonal_periods" ] = None
102-
103162 model = ETSModel (Y , error = model_kwargs ["error" ], trend = model_kwargs ["trend" ],
104163 damped_trend = model_kwargs ["damped_trend" ], seasonal = model_kwargs ["seasonal" ],
105164 seasonal_periods = model_kwargs ["seasonal_periods" ],
@@ -145,7 +204,7 @@ def _train_model(self, i, series_id, df: pd.DataFrame, model_kwargs: Dict[str, A
145204 if param in params :
146205 params .pop (param )
147206 self .model_parameters [series_id ] = {
148- "framework" : SupportedModels .Arima ,
207+ "framework" : SupportedModels .ETSForecaster ,
149208 ** params ,
150209 }
151210
0 commit comments