Skip to content

Commit 20c97ee

Browse files
SirDegrafJWittmeyerJWittmeyer
authored
Label studio import (#80)
* Adds extended functionality of upload tasks for label studio import * Renamed table field user_mappings to mappings * Makes simple analyzes for imported label studio file * Adds endpoint for setting mapping and makes use of existing update structure for integrating websocket e.g. into upload task updating * Adds check for state change * Adds logic for preparing task_names * Adds task selection fix * Adds project cleans action * Adds logic for import converter * Changes task collection to work with multiple results * Makes changes to converter logic * Adds handling of IGNORE and adds task_id to file name * Calls import-logic to import converted file * Adds check to l studio file import * Submodule change * Change from error to warning * Removes MANUAL from task name * Remvoes tempfile * Adds check for project update kern id * Adds dedicated label studio import * Major fixes and dedicated update logic * Minor code style changes * Adds filter for IGNORE value in label update * Enables tokenization in label studio upload * Resolves PR comments * Updates submodule mode, * Adds error handling with db rollback Co-authored-by: JWittmeyer <[email protected]> Co-authored-by: root <[email protected]>
1 parent a783c8f commit 20c97ee

File tree

18 files changed

+800
-41
lines changed

18 files changed

+800
-41
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
.DS_Store
22
.idea
3-
tmpfile.
3+
tmpfile.*
44

55
# Byte-compiled / optimized / DLL files
66
__pycache__/
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
"""Adds upload task fields for label studio
2+
3+
Revision ID: 09311360f8b9
4+
Revises: 87f463aa5112
5+
Create Date: 2022-11-07 10:32:10.881495
6+
7+
"""
8+
from alembic import op
9+
import sqlalchemy as sa
10+
11+
12+
# revision identifiers, used by Alembic.
13+
revision = '09311360f8b9'
14+
down_revision = '87f463aa5112'
15+
branch_labels = None
16+
depends_on = None
17+
18+
19+
def upgrade():
20+
# ### commands auto generated by Alembic - please adjust! ###
21+
op.add_column('upload_task', sa.Column('upload_type', sa.String(), nullable=True))
22+
op.add_column('upload_task', sa.Column('file_additional_info', sa.String(), nullable=True))
23+
op.add_column('upload_task', sa.Column('mappings', sa.String(), nullable=True))
24+
# ### end Alembic commands ###
25+
26+
27+
def downgrade():
28+
# ### commands auto generated by Alembic - please adjust! ###
29+
op.drop_column('upload_task', 'user_mapping')
30+
op.drop_column('upload_task', 'file_additional_info')
31+
op.drop_column('upload_task', 'upload_type')
32+
# ### end Alembic commands ###

api/transfer.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
import logging
22
import traceback
33

4+
import controller.transfer.labelstudio.import_preperator
45
from controller import organization
56
from starlette.endpoints import HTTPEndpoint
67
from starlette.responses import PlainTextResponse, JSONResponse
8+
9+
from controller.transfer.labelstudio import import_preperator
710
from submodules.s3 import controller as s3
811
from submodules.model.business_objects import organization
912

10-
from controller.transfer import manager as transfer_manager
13+
from controller.transfer import manager as transfer_manager, record_transfer_manager
1114
from controller.upload_task import manager as upload_task_manager
1215
from controller.auth import manager as auth_manager
1316
from controller.transfer import manager as transfer_manager
@@ -191,17 +194,27 @@ def get(self, request) -> JSONResponse:
191194

192195

193196
def init_file_import(task: UploadTask, project_id: str, is_global_update: bool) -> None:
197+
task_state = task.state
194198
if "records" in task.file_type:
195-
transfer_manager.import_records_from_file(project_id, task)
199+
if task.upload_type == enums.UploadTypes.LABEL_STUDIO.value:
200+
import_preperator.prepare_label_studio_import(project_id, task)
201+
else:
202+
transfer_manager.import_records_from_file(project_id, task)
196203
elif "project" in task.file_type:
197204
transfer_manager.import_project(project_id, task)
198205
elif "knowledge_base" in task.file_type:
199206
transfer_manager.import_knowledge_base(project_id, task)
200207

201-
notification.send_organization_update(
202-
project_id, f"file_upload:{str(task.id)}:state:{task.state}", is_global_update
203-
)
204-
if task.file_type != "knowledge_base":
208+
if task.state == task_state:
209+
# update is sent in update task if it was updated (e.g. with labeling studio)
210+
notification.send_organization_update(
211+
project_id,
212+
f"file_upload:{str(task.id)}:state:{task.state}",
213+
is_global_update,
214+
)
215+
if (
216+
task.file_type != "knowledge_base"
217+
):
205218
tokenization_service.request_tokenize_project(project_id, str(task.user_id))
206219

207220

controller/record_label_association/manager.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -300,9 +300,7 @@ def delete_record_label_association(
300300
project_id, record_id, association_ids
301301
)
302302
task_ids = get_labeling_tasks_from_ids(project_id, association_ids)
303-
record_label_association.delete_by_ids(
304-
project_id, record_id, association_ids, with_commit=True
305-
)
303+
record_label_association.delete_by_ids(project_id, association_ids, record_id, with_commit=True)
306304
for task_id in task_ids:
307305
update_is_relevant_manual_label(project_id, task_id, record_id)
308306
general.commit()

controller/transfer/knowledge_base_transfer_manager.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@
88

99

1010
def import_knowledge_base_file(project_id: str, task: UploadTask) -> None:
11-
upload_task_manager.update_task(
12-
project_id, task.id, state=enums.UploadStates.PENDING.value
13-
)
11+
upload_task_manager.update_task(project_id, task.id, state=enums.UploadStates.PENDING.value)
1412
general.commit()
1513

1614
file_type = task.file_name.rsplit("_", 1)[0].rsplit(".", 1)[1]
@@ -45,9 +43,7 @@ def import_knowledge_base_file(project_id: str, task: UploadTask) -> None:
4543
project_id, list_id, to_add, with_commit=True
4644
)
4745

