Skip to content

Commit fce237d

Browse files
Attribute calculation for new uploaded records (#123)
* Start on ac new records * Check if tokenization running and docbin * Tokenization for newly uploaded records * Tokenization for newly uploaded records * Added notification when all ac is done * Removed print * Update embeddings when new records are uploaded * Added enumerate for looping embeddings * Embeddings and ac for new uploaded records * Position for creating embeddings changed * Import functions replaced * Removed unused imports * PR comments * Submodules merge
1 parent aee43d5 commit fce237d

File tree

6 files changed

+206
-20
lines changed

6 files changed

+206
-20
lines changed

api/transfer.py

Lines changed: 178 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,24 @@
11
import logging
22
import traceback
3+
import time
4+
from typing import Any, List
35

46
from controller import organization
7+
from controller.embedding import util as embedding_util
8+
from controller.embedding import connector as embedding_connector
59
from starlette.endpoints import HTTPEndpoint
610
from starlette.responses import PlainTextResponse, JSONResponse
711

812
from controller.transfer.labelstudio import import_preperator
13+
from submodules.model.business_objects.tokenization import is_doc_bin_creation_running
914
from submodules.s3 import controller as s3
10-
from submodules.model.business_objects import organization
15+
from submodules.model.business_objects import (
16+
attribute,
17+
embedding,
18+
general,
19+
organization,
20+
tokenization,
21+
)
1122

1223
from controller.transfer import manager as transfer_manager
1324
from controller.upload_task import manager as upload_task_manager
@@ -16,13 +27,13 @@
1627
from controller.transfer import association_transfer_manager
1728
from controller.auth import manager as auth
1829
from controller.project import manager as project_manager
30+
from controller.attribute import manager as attribute_manager
1931

2032
from submodules.model import enums, exceptions
2133
from util.notification import create_notification
22-
from submodules.model.enums import NotificationType
23-
from submodules.model.models import UploadTask
24-
from submodules.model.business_objects import general
25-
from util import notification
34+
from submodules.model.enums import AttributeState, NotificationType, UploadStates
35+
from submodules.model.models import Embedding, UploadTask
36+
from util import daemon, notification
2637
from controller.tokenization import tokenization_service
2738

2839
logging.basicConfig(level=logging.DEBUG)
@@ -221,6 +232,7 @@ def init_file_import(task: UploadTask, project_id: str, is_global_update: bool)
221232
import_preperator.prepare_label_studio_import(project_id, task)
222233
else:
223234
transfer_manager.import_records_from_file(project_id, task)
235+
calculate_missing_attributes(project_id, task.user_id)
224236
elif "project" in task.file_type:
225237
transfer_manager.import_project(project_id, task)
226238
elif "knowledge_base" in task.file_type:
@@ -234,7 +246,10 @@ def init_file_import(task: UploadTask, project_id: str, is_global_update: bool)
234246
is_global_update,
235247
)
236248
if task.file_type != "knowledge_base":
237-
tokenization_service.request_tokenize_project(project_id, str(task.user_id))
249+
only_usable_attributes = task.file_type == "records_add"
250+
tokenization_service.request_tokenize_project(
251+
project_id, str(task.user_id), True, only_usable_attributes
252+
)
238253

239254

