Skip to content

Commit fda7a6c

Browse files
authored
ITEP-33590 Dataset ie support for keypoint detection (#169)
1 parent a5c2eb2 commit fda7a6c

File tree

29 files changed

+531
-234
lines changed

29 files changed

+531
-234
lines changed

interactive_ai/libs/iai_core_py/iai_core/factories/project_validator.py

Lines changed: 84 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def _validate_empty_labels(cls, parser: ParserT) -> None:
146146
raise ReservedLabelNameException(label_name=empty_label_name)
147147

148148
@classmethod
149+
@abstractmethod
149150
def _validate_keypoint_structure(cls, parser: ParserT) -> None:
150151
"""
151152
Validates that a user defined label graph edge has exactly 2 nodes, node names match with existing labels,
@@ -160,29 +161,6 @@ def _validate_keypoint_structure(cls, parser: ParserT) -> None:
160161
:raises NodeNameNotInLabelsException: if a node name does not match any of the label names
161162
:raises NodePositionIsOutOfBoundsException: if a node is out of bounds (not in the range [0.0, 1.0])
162163
"""
163-
duplicate_list = []
164-
for task_name in parser.get_tasks_names():
165-
keypoint_structure = parser.get_keypoint_structure_data(task_name=task_name)
166-
if not keypoint_structure:
167-
continue
168-
label_names = parser.get_custom_labels_names_by_task(task_name=task_name)
169-
edges = keypoint_structure["edges"]
170-
for edge in edges:
171-
nodes = edge["nodes"]
172-
if len(nodes) != 2:
173-
raise WrongNumberOfNodesException
174-
if nodes[0] not in label_names or nodes[1] not in label_names:
175-
raise IncorrectNodeNameInGraphException
176-
if set(nodes) in duplicate_list:
177-
raise DuplicateEdgeInGraphException
178-
duplicate_list.append(set(nodes))
179-
180-
positions = keypoint_structure["positions"]
181-
for position in positions:
182-
if position["label"] not in label_names:
183-
raise NodeNameNotInLabelsException
184-
if not 0 <= position["x"] <= 1 or not 0 <= position["y"] <= 1:
185-
raise NodePositionIsOutOfBoundsException
186164

187165
@classmethod
188166
@abstractmethod
@@ -417,6 +395,45 @@ def __validate_parent_labels_in_parent_task(
417395
),
418396
)
419397

398+
@classmethod
399+
def _validate_keypoint_structure(cls, parser: ProjectParser) -> None:
400+
"""
401+
Validates that a user defined label graph edge has exactly 2 nodes, node names match with existing labels,
402+
and has no duplicate edges
403+
404+
This method must be run after labels validation since it assumes that its labels param is valid.
405+
406+
:param parser: A parser instance that can read and decode the information necessary to create a project
407+
:raises WrongNumberOfNodesException: if an edge does not have 2 vertices
408+
:raises IncorrectNodeNameInGraphException: if an edge has an incorrect name
409+
:raises DuplicateEdgeInGraphException: if the graph contains a duplicate edge
410+
:raises NodeNameNotInLabelsException: if a node name does not match any of the label names
411+
:raises NodePositionIsOutOfBoundsException: if a node is out of bounds (not in the range [0.0, 1.0])
412+
"""
413+
duplicate_list = []
414+
for task_name in parser.get_tasks_names():
415+
keypoint_structure = parser.get_keypoint_structure_data(task_name=task_name)
416+
if not keypoint_structure:
417+
continue
418+
label_names = parser.get_custom_labels_names_by_task(task_name=task_name)
419+
edges = keypoint_structure["edges"]
420+
for edge in edges:
421+
nodes = edge["nodes"]
422+
if len(nodes) != 2:
423+
raise WrongNumberOfNodesException
424+
if nodes[0] not in label_names or nodes[1] not in label_names:
425+
raise IncorrectNodeNameInGraphException
426+
if set(nodes) in duplicate_list:
427+
raise DuplicateEdgeInGraphException
428+
duplicate_list.append(set(nodes))
429+
430+
positions = keypoint_structure["positions"]
431+
for position in positions:
432+
if position["label"] not in label_names:
433+
raise NodeNameNotInLabelsException
434+
if not 0 <= position["x"] <= 1 or not 0 <= position["y"] <= 1:
435+
raise NodePositionIsOutOfBoundsException
436+
420437

