From 1ab8f5c740a0ec3b9e6356bdff42bab76e524639 Mon Sep 17 00:00:00 2001 From: YangCaiyin Date: Mon, 26 May 2025 20:55:15 +0800 Subject: [PATCH 1/9] support holtwinters and enhance codes --- .../apache/iotdb/ainode/it/AINodeBasicIT.java | 20 +++++++++++++++++++ iotdb-core/ainode/ainode/core/constant.py | 1 + .../core/model/built_in_model_factory.py | 4 ++-- .../confignode/persistence/ModelInfo.java | 3 ++- .../function/tvf/ForecastTableFunction.java | 2 +- 5 files changed, 26 insertions(+), 4 deletions(-) diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeBasicIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeBasicIT.java index 07d29c0d224e..40aad154f56b 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeBasicIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeBasicIT.java @@ -176,6 +176,26 @@ public void ModelOperationTest() { } } + @Test + public void callInferenceTest2() { + String sql = + "CALL INFERENCE(_holtwinters, \"select s0 from root.AI.data\", predict_length=6, generateTime=true)"; + try (Connection connection = EnvFactory.getEnv().getConnection(); + Statement statement = connection.createStatement()) { + try (ResultSet resultSet = statement.executeQuery(sql)) { + ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); + checkHeader(resultSetMetaData, "Time,output0"); + int count = 0; + while (resultSet.next()) { + count++; + } + assertEquals(6, count); + } + }catch (SQLException e) { + fail(e.getMessage()); + } + } + @Test public void callInferenceTest() { String sql = diff --git a/iotdb-core/ainode/ainode/core/constant.py b/iotdb-core/ainode/ainode/core/constant.py index a80ca680989c..eb7f7b717bdf 100644 --- a/iotdb-core/ainode/ainode/core/constant.py +++ b/iotdb-core/ainode/ainode/core/constant.py @@ -139,6 +139,7 @@ class ModelInputName(Enum): class BuiltInModelType(Enum): # forecast models ARIMA = "_arima" + HOLTWINTERS = "_hotwinters" EXPONENTIAL_SMOOTHING = "_exponentialsmoothing" NAIVE_FORECASTER = "_naiveforecaster" STL_FORECASTER = "_stlforecaster" diff --git a/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py index 0d2991a7f9f4..662523f0ce9b 100644 --- a/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py +++ b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py @@ -85,7 +85,7 @@ def fetch_built_in_model(model_id, inference_attributes): # build the built-in model if model_id == BuiltInModelType.ARIMA.value: model = ArimaModel(attributes) - elif model_id == BuiltInModelType.EXPONENTIAL_SMOOTHING.value: + elif model_id == BuiltInModelType.EXPONENTIAL_SMOOTHING.value or model_id == BuiltInModelType.HOLTWINTERS.value: model = ExponentialSmoothingModel(attributes) elif model_id == BuiltInModelType.NAIVE_FORECASTER.value: model = NaiveForecasterModel(attributes) @@ -460,7 +460,7 @@ def parse_attribute(input_attributes: Dict[str, str], attribute_map: Dict[str, A ), AttributeName.ORDER.value: TupleAttribute( name=AttributeName.ORDER.value, - default_value=(1, 0, 0), + default_value=(96, 1, 96), value_type=int ), AttributeName.SEASONAL_ORDER.value: TupleAttribute( diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java index e2beede330ab..a72873e9aa7f 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java @@ -75,10 +75,11 @@ public class ModelInfo implements SnapshotProcessor { private static final Set builtInAnomalyDetectionModel = new HashSet<>(); static { - builtInForecastModel.add("_timerxl"); + builtInForecastModel.add("_TimerXL"); builtInForecastModel.add("_ARIMA"); builtInForecastModel.add("_NaiveForecaster"); builtInForecastModel.add("_STLForecaster"); + builtInForecastModel.add("_HoltWinters"); builtInForecastModel.add("_ExponentialSmoothing"); builtInAnomalyDetectionModel.add("_GaussianHMM"); builtInAnomalyDetectionModel.add("_GMMHMM"); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java index dbe91526bd8a..69d637fae525 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java @@ -333,7 +333,7 @@ public TableFunctionAnalysis analyze(Map arguments) { } } } else { - String[] predictedColumnsArray = predicatedColumns.split(","); + String[] predictedColumnsArray = predicatedColumns.split(";"); Map inputColumnIndexMap = new HashMap<>(); for (int i = 0, size = allInputColumnsName.size(); i < size; i++) { Optional fieldName = allInputColumnsName.get(i); From 4d20e0664295dd58f06f4431386914d61eb19014 Mon Sep 17 00:00:00 2001 From: YangCaiyin Date: Mon, 26 May 2025 21:01:07 +0800 Subject: [PATCH 2/9] restore --- iotdb-core/ainode/ainode/core/model/built_in_model_factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py index 662523f0ce9b..b15b8a8cbbbb 100644 --- a/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py +++ b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py @@ -460,7 +460,7 @@ def parse_attribute(input_attributes: Dict[str, str], attribute_map: Dict[str, A ), AttributeName.ORDER.value: TupleAttribute( name=AttributeName.ORDER.value, - default_value=(96, 1, 96), + default_value=(1, 0, 0), value_type=int ), AttributeName.SEASONAL_ORDER.value: TupleAttribute( From df052b59044bd5f7a365ed77b3eaf4e58940659a Mon Sep 17 00:00:00 2001 From: YangCaiyin Date: Mon, 26 May 2025 22:02:44 +0800 Subject: [PATCH 3/9] set the max line number of timer_XL to 2880 --- .../org/apache/iotdb/confignode/persistence/ModelInfo.java | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java index a72873e9aa7f..44093d100530 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java @@ -74,6 +74,8 @@ public class ModelInfo implements SnapshotProcessor { private static final Set builtInAnomalyDetectionModel = new HashSet<>(); + private static final int timerXLInputLength = 2880; + static { builtInForecastModel.add("_TimerXL"); builtInForecastModel.add("_ARIMA"); @@ -270,6 +272,9 @@ public GetModelInfoResp getModelInfo(GetModelInfoPlan plan) { // check if it's a built-in model if ((modelType = checkModelType(modelName)) != ModelType.USER_DEFINED) { modelInformation = new ModelInformation(modelType, modelName); + if (modelName.equalsIgnoreCase("_timerxl")) { + modelInformation.setInputLength(timerXLInputLength); + } } else { modelInformation = modelTable.getModelInformationById(modelName); } From 9e830c19cc5db40e2fcf2943b283a30db7719e1a Mon Sep 17 00:00:00 2001 From: YangCaiyin Date: Mon, 26 May 2025 22:04:59 +0800 Subject: [PATCH 4/9] spotless --- .../test/java/org/apache/iotdb/ainode/it/AINodeBasicIT.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeBasicIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeBasicIT.java index 40aad154f56b..d6c37023b5de 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeBasicIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeBasicIT.java @@ -179,9 +179,9 @@ public void ModelOperationTest() { @Test public void callInferenceTest2() { String sql = - "CALL INFERENCE(_holtwinters, \"select s0 from root.AI.data\", predict_length=6, generateTime=true)"; + "CALL INFERENCE(_holtwinters, \"select s0 from root.AI.data\", predict_length=6, generateTime=true)"; try (Connection connection = EnvFactory.getEnv().getConnection(); - Statement statement = connection.createStatement()) { + Statement statement = connection.createStatement()) { try (ResultSet resultSet = statement.executeQuery(sql)) { ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); checkHeader(resultSetMetaData, "Time,output0"); @@ -191,7 +191,7 @@ public void callInferenceTest2() { } assertEquals(6, count); } - }catch (SQLException e) { + } catch (SQLException e) { fail(e.getMessage()); } } From c6f425d4bf28352534816efd0d5d1f449634b58f Mon Sep 17 00:00:00 2001 From: YangCaiyin Date: Tue, 27 May 2025 14:28:05 +0800 Subject: [PATCH 5/9] fix typo --- iotdb-core/ainode/ainode/core/constant.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/iotdb-core/ainode/ainode/core/constant.py b/iotdb-core/ainode/ainode/core/constant.py index eb7f7b717bdf..467ffbfdec7b 100644 --- a/iotdb-core/ainode/ainode/core/constant.py +++ b/iotdb-core/ainode/ainode/core/constant.py @@ -139,7 +139,7 @@ class ModelInputName(Enum): class BuiltInModelType(Enum): # forecast models ARIMA = "_arima" - HOLTWINTERS = "_hotwinters" + HOLTWINTERS = "_holtwinters" EXPONENTIAL_SMOOTHING = "_exponentialsmoothing" NAIVE_FORECASTER = "_naiveforecaster" STL_FORECASTER = "_stlforecaster" From 74d8ec59dbed81d2ca5193db458b01df3ad2fc7c Mon Sep 17 00:00:00 2001 From: YangCaiyin Date: Tue, 27 May 2025 18:27:49 +0800 Subject: [PATCH 6/9] fix --- iotdb-core/ainode/ainode/core/model/built_in_model_factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py index b15b8a8cbbbb..f395e0bae9ea 100644 --- a/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py +++ b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py @@ -45,7 +45,7 @@ def get_model_attributes(model_id: str): attribute_map = arima_attribute_map elif model_id == BuiltInModelType.NAIVE_FORECASTER.value: attribute_map = naive_forecaster_attribute_map - elif model_id == BuiltInModelType.EXPONENTIAL_SMOOTHING.value: + elif model_id == BuiltInModelType.EXPONENTIAL_SMOOTHING.value or model_id == BuiltInModelType.HOLTWINTERS.value: attribute_map = exponential_smoothing_attribute_map elif model_id == BuiltInModelType.STL_FORECASTER.value: attribute_map = stl_forecaster_attribute_map From 4d7ce9edb8635b7a3659cca4d1b884a3f5ba5339 Mon Sep 17 00:00:00 2001 From: ycycse Date: Wed, 28 May 2025 19:29:40 +0800 Subject: [PATCH 7/9] fix IT --- .../apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java b/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java index 78bf4a1c574b..6c82e326d313 100644 --- a/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java @@ -537,11 +537,12 @@ public void testInformationSchema() throws SQLException { "model_id,", new HashSet<>( Arrays.asList( - "_timerxl,", "_STLForecaster,", "_NaiveForecaster,", - "_ARIMA,", - "_ExponentialSmoothing,"))); + "_HoltWinters,", + "_TimerXL,", + "_ExponentialSmoothing,", + "_ARIMA,"))); TestUtils.assertResultSetEqual( statement.executeQuery( From dcb7995618de7ba09d31abcf1033fa062904764e Mon Sep 17 00:00:00 2001 From: ycycse Date: Wed, 28 May 2025 20:48:41 +0800 Subject: [PATCH 8/9] fix IT --- .../org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java | 1 + 1 file changed, 1 insertion(+) diff --git a/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java b/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java index 6c82e326d313..027bf3dfe96c 100644 --- a/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java @@ -662,6 +662,7 @@ public void testInformationSchema() throws SQLException { "_timerxl,", "_STLForecaster,", "_NaiveForecaster,", + "_HoltWinters,", "_ARIMA,", "_ExponentialSmoothing,"))); From 266ce427b09be41b495706f182ceb55c6a99db72 Mon Sep 17 00:00:00 2001 From: ycycse Date: Wed, 28 May 2025 23:21:06 +0800 Subject: [PATCH 9/9] fix IT --- .../org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java b/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java index 027bf3dfe96c..7d5cda960fa7 100644 --- a/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java @@ -659,7 +659,7 @@ public void testInformationSchema() throws SQLException { "model_id,", new HashSet<>( Arrays.asList( - "_timerxl,", + "_TimerXL,", "_STLForecaster,", "_NaiveForecaster,", "_HoltWinters,",