Skip to content

Commit 3d8d098

Browse files
JWittmeyerSirDegrafJWittmeyerFelixKirschKern
authored
release updates (#82)
* Adds tokenization to project creation by label studio * Filters extraction tasks out of ls import preparation * Adds cancel mutation for zero shot runs * Checks if to_name has an attribute equivalent and defaults to full record task if not * Removes data from minio on project delete * Adds handling for missing or wrong typed date of models in model provider info * Fixes record delete embedding issue * label renaming findings Co-authored-by: SirDeGraf <[email protected]> Co-authored-by: JWittmeyer <[email protected]> Co-authored-by: felix0496 <[email protected]>
1 parent aa3711a commit 3d8d098

File tree

9 files changed

+154
-16
lines changed

9 files changed

+154
-16
lines changed

controller/labeling_task_label/manager.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,21 +118,59 @@ def __check_warnings_label_rename(
118118
project_id: str, label: LabelingTaskLabel, new_name: str
119119
) -> List[Dict[str, Any]]:
120120
append_me = []
121-
122121
information_sources = information_source.get_all(project_id)
122+
task_type = labeling_task.get(project_id, label.labeling_task_id).task_type
123+
124+
old_var_name = label.name.replace(" ", "_")
125+
new_var_name = new_name.replace(" ", "_")
126+
123127
for information_source_item in information_sources:
128+
old_highlighting, new_highlighting = [], []
124129
current_code = information_source_item.source_code
125-
new_code = re.sub(r"\b%s\b" % label.name, new_name, current_code)
130+
new_code = current_code
131+
132+
if task_type == LabelingTaskType.INFORMATION_EXTRACTION.value:
133+
if re.search("import knowledge", new_code):
134+
format_import = r"(?<=\bknowledge\.)({label_name})(?=\b)"
135+
pattern_import = format_import.format(label_name=old_var_name)
136+
old_highlighting.append(pattern_import)
137+
new_highlighting.append(format_import.format(label_name=new_var_name))
138+
new_code = re.sub(pattern_import, f"{new_var_name}", new_code)
139+
if re.search(rf"(?<=from knowledge import).*?\b{old_var_name}\b", new_code):
140+
format_relative_import = r"(?<!\['\"])(\b{label_name}\b)(?!['\"])"
141+
pattern_relative_import = format_relative_import.format(
142+
label_name=old_var_name
143+
)
144+
old_highlighting.append(pattern_relative_import)
145+
new_highlighting.append(
146+
format_relative_import.format(label_name=new_var_name)
147+
)
148+
new_code = re.sub(pattern_relative_import, f"{new_var_name}", new_code)
149+
150+
if information_source_item.labeling_task_id == label.labeling_task_id:
151+
format_label = r"['\"]{label_name}['\"]"
152+
pattern_label = format_label.format(label_name=label.name)
153+
old_highlighting.append(pattern_label)
154+
new_highlighting.append(format_label.format(label_name=new_name))
155+
new_code = re.sub(pattern_label, f'"{new_name}"', new_code)
156+
126157
if current_code != new_code:
127158
entry = __get_msg_dict(
128-
"Information source with matching word was detected."
159+
f"Matching label found in information source {information_source_item.name}."
129160
)
130161
entry["key"] = enums.CommentCategory.HEURISTIC.value
131162
entry["id"] = str(information_source_item.id)
163+
entry["information_source_name"] = information_source_item.name
132164
entry["old"] = current_code
133165
entry["new"] = new_code
134166
entry["old_name"] = label.name
135167
entry["new_name"] = new_name
168+
entry["old_highlighting"] = old_highlighting
169+
entry["new_highlighting"] = new_highlighting
170+
entry[
171+
"href"
172+
] = f"/projects/{project_id}/information_sources/{information_source_item.id}"
173+
136174
append_me.append(entry)
137175

138176
return append_me
@@ -152,4 +190,4 @@ def __check_label_rename_knowledge_base(
152190
entry["msg"] += "\n\tNew label name however, already exists as lookup list."
153191
append_to["errors"].append(entry)
154192
else:
155-
append_to["warnings"].append(entry)
193+
append_to["warnings"].insert(0, entry)

controller/model_provider/manager.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,14 @@ def get_model_provider_info() -> List[ModelProviderInfoResult]:
1515
# parse dates to datetime format
1616
for model in model_info:
1717
if model["date"]:
18-
model["date"] = datetime.fromisoformat(model["date"])
18+
try:
19+
date = datetime.fromisoformat(model["date"])
20+
if date:
21+
model["date"] = date
22+
except ValueError:
23+
pass
24+
except TypeError:
25+
pass
1926

2027
return model_info
2128

controller/project/manager.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,14 @@ def update_project(
8181
def delete_project(project_id: str) -> None:
8282
org_id = organization.get_id_by_project_id(project_id)
8383
project.delete_by_id(project_id, with_commit=True)
84-
daemon.run(s3.archive_bucket, org_id, project_id + "/")
84+
85+
daemon.run(__delete_project_data_from_minio, org_id, project_id)
86+
87+
88+
def __delete_project_data_from_minio(org_id, project_id: str) -> None:
89+
objects = s3.get_bucket_objects(org_id, project_id + "/")
90+
for obj in objects:
91+
s3.delete_object(org_id, obj)
8592

8693

8794
def import_sample_project(user_id: str, organization_id: str, name: str) -> Project:

controller/record/manager.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22

33
from graphql_api.types import ExtendedSearch
44
from submodules.model import Record, Attribute
5-
from submodules.model.business_objects import general, record, user_session
5+
from submodules.model.business_objects import general, record, user_session, embedding
66
from service.search import search
77

88
from controller.record import neural_search_connector
9+
from controller.embedding import manager as embedding_manager
10+
from util import daemon
911

1012

1113
def get_record(project_id: str, record_id: str) -> Record:
@@ -89,7 +91,14 @@ def get_records_by_extended_search(
8991

9092
def delete_record(project_id: str, record_id: str) -> None:
9193
record.delete(project_id, record_id, with_commit=True)
94+
daemon.run(__reupload_embeddings, project_id)
9295

9396

9497
def delete_all_records(project_id: str) -> None:
9598
record.delete_all(project_id, with_commit=True)
99+
100+
101+
def __reupload_embeddings(project_id: str) -> None:
102+
embeddings = embedding.get_finished_embeddings(project_id)
103+
for e in embeddings:
104+
embedding_manager.request_tensor_upload(project_id, str(e.id))

controller/transfer/labelstudio/import_preperator.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def analyze_file(
5050
ex_predictions = None
5151
ex_extraction = None
5252
ex_multiple_choices = None
53+
ex_to_names_check = None
5354
# multiple annotation for a user within the same record/task
5455
ex_multiple_annotations = None
5556

@@ -65,6 +66,10 @@ def analyze_file(
6566
ex_predictions = f"\n\tExample: record {record_id}"
6667
if not ex_multiple_annotations and __check_record_has_multi_annotation(record):
6768
ex_multiple_annotations = f"\n\tExample: record {record_id}"
69+
if not ex_to_names_check and __check_to_names_without_attribute_equivalent(
70+
record
71+
):
72+
ex_to_names_check = f"\n\tExample: record {record_id}"
6873
if (
6974
is_project_update
7075
and not ex_no_kern_id
@@ -108,6 +113,11 @@ def analyze_file(
108113
"Named Entity Recognition / extraction labels are not supported.\nThese annotations will be ignored if you proceed."
109114
+ ex_extraction
110115
)
116+
if ex_to_names_check:
117+
file_additional_info["warnings"].append(
118+
"Task targets found without equivalent in records attributes \nThese will be created as full record tasks if you proceed."
119+
+ ex_to_names_check
120+
)
111121
if ex_multiple_choices:
112122
file_additional_info["warnings"].append(
113123
"Multiple choices for a result set are not supported.\nThese annotations will be ignored if you proceed."
@@ -130,17 +140,33 @@ def analyze_file(
130140
file_additional_info["file_info"]["annotations"] = user_id_counts
131141

132142

133-
def __add_annotation_target(annotation: Dict[str, Any], tasks: Set[str]) -> None:
143+
def __add_annotation_target(
144+
annotation: Dict[str, Any], tasks: Set[str]
145+
) -> None:
134146
tasks |= __get_annotation_targets(annotation)
135147

136148

137149
def __get_annotation_targets(annotation: Dict[str, Any]) -> Set[str]:
138150
target = annotation.get("result")
139151
if target and len(target) > 0:
140-
return {t["from_name"] for t in target if "from_name" in t}
152+
return {
153+
t["from_name"]
154+
for t in target
155+
if "from_name" in t and t["type"] == "choices"
156+
}
141157
return {}
142158

143159

160+
def __check_to_names_without_attribute_equivalent(
161+
record: Dict[str, Any]
162+
) -> bool:
163+
for annotation in record.get("annotations"):
164+
target = annotation.get("result")
165+
to_names = [t["to_name"] for t in target if "to_name" in t and t["type"] == "choices"]
166+
167+
return len(set(to_names) - set(record.get("data"))) != 0
168+
169+
144170
def __check_record_has_values_for(
145171
record: Dict[str, Any], key: str, sub_key: Optional[str] = None
146172
) -> bool:

controller/transfer/labelstudio/project_creation_manager.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def manage_data_import(project_id: str, task_id: str) -> None:
2323
task = upload_task.get(project_id, task_id)
2424
file_path = download_file(project_id, task)
2525
mappings = json.loads(task.mappings)
26+
attribute_names = []
2627
user_mapping = mappings.get("users")
2728
user_mapping = create_unknown_users(user_mapping)
2829
attribute_task_mapping = mappings.get("tasks")
@@ -32,18 +33,17 @@ def manage_data_import(project_id: str, task_id: str) -> None:
3233
first_record_item = data[0]
3334
for attribute_name, attribute_value in first_record_item.get("data").items():
3435
__create_attribute(project_id, attribute_name, attribute_value)
36+
attribute_names.append(attribute_name)
3537

3638
labeling_tasks, records, record_label_associations = __extract_data(
37-
data, user_mapping, attribute_task_mapping
39+
data, user_mapping, attribute_task_mapping, attribute_names
3840
)
3941
label_id_lookup = __create_labeling_tasks(project_id, labeling_tasks)
4042

4143
CHUNK_SIZE = 500
42-
chunks = [records[x: x + CHUNK_SIZE] for x in range(0, len(records), CHUNK_SIZE)]
44+
chunks = [records[x : x + CHUNK_SIZE] for x in range(0, len(records), CHUNK_SIZE)]
4345
for idx, chunk in enumerate(chunks):
44-
__create_records(
45-
project_id, chunk, record_label_associations, label_id_lookup
46-
)
46+
__create_records(project_id, chunk, record_label_associations, label_id_lookup)
4747
number_records = len(records)
4848

4949
upload_task_manager.update_upload_task_to_finished(task)
@@ -87,7 +87,9 @@ def __create_records(
8787
)
8888

8989

90-
def __create_labeling_tasks(project_id: str, labeling_tasks: Dict[str, Any]) -> Dict[str, Any]:
90+
def __create_labeling_tasks(
91+
project_id: str, labeling_tasks: Dict[str, Any]
92+
) -> Dict[str, Any]:
9193
label_id_lookup = {}
9294

9395
attribute_ids_by_names = {
@@ -120,7 +122,12 @@ def __infer_target(target_attribute: str) -> str:
120122
)
121123

122124

123-
def __extract_data(data: Any, user_mapping: Dict[str, Any], attribute_task_mapping: Dict[str, Any]) -> Tuple[Dict[str, Any], List, Dict[str, Any]]:
125+
def __extract_data(
126+
data: Any,
127+
user_mapping: Dict[str, Any],
128+
attribute_task_mapping: Dict[str, Any],
129+
attribute_names: List[str],
130+
) -> Tuple[Dict[str, Any], List, Dict[str, Any]]:
124131
labeling_tasks = {}
125132
records = []
126133
record_label_associations = {}
@@ -160,6 +167,7 @@ def __extract_data(data: Any, user_mapping: Dict[str, Any], attribute_task_mappi
160167
if (
161168
attribute_task_mapping.get(task_name)
162169
== enums.RecordImportMappingValues.ATTRIBUTE_SPECIFIC.value
170+
and result.get("to_name") in attribute_names
163171
):
164172
labeling_tasks.get(task_name)["attribute"] = result.get("to_name")
165173

controller/transfer/manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import traceback
55
from typing import Any, List, Optional, Dict
66
import zipfile
7+
8+
from controller.tokenization import tokenization_service
79
from controller.transfer import export_parser
810
from controller.transfer.knowledge_base_transfer_manager import (
911
import_knowledge_base_file,
@@ -272,6 +274,8 @@ def import_label_studio_file(project_id: str, upload_task_id: str) -> None:
272274
project_update_manager.manage_data_import(project_id, upload_task_id)
273275
else:
274276
project_creation_manager.manage_data_import(project_id, upload_task_id)
277+
task = upload_task.get(project_id, upload_task_id)
278+
tokenization_service.request_tokenize_project(project_id, str(task.user_id))
275279
upload_task.update(project_id, upload_task_id, state=enums.UploadStates.DONE.value)
276280
except Exception:
277281
general.rollback()

controller/zero_shot/manager.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,3 +184,18 @@ def __start_zero_shot_for_project(
184184
f"Can't calculate stats for zero shot project {project_id}, is {information_source_id}",
185185
flush=True,
186186
)
187+
188+
189+
def cancel_zero_shot_run(
190+
project_id: str,
191+
information_source_id: str,
192+
payload_id: str,
193+
) -> None:
194+
item = information_source.get_payload(project_id, payload_id)
195+
if not item:
196+
raise ValueError("unknown payload:" + payload_id)
197+
if str(item.source_id) != information_source_id:
198+
raise ValueError("payload does not belong to information source")
199+
# setting the state to failed with be noted by the thread in zs service and handled
200+
item.state = enums.PayloadState.FAILED.value
201+
general.commit()

graphql_api/mutation/zero_shot.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,29 @@ def mutate(
2727
return ZeroShotProject(ok=True)
2828

2929

30+
class CancelZeroShotRun(graphene.Mutation):
31+
class Arguments:
32+
project_id = graphene.ID(required=True)
33+
information_source_id = graphene.ID(required=True)
34+
payload_id = graphene.ID(required=True)
35+
36+
ok = graphene.Boolean()
37+
38+
def mutate(
39+
self,
40+
info,
41+
project_id: str,
42+
information_source_id: str,
43+
payload_id: str,
44+
):
45+
auth_manager.check_demo_access(info)
46+
auth_manager.check_project_access(info, project_id)
47+
48+
manager.cancel_zero_shot_run(project_id, information_source_id, payload_id)
49+
50+
return ZeroShotProject(ok=True)
51+
52+
3053
class CreateZeroShotInformationSource(graphene.Mutation):
3154
class Arguments:
3255
project_id = graphene.ID(required=True)
@@ -60,3 +83,4 @@ def mutate(
6083
class ZeroShotMutation(graphene.ObjectType):
6184
zero_shot_project = ZeroShotProject.Field()
6285
create_zero_shot_information_source = CreateZeroShotInformationSource.Field()
86+
cancel_zero_shot_run = CancelZeroShotRun.Field()

0 commit comments

Comments
 (0)