Skip to content

Commit 0b8187b

Browse files
authored
Theta and ETS forecaster improvements (#1354)
1 parent 9f04e42 commit 0b8187b

File tree

4 files changed

+169
-47
lines changed

4 files changed

+169
-47
lines changed

ads/opctl/operator/lowcode/forecast/environment.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ dependencies:
88
- oracle-ads>=2.9.0
99
- prophet
1010
- neuralprophet
11-
- mlforecast
11+
- mlforecast==1.0.2
1212
- pmdarima
1313
- statsmodels
1414
- report-creator

ads/opctl/operator/lowcode/forecast/model/ets.py

Lines changed: 67 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
import numpy as np
1111
import optuna
1212
import pandas as pd
13+
from pandas.tseries.frequencies import to_offset
1314
from joblib import Parallel, delayed
1415
from optuna.trial import TrialState
16+
from scipy.stats import linregress
1517
from sktime.split import ExpandingWindowSplitter
1618
from statsmodels.tsa.exponential_smoothing.ets import ETSModel
1719

@@ -63,6 +65,39 @@ def preprocess(self, data, series_id):
6365
)
6466
return df_encoded.set_index(self.spec.datetime_column.name)
6567

68+
def _auto_detect_ets_params(self, y: pd.Series, seasonal_period: int) -> Dict[str, Any]:
69+
"""Detect trend, error type, and damping from data."""
70+
params = {}
71+
72+
# Detect trend via linear regression significance
73+
slope, _, r_value, p_value, _ = linregress(range(len(y)), y.values)
74+
has_trend = p_value < 0.05
75+
params["trend"] = "add" if has_trend else None
76+
params["damped_trend"] = (has_trend and abs(r_value) < 0.7) # weak trend → damp
77+
78+
# Detect additive vs multiplicative: does variance scale with level?
79+
if seasonal_period and len(y) >= 2 * seasonal_period:
80+
segments = [y.iloc[i:i + seasonal_period] for i in range(0, len(y) - seasonal_period, seasonal_period)]
81+
means = np.array([s.mean() for s in segments])
82+
stds = np.array([s.std() for s in segments])
83+
if means.std() > 0:
84+
corr = np.corrcoef(means, stds)[0, 1]
85+
params["error"] = "mul" if corr > 0.7 else "add"
86+
params["seasonal"] = params["error"] # match seasonal to error type
87+
else:
88+
params["error"] = "add"
89+
params["seasonal"] = "add"
90+
else:
91+
params["error"] = "add"
92+
params["seasonal"] = "add"
93+
94+
# Multiplicative requires strictly positive data
95+
if (y <= 0).any():
96+
params["error"] = "add"
97+
params["seasonal"] = "add"
98+
99+
return params
100+
66101
def _train_model(self, i, series_id, df: pd.DataFrame, model_kwargs: Dict[str, Any]):
67102
try:
68103
self.forecast_output.init_series_output(series_id=series_id, data_at_series=df)
@@ -73,12 +108,24 @@ def _train_model(self, i, series_id, df: pd.DataFrame, model_kwargs: Dict[str, A
73108

74109
Y = data_i[self.spec.target_column]
75110
dates = data_i.index.values
111+
is_daily = False
112+
if freq is not None:
113+
try:
114+
is_daily = to_offset(freq).name in ["D", "B"]
115+
except ValueError:
116+
pass
76117

118+
ets_model_args = self.spec.model_kwargs.get("restrict_daily_series_improvement_flow", True)
119+
sp, probable_sps = find_seasonal_period_from_dataset(Y)
77120
if model_kwargs["seasonal"] is None:
78121
model_kwargs["seasonal"] = "add"
79-
if model_kwargs["seasonal_periods"] is None:
80-
sp, probable_sps = find_seasonal_period_from_dataset(Y)
122+
123+
if ets_model_args and is_daily:
81124
model_kwargs["seasonal_periods"] = sp if sp > 1 else None
125+
else:
126+
auto_params = self._auto_detect_ets_params(Y, sp)
127+
for k, v in auto_params.items():
128+
model_kwargs[k] = v
82129

83130
if self.loaded_models is not None and series_id in self.loaded_models:
84131
previous_res = self.loaded_models[series_id].get("model")
@@ -92,14 +139,26 @@ def _train_model(self, i, series_id, df: pd.DataFrame, model_kwargs: Dict[str, A
92139
if self.perform_tuning:
93140
model_kwargs = self.run_tuning(Y, model_kwargs)
94141

95-
use_seasonal = (model_kwargs["seasonal"] is not None and
96-
model_kwargs["seasonal_periods"] is not None and
97-
len(Y) >= 2 * model_kwargs["seasonal_periods"]
98-
)
142+
if (model_kwargs["error"] == "mul" or
143+
model_kwargs["trend"] == "mul" or
144+
model_kwargs["seasonal"] == "mul") and (Y <= 0).any():
145+
146+
logger.info("Multiplicative model incompatible with non-positive values. Switching to additive.")
147+
model_kwargs["error"] = "add"
148+
if model_kwargs["trend"] == "mul":
149+
model_kwargs["trend"] = "add"
150+
if model_kwargs["seasonal"] == "mul":
151+
model_kwargs["seasonal"] = "add"
152+
153+
use_seasonal = (
154+
model_kwargs["seasonal"] is not None and
155+
model_kwargs["seasonal_periods"] is not None and
156+
model_kwargs["seasonal_periods"] > 1 and
157+
len(Y) >= max(2 * model_kwargs["seasonal_periods"], 10)
158+
)
99159
if not use_seasonal:
100160
model_kwargs["seasonal"] = None
101161
model_kwargs["seasonal_periods"] = None
102-
103162
model = ETSModel(Y, error=model_kwargs["error"], trend=model_kwargs["trend"],
104163
damped_trend=model_kwargs["damped_trend"], seasonal=model_kwargs["seasonal"],
105164
seasonal_periods=model_kwargs["seasonal_periods"],
@@ -145,7 +204,7 @@ def _train_model(self, i, series_id, df: pd.DataFrame, model_kwargs: Dict[str, A
145204
if param in params:
146205
params.pop(param)
147206
self.model_parameters[series_id] = {
148-
"framework": SupportedModels.Arima,
207+
"framework": SupportedModels.ETSForecaster,
149208
**params,
150209
}
151210

ads/opctl/operator/lowcode/forecast/model/theta.py

Lines changed: 100 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,17 @@
1313
from joblib import Parallel, delayed
1414
from optuna.trial import TrialState
1515
from sktime.forecasting.base import ForecastingHorizon
16+
from sktime.forecasting.compose import ForecastingPipeline
17+
from sktime.forecasting.model_selection import ForecastingGridSearchCV
1618
from sktime.forecasting.theta import ThetaForecaster
19+
from sktime.param_est.seasonality import SeasonalityACF, SeasonalityPeriodogram
1720
from sktime.split import ExpandingWindowSplitter
1821
from sktime.transformations.series.detrend import Deseasonalizer
1922

2023
from ads.opctl import logger
21-
from ads.opctl.operator.lowcode.common.utils import find_seasonal_period_from_dataset, normalize_frequency
24+
from ads.opctl.operator.lowcode.common.utils import normalize_frequency
2225
from ads.opctl.operator.lowcode.forecast.operator_config import ForecastOperatorConfig
2326
from ads.opctl.operator.lowcode.forecast.utils import (_label_encode_dataframe, _build_metrics_df)
24-
from .base_model import ForecastOperatorBaseModel
2527
from .forecast_datasets import ForecastDatasets, ForecastOutput
2628
from .univariate_model import UnivariateForecasterOperatorModel
2729
from ..const import (
@@ -64,65 +66,126 @@ def preprocess(self, data, series_id):
6466
)
6567
return df_encoded.set_index(self.spec.datetime_column.name)
6668

69+
def _get_sp_candidates(self, y, freq):
70+
"""Finds SP candidates using 1. Freq, 2. ACF, and 3. Periodogram."""
71+
candidates = set()
72+
73+
# 1. Frequency mapping
74+
freq_map = {'H': 24, 'D': 7, 'W': 52, 'M': 12, 'Q': 4, 'Y': 1}
75+
if freq:
76+
base_freq = "".join(filter(str.isalpha, freq))
77+
if base_freq in freq_map:
78+
candidates.add(freq_map[base_freq])
79+
80+
# 2. SeasonalityACF
81+
try:
82+
acf_est = SeasonalityACF()
83+
acf_est.fit(y)
84+
candidates.add(acf_est.get_fitted_params()["sp"])
85+
except Exception as e:
86+
logger.debug(f"Unable to find seasonality using ACF: {e}")
87+
88+
# 3. SeasonalityPeriodogram
89+
try:
90+
period_est = SeasonalityPeriodogram()
91+
period_est.fit(y)
92+
candidates.add(period_est.get_fitted_params()["sp"])
93+
except Exception as e:
94+
logger.debug(f"Unable to find seasonality using SeasonalityPeriodogram: {e}")
95+
96+
valid_candidates = [int(sp) for sp in candidates if len(y) >= 2 * sp]
97+
if 1 not in valid_candidates:
98+
valid_candidates.append(1)
99+
logger.debug(f"Found {valid_candidates} seasonality candidates")
100+
return sorted(list(valid_candidates))
101+
67102
def _train_model(self, i, series_id, df: pd.DataFrame, model_kwargs: Dict[str, Any]):
68103
try:
69104
self.forecast_output.init_series_output(series_id=series_id, data_at_series=df)
70105
data = self.preprocess(df, series_id)
71-
72106
data_i = self.drop_horizon(data)
73107
target = self.spec.target_column
74108

75-
freq = self.datasets.get_datetime_frequency() if self.datasets.get_datetime_frequency() is not None else pd.infer_freq(
76-
data_i.index)
109+
freq = self.datasets.get_datetime_frequency() or pd.infer_freq(data_i.index)
110+
normalized_freq = normalize_frequency(freq)
77111
if freq is not None:
78-
normalized_freq = normalize_frequency(freq)
79112
data_i.index = data_i.index.to_period(normalized_freq)
80113

81114
y = data_i[target]
82115
X_in = data_i.drop(target, axis=1)
83116

84-
if model_kwargs["deseasonalize"] and model_kwargs["sp"] is None:
85-
sp, probable_sps = find_seasonal_period_from_dataset(y)
117+
# --- 1. Determine Deseasonalization Strategy ---
118+
using_additive_deseasonalization = False
119+
additive_deseasonalizer = None
120+
121+
# If negative values exist, we must use manual additive deseasonalization
122+
if model_kwargs.get("deseasonalize", True) and (y <= 0).any():
123+
logger.info(f"Negative values detected in {series_id}. Using manual additive deseasonalization.")
124+
using_additive_deseasonalization = True
125+
model_kwargs["deseasonalize_model"] = "add"
86126
else:
87-
sp, probable_sps = 1, [1]
127+
model_kwargs["deseasonalize_model"] = "mul"
88128

89-
model_kwargs["sp"] = model_kwargs.get("sp") or sp
129+
sp_candidates = self._get_sp_candidates(y, normalized_freq)
90130

91-
if not sp or len(y) < 2 * model_kwargs["sp"]:
92-
model_kwargs["deseasonalize"] = False
93-
94-
# If model already loaded, extract parameters (best-effort)
95131
if self.loaded_models is not None and series_id in self.loaded_models:
96132
previous_res = self.loaded_models[series_id].get("model")
97133
fitted_params = previous_res.get_fitted_params()
98134
model_kwargs["initial_level"] = fitted_params.get("initial_level", None)
99135
elif self.perform_tuning:
100-
model_kwargs = self.run_tuning(y, X_in, model_kwargs, probable_sps)
136+
model_kwargs = self.run_tuning(y, X_in, model_kwargs, sp_candidates)
137+
elif model_kwargs.get("sp") is None:
138+
logger.debug(f"Found {sp_candidates} SP candidates")
139+
sp_candidates.append(1)
140+
if not sp_candidates:
141+
best_sp = 1
142+
elif len(sp_candidates) == 1:
143+
best_sp = sp_candidates[0]
144+
else:
145+
cv = ExpandingWindowSplitter(
146+
initial_window=min(int(len(y) * 0.7), len(y) - self.spec.horizon),
147+
step_length=max(1, int(self.spec.horizon / 2)),
148+
fh=range(1, self.spec.horizon + 1)
149+
)
101150

102-
# Fit ThetaModel using params
103-
using_additive_deseasonalization = False
104-
additive_deseasonalizer = None
105-
if model_kwargs["deseasonalize"]:
106-
if (y <= 0).any():
107-
logger.warning(
108-
"Processing data with additive deseasonalization model as data contains negative or zero values which can't be deseasonalized using multiplicative deseasonalization. And ThetaForecaster by default only supports multiplicative deseasonalization.")
109-
model_kwargs["deseasonalize_model"] = "add"
110-
using_additive_deseasonalization = True
111-
additive_deseasonalizer = Deseasonalizer(
112-
sp=model_kwargs["sp"],
113-
model="additive",
151+
if using_additive_deseasonalization:
152+
forecaster = ForecastingPipeline([
153+
("deseasonalize", Deseasonalizer(model="additive")),
154+
("theta", ThetaForecaster(deseasonalize=False))
155+
])
156+
param_grid = {"deseasonalize__sp": sp_candidates}
157+
else:
158+
forecaster = ThetaForecaster(deseasonalize=True)
159+
param_grid = {"sp": sp_candidates}
160+
161+
gscv = ForecastingGridSearchCV(
162+
forecaster=forecaster,
163+
cv=cv,
164+
param_grid=param_grid,
114165
)
115-
y_adj = additive_deseasonalizer.fit_transform(y)
116-
y = y_adj
117-
model_kwargs["deseasonalize"] = False
118-
else:
119-
model_kwargs["deseasonalize_model"] = ""
166+
gscv.fit(y, X=X_in)
167+
168+
# Extract the best sp based on which param name was used
169+
best_params = gscv.best_params_
170+
logger.info(f"Found {best_params} from seasonality candidates")
171+
best_sp = best_params.get("deseasonalize__sp") or best_params.get("sp")
120172

121-
model = ThetaForecaster(initial_level=model_kwargs["initial_level"],
122-
deseasonalize=model_kwargs["deseasonalize"],
123-
sp=1 if model_kwargs["deseasonalize_model"] == "add" else model_kwargs.get("sp",
124-
1), )
125-
model.fit(y, X=X_in)
173+
model_kwargs["sp"] = best_sp
174+
175+
y_to_fit = y.copy()
176+
if using_additive_deseasonalization and model_kwargs["sp"] > 1:
177+
additive_deseasonalizer = Deseasonalizer(sp=model_kwargs["sp"], model="additive")
178+
y_to_fit = additive_deseasonalizer.fit_transform(y)
179+
180+
if model_kwargs["sp"] == 1 or using_additive_deseasonalization:
181+
model_kwargs["deseasonalize"] = False
182+
183+
model = ThetaForecaster(
184+
initial_level=model_kwargs.get("initial_level"),
185+
deseasonalize=model_kwargs.get("deseasonalize", True),
186+
sp=1 if using_additive_deseasonalization else model_kwargs.get("sp", 1),
187+
)
188+
model.fit(y_to_fit, X=X_in)
126189

127190
fh = ForecastingHorizon(range(1, self.spec.horizon + 1), is_relative=True)
128191
fh_in_sample = ForecastingHorizon(range(-len(data_i) + 1, 1))

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ forecast = [
176176
"py-cpuinfo",
177177
"rich",
178178
"autots",
179-
"mlforecast",
179+
"mlforecast==1.0.2",
180180
"neuralprophet>=0.7.0",
181181
"pytorch-lightning==2.5.5",
182182
"numpy<2.0.0",

0 commit comments

Comments
 (0)