Skip to content

Commit 3d49193

Browse files
authored
Merge pull request #34 from dbinfrago/feat/annotation-type-check
feat: add annotation type-per-sensor check
2 parents ef38258 + 64d2806 commit 3d49193

5 files changed

Lines changed: 262 additions & 0 deletions

File tree

raillabel_providerkit/validation/issue.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class IssueType(Enum):
2828
SENSOR_TYPE_WRONG = "SensorTypeWrong"
2929
UNEXPECTED_CLASS = "UnexpectedClassIssue"
3030
URI_FORMAT = "UriFormatIssue"
31+
ANNOTATION_SENSOR_MISMATCH = "AnnotationSensorMismatch"
3132

3233
@classmethod
3334
def names(cls) -> list[str]:

raillabel_providerkit/validation/validate.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
validate_sensors,
2323
validate_uris,
2424
)
25+
from .validate_annotation_type_per_sensor.validate_annotation_type_per_sensor import (
26+
validate_annotation_type_per_sensor,
27+
)
2528

2629

2730
def validate( # noqa: C901, PLR0913
@@ -34,6 +37,7 @@ def validate( # noqa: C901, PLR0913
3437
validate_for_uris: bool = True,
3538
validate_for_dimensions: bool = True,
3639
validate_for_horizon: bool = True,
40+
validate_for_annotation_type_per_sensor: bool = True,
3741
) -> list[Issue]:
3842
"""Validate a scene based on the Deutsche Bahn Requirements.
3943
@@ -56,6 +60,8 @@ def validate( # noqa: C901, PLR0913
5660
validate_for_dimensions: If True, issues are returned if the dimensions of cuboids are
5761
outside the expected values range.
5862
validate_for_horizon: If True, issues are returned if annotations cross the horizon.
63+
validate_for_annotation_type_per_sensor: Validate that annotation types match sensor types.
64+
5965
6066
Returns:
6167
List of all requirement errors in the scene. If an empty list is returned, then there are no
@@ -96,4 +102,7 @@ def validate( # noqa: C901, PLR0913
96102
if validate_for_horizon:
97103
errors.extend(validate_horizon(scene))
98104

105+
if validate_for_annotation_type_per_sensor:
106+
errors.extend(validate_annotation_type_per_sensor(scene))
107+
99108
return errors
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Copyright DB InfraGO AG and contributors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Validation for annotation type per sensor."""
5+
6+
from .validate_annotation_type_per_sensor import validate_annotation_type_per_sensor
7+
8+
__all__ = ["validate_annotation_type_per_sensor"]
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# Copyright DB InfraGO AG and contributors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from __future__ import annotations
5+
6+
from collections.abc import Iterable
7+
from typing import Any
8+
from uuid import UUID
9+
10+
import raillabel
11+
12+
from raillabel_providerkit.validation.issue import Issue, IssueIdentifiers, IssueType
13+
14+
# Mapping:
15+
# - camera: Bbox
16+
# - lidar: Cuboid
17+
# - radar: und Cuboid
18+
_ALLOWED_BY_SENSOR_TYPE: dict[str, tuple[str, ...]] = {
19+
"camera": ("Bbox",),
20+
"lidar": ("Cuboid",),
21+
"radar": ("Bbox", "Cuboid"),
22+
}
23+
24+
25+
def _normalize_anno_type(t: str) -> str:
26+
t_low = (t or "").strip().lower()
27+
mapping = {
28+
"bbox": "Bbox",
29+
"cuboid": "Cuboid",
30+
"poly2d": "Poly2d",
31+
"poly3d": "Poly3d",
32+
"num": "Num",
33+
"seg3d": "Seg3d",
34+
}
35+
return mapping.get(t_low, t if t else "")
36+
37+
38+
def _iter_annotations(scene: raillabel.Scene) -> Iterable[tuple[int, Any]]:
39+
"""Iterate over all annotations in the scene."""
40+
frames = getattr(scene, "frames", {}) or {}
41+
for frame_id, frame in getattr(frames, "items", lambda: frames.items())():
42+
annotations = getattr(frame, "annotations", {}) or {}
43+
for anno in annotations.values():
44+
yield frame_id, anno
45+
46+
47+
def _get_annotation_id(anno: object) -> UUID | None:
48+
return getattr(anno, "id", None)
49+
50+
51+
def _get_annotation_type(anno: object) -> str | None:
52+
t = getattr(anno, "type", None)
53+
if not t:
54+
return None
55+
return _normalize_anno_type(str(t))
56+
57+
58+
def _get_sensor_id_from_annotation(anno: object) -> str | None:
59+
# 1) coordinate.sensor / coordinate.sensor_id
60+
coord = getattr(anno, "coordinate", None)
61+
if coord is not None:
62+
sid = getattr(coord, "sensor", None) or getattr(coord, "sensor_id", None)
63+
if sid:
64+
return str(sid)
65+
66+
# 2) anno.sensor / anno.sensor_id
67+
for attr in ("sensor", "sensor_id"):
68+
sid = getattr(anno, attr, None)
69+
if sid:
70+
return str(sid)
71+
72+
# 3) coordinates-Collection
73+
coords = getattr(anno, "coordinates", None)
74+
if isinstance(coords, list | tuple):
75+
for c in coords:
76+
sid = getattr(c, "sensor", None) or getattr(c, "sensor_id", None)
77+
if sid:
78+
return str(sid)
79+
80+
return None
81+
82+
83+
def _sensor_type(scene: raillabel.Scene, sensor_id: str) -> str | None:
84+
"""Return the sensor type ('camera', 'lidar', 'radar') or None if unknown."""
85+
sensors = getattr(scene, "sensors", {}) or {}
86+
sensor = sensors.get(sensor_id)
87+
if sensor is None and hasattr(sensors, "get"):
88+
sensor = sensors.get(sensor_id)
89+
if sensor is None:
90+
return None
91+
s_type = getattr(sensor, "type", None)
92+
return str(s_type).lower() if s_type else None
93+
94+
95+
def validate_annotation_type_per_sensor(scene: raillabel.Scene) -> list[Issue]:
96+
"""Validate that each annotation type is compatible with its sensor type.
97+
98+
Returns:
99+
List of Issues if mismatches are found.
100+
"""
101+
issues: list[Issue] = []
102+
103+
for frame_id, anno in _iter_annotations(scene):
104+
a_type = _get_annotation_type(anno)
105+
if not a_type:
106+
continue
107+
108+
sensor_id = _get_sensor_id_from_annotation(anno)
109+
if not sensor_id:
110+
continue
111+
112+
s_type = _sensor_type(scene, sensor_id)
113+
if not s_type:
114+
continue
115+
116+
allowed = _ALLOWED_BY_SENSOR_TYPE.get(s_type, ())
117+
if allowed and a_type not in allowed:
118+
issue = Issue(
119+
type=IssueType.ANNOTATION_SENSOR_MISMATCH,
120+
identifiers=IssueIdentifiers(
121+
frame=frame_id,
122+
sensor=sensor_id,
123+
annotation=_get_annotation_id(anno),
124+
annotation_type=a_type,
125+
),
126+
reason=(
127+
f"Annotation type '{a_type}' not allowed for sensor type '{s_type}'. "
128+
f"Allowed types: {', '.join(allowed) if allowed else 'n/a'}."
129+
),
130+
)
131+
issues.append(issue)
132+
133+
return issues
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright DB InfraGO AG and contributors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from __future__ import annotations
5+
from dataclasses import dataclass
6+
from typing import Dict, Optional, List
7+
from uuid import uuid4
8+
9+
import pytest
10+
11+
from raillabel_providerkit.validation.issue import IssueType
12+
from raillabel_providerkit.validation.validate_annotation_type_per_sensor import (
13+
validate_annotation_type_per_sensor,
14+
)
15+
16+
17+
@dataclass
18+
class _Coord:
19+
sensor: Optional[str] = None
20+
sensor_id: Optional[str] = None
21+
22+
23+
@dataclass
24+
class _Annotation:
25+
id: str
26+
type: str
27+
coordinate: Optional[_Coord] = None
28+
29+
30+
@dataclass
31+
class _Frame:
32+
annotations: Dict[str, _Annotation]
33+
34+
35+
@dataclass
36+
class _Sensor:
37+
id: str
38+
type: str
39+
40+
41+
@dataclass
42+
class _Scene:
43+
sensors: Dict[str, _Sensor]
44+
frames: Dict[int, _Frame]
45+
46+
47+
# ----------------------------------------------------------------
48+
49+
50+
def _mk_bbox(sensor_id: str) -> _Annotation:
51+
return _Annotation(id=str(uuid4()), type="Bbox", coordinate=_Coord(sensor=sensor_id))
52+
53+
54+
def _mk_cuboid(sensor_id: str) -> _Annotation:
55+
return _Annotation(id=str(uuid4()), type="Cuboid", coordinate=_Coord(sensor=sensor_id))
56+
57+
58+
def test_camera_allows_only_bbox():
59+
scene = _Scene(
60+
sensors={"cam0": _Sensor(id="cam0", type="camera")},
61+
frames={0: _Frame(annotations={"a": _mk_bbox("cam0")})},
62+
)
63+
issues = validate_annotation_type_per_sensor(scene)
64+
assert issues == []
65+
66+
scene_bad = _Scene(
67+
sensors={"cam0": _Sensor(id="cam0", type="camera")},
68+
frames={0: _Frame(annotations={"a": _mk_cuboid("cam0")})},
69+
)
70+
issues = validate_annotation_type_per_sensor(scene_bad)
71+
assert len(issues) == 1
72+
assert issues[0].type == IssueType.ANNOTATION_SENSOR_MISMATCH
73+
assert issues[0].identifiers.annotation_type == "Cuboid"
74+
assert issues[0].identifiers.sensor == "cam0"
75+
assert issues[0].identifiers.frame == 0
76+
77+
78+
def test_lidar_allows_only_cuboid():
79+
scene = _Scene(
80+
sensors={"lid0": _Sensor(id="lid0", type="lidar")},
81+
frames={1: _Frame(annotations={"a": _mk_cuboid("lid0")})},
82+
)
83+
assert validate_annotation_type_per_sensor(scene) == []
84+
85+
scene_bad = _Scene(
86+
sensors={"lid0": _Sensor(id="lid0", type="lidar")},
87+
frames={1: _Frame(annotations={"a": _mk_bbox("lid0")})},
88+
)
89+
issues = validate_annotation_type_per_sensor(scene_bad)
90+
assert len(issues) == 1
91+
assert issues[0].type == IssueType.ANNOTATION_SENSOR_MISMATCH
92+
assert issues[0].identifiers.annotation_type == "Bbox"
93+
assert issues[0].identifiers.sensor == "lid0"
94+
assert issues[0].identifiers.frame == 1
95+
96+
97+
def test_radar_allows_bbox_and_cuboid():
98+
scene = _Scene(
99+
sensors={"rad0": _Sensor(id="rad0", type="radar")},
100+
frames={2: _Frame(annotations={"a": _mk_bbox("rad0"), "b": _mk_cuboid("rad0")})},
101+
)
102+
assert validate_annotation_type_per_sensor(scene) == []
103+
104+
105+
def test_missing_sensor_binding_is_ignored():
106+
ann = _Annotation(id=str(uuid4()), type="Bbox", coordinate=None) # kein Sensor
107+
scene = _Scene(
108+
sensors={"cam0": _Sensor(id="cam0", type="camera")},
109+
frames={3: _Frame(annotations={"a": ann})},
110+
)
111+
assert validate_annotation_type_per_sensor(scene) == []

0 commit comments

Comments
 (0)