Skip to content

Commit 18cccc3

Browse files
committed
Reconstruct model management and model loading
1 parent de9fe48 commit 18cccc3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+4730
-598
lines changed

iotdb-core/ainode/iotdb/ainode/core/config.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,21 +30,21 @@
3030
AINODE_CONF_FILE_NAME,
3131
AINODE_CONF_GIT_FILE_NAME,
3232
AINODE_CONF_POM_FILE_NAME,
33+
AINODE_FINETUNE_MODELS_DIR,
3334
AINODE_INFERENCE_BATCH_INTERVAL_IN_MS,
3435
AINODE_INFERENCE_EXTRA_MEMORY_RATIO,
3536
AINODE_INFERENCE_MAX_PREDICT_LENGTH,
3637
AINODE_INFERENCE_MEMORY_USAGE_RATIO,
3738
AINODE_INFERENCE_MODEL_MEM_USAGE_MAP,
3839
AINODE_LOG_DIR,
3940
AINODE_MODELS_DIR,
40-
AINODE_ROOT_CONF_DIRECTORY_NAME,
41-
AINODE_ROOT_DIR,
4241
AINODE_RPC_ADDRESS,
4342
AINODE_RPC_PORT,
4443
AINODE_SYSTEM_DIR,
4544
AINODE_SYSTEM_FILE_NAME,
4645
AINODE_TARGET_CONFIG_NODE_LIST,
4746
AINODE_THRIFT_COMPRESSION_ENABLED,
47+
AINODE_USER_DEFINED_MODELS_DIR,
4848
AINODE_VERSION_INFO,
4949
)
5050
from iotdb.ainode.core.exception import BadNodeUrlError
@@ -97,6 +97,8 @@ def __init__(self):
9797
# Directory to save models
9898
self._ain_models_dir = AINODE_MODELS_DIR
9999
self._ain_builtin_models_dir = AINODE_BUILTIN_MODELS_DIR
100+
self._ain_finetune_models_dir = AINODE_FINETUNE_MODELS_DIR
101+
self._ain_user_defined_models_dir = AINODE_USER_DEFINED_MODELS_DIR
100102
self._ain_system_dir = AINODE_SYSTEM_DIR
101103

102104
# Whether to enable compression for thrift
@@ -211,6 +213,18 @@ def get_ain_builtin_models_dir(self) -> str:
211213
def set_ain_builtin_models_dir(self, ain_builtin_models_dir: str) -> None:
212214
self._ain_builtin_models_dir = ain_builtin_models_dir
213215

216+
def get_ain_finetune_models_dir(self) -> str:
217+
return self._ain_finetune_models_dir
218+
219+
def set_ain_finetune_models_dir(self, ain_finetune_models_dir: str) -> None:
220+
self._ain_finetune_models_dir = ain_finetune_models_dir
221+
222+
def get_ain_user_defined_models_dir(self) -> str:
223+
return self._ain_user_defined_models_dir
224+
225+
def set_ain_user_defined_models_dir(self, ain_user_defined_models_dir: str) -> None:
226+
self._ain_user_defined_models_dir = ain_user_defined_models_dir
227+
214228
def get_ain_system_dir(self) -> str:
215229
return self._ain_system_dir
216230

@@ -315,9 +329,7 @@ def _load_config_from_file(self) -> None:
315329
if "ainode_id" in system_configs:
316330
self._config.set_ainode_id(int(system_configs["ainode_id"]))
317331

318-
git_file = os.path.join(
319-
AINODE_ROOT_DIR, AINODE_ROOT_CONF_DIRECTORY_NAME, AINODE_CONF_GIT_FILE_NAME
320-
)
332+
git_file = os.path.join(AINODE_CONF_DIRECTORY_NAME, AINODE_CONF_GIT_FILE_NAME)
321333
if os.path.exists(git_file):
322334
git_configs = load_properties(git_file)
323335
if "git.commit.id.abbrev" in git_configs:
@@ -327,9 +339,7 @@ def _load_config_from_file(self) -> None:
327339
build_info += "-dev"
328340
self._config.set_build_info(build_info)
329341

330-
pom_file = os.path.join(
331-
AINODE_ROOT_DIR, AINODE_ROOT_CONF_DIRECTORY_NAME, AINODE_CONF_POM_FILE_NAME
332-
)
342+
pom_file = os.path.join(AINODE_CONF_DIRECTORY_NAME, AINODE_CONF_POM_FILE_NAME)
333343
if os.path.exists(pom_file):
334344
pom_configs = load_properties(pom_file)
335345
if "version" in pom_configs:

iotdb-core/ainode/iotdb/ainode/core/constant.py

Lines changed: 14 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,12 @@
2121
from enum import Enum
2222
from typing import List
2323

24-
from iotdb.ainode.core.model.model_enums import BuiltInModelType
2524
from iotdb.thrift.common.ttypes import TEndPoint
2625