240255
def file_import_error_handling(
@@ -258,3 +273,160 @@ def file_import_error_handling(
258273
notification.send_organization_update(
259274
project_id, f"file_upload:{str(task.id)}:state:{task.state}", is_global_update
260275
)
276+
277+
278+
def calculate_missing_attributes(project_id: str, user_id: str) -> None:
279+
daemon.run(
280+
__calculate_missing_attributes,
281+
project_id,
282+
user_id,
283+
)
284+
285+
286+
def __calculate_missing_attributes(project_id: str, user_id: str) -> None:
287+
# wait a second to ensure that the process is started in the tokenization service
288+
time.sleep(5)
289+
ctx_token = general.get_ctx_token()
290+
attributes_usable = attribute.get_all_ordered(
291+
project_id,
292+
True,
293+
state_filter=[
294+
enums.AttributeState.USABLE.value,
295+
],
296+
)
297+
if len(attributes_usable) == 0:
298+
return
299+
# stored as list so connection results do not affect
300+
attribute_ids = [str(att_usable.id) for att_usable in attributes_usable]
301+
for att_id in attribute_ids:
302+
attribute.update(project_id, att_id, state=enums.AttributeState.INITIAL.value)
303+
general.commit()
304+
notification.send_organization_update(
305+
project_id=project_id, message="calculate_attribute:started:all"
306+
)
307+
# first check project tokenization completed
308+
i = 0
309+
while True:
310+
i += 1
311+
if i >= 60:
312+
i = 0
313+
ctx_token = general.remove_and_refresh_session(ctx_token, True)
314+
if tokenization.is_doc_bin_creation_running(project_id):
315+
time.sleep(5)
316+
continue
317+
else:
318+
break
319+
# next, ensure that the attributes are calculated and tokenized
320+
i = 0
321+
while True:
322+
time.sleep(1)
323+
i += 1
324+
if len(attribute_ids) == 0:
325+
notification.send_organization_update(
326+
project_id=project_id,
327+
message="calculate_attribute:finished:all",
328+
)
329+
break
330+
if i >= 60:
331+
i = 0
332+
ctx_token = general.remove_and_refresh_session(ctx_token, True)
333+
334+
current_att_id = attribute_ids[0]
335+
current_att = attribute.get(project_id, current_att_id)
336+
if current_att.state == enums.AttributeState.RUNNING.value:
337+
continue
338+
elif current_att.state == enums.AttributeState.INITIAL.value:
339+
attribute_manager.calculate_user_attribute_all_records(
340+
project_id, user_id, current_att_id, True
341+
)
342+
else:
343+
if tokenization.is_doc_bin_creation_running_for_attribute(
344+
project_id, current_att.name
345+
):
346+
time.sleep(5)
347+
continue
348+
else:
349+
attribute_ids.pop(0)
350+
notification.send_organization_update(
351+
project_id=project_id,
352+
message=f"calculate_attribute:finished:{current_att_id}",
353+
)
354+
time.sleep(5)
355+
356+
general.remove_and_refresh_session(ctx_token, False)
357+
calculate_missing_embedding_tensors(project_id, user_id)
358+
359+
360+
def calculate_missing_embedding_tensors(project_id: str, user_id: str) -> None:
361+
daemon.run(
362+
__calculate_missing_embedding_tensors,
363+
project_id,
364+
user_id,
365+
)
366+
367+
368+
def __calculate_missing_embedding_tensors(project_id: str, user_id: str) -> None:
369+
ctx_token = general.get_ctx_token()
370+
embeddings = embedding.get_finished_embeddings_by_started_at(project_id)
371+
if len(embeddings) == 0:
372+
return
373+
374+
embedding_ids = [str(embed.id) for embed in embeddings]
375+
for embed_id in embedding_ids:
376+
embedding.update_embedding_state_waiting(project_id, embed_id)
377+
general.commit()
378+
379+
try:
380+
ctx_token = __create_embeddings(project_id, embedding_ids, user_id, ctx_token)
381+
except Exception as e:
382+
print(
383+
f"Error while recreating embeddings for {project_id} when new records are uploaded : {e}"
384+
)
385+
get_waiting_embeddings = embedding.get_waiting_embeddings(project_id)
386+
for embed in get_waiting_embeddings:
387+
embedding.update_embedding_state_failed(project_id, str(embed.id))
388+
general.commit()
389+
finally:
390+
notification.send_organization_update(
391+
project_id=project_id, message="embedding:finished:all"
392+
)
393+
general.remove_and_refresh_session(ctx_token, False)
394+
395+
396+
def __create_embeddings(
397+
project_id: str,
398+
embedding_ids: List[str],
399+
user_id: str,
400+
ctx_token: Any,
401+
) -> Any:
402+
notification.send_organization_update(
403+
project_id=project_id, message="embedding:started:all"
404+
)
405+
for embedding_id in embedding_ids:
406+
ctx_token = general.remove_and_refresh_session(ctx_token, request_new=True)
407+
embedding_item = embedding.get(project_id, embedding_id)
408+
if not embedding_item:
409+
continue
410+
411+
embedding_connector.request_deleting_embedding(project_id, embedding_id)
412+
413+
attribute_id = str(embedding_item.attribute_id)
414+
attribute_name = attribute.get(project_id, attribute_id).name
415+
if embedding_item.type == enums.EmbeddingType.ON_ATTRIBUTE.value:
416+
prefix = f"{attribute_name}-classification-"
417+
config_string = embedding_item.name[len(prefix) :]
418+
embedding_connector.request_creating_attribute_level_embedding(
419+
project_id, attribute_id, user_id, config_string
420+
)
421+
else:
422+
prefix = f"{attribute_name}-extraction-"
423+
config_string = embedding_item.name[len(prefix) :]
424+
embedding_connector.request_creating_token_level_embedding(
425+
project_id, attribute_id, user_id, config_string
426+
)
427+
time.sleep(5)
428+
while embedding_util.has_encoder_running(project_id):
429+
if embedding_item.state == enums.EmbeddingState.WAITING.value:
430+
break
431+
time.sleep(1)
432+
return ctx_token

controller/attribute/manager.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import time
2-
from typing import List, Tuple
2+
from typing import List, Optional, Tuple
33
from controller.tokenization.tokenization_service import (
44
request_tokenize_calculated_attribute,
55
request_tokenize_project,
@@ -150,7 +150,7 @@ def add_running_id(
150150

151151

152152
def calculate_user_attribute_all_records(
153-
project_id: str, user_id: str, attribute_id: str
153+
project_id: str, user_id: str, attribute_id: str, include_rats: bool = True
154154
) -> None:
155155
if attribute.get_all(
156156
project_id=project_id, state_filter=[AttributeState.RUNNING.value]
@@ -185,7 +185,6 @@ def calculate_user_attribute_all_records(
185185
append_to_logs=False,
186186
)
187187
return
188-
189188
attribute.update(
190189
project_id=project_id,
191190
attribute_id=attribute_id,
@@ -201,15 +200,18 @@ def calculate_user_attribute_all_records(
201200
project_id,
202201
user_id,
203202
attribute_id,
203+
include_rats,
204204
)
205205

206206

207207
def __calculate_user_attribute_all_records(
208-
project_id: str, user_id: str, attribute_id: str
208+
project_id: str, user_id: str, attribute_id: str, include_rats: bool
209209
) -> None:
210210
try:
211211
calculated_attributes = util.run_attribute_calculation_exec_env(
212-
attribute_id=attribute_id, project_id=project_id, doc_bin="docbin_full"
212+
attribute_id=attribute_id,
213+
project_id=project_id,
214+
doc_bin="docbin_full",
213215
)
214216
if not calculated_attributes:
215217
__notify_attribute_calculation_failed(
@@ -258,7 +260,7 @@ def __calculate_user_attribute_all_records(
258260
)
259261
try:
260262
request_tokenize_calculated_attribute(
261-
project_id, user_id, attribute_item.id
263+
project_id, user_id, attribute_item.id, include_rats
262264
)
263265
except:
264266
record.delete_user_created_attribute(

controller/attribute/util.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ def add_log_to_attribute_logs(
3232

3333

3434
def prepare_sample_records_doc_bin(attribute_id: str, project_id: str) -> str:
35-
3635
sample_records = record.get_attribute_calculation_sample_records(project_id)
3736

3837
sample_records_doc_bin = tokenization.get_doc_bin_table_to_json(
@@ -54,7 +53,6 @@ def prepare_sample_records_doc_bin(attribute_id: str, project_id: str) -> str:
5453
def run_attribute_calculation_exec_env(
5554
attribute_id: str, project_id: str, doc_bin: str
5655
) -> None:
57-
5856
attribute_item = attribute.get(project_id, attribute_id)
5957

6058
prefixed_function_name = f"{attribute_id}_fn"

controller/tokenization/tokenization_service.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,32 @@ def request_tokenize_record(project_id: str, record_id: str) -> None:
1515
service_requests.post_call_or_raise(url, data)
1616

1717

18-
def request_tokenize_project(project_id: str, user_id: str) -> None:
18+
def request_tokenize_project(
19+
project_id: str,
20+
user_id: str,
21+
include_rats: bool = True,
22+
only_uploaded_attributes: bool = False,
23+
) -> None:
1924
url = f"{BASE_URI}/tokenize_project"
2025
data = {
2126
"project_id": str(project_id),
2227
"record_id": "",
2328
"user_id": str(user_id),
29+
"include_rats": include_rats,
30+
"only_uploaded_attributes": only_uploaded_attributes,
2431
}
2532
service_requests.post_call_or_raise(url, data)
2633

2734

2835
def request_tokenize_calculated_attribute(
29-
project_id: str, user_id: str, attribute_id: str
36+
project_id: str, user_id: str, attribute_id: str, include_rats: bool = True
3037
) -> None:
3138
url = f"{BASE_URI}/tokenize_calculated_attribute"
3239
data = {
3340
"project_id": str(project_id),
3441
"user_id": str(user_id),
3542
"attribute_id": str(attribute_id),
43+
"include_rats": include_rats,
3644
}
3745
service_requests.post_call_or_raise(url, data)
3846

controller/transfer/checks.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from controller.transfer.valid_arguments import valid_arguments
66
import pandas as pd
77
from util.notification import create_notification
8-
from submodules.model.enums import NotificationType
8+
from submodules.model.enums import AttributeState, NotificationType
99
from submodules.model.business_objects import attribute, record, general
1010
from controller.labeling_task.util import infer_labeling_task_name
1111
import logging
@@ -52,7 +52,6 @@ def run_checks(df: pd.DataFrame, project_id, user_id) -> None:
5252
duplicated_task_names = set()
5353
task_names_set = set()
5454
for task_name in task_names:
55-
5655
if task_name in task_names_set:
5756
duplicated_task_names.add(task_name)
5857
else:
@@ -69,9 +68,16 @@ def run_checks(df: pd.DataFrame, project_id, user_id) -> None:
6968
errors["DuplicatedTaskNames"] = notification.message
7069

7170
# check attribute equality
72-
attribute_entities = attribute.get_all(project_id)
71+
attribute_entities = attribute.get_all(
72+
project_id,
73+
state_filter=[
74+
AttributeState.UPLOADED.value,
75+
AttributeState.AUTOMATICALLY_CREATED.value,
76+
],
77+
)
7378
attribute_names = [attribute_item.name for attribute_item in attribute_entities]
7479
differences = set(attribute_names).difference(set(attributes))
80+
7581
if differences:
7682
guard = True
7783
notification = create_notification(

0 commit comments

Comments
 (0)