-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbase_loader.py
More file actions
114 lines (93 loc) · 4.08 KB
/
base_loader.py
File metadata and controls
114 lines (93 loc) · 4.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from abc import abstractmethod
from collections.abc import Callable
from functools import wraps
from luxonis_ml.data.loaders import BaseLoader
from luxonis_ml.typing import LoaderOutput
from luxonis_eval.registry import DATALOADERS_REGISTRY
from luxonis_eval.utils.utils import check_loader_classes, check_loader_output
def validate_loader_output(func: Callable) -> Callable:
"""
Decorator to validate the output of a loader's __getitem__ method.
Parameters
----------
func : Callable
The function to be decorated.
Returns
-------
Callable
The wrapped function with validation.
"""
@wraps(func)
def wrapper(self: BaseEvalLoader, idx: int) -> LoaderOutput:
result = func(self, idx)
try:
check_loader_output(result)
except TypeError as e:
raise TypeError(
f"Invalid loader output for {self.__class__.__name__} at index {idx}: {e}"
) from e
return result
return wrapper
class BaseEvalLoader(BaseLoader, register=False):
REGISTRY = DATALOADERS_REGISTRY
def __init__(self, **kwargs):
self.classes = self.load_classes()
try:
check_loader_classes(self.classes)
except TypeError as e:
raise TypeError(
f"Invalid loader classes for {self.__class__.__name__}: {e}"
) from e
super().__init__(**kwargs)
def __init_subclass__(cls, **kwargs):
"""
Initialize subclass with validation for __getitem__ method.
Parameters
----------
**kwargs
Keyword arguments passed to the parent class.
"""
super().__init_subclass__(**kwargs)
cls.__getitem__ = validate_loader_output(cls.__dict__["__getitem__"])
@abstractmethod
def load_classes(self) -> dict[str, int]:
"""Loads and returns the class mapping for the dataset. This method is called once during __init__ and its return value is assigned to self.classes. Subclasses must implement this method to provide a mapping of class names to their integer indices.
Returns
-------
dict[str, int]
A mapping of class name to class index, e.g. {"cat": 0, "dog": 1}.
"""
@abstractmethod
def get_class_mapping(
self, **kwargs
) -> tuple[dict[int, str], dict[int, str], dict[int, int]]:
"""Returns the LDF class map, native class map, and class index map.
The LDF class map reflects how classes are indexed within LuxonisML's
data format (LDF), where classes are sorted alphabetically and indices
may therefore differ from those used during model training. The native
class map reflects the original class-to-index mapping the model was
trained on (e.g. COCO ordering). The class index map bridges the two
by mapping each LDF index to its corresponding native index, allowing
correct alignment of predictions against ground-truth annotations.
When implementing this method for a LuxonisLoader-backed dataset,
the LDF and native class maps will generally differ and the class index
map must explicitly encode the remapping (e.g. {0: 3, 1: 0, ...}).
When implementing this method for a custom dataset that inherits
directly from the BaseEvalLoader class, the LDF and native class maps
should be identical — both derived from self.classes — and the
class index map should be an identity mapping ({0: 0, 1: 1, ...}).
Parameters
----------
**kwargs
Additional keyword arguments that may be used to customize the
class mapping.
Returns
-------
tuple[dict[int, str], dict[int, str], dict[int, int]]
A 3-tuple of:
- LDF class map (dict[int, str]): LDF index to class name.
- Native class map (dict[int, str]): original index the
model was trained on to class name.
- Class index map (dict[int, int]): mapping from each LDF
index to its corresponding native index.
"""