Skip to content

Commit ee343b2

Browse files
committed
refactor AIN CI tests
1 parent 687ae9f commit ee343b2

File tree

9 files changed

+369
-655
lines changed

9 files changed

+369
-655
lines changed
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.iotdb.ainode.it;
21+
22+
import org.apache.iotdb.ainode.utils.AINodeTestUtils;
23+
import org.apache.iotdb.it.env.EnvFactory;
24+
import org.apache.iotdb.it.framework.IoTDBTestRunner;
25+
import org.apache.iotdb.itbase.category.AIClusterIT;
26+
import org.apache.iotdb.itbase.env.BaseEnv;
27+
28+
import org.junit.AfterClass;
29+
import org.junit.Assert;
30+
import org.junit.BeforeClass;
31+
import org.junit.Test;
32+
import org.junit.experimental.categories.Category;
33+
import org.junit.runner.RunWith;
34+
35+
import java.sql.Connection;
36+
import java.sql.ResultSet;
37+
import java.sql.SQLException;
38+
import java.sql.Statement;
39+
40+
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP;
41+
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest;
42+
import static org.apache.iotdb.db.it.utils.TestUtils.prepareData;
43+
44+
@RunWith(IoTDBTestRunner.class)
45+
@Category({AIClusterIT.class})
46+
public class AINodeCallInferenceIT {
47+
48+
static String[] WRITE_SQL_IN_TREE =
49+
new String[] {
50+
"CREATE DATABASE root.AI",
51+
"CREATE TIMESERIES root.AI.s0 WITH DATATYPE=FLOAT, ENCODING=RLE",
52+
"CREATE TIMESERIES root.AI.s1 WITH DATATYPE=DOUBLE, ENCODING=RLE",
53+
"CREATE TIMESERIES root.AI.s2 WITH DATATYPE=INT32, ENCODING=RLE",
54+
"CREATE TIMESERIES root.AI.s3 WITH DATATYPE=INT64, ENCODING=RLE",
55+
};
56+
57+
@BeforeClass
58+
public static void setUp() throws Exception {
59+
// Init 1C1D1A cluster environment
60+
EnvFactory.getEnv().initClusterEnvironment(1, 1);
61+
prepareData(WRITE_SQL_IN_TREE);
62+
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
63+
Statement statement = connection.createStatement()) {
64+
for (int i = 0; i < 2880; i++) {
65+
statement.execute(
66+
String.format(
67+
"INSERT INTO root.AI(timestamp,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)",
68+
i, (float) i, (double) i, i, i));
69+
}
70+
}
71+
}
72+
73+
@AfterClass
74+
public static void tearDown() throws Exception {
75+
EnvFactory.getEnv().cleanClusterEnvironment();
76+
}
77+
78+
@Test
79+
public void callInferenceTest() throws SQLException {
80+
for (AINodeTestUtils.FakeModelInfo modelInfo : BUILTIN_MODEL_MAP.values()) {
81+
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
82+
Statement statement = connection.createStatement()) {
83+
callInferenceTest(statement, modelInfo);
84+
}
85+
}
86+
}
87+
88+
public void callInferenceTest(Statement statement, AINodeTestUtils.FakeModelInfo modelInfo)
89+
throws SQLException {
90+
// Invoke call inference for specified models, there should exist result.
91+
String callInferenceSQL =
92+
String.format("CALL INFERENCE(%s, \"select s1 from root.AI\")", modelInfo.getModelId());
93+
try (ResultSet resultSet = statement.executeQuery(callInferenceSQL)) {
94+
int count = 0;
95+
while (resultSet.next()) {
96+
count++;
97+
}
98+
// Ensure the call inference return results
99+
Assert.assertTrue(count > 0);
100+
}
101+
}
102+
103+
@Test
104+
public void errorCallInferenceTestInTree() throws SQLException {
105+
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
106+
Statement statement = connection.createStatement()) {
107+
String sql = "CALL INFERENCE(notFound404, \"select s0,s1,s2 from root.AI\", window=head(5))";
108+
errorTest(statement, sql, "1505: model [notFound404] has not been created.");
109+
}
110+
}
111+
}
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.iotdb.ainode.it;
21+
22+
import org.apache.iotdb.ainode.utils.AINodeTestUtils;
23+
import org.apache.iotdb.it.env.EnvFactory;
24+
import org.apache.iotdb.it.framework.IoTDBTestRunner;
25+
import org.apache.iotdb.itbase.category.AIClusterIT;
26+
import org.apache.iotdb.itbase.env.BaseEnv;
27+
28+
import org.junit.AfterClass;
29+
import org.junit.Assert;
30+
import org.junit.BeforeClass;
31+
import org.junit.Test;
32+
import org.junit.experimental.categories.Category;
33+
import org.junit.runner.RunWith;
34+
import org.slf4j.Logger;
35+
import org.slf4j.LoggerFactory;
36+
37+
import java.sql.Connection;
38+
import java.sql.ResultSet;
39+
import java.sql.SQLException;
40+
import java.sql.Statement;
41+
42+
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_LTSM_MAP;
43+
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelNotOnSpecifiedDevice;
44+
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelOnSpecifiedDevice;
45+
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.concurrentInference;
46+
47+
@RunWith(IoTDBTestRunner.class)
48+
@Category({AIClusterIT.class})
49+
public class AINodeConcurrentForecastIT {
50+
51+
private static final Logger LOGGER = LoggerFactory.getLogger(AINodeConcurrentForecastIT.class);
52+
53+
private static final String FORECAST_TABLE_FUNCTION_SQL_TEMPLATE =
54+
"SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time), forecast_length=>%d";
55+
56+
@BeforeClass
57+
public static void setUp() throws Exception {
58+
// Init 1C1D1A cluster environment
59+
EnvFactory.getEnv().initClusterEnvironment(1, 1);
60+
prepareDataForTableModel();
61+
}
62+
63+
@AfterClass
64+
public static void tearDown() throws Exception {
65+
EnvFactory.getEnv().cleanClusterEnvironment();
66+
}
67+
68+
private static void prepareDataForTableModel() throws SQLException {
69+
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
70+
Statement statement = connection.createStatement()) {
71+
statement.execute("CREATE DATABASE root");
72+
statement.execute("CREATE TABLE root.AI (s DOUBLE FIELD)");
73+
for (int i = 0; i < 2880; i++) {
74+
statement.execute(
75+
String.format(
76+
"INSERT INTO root.AI(time, s) VALUES(%d, %f)", i, Math.sin(i * Math.PI / 1440)));
77+
}
78+
}
79+
}
80+
81+
@Test
82+
public void concurrentGPUForecastTest() throws SQLException, InterruptedException {
83+
for (AINodeTestUtils.FakeModelInfo modelInfo : BUILTIN_LTSM_MAP.values()) {
84+
concurrentGPUForecastTest(modelInfo);
85+
}
86+
}
87+
88+
public void concurrentGPUForecastTest(AINodeTestUtils.FakeModelInfo modelInfo)
89+
throws SQLException, InterruptedException {
90+
final int forecastLength = 512;
91+
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
92+
Statement statement = connection.createStatement()) {
93+
// Single forecast request can be processed successfully
94+
final String forecastSQL =
95+
String.format(
96+
FORECAST_TABLE_FUNCTION_SQL_TEMPLATE, modelInfo.getModelId(), forecastLength);
97+
try (ResultSet resultSet = statement.executeQuery(forecastSQL)) {
98+
int count = 0;
99+
while (resultSet.next()) {
100+
count++;
101+
}
102+
Assert.assertEquals(forecastLength, count);
103+
}
104+
105+
final int threadCnt = 10;
106+
final int loop = 100;
107+
final String devices = "0,1";
108+
statement.execute(
109+
String.format("LOAD MODEL %s TO DEVICES '%s'", modelInfo.getModelId(), devices));
110+
checkModelOnSpecifiedDevice(
111+
statement, modelInfo.getModelId(), modelInfo.getModelType(), devices);
112+
long startTime = System.currentTimeMillis();
113+
concurrentInference(statement, forecastSQL, threadCnt, loop, forecastLength);
114+
long endTime = System.currentTimeMillis();
115+
LOGGER.info(
116+
String.format(
117+
"Model %s concurrent inference %d reqs (%d threads, %d loops) in GPU takes time: %dms",
118+
modelInfo.getModelId(), threadCnt * loop, threadCnt, loop, endTime - startTime));
119+
statement.execute(
120+
String.format("UNLOAD MODEL %s FROM DEVICES '%s'", modelInfo.getModelId(), devices));
121+
checkModelNotOnSpecifiedDevice(statement, modelInfo.getModelId(), devices);
122+
}
123+
}
124+
}

0 commit comments

Comments
 (0)