1
1
import logging
2
2
import traceback
3
+ import time
4
+ from typing import Any , List
3
5
4
6
from controller import organization
7
+ from controller .embedding import util as embedding_util
8
+ from controller .embedding import connector as embedding_connector
5
9
from starlette .endpoints import HTTPEndpoint
6
10
from starlette .responses import PlainTextResponse , JSONResponse
7
11
8
12
from controller .transfer .labelstudio import import_preperator
13
+ from submodules .model .business_objects .tokenization import is_doc_bin_creation_running
9
14
from submodules .s3 import controller as s3
10
- from submodules .model .business_objects import organization
15
+ from submodules .model .business_objects import (
16
+ attribute ,
17
+ embedding ,
18
+ general ,
19
+ organization ,
20
+ tokenization ,
21
+ )
11
22
12
23
from controller .transfer import manager as transfer_manager
13
24
from controller .upload_task import manager as upload_task_manager
16
27
from controller .transfer import association_transfer_manager
17
28
from controller .auth import manager as auth
18
29
from controller .project import manager as project_manager
30
+ from controller .attribute import manager as attribute_manager
19
31
20
32
from submodules .model import enums , exceptions
21
33
from util .notification import create_notification
22
- from submodules .model .enums import NotificationType
23
- from submodules .model .models import UploadTask
24
- from submodules .model .business_objects import general
25
- from util import notification
34
+ from submodules .model .enums import AttributeState , NotificationType , UploadStates
35
+ from submodules .model .models import Embedding , UploadTask
36
+ from util import daemon , notification
26
37
from controller .tokenization import tokenization_service
27
38
28
39
logging .basicConfig (level = logging .DEBUG )
@@ -221,6 +232,7 @@ def init_file_import(task: UploadTask, project_id: str, is_global_update: bool)
221
232
import_preperator .prepare_label_studio_import (project_id , task )
222
233
else :
223
234
transfer_manager .import_records_from_file (project_id , task )
235
+ calculate_missing_attributes (project_id , task .user_id )
224
236
elif "project" in task .file_type :
225
237
transfer_manager .import_project (project_id , task )
226
238
elif "knowledge_base" in task .file_type :
@@ -234,7 +246,10 @@ def init_file_import(task: UploadTask, project_id: str, is_global_update: bool)
234
246
is_global_update ,
235
247
)
236
248
if task .file_type != "knowledge_base" :
237
- tokenization_service .request_tokenize_project (project_id , str (task .user_id ))
249
+ only_usable_attributes = task .file_type == "records_add"
250
+ tokenization_service .request_tokenize_project (
251
+ project_id , str (task .user_id ), True , only_usable_attributes
252
+ )
238
253
239
254
240
255
def file_import_error_handling (
@@ -258,3 +273,160 @@ def file_import_error_handling(
258
273
notification .send_organization_update (
259
274
project_id , f"file_upload:{ str (task .id )} :state:{ task .state } " , is_global_update
260
275
)
276
+
277
+
278
+ def calculate_missing_attributes (project_id : str , user_id : str ) -> None :
279
+ daemon .run (
280
+ __calculate_missing_attributes ,
281
+ project_id ,
282
+ user_id ,
283
+ )
284
+
285
+
286
+ def __calculate_missing_attributes (project_id : str , user_id : str ) -> None :
287
+ # wait a second to ensure that the process is started in the tokenization service
288
+ time .sleep (5 )
289
+ ctx_token = general .get_ctx_token ()
290
+ attributes_usable = attribute .get_all_ordered (
291
+ project_id ,
292
+ True ,
293
+ state_filter = [
294
+ enums .AttributeState .USABLE .value ,
295
+ ],
296
+ )
297
+ if len (attributes_usable ) == 0 :
298
+ return
299
+ # stored as list so connection results do not affect
300
+ attribute_ids = [str (att_usable .id ) for att_usable in attributes_usable ]
301
+ for att_id in attribute_ids :
302
+ attribute .update (project_id , att_id , state = enums .AttributeState .INITIAL .value )
303
+ general .commit ()
304
+ notification .send_organization_update (
305
+ project_id = project_id , message = "calculate_attribute:started:all"
306
+ )
307
+ # first check project tokenization completed
308
+ i = 0
309
+ while True :
310
+ i += 1
311
+ if i >= 60 :
312
+ i = 0
313
+ ctx_token = general .remove_and_refresh_session (ctx_token , True )
314
+ if tokenization .is_doc_bin_creation_running (project_id ):
315
+ time .sleep (5 )
316
+ continue
317
+ else :
318
+ break
319
+ # next, ensure that the attributes are calculated and tokenized
320
+ i = 0
321
+ while True :
322
+ time .sleep (1 )
323
+ i += 1
324
+ if len (attribute_ids ) == 0 :
325
+ notification .send_organization_update (
326
+ project_id = project_id ,
327
+ message = "calculate_attribute:finished:all" ,
328
+ )
329
+ break
330
+ if i >= 60 :
331
+ i = 0
332
+ ctx_token = general .remove_and_refresh_session (ctx_token , True )
333
+
334
+ current_att_id = attribute_ids [0 ]
335
+ current_att = attribute .get (project_id , current_att_id )
336
+ if current_att .state == enums .AttributeState .RUNNING .value :
337
+ continue
338
+ elif current_att .state == enums .AttributeState .INITIAL .value :
339
+ attribute_manager .calculate_user_attribute_all_records (
340
+ project_id , user_id , current_att_id , True
341
+ )
342
+ else :
343
+ if tokenization .is_doc_bin_creation_running_for_attribute (
344
+ project_id , current_att .name
345
+ ):
346
+ time .sleep (5 )
347
+ continue
348
+ else :
349
+ attribute_ids .pop (0 )
350
+ notification .send_organization_update (
351
+ project_id = project_id ,
352
+ message = f"calculate_attribute:finished:{ current_att_id } " ,
353
+ )
354
+ time .sleep (5 )
355
+
356
+ general .remove_and_refresh_session (ctx_token , False )
357
+ calculate_missing_embedding_tensors (project_id , user_id )
358
+
359
+
360
+ def calculate_missing_embedding_tensors (project_id : str , user_id : str ) -> None :
361
+ daemon .run (
362
+ __calculate_missing_embedding_tensors ,
363
+ project_id ,
364
+ user_id ,
365
+ )
366
+
367
+
368
+ def __calculate_missing_embedding_tensors (project_id : str , user_id : str ) -> None :
369
+ ctx_token = general .get_ctx_token ()
370
+ embeddings = embedding .get_finished_embeddings_by_started_at (project_id )
371
+ if len (embeddings ) == 0 :
372
+ return
373
+
374
+ embedding_ids = [str (embed .id ) for embed in embeddings ]
375
+ for embed_id in embedding_ids :
376
+ embedding .update_embedding_state_waiting (project_id , embed_id )
377
+ general .commit ()
378
+
379
+ try :
380
+ ctx_token = __create_embeddings (project_id , embedding_ids , user_id , ctx_token )
381
+ except Exception as e :
382
+ print (
383
+ f"Error while recreating embeddings for { project_id } when new records are uploaded : { e } "
384
+ )
385
+ get_waiting_embeddings = embedding .get_waiting_embeddings (project_id )
386
+ for embed in get_waiting_embeddings :
387
+ embedding .update_embedding_state_failed (project_id , str (embed .id ))
388
+ general .commit ()
389
+ finally :
390
+ notification .send_organization_update (
391
+ project_id = project_id , message = "embedding:finished:all"
392
+ )
393
+ general .remove_and_refresh_session (ctx_token , False )
394
+
395
+
396
+ def __create_embeddings (
397
+ project_id : str ,
398
+ embedding_ids : List [str ],
399
+ user_id : str ,
400
+ ctx_token : Any ,
401
+ ) -> Any :
402
+ notification .send_organization_update (
403
+ project_id = project_id , message = "embedding:started:all"
404
+ )
405
+ for embedding_id in embedding_ids :
406
+ ctx_token = general .remove_and_refresh_session (ctx_token , request_new = True )
407
+ embedding_item = embedding .get (project_id , embedding_id )
408
+ if not embedding_item :
409
+ continue
410
+
411
+ embedding_connector .request_deleting_embedding (project_id , embedding_id )
412
+
413
+ attribute_id = str (embedding_item .attribute_id )
414
+ attribute_name = attribute .get (project_id , attribute_id ).name
415
+ if embedding_item .type == enums .EmbeddingType .ON_ATTRIBUTE .value :
416
+ prefix = f"{ attribute_name } -classification-"
417
+ config_string = embedding_item .name [len (prefix ) :]
418
+ embedding_connector .request_creating_attribute_level_embedding (
419
+ project_id , attribute_id , user_id , config_string
420
+ )
421
+ else :
422
+ prefix = f"{ attribute_name } -extraction-"
423
+ config_string = embedding_item .name [len (prefix ) :]
424
+ embedding_connector .request_creating_token_level_embedding (
425
+ project_id , attribute_id , user_id , config_string
426
+ )
427
+ time .sleep (5 )
428
+ while embedding_util .has_encoder_running (project_id ):
429
+ if embedding_item .state == enums .EmbeddingState .WAITING .value :
430
+ break
431
+ time .sleep (1 )
432
+ return ctx_token
0 commit comments