Skip to content

Commit 61ec871

Browse files
JWittmeyerJWittmeyer
and
JWittmeyer
authored
Different Bugs and small issues collector (#81)
* Adds check for crowd heuristic on run * Submodule change * Adds label rename logic * Removes required falg for project id on recommendations for zero shot and embedders * Submodule change * PR comments * Submodule change Co-authored-by: JWittmeyer <[email protected]>
1 parent 20c97ee commit 61ec871

File tree

9 files changed

+189
-9
lines changed

9 files changed

+189
-9
lines changed

controller/labeling_task_label/manager.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,18 @@
33
labeling_task_label,
44
labeling_task,
55
general,
6+
knowledge_base,
7+
information_source,
8+
record_label_association,
69
)
710
from controller.knowledge_base.util import create_knowledge_base_if_not_existing
811
from submodules.model.enums import LabelingTaskType
12+
from submodules.model import enums
13+
from util import notification
14+
15+
16+
from typing import List, Any, Dict
17+
import re
918

1019

1120
def get_label(project_id: str, label_id: str) -> LabelingTaskLabel:
@@ -24,6 +33,33 @@ def update_label_hotkey(project_id: str, label_id: str, hotkey: str) -> None:
2433
general.commit()
2534

2635

36+
def update_label_name(project_id: str, label_id: str, new_name: str) -> None:
37+
label = get_label(project_id, label_id)
38+
label.name = new_name
39+
general.commit()
40+
41+
42+
def handle_label_rename_warning(project_id: str, warning_data: Dict[str, str]) -> None:
43+
if warning_data["key"] == enums.CommentCategory.KNOWLEDGE_BASE.value:
44+
knowledge_base_item = knowledge_base.get_by_name(
45+
project_id, warning_data["old"]
46+
)
47+
knowledge_base_item.name = warning_data["new"]
48+
general.commit()
49+
50+
notification.send_organization_update(
51+
project_id, f"knowledge_base_updated:{str(knowledge_base_item.id)}"
52+
)
53+
elif warning_data["key"] == enums.CommentCategory.HEURISTIC.value:
54+
information_source_item = information_source.get(project_id, warning_data["id"])
55+
information_source_item.source_code = warning_data["new"]
56+
general.commit()
57+
58+
notification.send_organization_update(
59+
project_id, f"information_source_updated:{str(information_source_item.id)}"
60+
)
61+
62+
2763
def create_label(
2864
project_id: str, name: str, labeling_task_id: str, label_color: str
2965
) -> LabelingTaskLabel:
@@ -39,3 +75,81 @@ def create_label(
3975

4076
def delete_label(project_id: str, label_id: str) -> None:
4177
labeling_task_label.delete(project_id, label_id, with_commit=True)
78+
79+
80+
def check_rename_label(project_id: str, label_id: str, new_name: str) -> List[Any]:
81+
label_item = labeling_task_label.get(project_id, label_id)
82+
if not label_item or label_item.name == new_name:
83+
return []
84+
return __get_change_dict(project_id, label_item, new_name)
85+
86+
87+
def __get_change_dict(
88+
project_id: str, label: LabelingTaskLabel, new_name: str
89+
) -> Dict[str, str]:
90+
return_values = {
91+
"errors": __check_errors_label_rename(project_id, label, new_name),
92+
"warnings": __check_warnings_label_rename(project_id, label, new_name),
93+
"infos": [],
94+
}
95+
__check_label_rename_knowledge_base(project_id, label, new_name, return_values)
96+
if len(return_values["errors"]) == 0 and len(return_values["warnings"]) == 0:
97+
return_values["infos"].append(__get_msg_dict("No issues detected"))
98+
return return_values
99+
100+
101+
def __get_msg_dict(msg: str) -> Dict[str, str]:
102+
return {"msg": msg}
103+
104+
105+
def __check_errors_label_rename(
106+
project_id: str, label: LabelingTaskLabel, new_name: str
107+
) -> List[Dict[str, Any]]:
108+
append_me = []
109+
existing_with_name = labeling_task_label.get_by_name(
110+
project_id, str(label.labeling_task_id), new_name
111+
)
112+
if existing_with_name:
113+
append_me.append(__get_msg_dict("Label with name already exists"))
114+
return append_me
115+
116+
117+
def __check_warnings_label_rename(
118+
project_id: str, label: LabelingTaskLabel, new_name: str
119+
) -> List[Dict[str, Any]]:
120+
append_me = []
121+
122+
information_sources = information_source.get_all(project_id)
123+
for information_source_item in information_sources:
124+
current_code = information_source_item.source_code
125+
new_code = re.sub(r"\b%s\b" % label.name, new_name, current_code)
126+
if current_code != new_code:
127+
entry = __get_msg_dict(
128+
"Information source with matching word was detected."
129+
)
130+
entry["key"] = enums.CommentCategory.HEURISTIC.value
131+
entry["id"] = str(information_source_item.id)
132+
entry["old"] = current_code
133+
entry["new"] = new_code
134+
entry["old_name"] = label.name
135+
entry["new_name"] = new_name
136+
append_me.append(entry)
137+
138+
return append_me
139+
140+
141+
def __check_label_rename_knowledge_base(
142+
project_id: str, label: LabelingTaskLabel, new_name: str, append_to
143+
) -> None:
144+
knowledge_base_item = knowledge_base.get_by_name(project_id, label.name)
145+
if knowledge_base_item:
146+
entry = __get_msg_dict("Lookup list with same name as label exists.")
147+
entry["key"] = enums.CommentCategory.KNOWLEDGE_BASE.value
148+
entry["old"] = knowledge_base_item.name
149+
entry["new"] = new_name
150+
knowledge_base_item_new = knowledge_base.get_by_name(project_id, new_name)
151+
if knowledge_base_item_new:
152+
entry["msg"] += "\n\tNew label name however, already exists as lookup list."
153+
append_to["errors"].append(entry)
154+
else:
155+
append_to["warnings"].append(entry)

controller/payload/manager.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ def create_payload(
1919
user_id: str,
2020
asynchronous: Optional[bool] = True,
2121
) -> InformationSourcePayload:
22+
information_source_item = information_source.get(project_id, information_source_id)
23+
if information_source_item.type == enums.InformationSourceType.CROWD_LABELER:
24+
return None
2225
return payload_scheduler.create_payload(
2326
info, project_id, information_source_id, user_id, asynchronous
2427
)
@@ -75,7 +78,9 @@ def get_labeling_function_on_10_records(
7578
)
7679

7780

78-
def fill_missing_record_ids(sample_records: List[str], calculated_labels: Dict[str, List[Any]]) -> List[str]:
81+
def fill_missing_record_ids(
82+
sample_records: List[str], calculated_labels: Dict[str, List[Any]]
83+
) -> List[str]:
7984
for record_item in sample_records:
8085
record_id = record_item[0]
8186
if record_id not in calculated_labels:

controller/zero_shot/manager.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,10 @@ def get_zero_shot_recommendations(
4646
project_id: Optional[str] = None,
4747
) -> List[Dict[str, str]]:
4848
recommendations = zs_service.get_recommended_models()
49-
project_item = project.get(project_id)
49+
if not project_id:
50+
return recommendations
5051

52+
project_item = project.get(project_id)
5153
if project_item and project_item.tokenizer_blank:
5254
recommendations = [
5355
r for r in recommendations if r["language"] == project_item.tokenizer_blank

graphql_api/mutation/labeling_task_label.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from controller.auth import manager as auth
88
from controller.project import manager as project_manager
99
from util import doc_ock, notification
10+
from typing import Dict
1011

1112

1213
class CreateLabelingTaskLabel(graphene.Mutation):
@@ -83,6 +84,35 @@ def mutate(
8384
return UpdateLabelingTaskLabelColor(ok=True)
8485

8586

87+
class UpdateLabelingTaskLabelName(graphene.Mutation):
88+
class Arguments:
89+
project_id = graphene.ID(required=True)
90+
labeling_task_label_id = graphene.ID(required=True)
91+
new_name = graphene.String(required=True)
92+
93+
ok = graphene.Boolean()
94+
95+
def mutate(self, info, project_id: str, labeling_task_label_id: str, new_name: str):
96+
auth.check_demo_access(info)
97+
auth.check_project_access(info, project_id)
98+
manager.update_label_name(project_id, labeling_task_label_id, new_name)
99+
return UpdateLabelingTaskLabelColor(ok=True)
100+
101+
102+
class HandleLabelRenameWarnings(graphene.Mutation):
103+
class Arguments:
104+
project_id = graphene.ID(required=True)
105+
warning_data = graphene.JSONString(required=True)
106+
107+
ok = graphene.Boolean()
108+
109+
def mutate(self, info, project_id: str, warning_data: Dict[str, str]):
110+
auth.check_demo_access(info)
111+
auth.check_project_access(info, project_id)
112+
manager.handle_label_rename_warning(project_id, warning_data)
113+
return UpdateLabelingTaskLabelColor(ok=True)
114+
115+
86116
class DeleteLabelingTaskLabel(graphene.Mutation):
87117
class Arguments:
88118
project_id = graphene.ID(required=True)
@@ -107,3 +137,5 @@ class LabelingTaskLabelMutation(graphene.ObjectType):
107137
delete_label = DeleteLabelingTaskLabel.Field()
108138
update_label_color = UpdateLabelingTaskLabelColor.Field()
109139
update_label_hotkey = UpdateLabelingTaskLabelHotkey.Field()
140+
update_label_name = UpdateLabelingTaskLabelName.Field()
141+
handle_label_rename_warnings = HandleLabelRenameWarnings.Field()

graphql_api/query/embedding.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List
1+
from typing import List, Optional
22

33
import graphene
44

@@ -13,7 +13,7 @@
1313
class EmbeddingQuery(graphene.ObjectType):
1414
recommended_encoders = graphene.Field(
1515
graphene.List(Encoder),
16-
project_id=graphene.ID(required=True),
16+
project_id=graphene.ID(required=False),
1717
)
1818

1919
language_models = graphene.Field(graphene.List(LanguageModel))
@@ -22,9 +22,12 @@ class EmbeddingQuery(graphene.ObjectType):
2222
RecordTokenizationTask, project_id=graphene.ID(required=True)
2323
)
2424

25-
def resolve_recommended_encoders(self, info, project_id: str) -> List[Encoder]:
25+
def resolve_recommended_encoders(
26+
self, info, project_id: Optional[str] = None
27+
) -> List[Encoder]:
2628
auth.check_demo_access(info)
27-
auth.check_project_access(info, project_id)
29+
if project_id:
30+
auth.check_project_access(info, project_id)
2831
return manager.get_recommended_encoders()
2932

3033
def resolve_language_models(self, info) -> List[LanguageModel]:
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import graphene
2+
3+
from controller.auth import manager as auth
4+
from graphql_api.types import LabelingTask
5+
from controller.labeling_task_label import manager
6+
7+
8+
class LabelingTaskLabelQuery(graphene.ObjectType):
9+
check_rename_label = graphene.Field(
10+
graphene.JSONString,
11+
project_id=graphene.ID(required=True),
12+
label_id=graphene.ID(required=True),
13+
new_name=graphene.String(required=True),
14+
)
15+
16+
def resolve_check_rename_label(
17+
self, info, project_id: str, label_id: str, new_name: str
18+
) -> LabelingTask:
19+
auth.check_demo_access(info)
20+
auth.check_project_access(info, project_id)
21+
return manager.check_rename_label(project_id, label_id, new_name)

graphql_api/query/zero_shot.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class ZeroShotQuery(graphene.ObjectType):
2222

2323
zero_shot_recommendations = graphene.Field(
2424
graphene.JSONString,
25-
project_id=graphene.ID(),
25+
project_id=graphene.ID(required=False),
2626
)
2727

2828
zero_shot_10_records = graphene.Field(
@@ -58,7 +58,8 @@ def resolve_zero_shot_recommendations(
5858
self, info, project_id: Optional[str] = None
5959
) -> List[Dict[str, str]]:
6060
auth_manager.check_demo_access(info)
61-
auth_manager.check_project_access(info, project_id)
61+
if project_id:
62+
auth_manager.check_project_access(info, project_id)
6263
return manager.get_zero_shot_recommendations(project_id)
6364

6465
def resolve_zero_shot_10_records(

graphql_api/schema.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from graphql_api.query.knowledge_base import KnowledgeBaseQuery
1212
from graphql_api.query.knowledge_term import KnowledgeTermQuery
1313
from graphql_api.query.labeling_task import LabelingTaskQuery
14+
from graphql_api.query.labeling_task_label import LabelingTaskLabelQuery
1415
from graphql_api.query.misc import MiscQuery
1516
from graphql_api.query.notification import NotificationQuery
1617
from graphql_api.query.organization import OrganizationQuery
@@ -58,6 +59,7 @@ class Query(
5859
KnowledgeBaseQuery,
5960
KnowledgeTermQuery,
6061
LabelingTaskQuery,
62+
LabelingTaskLabelQuery,
6163
MiscQuery,
6264
ModelProviderQuery,
6365
OrganizationQuery,

0 commit comments

Comments
 (0)