Skip to content

Commit b3ef023

Browse files
committed
Merge branch 'main' into jupyter-slurm-integration
2 parents e5d62ae + 0b8187b commit b3ef023

File tree

13 files changed

+283
-94
lines changed

13 files changed

+283
-94
lines changed

ads/aqua/common/enums.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2026 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

55
from typing import Dict, List
@@ -25,6 +25,7 @@ class PredictEndpoints(ExtendedEnum):
2525
TEXT_COMPLETIONS_ENDPOINT = "/v1/completions"
2626
EMBEDDING_ENDPOINT = "/v1/embedding"
2727
RESPONSES = "/v1/responses"
28+
FORECAST = "v1/forecast"
2829

2930

3031
class Tags(ExtendedEnum):
@@ -47,6 +48,7 @@ class Tags(ExtendedEnum):
4748
MULTIMODEL_TYPE_TAG = "aqua_multimodel"
4849
STACKED_MODEL_TYPE_TAG = "aqua_stacked_model"
4950
AQUA_FINE_TUNE_MODEL_VERSION = "fine_tune_model_version"
51+
MODEL_DEPLOY_PREDICT_ENDPOINT = "model_deploy_predict_endpoint"
5052

5153

5254
class InferenceContainerType(ExtendedEnum):

ads/aqua/modeldeployment/deployment.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,12 @@
2121
ComputeShapeSummary,
2222
ContainerPath,
2323
)
24-
from ads.aqua.common.enums import InferenceContainerTypeFamily, ModelFormat, Tags
24+
from ads.aqua.common.enums import (
25+
InferenceContainerTypeFamily,
26+
ModelFormat,
27+
PredictEndpoints,
28+
Tags,
29+
)
2530
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
2631
from ads.aqua.common.utils import (
2732
DEFINED_METADATA_TO_FILE_MAP,
@@ -274,6 +279,11 @@ def create(
274279
create_deployment_details.env_var.update(
275280
{Tags.TASK.upper(): ModelTask.TIME_SERIES_FORECASTING}
276281
)
282+
create_deployment_details.env_var.update(
283+
{
284+
Tags.MODEL_DEPLOY_PREDICT_ENDPOINT.upper(): PredictEndpoints.FORECAST
285+
}
286+
)
277287
return self._create(
278288
aqua_model=aqua_model,
279289
create_deployment_details=create_deployment_details,

ads/dataset/dataset_browser.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,27 @@
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

77

8-
from __future__ import print_function, absolute_import
8+
from __future__ import absolute_import, print_function
99

10-
import re, pathlib, os
10+
import os
11+
import pathlib
12+
import re
1113
import urllib.parse
1214
from abc import ABC, abstractmethod
1315
from os import listdir
14-
from os.path import isfile, isdir, join, getsize
15-
from typing import List, Set, Tuple, Dict
16-
17-
import requests
16+
from os.path import getsize, isdir, isfile, join
17+
from typing import Dict, List, Set, Tuple
1818

1919
import pandas as pd
20+
import requests
2021
import sklearn.datasets as sk_datasets
2122

22-
from ads.dataset import helper
23-
from ads.common.utils import inject_and_copy_kwargs
2423
from ads.common.decorator.runtime_dependency import (
25-
runtime_dependency,
2624
OptionalDependency,
25+
runtime_dependency,
2726
)
27+
from ads.common.utils import inject_and_copy_kwargs
28+
from ads.dataset import helper
2829

2930

3031
class DatasetBrowser(ABC):
@@ -318,7 +319,7 @@ def open(self, name: str, **kwargs):
318319

319320
class SklearnDatasets(DatasetBrowser):
320321

321-
sklearn_datasets = ["breast_cancer", "diabetes", "iris", "wine", "digits"]
322+
sklearn_datasets = ["breast_cancer", "iris", "wine", "digits"]
322323

323324
def __init__(self):
324325
super(DatasetBrowser, self).__init__()

ads/opctl/operator/lowcode/common/transformations.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,8 @@ def build_fforms_meta_features(self, data, target_col=None, group_cols=None):
329329
if target_col not in data.columns:
330330
raise ValueError(f"Target column '{target_col}' not found in DataFrame")
331331

332+
data[target_col] = data[target_col].fillna(0)
333+
332334
# Check if group_cols are provided and valid
333335
if group_cols is not None:
334336
if not isinstance(group_cols, list):

ads/opctl/operator/lowcode/forecast/__main__.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,13 @@
99
import sys
1010
from typing import Dict, List
1111

12-
import pandas as pd
1312
import yaml
1413

1514
from ads.opctl import logger
1615
from ads.opctl.operator.common.const import ENV_OPERATOR_ARGS
1716
from ads.opctl.operator.common.utils import _parse_input_args
1817

19-
from .const import AUTO_SELECT_SERIES
18+
from .const import AUTO_SELECT, AUTO_SELECT_SERIES
2019
from .model.forecast_datasets import ForecastDatasets, ForecastResults
2120
from .operator_config import ForecastOperatorConfig
2221
from .whatifserve import ModelDeploymentManager
@@ -29,8 +28,10 @@ def operate(operator_config: ForecastOperatorConfig) -> ForecastResults:
2928
datasets = ForecastDatasets(operator_config)
3029
model = ForecastOperatorModelFactory.get_model(operator_config, datasets)
3130

32-
if operator_config.spec.model == AUTO_SELECT_SERIES and hasattr(
33-
operator_config.spec, "meta_features"
31+
if (
32+
operator_config.spec.model == AUTO_SELECT_SERIES
33+
and hasattr(operator_config.spec, "meta_features")
34+
and operator_config.spec.target_category_columns
3435
):
3536
# For AUTO_SELECT_SERIES, handle each series with its specific model
3637
meta_features = operator_config.spec.meta_features
@@ -64,8 +65,6 @@ def operate(operator_config: ForecastOperatorConfig) -> ForecastResults:
6465
)
6566
sub_results_list.append(sub_results)
6667

67-
# results_df = pd.concat([results_df, sub_result_df], ignore_index=True, axis=0)
68-
# elapsed_time += sub_elapsed_time
6968
# Merge all sub_results into a single ForecastResults object
7069
if sub_results_list:
7170
results = sub_results_list[0]
@@ -75,6 +74,20 @@ def operate(operator_config: ForecastOperatorConfig) -> ForecastResults:
7574
results = None
7675

7776
else:
77+
# When AUTO_SELECT_SERIES is specified but target_category_columns is not,
78+
# we fall back to AUTO_SELECT behavior.
79+
if (
80+
operator_config.spec.model == AUTO_SELECT_SERIES
81+
and not operator_config.spec.target_category_columns
82+
):
83+
84+
logger.warning(
85+
"AUTO_SELECT_SERIES cannot be run with a single-series dataset or when "
86+
"'target_category_columns' is not provided. Falling back to AUTO_SELECT."
87+
)
88+
89+
operator_config.spec.model = AUTO_SELECT
90+
model = ForecastOperatorModelFactory.get_model(operator_config, datasets)
7891
# For other cases, use the single selected model
7992
results = model.generate_report()
8093
# saving to model catalog

ads/opctl/operator/lowcode/forecast/environment.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ dependencies:
88
- oracle-ads>=2.9.0
99
- prophet
1010
- neuralprophet
11-
- mlforecast
11+
- mlforecast==1.0.2
1212
- pmdarima
1313
- statsmodels
1414
- report-creator

ads/opctl/operator/lowcode/forecast/model/ets.py

Lines changed: 67 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
import numpy as np
1111
import optuna
1212
import pandas as pd
13+
from pandas.tseries.frequencies import to_offset
1314
from joblib import Parallel, delayed
1415
from optuna.trial import TrialState
16+
from scipy.stats import linregress
1517
from sktime.split import ExpandingWindowSplitter
1618
from statsmodels.tsa.exponential_smoothing.ets import ETSModel
1719

@@ -63,6 +65,39 @@ def preprocess(self, data, series_id):
6365
)
6466
return df_encoded.set_index(self.spec.datetime_column.name)
6567

68+
def _auto_detect_ets_params(self, y: pd.Series, seasonal_period: int) -> Dict[str, Any]:
69+
"""Detect trend, error type, and damping from data."""
70+
params = {}
71+
72+
# Detect trend via linear regression significance
73+
slope, _, r_value, p_value, _ = linregress(range(len(y)), y.values)
74+
has_trend = p_value < 0.05
75+
params["trend"] = "add" if has_trend else None
76+
params["damped_trend"] = (has_trend and abs(r_value) < 0.7) # weak trend → damp
77+
78+
# Detect additive vs multiplicative: does variance scale with level?
79+
if seasonal_period and len(y) >= 2 * seasonal_period:
80+
segments = [y.iloc[i:i + seasonal_period] for i in range(0, len(y) - seasonal_period, seasonal_period)]
81+
means = np.array([s.mean() for s in segments])
82+
stds = np.array([s.std() for s in segments])
83+
if means.std() > 0:
84+
corr = np.corrcoef(means, stds)[0, 1]
85+
params["error"] = "mul" if corr > 0.7 else "add"
86+
params["seasonal"] = params["error"] # match seasonal to error type
87+
else:
88+
params["error"] = "add"
89+
params["seasonal"] = "add"
90+
else:
91+
params["error"] = "add"
92+
params["seasonal"] = "add"
93+
94+
# Multiplicative requires strictly positive data
95+
if (y <= 0).any():
96+
params["error"] = "add"
97+
params["seasonal"] = "add"
98+
99+
return params
100+
66101
def _train_model(self, i, series_id, df: pd.DataFrame, model_kwargs: Dict[str, Any]):
67102
try:
68103
self.forecast_output.init_series_output(series_id=series_id, data_at_series=df)
@@ -73,12 +108,24 @@ def _train_model(self, i, series_id, df: pd.DataFrame, model_kwargs: Dict[str, A
73108

74109
Y = data_i[self.spec.target_column]
75110
dates = data_i.index.values
111+
is_daily = False
112+
if freq is not None:
113+
try:
114+
is_daily = to_offset(freq).name in ["D", "B"]
115+
except ValueError:
116+
pass
76117

118+
ets_model_args = self.spec.model_kwargs.get("restrict_daily_series_improvement_flow", True)
119+
sp, probable_sps = find_seasonal_period_from_dataset(Y)
77120
if model_kwargs["seasonal"] is None:
78121
model_kwargs["seasonal"] = "add"
79-
if model_kwargs["seasonal_periods"] is None:
80-
sp, probable_sps = find_seasonal_period_from_dataset(Y)
122+
123+
if ets_model_args and is_daily:
81124
model_kwargs["seasonal_periods"] = sp if sp > 1 else None
125+
else:
126+
auto_params = self._auto_detect_ets_params(Y, sp)
127+
for k, v in auto_params.items():
128+
model_kwargs[k] = v
82129

83130
if self.loaded_models is not None and series_id in self.loaded_models:
84131
previous_res = self.loaded_models[series_id].get("model")
@@ -92,14 +139,26 @@ def _train_model(self, i, series_id, df: pd.DataFrame, model_kwargs: Dict[str, A
92139
if self.perform_tuning:
93140
model_kwargs = self.run_tuning(Y, model_kwargs)
94141

95-
use_seasonal = (model_kwargs["seasonal"] is not None and
96-
model_kwargs["seasonal_periods"] is not None and
97-
len(Y) >= 2 * model_kwargs["seasonal_periods"]
98-
)
142+
if (model_kwargs["error"] == "mul" or
143+
model_kwargs["trend"] == "mul" or
144+
model_kwargs["seasonal"] == "mul") and (Y <= 0).any():
145+
146+
logger.info("Multiplicative model incompatible with non-positive values. Switching to additive.")
147+
model_kwargs["error"] = "add"
148+
if model_kwargs["trend"] == "mul":
149+
model_kwargs["trend"] = "add"
150+
if model_kwargs["seasonal"] == "mul":
151+
model_kwargs["seasonal"] = "add"
152+
153+
use_seasonal = (
154+
model_kwargs["seasonal"] is not None and
155+
model_kwargs["seasonal_periods"] is not None and
156+
model_kwargs["seasonal_periods"] > 1 and
157+
len(Y) >= max(2 * model_kwargs["seasonal_periods"], 10)
158+
)
99159
if not use_seasonal:
100160
model_kwargs["seasonal"] = None
101161
model_kwargs["seasonal_periods"] = None
102-
103162
model = ETSModel(Y, error=model_kwargs["error"], trend=model_kwargs["trend"],
104163
damped_trend=model_kwargs["damped_trend"], seasonal=model_kwargs["seasonal"],
105164
seasonal_periods=model_kwargs["seasonal_periods"],
@@ -145,7 +204,7 @@ def _train_model(self, i, series_id, df: pd.DataFrame, model_kwargs: Dict[str, A
145204
if param in params:
146205
params.pop(param)
147206
self.model_parameters[series_id] = {
148-
"framework": SupportedModels.Arima,
207+
"framework": SupportedModels.ETSForecaster,
149208
**params,
150209
}
151210

0 commit comments

Comments
 (0)