diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeBasicIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeBasicIT.java index 07d29c0d224e..d6c37023b5de 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeBasicIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeBasicIT.java @@ -176,6 +176,26 @@ public void ModelOperationTest() { } } + @Test + public void callInferenceTest2() { + String sql = + "CALL INFERENCE(_holtwinters, \"select s0 from root.AI.data\", predict_length=6, generateTime=true)"; + try (Connection connection = EnvFactory.getEnv().getConnection(); + Statement statement = connection.createStatement()) { + try (ResultSet resultSet = statement.executeQuery(sql)) { + ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); + checkHeader(resultSetMetaData, "Time,output0"); + int count = 0; + while (resultSet.next()) { + count++; + } + assertEquals(6, count); + } + } catch (SQLException e) { + fail(e.getMessage()); + } + } + @Test public void callInferenceTest() { String sql = diff --git a/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java b/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java index 78bf4a1c574b..7d5cda960fa7 100644 --- a/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java @@ -537,11 +537,12 @@ public void testInformationSchema() throws SQLException { "model_id,", new HashSet<>( Arrays.asList( - "_timerxl,", "_STLForecaster,", "_NaiveForecaster,", - "_ARIMA,", - "_ExponentialSmoothing,"))); + "_HoltWinters,", + "_TimerXL,", + "_ExponentialSmoothing,", + "_ARIMA,"))); TestUtils.assertResultSetEqual( statement.executeQuery( @@ -658,9 +659,10 @@ public void testInformationSchema() throws SQLException { "model_id,", new HashSet<>( Arrays.asList( - "_timerxl,", + "_TimerXL,", "_STLForecaster,", "_NaiveForecaster,", + "_HoltWinters,", "_ARIMA,", "_ExponentialSmoothing,"))); diff --git a/iotdb-core/ainode/ainode/core/constant.py b/iotdb-core/ainode/ainode/core/constant.py index a80ca680989c..467ffbfdec7b 100644 --- a/iotdb-core/ainode/ainode/core/constant.py +++ b/iotdb-core/ainode/ainode/core/constant.py @@ -139,6 +139,7 @@ class ModelInputName(Enum): class BuiltInModelType(Enum): # forecast models ARIMA = "_arima" + HOLTWINTERS = "_holtwinters" EXPONENTIAL_SMOOTHING = "_exponentialsmoothing" NAIVE_FORECASTER = "_naiveforecaster" STL_FORECASTER = "_stlforecaster" diff --git a/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py index 0d2991a7f9f4..f395e0bae9ea 100644 --- a/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py +++ b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py @@ -45,7 +45,7 @@ def get_model_attributes(model_id: str): attribute_map = arima_attribute_map elif model_id == BuiltInModelType.NAIVE_FORECASTER.value: attribute_map = naive_forecaster_attribute_map - elif model_id == BuiltInModelType.EXPONENTIAL_SMOOTHING.value: + elif model_id == BuiltInModelType.EXPONENTIAL_SMOOTHING.value or model_id == BuiltInModelType.HOLTWINTERS.value: attribute_map = exponential_smoothing_attribute_map elif model_id == BuiltInModelType.STL_FORECASTER.value: attribute_map = stl_forecaster_attribute_map @@ -85,7 +85,7 @@ def fetch_built_in_model(model_id, inference_attributes): # build the built-in model if model_id == BuiltInModelType.ARIMA.value: model = ArimaModel(attributes) - elif model_id == BuiltInModelType.EXPONENTIAL_SMOOTHING.value: + elif model_id == BuiltInModelType.EXPONENTIAL_SMOOTHING.value or model_id == BuiltInModelType.HOLTWINTERS.value: model = ExponentialSmoothingModel(attributes) elif model_id == BuiltInModelType.NAIVE_FORECASTER.value: model = NaiveForecasterModel(attributes) diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java index e2beede330ab..44093d100530 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java @@ -74,11 +74,14 @@ public class ModelInfo implements SnapshotProcessor { private static final Set builtInAnomalyDetectionModel = new HashSet<>(); + private static final int timerXLInputLength = 2880; + static { - builtInForecastModel.add("_timerxl"); + builtInForecastModel.add("_TimerXL"); builtInForecastModel.add("_ARIMA"); builtInForecastModel.add("_NaiveForecaster"); builtInForecastModel.add("_STLForecaster"); + builtInForecastModel.add("_HoltWinters"); builtInForecastModel.add("_ExponentialSmoothing"); builtInAnomalyDetectionModel.add("_GaussianHMM"); builtInAnomalyDetectionModel.add("_GMMHMM"); @@ -269,6 +272,9 @@ public GetModelInfoResp getModelInfo(GetModelInfoPlan plan) { // check if it's a built-in model if ((modelType = checkModelType(modelName)) != ModelType.USER_DEFINED) { modelInformation = new ModelInformation(modelType, modelName); + if (modelName.equalsIgnoreCase("_timerxl")) { + modelInformation.setInputLength(timerXLInputLength); + } } else { modelInformation = modelTable.getModelInformationById(modelName); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java index dbe91526bd8a..69d637fae525 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java @@ -333,7 +333,7 @@ public TableFunctionAnalysis analyze(Map arguments) { } } } else { - String[] predictedColumnsArray = predicatedColumns.split(","); + String[] predictedColumnsArray = predicatedColumns.split(";"); Map inputColumnIndexMap = new HashMap<>(); for (int i = 0, size = allInputColumnsName.size(); i < size; i++) { Optional fieldName = allInputColumnsName.get(i);