421438
class ProjectUpdateValidator(ProjectValidator[ProjectUpdateParser]):
422439
def validate(self, parser: ProjectUpdateParser) -> None:
@@ -726,3 +743,47 @@ def __validate_parent_labels_in_parent_task(
726743
if not is_found
727744
),
728745
)
746+
747+
@classmethod
748+
def _validate_keypoint_structure(cls, parser: ProjectUpdateParser) -> None:
749+
"""
750+
Validates that a user defined label graph edge has exactly 2 nodes, node names match with existing labels,
751+
and has no duplicate edges
752+
753+
This method must be run after labels validation since it assumes that its labels param is valid.
754+
755+
:param parser: A parser instance that can read and decode the information necessary to create a project
756+
:raises WrongNumberOfNodesException: if an edge does not have 2 vertices
757+
:raises IncorrectNodeNameInGraphException: if an edge has an incorrect name
758+
:raises DuplicateEdgeInGraphException: if the graph contains a duplicate edge
759+
:raises NodeNameNotInLabelsException: if a node name does not match any of the label names
760+
:raises NodePositionIsOutOfBoundsException: if a node is out of bounds (not in the range [0.0, 1.0])
761+
"""
762+
duplicate_list = []
763+
for task_name in parser.get_tasks_names():
764+
keypoint_structure = parser.get_keypoint_structure_data(task_name=task_name)
765+
if not keypoint_structure:
766+
continue
767+
label_names = list(parser.get_custom_labels_names_by_task(task_name=task_name))
768+
label_ids = [
769+
str(parser.get_label_id_by_name(task_name=task_name, label_name=label_name))
770+
for label_name in label_names
771+
]
772+
labels = label_names + label_ids
773+
edges = keypoint_structure["edges"]
774+
for edge in edges:
775+
nodes = edge["nodes"]
776+
if len(nodes) != 2:
777+
raise WrongNumberOfNodesException
778+
if nodes[0] not in labels or nodes[1] not in labels:
779+
raise IncorrectNodeNameInGraphException
780+
if set(nodes) in duplicate_list:
781+
raise DuplicateEdgeInGraphException
782+
duplicate_list.append(set(nodes))
783+
784+
positions = keypoint_structure["positions"]
785+
for position in positions:
786+
if position["label"] not in labels:
787+
raise NodeNameNotInLabelsException
788+
if not 0 <= position["x"] <= 1 or not 0 <= position["y"] <= 1:
789+
raise NodePositionIsOutOfBoundsException

