Skip to content

ITEP-33590 Dataset ie support for keypoint detection #169

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
May 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
a669514
Dataset ie support for keypoint detection
A-Artemis May 8, 2025
7bc7c11
Update uv.lock
A-Artemis May 8, 2025
b62efd9
fix type
A-Artemis May 8, 2025
4a41127
Merge branch 'main' into aurelien/dataset-ie-keypoint-detection
A-Artemis May 9, 2025
bbbe747
Merge branch 'main' into aurelien/dataset-ie-keypoint-detection
A-Artemis May 15, 2025
c1ef2f2
Merge branch 'main' into aurelien/dataset-ie-keypoint-detection
A-Artemis May 15, 2025
cb9f000
Added keypoint detection to supported import projects
A-Artemis May 15, 2025
844c296
Added test for export, and updated project rest validator to handle I…
A-Artemis May 19, 2025
2cd1562
Merge branch 'main' of https://github.com/open-edge-platform/geti int…
A-Artemis May 19, 2025
a638c92
Updated project validator
A-Artemis May 19, 2025
087d54e
corrected type
A-Artemis May 19, 2025
bafd0ec
test fixes
A-Artemis May 19, 2025
00632b6
support for keypoint IDs and strings in project edit
A-Artemis May 20, 2025
d9f9333
Updated sc_extractor.py to use PointCategories
A-Artemis May 20, 2025
49d62b1
Merge branch 'main' of https://github.com/open-edge-platform/geti int…
A-Artemis May 20, 2025
548e383
Fix typing and mypy
A-Artemis May 20, 2025
725e75a
Merge branch 'main' into aurelien/dataset-ie-keypoint-detection
A-Artemis May 21, 2025
2f3b874
added points categories to label mapper
A-Artemis May 21, 2025
23fa604
Merge remote-tracking branch 'origin/aurelien/dataset-ie-keypoint-det…
A-Artemis May 21, 2025
2da9f7b
Added keypoint type for dm datasets which do not include a bbox annot…
A-Artemis May 21, 2025
829d2fa
Merge branch 'main' into aurelien/dataset-ie-keypoint-detection
A-Artemis May 21, 2025
3984d63
Changed order of checking for annotation type.
A-Artemis May 21, 2025
fe5dfae
added logging
A-Artemis May 22, 2025
e19911a
Added more logging
A-Artemis May 22, 2025
240ac1f
Added even more logging
A-Artemis May 22, 2025
17638f4
Added even more logging
A-Artemis May 22, 2025
2fa33e4
Merge branch 'main' into aurelien/dataset-ie-keypoint-detection
A-Artemis May 22, 2025
8b76518
Added even more logging
A-Artemis May 22, 2025
986f200
Added back parameter
A-Artemis May 23, 2025
4c89574
reverted test change
A-Artemis May 23, 2025
6cead27
Fixed index out of range when getting keypoint label name
A-Artemis May 23, 2025
92b1763
Removed logging, and fixed indexing issues
A-Artemis May 23, 2025
211087d
Merge branch 'main' into aurelien/dataset-ie-keypoint-detection
A-Artemis May 26, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def _validate_empty_labels(cls, parser: ParserT) -> None:
raise ReservedLabelNameException(label_name=empty_label_name)

@classmethod
@abstractmethod
def _validate_keypoint_structure(cls, parser: ParserT) -> None:
"""
Validates that a user defined label graph edge has exactly 2 nodes, node names match with existing labels,
Expand All @@ -160,29 +161,6 @@ def _validate_keypoint_structure(cls, parser: ParserT) -> None:
:raises NodeNameNotInLabelsException: if a node name does not match any of the label names
:raises NodePositionIsOutOfBoundsException: if a node is out of bounds (not in the range [0.0, 1.0])
"""
duplicate_list = []
for task_name in parser.get_tasks_names():
keypoint_structure = parser.get_keypoint_structure_data(task_name=task_name)
if not keypoint_structure:
continue
label_names = parser.get_custom_labels_names_by_task(task_name=task_name)
edges = keypoint_structure["edges"]
for edge in edges:
nodes = edge["nodes"]
if len(nodes) != 2:
raise WrongNumberOfNodesException
if nodes[0] not in label_names or nodes[1] not in label_names:
raise IncorrectNodeNameInGraphException
if set(nodes) in duplicate_list:
raise DuplicateEdgeInGraphException
duplicate_list.append(set(nodes))