48-
upload_task_manager.update_task(
49-
project_id, task.id, state=enums.UploadStates.IN_PROGRESS.value
50-
)
46+
upload_task_manager.update_task(project_id, task.id, state=enums.UploadStates.IN_PROGRESS.value)
5147
task.state = enums.UploadStates.DONE.value
5248
general.commit()
5349

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
import json
2+
import traceback
3+
import os
4+
5+
from controller.transfer.record_transfer_manager import download_file
6+
from submodules.model import UploadTask, enums
7+
from submodules.model.business_objects import project, record
8+
from controller.upload_task import manager as task_manager
9+
from typing import Set, Dict, Any, Optional
10+
11+
12+
def prepare_label_studio_import(project_id: str, task: UploadTask) -> None:
13+
# pre init to ensure we can always append an error
14+
file_additional_info = __get_blank_file_additional_info()
15+
project_item = project.get(project_id)
16+
if not project_item:
17+
file_additional_info["errors"].append("Can't find project".format(e))
18+
try:
19+
is_project_update = record.count(project_id) != 0
20+
file_path = download_file(project_id, task)
21+
_, extension = os.path.splitext(file_path)
22+
if extension == ".json":
23+
with open(file_path) as file:
24+
data = json.load(file)
25+
analyze_file(data, file_additional_info, is_project_update)
26+
else:
27+
file_additional_info["errors"].append(f"Unsupported file type {extension}")
28+
except Exception as e:
29+
file_additional_info["errors"].append(
30+
"Error while analyzing file: {}".format(e)
31+
)
32+
print(traceback.format_exc(), flush=True)
33+
dumped_info = json.dumps(file_additional_info)
34+
task_manager.update_task(
35+
project_id,
36+
task.id,
37+
state=enums.UploadStates.PREPARED.value,
38+
file_additional_info=dumped_info,
39+
)
40+
41+
42+
def analyze_file(
43+
data: Dict[str, Any], file_additional_info: Dict[str, Any], is_project_update: bool
44+
) -> None:
45+
user_id_counts = {}
46+
tasks = set()
47+
record_count = 0
48+
ex_no_kern_id = None
49+
ex_drafts = None
50+
ex_predictions = None
51+
ex_extraction = None
52+
ex_multiple_choices = None
53+
# multiple annotation for a user within the same record/task
54+
ex_multiple_annotations = None
55+
56+
for record in data:
57+
if type(record) is not dict:
58+
file_additional_info["errors"].append("Import format not recognized")
59+
break
60+
record_count += 1
61+
record_id = record["id"]
62+
if not ex_drafts and __check_record_has_values_for(record, "drafts"):
63+
ex_drafts = f"\n\tExample: record {record_id}"
64+
if not ex_predictions and __check_record_has_values_for(record, "predictions"):
65+
ex_predictions = f"\n\tExample: record {record_id}"
66+
if not ex_multiple_annotations and __check_record_has_multi_annotation(record):
67+
ex_multiple_annotations = f"\n\tExample: record {record_id}"
68+
if (
69+
is_project_update
70+
and not ex_no_kern_id
71+
and not __check_record_has_values_for(
72+
record, "data", "kern_refinery_record_id"
73+
)
74+
):
75+
ex_no_kern_id = f"\n\tExample: record {record_id}"
76+
for annotation in record["annotations"]:
77+
annotation_id = annotation["id"]
78+
if not ex_extraction and __check_annotation_has_extraction(annotation):
79+
ex_extraction = (
80+
f"\n\tExample: record {record_id} - annotation {annotation_id}"
81+
)
82+
if not ex_multiple_choices and __check_annotation_has_multiclass(
83+
annotation
84+
):
85+
ex_multiple_choices = (
86+
f"\n\tExample: record {record_id} - annotation {annotation_id}"
87+
)
88+
user_id = annotation["completed_by"]
89+
__add_annotation_target(annotation, tasks)
90+
91+
if user_id in user_id_counts:
92+
user_id_counts[user_id] += 1
93+
else:
94+
user_id_counts[user_id] = 1
95+
96+
if ex_drafts:
97+
file_additional_info["warnings"].append(
98+
"Label Studio drafts are not supported." + ex_drafts
99+
)
100+
101+
if ex_predictions:
102+
file_additional_info["warnings"].append(
103+
"Label Studio predictions are not supported." + ex_predictions
104+
)
105+
106+
if ex_extraction:
107+
file_additional_info["warnings"].append(
108+
"Named Entity Recognition / extraction labels are not supported.\nThese annotations will be ignored if you proceed."
109+
+ ex_extraction
110+
)
111+
if ex_multiple_choices:
112+
file_additional_info["warnings"].append(
113+
"Multiple choices for a result set are not supported.\nThese annotations will be ignored if you proceed."
114+
+ ex_multiple_choices
115+
)
116+
if ex_multiple_annotations:
117+
file_additional_info["errors"].append(
118+
"Multiple annotations for the same user within the same record\ntargeting the same task are not supported."
119+
+ ex_multiple_annotations
120+
)
121+
if ex_no_kern_id:
122+
file_additional_info["errors"].append(
123+
"Project update without kern record id. Can't update project (see restrictions)."
124+
+ ex_multiple_annotations
125+
)
126+
127+
file_additional_info["user_ids"] = list(user_id_counts.keys())
128+
file_additional_info["tasks"] = list(tasks)
129+
file_additional_info["file_info"]["records"] = record_count
130+
file_additional_info["file_info"]["annotations"] = user_id_counts
131+
132+
133+
def __add_annotation_target(annotation: Dict[str, Any], tasks: Set[str]) -> None:
134+
tasks |= __get_annotation_targets(annotation)
135+
136+
137+
def __get_annotation_targets(annotation: Dict[str, Any]) -> Set[str]:
138+
target = annotation.get("result")
139+
if target and len(target) > 0:
140+
return {t["from_name"] for t in target if "from_name" in t}
141+
return {}
142+
143+
144+
def __check_record_has_values_for(
145+
record: Dict[str, Any], key: str, sub_key: Optional[str] = None
146+
) -> bool:
147+
value = record.get(key)
148+
if value:
149+
if not sub_key:
150+
return True
151+
else:
152+
return __check_record_has_values_for(value, sub_key)
153+
return False
154+
155+
156+
def __check_record_has_multi_annotation(record: Dict[str, Any]) -> bool:
157+
annotations = record.get("annotations")
158+
if not annotations or len(annotations) < 2:
159+
return False
160+
lookup = {}
161+
for annotation in annotations:
162+
user_id = annotation.get("completed_by")
163+
if user_id not in lookup:
164+
lookup[user_id] = {}
165+
targets = __get_annotation_targets(annotation)
166+
for target in targets:
167+
if target not in lookup[user_id]:
168+
lookup[user_id][target] = 1
169+
else:
170+
return True
171+
return False
172+
173+
174+
def __check_annotation_has_extraction(annotation: Dict[str, Any]) -> bool:
175+
results = annotation.get("result")
176+
if not results:
177+
return False
178+
for result in results:
179+
if result.get("type") != "choices":
180+
return True
181+
return False
182+
183+
184+
def __check_annotation_has_multiclass(annotation: Dict[str, Any]) -> bool:
185+
results = annotation.get("result")
186+
if not results:
187+
return False
188+
for result in results:
189+
if result.get("type") == "choices" and len(result["value"]["choices"]) > 1:
190+
return True
191+
return False
192+
193+
194+
def __get_blank_file_additional_info() -> Dict[str, Any]:
195+
return {
196+
"user_ids": [],
197+
"tasks": [],
198+
"errors": [],
199+
"warnings": [],
200+
"info": [],
201+
"file_info": {"records": 0, "annotations": {}},
202+
}

0 commit comments

Comments
 (0)