|
21 | 21 | from enum import Enum |
22 | 22 | from typing import List |
23 | 23 |
|
24 | | -from iotdb.ainode.core.model.model_enums import BuiltInModelType |
25 | 24 | from iotdb.thrift.common.ttypes import TEndPoint |
26 | 25 |
|
| 26 | +IOTDB_AINODE_HOME = os.getenv("IOTDB_AINODE_HOME", "") |
27 | 27 | AINODE_VERSION_INFO = "UNKNOWN" |
28 | 28 | AINODE_BUILD_INFO = "UNKNOWN" |
29 | | -AINODE_CONF_DIRECTORY_NAME = "conf" |
30 | | -AINODE_ROOT_CONF_DIRECTORY_NAME = "conf" |
| 29 | +AINODE_CONF_DIRECTORY_NAME = os.path.join(IOTDB_AINODE_HOME, "conf") |
31 | 30 | AINODE_CONF_FILE_NAME = "iotdb-ainode.properties" |
32 | 31 | AINODE_CONF_GIT_FILE_NAME = "git.properties" |
33 | 32 | AINODE_CONF_POM_FILE_NAME = "pom.properties" |
|
53 | 52 | AINODE_INFERENCE_BATCH_INTERVAL_IN_MS = 15 |
54 | 53 | AINODE_INFERENCE_MAX_PREDICT_LENGTH = 2880 |
55 | 54 | AINODE_INFERENCE_MODEL_MEM_USAGE_MAP = { |
56 | | - BuiltInModelType.SUNDIAL.value: 1036 * 1024**2, # 1036 MiB |
57 | | - BuiltInModelType.TIMER_XL.value: 856 * 1024**2, # 856 MiB |
| 55 | + "sundial": 1036 * 1024**2, # 1036 MiB |
| 56 | + "timerxl": 856 * 1024**2, # 856 MiB |
58 | 57 | } # the memory usage of each model in bytes |
59 | 58 | AINODE_INFERENCE_MEMORY_USAGE_RATIO = 0.4 # the device space allocated for inference |
60 | 59 | AINODE_INFERENCE_EXTRA_MEMORY_RATIO = ( |
61 | 60 | 1.2 # the overhead ratio for inference, used to estimate the pool size |
62 | 61 | ) |
63 | 62 |
|
64 | | -# AINode folder structure |
65 | | -AINODE_ROOT_DIR = os.path.dirname( |
66 | | - os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) |
| 63 | +AINODE_MODELS_DIR = os.path.join(IOTDB_AINODE_HOME, "data/ainode/models") |
| 64 | +AINODE_BUILTIN_MODELS_DIR = os.path.join( |
| 65 | + IOTDB_AINODE_HOME, "data/ainode/models/builtin" |
| 66 | +) # For built-in models, we only need to store their weights and config. |
| 67 | +AINODE_FINETUNE_MODELS_DIR = os.path.join( |
| 68 | + IOTDB_AINODE_HOME, "data/ainode/models/finetune" |
| 69 | +) |
| 70 | +AINODE_USER_DEFINED_MODELS_DIR = os.path.join( |
| 71 | + IOTDB_AINODE_HOME, "data/ainode/models/user_defined" |
67 | 72 | ) |
68 | | -AINODE_MODELS_DIR = "data/ainode/models" |
69 | | -AINODE_BUILTIN_MODELS_DIR = "data/ainode/models/weights" # For built-in models, we only need to store their weights and config. |
70 | 73 | AINODE_SYSTEM_DIR = "data/ainode/system" |
71 | 74 | AINODE_LOG_DIR = "logs" |
| 75 | +AINODE_CACHE_DIR = os.path.expanduser("~/.cache/ainode") |
72 | 76 |
|
73 | 77 | # AINode log |
74 | 78 | LOG_FILE_TYPE = ["all", "info", "warn", "error"] |
@@ -143,132 +147,8 @@ def name(self): |
143 | 147 | return self.value |
144 | 148 |
|
145 | 149 |
|
146 | | -class ForecastModelType(Enum): |
147 | | - DLINEAR = "dlinear" |
148 | | - DLINEAR_INDIVIDUAL = "dlinear_individual" |
149 | | - NBEATS = "nbeats" |
150 | | - |
151 | | - @classmethod |
152 | | - def values(cls) -> List[str]: |
153 | | - values = [] |
154 | | - for item in list(cls): |
155 | | - values.append(item.value) |
156 | | - return values |
157 | | - |
158 | | - |
159 | 150 | class ModelInputName(Enum): |
160 | 151 | DATA_X = "data_x" |
161 | 152 | TIME_STAMP_X = "time_stamp_x" |
162 | 153 | TIME_STAMP_Y = "time_stamp_y" |
163 | 154 | DEC_INP = "dec_inp" |
164 | | - |
165 | | - |
166 | | -class AttributeName(Enum): |
167 | | - # forecast Attribute |
168 | | - PREDICT_LENGTH = "predict_length" |
169 | | - |
170 | | - # NaiveForecaster |
171 | | - STRATEGY = "strategy" |
172 | | - SP = "sp" |
173 | | - |
174 | | - # STLForecaster |
175 | | - # SP = 'sp' |
176 | | - SEASONAL = "seasonal" |
177 | | - SEASONAL_DEG = "seasonal_deg" |
178 | | - TREND_DEG = "trend_deg" |
179 | | - LOW_PASS_DEG = "low_pass_deg" |
180 | | - SEASONAL_JUMP = "seasonal_jump" |
181 | | - TREND_JUMP = "trend_jump" |
182 | | - LOSS_PASS_JUMP = "low_pass_jump" |
183 | | - |
184 | | - # ExponentialSmoothing |
185 | | - DAMPED_TREND = "damped_trend" |
186 | | - INITIALIZATION_METHOD = "initialization_method" |
187 | | - OPTIMIZED = "optimized" |
188 | | - REMOVE_BIAS = "remove_bias" |
189 | | - USE_BRUTE = "use_brute" |
190 | | - |
191 | | - # Arima |
192 | | - ORDER = "order" |
193 | | - SEASONAL_ORDER = "seasonal_order" |
194 | | - METHOD = "method" |
195 | | - MAXITER = "maxiter" |
196 | | - SUPPRESS_WARNINGS = "suppress_warnings" |
197 | | - OUT_OF_SAMPLE_SIZE = "out_of_sample_size" |
198 | | - SCORING = "scoring" |
199 | | - WITH_INTERCEPT = "with_intercept" |
200 | | - TIME_VARYING_REGRESSION = "time_varying_regression" |
201 | | - ENFORCE_STATIONARITY = "enforce_stationarity" |
202 | | - ENFORCE_INVERTIBILITY = "enforce_invertibility" |
203 | | - SIMPLE_DIFFERENCING = "simple_differencing" |
204 | | - MEASUREMENT_ERROR = "measurement_error" |
205 | | - MLE_REGRESSION = "mle_regression" |
206 | | - HAMILTON_REPRESENTATION = "hamilton_representation" |
207 | | - CONCENTRATE_SCALE = "concentrate_scale" |
208 | | - |
209 | | - # GAUSSIAN_HMM |
210 | | - N_COMPONENTS = "n_components" |
211 | | - COVARIANCE_TYPE = "covariance_type" |
212 | | - MIN_COVAR = "min_covar" |
213 | | - STARTPROB_PRIOR = "startprob_prior" |
214 | | - TRANSMAT_PRIOR = "transmat_prior" |
215 | | - MEANS_PRIOR = "means_prior" |
216 | | - MEANS_WEIGHT = "means_weight" |
217 | | - COVARS_PRIOR = "covars_prior" |
218 | | - COVARS_WEIGHT = "covars_weight" |
219 | | - ALGORITHM = "algorithm" |
220 | | - N_ITER = "n_iter" |
221 | | - TOL = "tol" |
222 | | - PARAMS = "params" |
223 | | - INIT_PARAMS = "init_params" |
224 | | - IMPLEMENTATION = "implementation" |
225 | | - |
226 | | - # GMMHMM |
227 | | - # N_COMPONENTS = "n_components" |
228 | | - N_MIX = "n_mix" |
229 | | - # MIN_COVAR = "min_covar" |
230 | | - # STARTPROB_PRIOR = "startprob_prior" |
231 | | - # TRANSMAT_PRIOR = "transmat_prior" |
232 | | - WEIGHTS_PRIOR = "weights_prior" |
233 | | - |
234 | | - # MEANS_PRIOR = "means_prior" |
235 | | - # MEANS_WEIGHT = "means_weight" |
236 | | - # ALGORITHM = "algorithm" |
237 | | - # COVARIANCE_TYPE = "covariance_type" |
238 | | - # N_ITER = "n_iter" |
239 | | - # TOL = "tol" |
240 | | - # INIT_PARAMS = "init_params" |
241 | | - # PARAMS = "params" |
242 | | - # IMPLEMENTATION = "implementation" |
243 | | - |
244 | | - # STRAY |
245 | | - ALPHA = "alpha" |
246 | | - K = "k" |
247 | | - KNN_ALGORITHM = "knn_algorithm" |
248 | | - P = "p" |
249 | | - SIZE_THRESHOLD = "size_threshold" |
250 | | - OUTLIER_TAIL = "outlier_tail" |
251 | | - |
252 | | - # timerxl |
253 | | - INPUT_TOKEN_LEN = "input_token_len" |
254 | | - HIDDEN_SIZE = "hidden_size" |
255 | | - INTERMEDIATE_SIZE = "intermediate_size" |
256 | | - OUTPUT_TOKEN_LENS = "output_token_lens" |
257 | | - NUM_HIDDEN_LAYERS = "num_hidden_layers" |
258 | | - NUM_ATTENTION_HEADS = "num_attention_heads" |
259 | | - HIDDEN_ACT = "hidden_act" |
260 | | - USE_CACHE = "use_cache" |
261 | | - ROPE_THETA = "rope_theta" |
262 | | - ATTENTION_DROPOUT = "attention_dropout" |
263 | | - INITIALIZER_RANGE = "initializer_range" |
264 | | - MAX_POSITION_EMBEDDINGS = "max_position_embeddings" |
265 | | - CKPT_PATH = "ckpt_path" |
266 | | - |
267 | | - # sundial |
268 | | - DROPOUT_RATE = "dropout_rate" |
269 | | - FLOW_LOSS_DEPTH = "flow_loss_depth" |
270 | | - NUM_SAMPLING_STEPS = "num_sampling_steps" |
271 | | - DIFFUSION_BATCH_MUL = "diffusion_batch_mul" |
272 | | - |
273 | | - def name(self) -> str: |
274 | | - return self.value |
0 commit comments