positions = keypoint_structure["positions"]
for position in positions:
if position["label"] not in label_names:
raise NodeNameNotInLabelsException
if not 0 <= position["x"] <= 1 or not 0 <= position["y"] <= 1:
raise NodePositionIsOutOfBoundsException

@classmethod
@abstractmethod
Expand Down Expand Up @@ -417,6 +395,45 @@ def __validate_parent_labels_in_parent_task(
),
)

@classmethod
def _validate_keypoint_structure(cls, parser: ProjectParser) -> None:
"""
Validates that a user defined label graph edge has exactly 2 nodes, node names match with existing labels,
and has no duplicate edges

This method must be run after labels validation since it assumes that its labels param is valid.

:param parser: A parser instance that can read and decode the information necessary to create a project
:raises WrongNumberOfNodesException: if an edge does not have 2 vertices
:raises IncorrectNodeNameInGraphException: if an edge has an incorrect name
:raises DuplicateEdgeInGraphException: if the graph contains a duplicate edge
:raises NodeNameNotInLabelsException: if a node name does not match any of the label names
:raises NodePositionIsOutOfBoundsException: if a node is out of bounds (not in the range [0.0, 1.0])
"""
duplicate_list = []
for task_name in parser.get_tasks_names():
keypoint_structure = parser.get_keypoint_structure_data(task_name=task_name)
if not keypoint_structure:
continue
label_names = parser.get_custom_labels_names_by_task(task_name=task_name)
edges = keypoint_structure["edges"]
for edge in edges:
nodes = edge["nodes"]
if len(nodes) != 2:
raise WrongNumberOfNodesException
if nodes[0] not in label_names or nodes[1] not in label_names:
raise IncorrectNodeNameInGraphException
if set(nodes) in duplicate_list:
raise DuplicateEdgeInGraphException
duplicate_list.append(set(nodes))

positions = keypoint_structure["positions"]
for position in positions:
if position["label"] not in label_names:
raise NodeNameNotInLabelsException
if not 0 <= position["x"] <= 1 or not 0 <= position["y"] <= 1:
raise NodePositionIsOutOfBoundsException


class ProjectUpdateValidator(ProjectValidator[ProjectUpdateParser]):
def validate(self, parser: ProjectUpdateParser) -> None:
Expand Down Expand Up @@ -726,3 +743,47 @@ def __validate_parent_labels_in_parent_task(
if not is_found
),
)

@classmethod
def _validate_keypoint_structure(cls, parser: ProjectUpdateParser) -> None:
"""
Validates that a user defined label graph edge has exactly 2 nodes, node names match with existing labels,
and has no duplicate edges

This method must be run after labels validation since it assumes that its labels param is valid.

:param parser: A parser instance that can read and decode the information necessary to create a project
:raises WrongNumberOfNodesException: if an edge does not have 2 vertices
:raises IncorrectNodeNameInGraphException: if an edge has an incorrect name
:raises DuplicateEdgeInGraphException: if the graph contains a duplicate edge
:raises NodeNameNotInLabelsException: if a node name does not match any of the label names
:raises NodePositionIsOutOfBoundsException: if a node is out of bounds (not in the range [0.0, 1.0])
"""
duplicate_list = []
for task_name in parser.get_tasks_names():
keypoint_structure = parser.get_keypoint_structure_data(task_name=task_name)
if not keypoint_structure:
continue
label_names = list(parser.get_custom_labels_names_by_task(task_name=task_name))
label_ids = [
str(parser.get_label_id_by_name(task_name=task_name, label_name=label_name))
for label_name in label_names
]
labels = label_names + label_ids
edges = keypoint_structure["edges"]
for edge in edges:
nodes = edge["nodes"]
if len(nodes) != 2:
raise WrongNumberOfNodesException
if nodes[0] not in labels or nodes[1] not in labels:
raise IncorrectNodeNameInGraphException
if set(nodes) in duplicate_list:
raise DuplicateEdgeInGraphException
duplicate_list.append(set(nodes))