interactive_ai/libs/iai_core_py/iai_core/utils/project_builder.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -827,15 +827,16 @@ def _build_keypoint_structure(
827827
:return: the KeypointStructure
828828
"""
829829
label_name_to_id: dict[str, ID] = {label.name: label.id_ for label in labels}
830+
label_ids = [label.id_ for label in labels]
830831
edges = []
831832
for edge in keypoint_structure_data["edges"]:
832-
node_1 = label_name_to_id[edge["nodes"][0]]
833-
node_2 = label_name_to_id[edge["nodes"][1]]
833+
node_1 = ID(edge["nodes"][0]) if ID(edge["nodes"][0]) in label_ids else label_name_to_id[edge["nodes"][0]]
834+
node_2 = ID(edge["nodes"][1]) if ID(edge["nodes"][1]) in label_ids else label_name_to_id[edge["nodes"][1]]
834835
edges.append(KeypointEdge(node_1=node_1, node_2=node_2))
835836

836837
positions = []
837838
for position in keypoint_structure_data["positions"]:
838-
node = label_name_to_id[position["label"]]
839+
node = ID(position["label"]) if ID(position["label"]) in label_ids else label_name_to_id[position["label"]]
839840
x = position["x"]
840841
y = position["y"]
841842
positions.append(KeypointPosition(node=node, x=x, y=y))

interactive_ai/libs/iai_core_py/iai_core/utils/project_factory.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from iai_core.entities.model_template import ModelTemplate, NullModelTemplate
2323
from iai_core.entities.project import Project
2424
from iai_core.entities.task_graph import TaskEdge, TaskGraph
25-
from iai_core.entities.task_node import TaskNode, TaskProperties
25+
from iai_core.entities.task_node import TaskNode, TaskProperties, TaskType
2626
from iai_core.repos import (
2727
ActiveModelStateRepo,
2828
ConfigurableParametersRepo,
@@ -109,6 +109,7 @@ def create_project_with_task_graph( # noqa: PLR0913
109109
project_id: ID | None = None,
110110
user_names: list[str] | None = None,
111111
hidden: bool = False,
112+
keypoint_structure: KeypointStructure | None = None,
112113
) -> Project:
113114
"""
114115
Create a project given a task graph
@@ -122,6 +123,7 @@ def create_project_with_task_graph( # noqa: PLR0913
122123
:param model_templates: List of model templates to create the model storages for each task
123124
:param user_names: User names to assign to the project
124125
:param hidden: Whether to keep the project as hidden after creation
126+
:param keypoint_structure: Keypoint structure to assign to the project, only for Keypoint Detection projects
125127
:return: created project
126128
"""
127129
if project_id is None:
@@ -146,15 +148,16 @@ def create_project_with_task_graph( # noqa: PLR0913
146148
_id=DatasetStorageRepo.generate_id(),
147149
)
148150
DatasetStorageRepo(project_identifier).save(dataset_storage)
149-
keypoint_structure = None
150-
if FeatureFlagProvider.is_enabled(FEATURE_FLAG_KEYPOINT_DETECTION):
151+
152+
if FeatureFlagProvider.is_enabled(FEATURE_FLAG_KEYPOINT_DETECTION) and keypoint_structure is None:
151153
keypoint_structure = KeypointStructure(
152154
edges=[KeypointEdge(node_1=ID("node_1"), node_2=ID("node_2"))],
153155
positions=[
154156
KeypointPosition(node=ID("node_1"), x=0.123, y=0.123),
155157
KeypointPosition(node=ID("node_2"), x=1, y=1),
156158
],
157159
)
160+
158161
# Create graph with one task
159162
project = Project(
160163
id=project_id,
@@ -252,6 +255,7 @@ def create_project_single_task( # noqa: PLR0915, PLR0913
252255
empty_label_name: str | None = None,
253256
is_multi_label_classification: bool | None = False,
254257
hidden: bool = False,
258+
keypoint_structure: KeypointStructure | None = None,
255259
) -> Project:
256260
"""
257261
Create a project with one task in the pipeline.
@@ -267,19 +271,19 @@ def create_project_single_task( # noqa: PLR0915, PLR0913
267271
This attribute is ignored when label_schema is provided.
268272
:param model_template_id: Model template for the project
269273
(either the model template ID or the model template itself)
270-
:param user_names: User names to assign to the project
271-
:param configurable_parameters: Optional, configurable parameters to assign
272-
to the task node in the Project.
273-
:param workspace: Optional, workspace
274+
:param user_names: Usernames to assign to the project
274275
:param label_schema: Optional, label schema relative to the project.
275276
If provided, then label_names is ignored
276277
If unspecified, the default workspace is used.
278+
:param label_groups: Optional. label group metadata
279+
:param labelname_to_parent: Optional. label tree structure
280+
:param configurable_parameters: Optional, configurable parameters to assign
281+
to the task node in the Project.
277282
:param empty_label_name: Optional. If an empty label needs to be created,
278283
this parameter is used to customize its name.
279284
:param is_multi_label_classification: Optional. True if created project is multi-label classification
280285
:param hidden: Whether to keep the project as hidden after creation.
281-
:param label_groups: Optional. label group metadata
282-
:param labelname_to_parent: Optional. label tree structure
286+
:param keypoint_structure: Keypoint structure to assign to the project, only for Keypoint Detection projects
283287
:return: Created project
284288
"""
285289
logger.warning("Method `create_project_single_task` is deprecated.")
@@ -293,7 +297,6 @@ def create_project_single_task( # noqa: PLR0915, PLR0913
293297
if isinstance(model_template, NullModelTemplate):
294298
raise ModelTemplateError("A NullModelTemplate was created.")
295299

296-
CTX_SESSION_VAR.get().workspace_id
297300
project_id = ProjectRepo.generate_id()
298301
dataset_template = ModelTemplateList().get_by_id("dataset")
299302
task_node_id = TaskNodeRepo.generate_id()
@@ -323,6 +326,9 @@ def create_project_single_task( # noqa: PLR0915, PLR0913
323326
task_edge = TaskEdge(from_task=image_dataset_task_node, to_task=task_node)
324327
task_graph.add_task_edge(task_edge)
325328

329+
if task_node.task_properties.task_type == TaskType.KEYPOINT_DETECTION and not keypoint_structure:
330+
raise ValueError("Please provide a keypoint structure for keypoint detection projects.")
331+
326332
project = ProjectFactory.create_project_with_task_graph(
327333
project_id=project_id,
328334
name=name,
@@ -332,6 +338,7 @@ def create_project_single_task( # noqa: PLR0915, PLR0913
332338
task_graph=task_graph,
333339
model_templates=model_templates,
334340
hidden=hidden,
341+
keypoint_structure=keypoint_structure,
335342
)
336343

337344
project_labels: list[Label]
@@ -374,7 +381,7 @@ def create_project_single_task( # noqa: PLR0915, PLR0913
374381
label_groups=label_groups, labelname_to_label=labelname_to_label
375382
)
376383

377-
# labels not have an explicite grouping should be included to an exclusive_group
384+
# labels not have an explicit grouping should be included to an exclusive_group
378385
ungrouped_label_names = [label for label in project_labels if label.name not in grouped_label_names]
379386
exclusive_group = LabelGroup(
380387
name="labels",

interactive_ai/services/api/schemas/projects/requests/put/keypoint_edge.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,6 @@ properties:
55
nodes:
66
type: array
77
items:
8-
$ref: '../../../mongo_id.yaml'
8+
anyOf:
9+
- type: string
10+
- $ref: '../../../mongo_id.yaml'

interactive_ai/services/api/schemas/projects/requests/put/keypoint_position.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ required:
55
- y
66
properties:
77
label:
8-
$ref: '../../../mongo_id.yaml'
8+
anyOf:
9+
- type: string
10+
- $ref: '../../../mongo_id.yaml'
911
x:
1012
type: number
1113
format: float

interactive_ai/services/dataset_ie/app/communication/helpers/import_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
Domain.ANOMALY_DETECTION,
2929
Domain.ANOMALY_SEGMENTATION,
3030
Domain.ROTATED_DETECTION,
31+
Domain.KEYPOINT_DETECTION,
3132
]
3233

3334

@@ -142,6 +143,7 @@ def get_validated_task_type(cls, project: Project) -> TaskType:
142143
TaskType.ANOMALY_DETECTION,
143144
TaskType.ANOMALY_SEGMENTATION,
144145
TaskType.ROTATED_DETECTION,
146+
TaskType.KEYPOINT_DETECTION,
145147
]
146148

147149
trainable_tasks = project.get_trainable_task_nodes()

interactive_ai/services/director/app/coordination/dataset_manager/dataset_counter_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,5 +127,5 @@ class KeypointDetectionCounterConfig(DatasetCounterConfig):
127127
description="The minimum number of new annotations required "
128128
"before auto-train is triggered. Auto-training will start every time "
129129
"that this number of annotations is created.",
130-
visible_in_ui=False,
130+
visible_in_ui=True,
131131
)

interactive_ai/services/resource/app/communication/rest_data_validator/project_rest_validator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,7 @@ def _validate_keypoint_structure(data: dict[str, Any], labels: list[LabelPropert
670670
:raises NodeNameNotInLabelsException: if a node name does not match any of the label names
671671
:raises NodePositionIsOutOfBoundsException: if a node is out of bounds (not in the range [0.0, 1.0])
672672
"""
673-
existing_labels = [label.name for label in labels] + [label.id for label in labels]
673+
existing_labels = [label.name for label in labels] + [str(label.id) for label in labels]
674674
pipeline_data = data[PIPELINE]
675675
duplicate_list = []
676676
is_anomaly_reduced = FeatureFlagProvider.is_enabled(FeatureFlag.FEATURE_FLAG_ANOMALY_REDUCTION)

0 commit comments

Comments
 (0)