26+
IOTDB_AINODE_HOME = os.getenv("IOTDB_AINODE_HOME", "")
2727
AINODE_VERSION_INFO = "UNKNOWN"
2828
AINODE_BUILD_INFO = "UNKNOWN"
29-
AINODE_CONF_DIRECTORY_NAME = "conf"
30-
AINODE_ROOT_CONF_DIRECTORY_NAME = "conf"
29+
AINODE_CONF_DIRECTORY_NAME = os.path.join(IOTDB_AINODE_HOME, "conf")
3130
AINODE_CONF_FILE_NAME = "iotdb-ainode.properties"
3231
AINODE_CONF_GIT_FILE_NAME = "git.properties"
3332
AINODE_CONF_POM_FILE_NAME = "pom.properties"
@@ -53,22 +52,27 @@
5352
AINODE_INFERENCE_BATCH_INTERVAL_IN_MS = 15
5453
AINODE_INFERENCE_MAX_PREDICT_LENGTH = 2880
5554
AINODE_INFERENCE_MODEL_MEM_USAGE_MAP = {
56-
BuiltInModelType.SUNDIAL.value: 1036 * 1024**2, # 1036 MiB
57-
BuiltInModelType.TIMER_XL.value: 856 * 1024**2, # 856 MiB
55+
"sundial": 1036 * 1024**2, # 1036 MiB
56+
"timerxl": 856 * 1024**2, # 856 MiB
5857
} # the memory usage of each model in bytes
5958
AINODE_INFERENCE_MEMORY_USAGE_RATIO = 0.4 # the device space allocated for inference
6059
AINODE_INFERENCE_EXTRA_MEMORY_RATIO = (
6160
1.2 # the overhead ratio for inference, used to estimate the pool size
6261
)
6362