positions = keypoint_structure["positions"]
for position in positions:
if position["label"] not in labels:
raise NodeNameNotInLabelsException
if not 0 <= position["x"] <= 1 or not 0 <= position["y"] <= 1:
raise NodePositionIsOutOfBoundsException
Original file line number Diff line number Diff line change
Expand Up @@ -827,15 +827,16 @@ def _build_keypoint_structure(
:return: the KeypointStructure
"""
label_name_to_id: dict[str, ID] = {label.name: label.id_ for label in labels}
label_ids = [label.id_ for label in labels]
edges = []
for edge in keypoint_structure_data["edges"]:
node_1 = label_name_to_id[edge["nodes"][0]]
node_2 = label_name_to_id[edge["nodes"][1]]
node_1 = ID(edge["nodes"][0]) if ID(edge["nodes"][0]) in label_ids else label_name_to_id[edge["nodes"][0]]
node_2 = ID(edge["nodes"][1]) if ID(edge["nodes"][1]) in label_ids else label_name_to_id[edge["nodes"][1]]
edges.append(KeypointEdge(node_1=node_1, node_2=node_2))

positions = []
for position in keypoint_structure_data["positions"]:
node = label_name_to_id[position["label"]]
node = ID(position["label"]) if ID(position["label"]) in label_ids else label_name_to_id[position["label"]]
x = position["x"]
y = position["y"]
positions.append(KeypointPosition(node=node, x=x, y=y))
Expand Down
29 changes: 18 additions & 11 deletions interactive_ai/libs/iai_core_py/iai_core/utils/project_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from iai_core.entities.model_template import ModelTemplate, NullModelTemplate
from iai_core.entities.project import Project
from iai_core.entities.task_graph import TaskEdge, TaskGraph
from iai_core.entities.task_node import TaskNode, TaskProperties
from iai_core.entities.task_node import TaskNode, TaskProperties, TaskType
from iai_core.repos import (
ActiveModelStateRepo,
ConfigurableParametersRepo,
Expand Down Expand Up @@ -109,6 +109,7 @@ def create_project_with_task_graph( # noqa: PLR0913
project_id: ID | None = None,
user_names: list[str] | None = None,
hidden: bool = False,
keypoint_structure: KeypointStructure | None = None,
) -> Project:
"""
Create a project given a task graph
Expand All @@ -122,6 +123,7 @@ def create_project_with_task_graph( # noqa: PLR0913
:param model_templates: List of model templates to create the model storages for each task
:param user_names: User names to assign to the project
:param hidden: Whether to keep the project as hidden after creation
:param keypoint_structure: Keypoint structure to assign to the project, only for Keypoint Detection projects
:return: created project
"""
if project_id is None:
Expand All @@ -146,15 +148,16 @@ def create_project_with_task_graph( # noqa: PLR0913
_id=DatasetStorageRepo.generate_id(),
)
DatasetStorageRepo(project_identifier).save(dataset_storage)
keypoint_structure = None
if FeatureFlagProvider.is_enabled(FEATURE_FLAG_KEYPOINT_DETECTION):

if FeatureFlagProvider.is_enabled(FEATURE_FLAG_KEYPOINT_DETECTION) and keypoint_structure is None:
keypoint_structure = KeypointStructure(
edges=[KeypointEdge(node_1=ID("node_1"), node_2=ID("node_2"))],
positions=[
KeypointPosition(node=ID("node_1"), x=0.123, y=0.123),
KeypointPosition(node=ID("node_2"), x=1, y=1),
],
)

# Create graph with one task
project = Project(
id=project_id,
Expand Down Expand Up @@ -252,6 +255,7 @@ def create_project_single_task( # noqa: PLR0915, PLR0913
empty_label_name: str | None = None,
is_multi_label_classification: bool | None = False,
hidden: bool = False,
keypoint_structure: KeypointStructure | None = None,
) -> Project:
"""
Create a project with one task in the pipeline.
Expand All @@ -267,19 +271,19 @@ def create_project_single_task( # noqa: PLR0915, PLR0913
This attribute is ignored when label_schema is provided.
:param model_template_id: Model template for the project
(either the model template ID or the model template itself)
:param user_names: User names to assign to the project
:param configurable_parameters: Optional, configurable parameters to assign
to the task node in the Project.
:param workspace: Optional, workspace
:param user_names: Usernames to assign to the project
:param label_schema: Optional, label schema relative to the project.
If provided, then label_names is ignored
If unspecified, the default workspace is used.
:param label_groups: Optional. label group metadata
:param labelname_to_parent: Optional. label tree structure
:param configurable_parameters: Optional, configurable parameters to assign
to the task node in the Project.
:param empty_label_name: Optional. If an empty label needs to be created,
this parameter is used to customize its name.
:param is_multi_label_classification: Optional. True if created project is multi-label classification
:param hidden: Whether to keep the project as hidden after creation.
:param label_groups: Optional. label group metadata
:param labelname_to_parent: Optional. label tree structure
:param keypoint_structure: Keypoint structure to assign to the project, only for Keypoint Detection projects
:return: Created project
"""
logger.warning("Method `create_project_single_task` is deprecated.")
Expand All @@ -293,7 +297,6 @@ def create_project_single_task( # noqa: PLR0915, PLR0913
if isinstance(model_template, NullModelTemplate):
raise ModelTemplateError("A NullModelTemplate was created.")

CTX_SESSION_VAR.get().workspace_id
project_id = ProjectRepo.generate_id()
dataset_template = ModelTemplateList().get_by_id("dataset")
task_node_id = TaskNodeRepo.generate_id()
Expand Down Expand Up @@ -323,6 +326,9 @@ def create_project_single_task( # noqa: PLR0915, PLR0913
task_edge = TaskEdge(from_task=image_dataset_task_node, to_task=task_node)
task_graph.add_task_edge(task_edge)

if task_node.task_properties.task_type == TaskType.KEYPOINT_DETECTION and not keypoint_structure:
raise ValueError("Please provide a keypoint structure for keypoint detection projects.")

project = ProjectFactory.create_project_with_task_graph(
project_id=project_id,
name=name,
Expand All @@ -332,6 +338,7 @@ def create_project_single_task( # noqa: PLR0915, PLR0913
task_graph=task_graph,
model_templates=model_templates,
hidden=hidden,
keypoint_structure=keypoint_structure,
)

project_labels: list[Label]
Expand Down Expand Up @@ -374,7 +381,7 @@ def create_project_single_task( # noqa: PLR0915, PLR0913
label_groups=label_groups, labelname_to_label=labelname_to_label
)

# labels not have an explicite grouping should be included to an exclusive_group
# labels not have an explicit grouping should be included to an exclusive_group
ungrouped_label_names = [label for label in project_labels if label.name not in grouped_label_names]
exclusive_group = LabelGroup(
name="labels",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@ properties:
nodes:
type: array
items:
$ref: '../../../mongo_id.yaml'
anyOf:
- type: string
- $ref: '../../../mongo_id.yaml'
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ required:
- y
properties:
label:
$ref: '../../../mongo_id.yaml'
anyOf:
- type: string
- $ref: '../../../mongo_id.yaml'
x:
type: number
format: float
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Domain.ANOMALY_DETECTION,
Domain.ANOMALY_SEGMENTATION,
Domain.ROTATED_DETECTION,
Domain.KEYPOINT_DETECTION,
]


Expand Down Expand Up @@ -142,6 +143,7 @@ def get_validated_task_type(cls, project: Project) -> TaskType:
TaskType.ANOMALY_DETECTION,
TaskType.ANOMALY_SEGMENTATION,
TaskType.ROTATED_DETECTION,
TaskType.KEYPOINT_DETECTION,
]

trainable_tasks = project.get_trainable_task_nodes()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,5 +127,5 @@ class KeypointDetectionCounterConfig(DatasetCounterConfig):
description="The minimum number of new annotations required "
"before auto-train is triggered. Auto-training will start every time "
"that this number of annotations is created.",
visible_in_ui=False,
visible_in_ui=True,
)
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,7 @@ def _validate_keypoint_structure(data: dict[str, Any], labels: list[LabelPropert
:raises NodeNameNotInLabelsException: if a node name does not match any of the label names
:raises NodePositionIsOutOfBoundsException: if a node is out of bounds (not in the range [0.0, 1.0])
"""
existing_labels = [label.name for label in labels] + [label.id for label in labels]
existing_labels = [label.name for label in labels] + [str(label.id) for label in labels]
pipeline_data = data[PIPELINE]
duplicate_list = []
is_anomaly_reduced = FeatureFlagProvider.is_enabled(FeatureFlag.FEATURE_FLAG_ANOMALY_REDUCTION)
Expand Down
Loading
Loading