64-
# AINode folder structure
65-
AINODE_ROOT_DIR = os.path.dirname(
66-
os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
63+
AINODE_MODELS_DIR = os.path.join(IOTDB_AINODE_HOME, "data/ainode/models")
64+
AINODE_BUILTIN_MODELS_DIR = os.path.join(
65+
IOTDB_AINODE_HOME, "data/ainode/models/builtin"
66+
) # For built-in models, we only need to store their weights and config.
67+
AINODE_FINETUNE_MODELS_DIR = os.path.join(
68+
IOTDB_AINODE_HOME, "data/ainode/models/finetune"
69+
)
70+
AINODE_USER_DEFINED_MODELS_DIR = os.path.join(
71+
IOTDB_AINODE_HOME, "data/ainode/models/user_defined"
6772
)
68-
AINODE_MODELS_DIR = "data/ainode/models"
69-
AINODE_BUILTIN_MODELS_DIR = "data/ainode/models/weights" # For built-in models, we only need to store their weights and config.
7073
AINODE_SYSTEM_DIR = "data/ainode/system"
7174
AINODE_LOG_DIR = "logs"
75+
AINODE_CACHE_DIR = os.path.expanduser("~/.cache/ainode")
7276

7377
# AINode log
7478
LOG_FILE_TYPE = ["all", "info", "warn", "error"]
@@ -143,132 +147,8 @@ def name(self):
143147
return self.value
144148

145149

146-
class ForecastModelType(Enum):
147-
DLINEAR = "dlinear"
148-
DLINEAR_INDIVIDUAL = "dlinear_individual"
149-
NBEATS = "nbeats"
150-
151-
@classmethod
152-
def values(cls) -> List[str]:
153-
values = []
154-
for item in list(cls):
155-
values.append(item.value)
156-
return values
157-
158-
159150
class ModelInputName(Enum):
160151
DATA_X = "data_x"
161152
TIME_STAMP_X = "time_stamp_x"
162153
TIME_STAMP_Y = "time_stamp_y"
163154
DEC_INP = "dec_inp"
164-
165-
166-
class AttributeName(Enum):
167-
# forecast Attribute
168-
PREDICT_LENGTH = "predict_length"
169-
170-
# NaiveForecaster
171-
STRATEGY = "strategy"
172-
SP = "sp"
173-
174-
# STLForecaster
175-
# SP = 'sp'
176-
SEASONAL = "seasonal"
177-
SEASONAL_DEG = "seasonal_deg"
178-
TREND_DEG = "trend_deg"
179-
LOW_PASS_DEG = "low_pass_deg"
180-
SEASONAL_JUMP = "seasonal_jump"
181-
TREND_JUMP = "trend_jump"
182-
LOSS_PASS_JUMP = "low_pass_jump"
183-
184-
# ExponentialSmoothing
185-
DAMPED_TREND = "damped_trend"
186-
INITIALIZATION_METHOD = "initialization_method"
187-
OPTIMIZED = "optimized"
188-
REMOVE_BIAS = "remove_bias"
189-
USE_BRUTE = "use_brute"
190-
191-
# Arima
192-
ORDER = "order"
193-
SEASONAL_ORDER = "seasonal_order"
194-
METHOD = "method"
195-
MAXITER = "maxiter"
196-
SUPPRESS_WARNINGS = "suppress_warnings"
197-
OUT_OF_SAMPLE_SIZE = "out_of_sample_size"
198-
SCORING = "scoring"
199-
WITH_INTERCEPT = "with_intercept"
200-
TIME_VARYING_REGRESSION = "time_varying_regression"
201-
ENFORCE_STATIONARITY = "enforce_stationarity"
202-
ENFORCE_INVERTIBILITY = "enforce_invertibility"
203-
SIMPLE_DIFFERENCING = "simple_differencing"
204-
MEASUREMENT_ERROR = "measurement_error"
205-
MLE_REGRESSION = "mle_regression"
206-
HAMILTON_REPRESENTATION = "hamilton_representation"
207-
CONCENTRATE_SCALE = "concentrate_scale"
208-
209-
# GAUSSIAN_HMM
210-
N_COMPONENTS = "n_components"
211-
COVARIANCE_TYPE = "covariance_type"
212-
MIN_COVAR = "min_covar"
213-
STARTPROB_PRIOR = "startprob_prior"
214-
TRANSMAT_PRIOR = "transmat_prior"
215-
MEANS_PRIOR = "means_prior"
216-
MEANS_WEIGHT = "means_weight"
217-
COVARS_PRIOR = "covars_prior"
218-
COVARS_WEIGHT = "covars_weight"
219-
ALGORITHM = "algorithm"
220-
N_ITER = "n_iter"
221-
TOL = "tol"
222-
PARAMS = "params"
223-
INIT_PARAMS = "init_params"
224-
IMPLEMENTATION = "implementation"
225-
226-
# GMMHMM
227-
# N_COMPONENTS = "n_components"
228-
N_MIX = "n_mix"
229-
# MIN_COVAR = "min_covar"
230-
# STARTPROB_PRIOR = "startprob_prior"
231-
# TRANSMAT_PRIOR = "transmat_prior"
232-
WEIGHTS_PRIOR = "weights_prior"
233-
234-
# MEANS_PRIOR = "means_prior"
235-
# MEANS_WEIGHT = "means_weight"
236-
# ALGORITHM = "algorithm"
237-
# COVARIANCE_TYPE = "covariance_type"
238-
# N_ITER = "n_iter"
239-
# TOL = "tol"
240-
# INIT_PARAMS = "init_params"
241-
# PARAMS = "params"
242-
# IMPLEMENTATION = "implementation"
243-
244-
# STRAY
245-
ALPHA = "alpha"
246-
K = "k"
247-
KNN_ALGORITHM = "knn_algorithm"
248-
P = "p"
249-
SIZE_THRESHOLD = "size_threshold"
250-
OUTLIER_TAIL = "outlier_tail"
251-
252-
# timerxl
253-
INPUT_TOKEN_LEN = "input_token_len"
254-
HIDDEN_SIZE = "hidden_size"
255-
INTERMEDIATE_SIZE = "intermediate_size"
256-
OUTPUT_TOKEN_LENS = "output_token_lens"
257-
NUM_HIDDEN_LAYERS = "num_hidden_layers"
258-
NUM_ATTENTION_HEADS = "num_attention_heads"
259-
HIDDEN_ACT = "hidden_act"
260-
USE_CACHE = "use_cache"
261-
ROPE_THETA = "rope_theta"
262-
ATTENTION_DROPOUT = "attention_dropout"
263-
INITIALIZER_RANGE = "initializer_range"
264-
MAX_POSITION_EMBEDDINGS = "max_position_embeddings"
265-
CKPT_PATH = "ckpt_path"
266-
267-
# sundial
268-
DROPOUT_RATE = "dropout_rate"
269-
FLOW_LOSS_DEPTH = "flow_loss_depth"
270-
NUM_SAMPLING_STEPS = "num_sampling_steps"
271-
DIFFUSION_BATCH_MUL = "diffusion_batch_mul"
272-
273-
def name(self) -> str:
274-
return self.value

iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,12 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
#
18+
1819
import threading
1920
from typing import Any
2021

2122
import torch
2223

23-
from iotdb.ainode.core.inference.strategy.abstract_inference_pipeline import (
24-
AbstractInferencePipeline,
25-
)
2624
from iotdb.ainode.core.log import Logger
2725
from iotdb.ainode.core.util.atmoic_int import AtomicInt
2826

@@ -41,7 +39,6 @@ def __init__(
4139
req_id: str,
4240
model_id: str,
4341
inputs: torch.Tensor,
44-
inference_pipeline: AbstractInferencePipeline,
4542
max_new_tokens: int = 96,
4643
**infer_kwargs,
4744
):
@@ -52,7 +49,6 @@ def __init__(
5249
self.model_id = model_id
5350
self.inputs = inputs
5451
self.infer_kwargs = infer_kwargs
55-
self.inference_pipeline = inference_pipeline
5652
self.max_new_tokens = (
5753
max_new_tokens # Number of time series data points to generate
5854
)

0 commit comments

Comments
 (0)