Skip to content

Commit a4450a2

Browse files
committed
remove label metadata
1 parent a576630 commit a4450a2

6 files changed

Lines changed: 63 additions & 159 deletions

File tree

src/valor_lite/semantic_segmentation/computation.py

Lines changed: 6 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,39 +2,12 @@
22
from numpy.typing import NDArray
33

44

5-
def compute_label_metadata(
6-
confusion_matrices: NDArray[np.int64],
7-
n_labels: int,
8-
) -> NDArray[np.int64]:
9-
"""
10-
Computes label metadata returning a count of annotations per label.
11-
12-
Parameters
13-
----------
14-
confusion_matrices : NDArray[np.int64]
15-
Confusion matrices per datum with shape (n_datums, n_labels + 1, n_labels + 1).
16-
n_labels : int
17-
The total number of unique labels.
18-
19-
Returns
20-
-------
21-
NDArray[np.int64]
22-
The label metadata array with shape (n_labels, 2).
23-
Index 0 - Ground truth label count
24-
Index 1 - Prediction label count
25-
"""
26-
label_metadata = np.zeros((n_labels, 2), dtype=np.int64)
27-
label_metadata[:, 0] = confusion_matrices[:, 1:, :].sum(axis=(0, 2))
28-
label_metadata[:, 1] = confusion_matrices[:, :, 1:].sum(axis=(0, 1))
29-
return label_metadata
30-
31-
325
def filter_cache(
336
confusion_matrices: NDArray[np.int64],
347
datum_mask: NDArray[np.bool_],
358
label_mask: NDArray[np.bool_],
369
number_of_labels: int,
37-
) -> tuple[NDArray[np.int64], NDArray[np.int64]]:
10+
) -> tuple[NDArray[np.int64]]:
3811
"""
3912
Performs the filter operation over the internal cache.
4013
@@ -75,11 +48,7 @@ def filter_cache(
7548

7649
confusion_matrices = confusion_matrices[datum_mask]
7750

78-
label_metadata = compute_label_metadata(
79-
confusion_matrices=confusion_matrices,
80-
n_labels=number_of_labels,
81-
)
82-
return confusion_matrices, label_metadata
51+
return confusion_matrices
8352

8453

8554
def compute_intermediates(
@@ -141,7 +110,6 @@ def compute_intermediates(
141110

142111
def compute_metrics(
143112
confusion_matrices: NDArray[np.int64],
144-
label_metadata: NDArray[np.int64],
145113
n_pixels: int,
146114
) -> tuple[
147115
NDArray[np.float64],
@@ -183,7 +151,10 @@ def compute_metrics(
183151
NDArray[np.float64]
184152
Unmatched ground truth ratios.
185153
"""
186-
n_labels = label_metadata.shape[0]
154+
n_labels = confusion_matrices.shape[-1] - 1
155+
label_metadata = np.zeros((n_labels, 2), dtype=np.int64)
156+
label_metadata[:, 0] = confusion_matrices[:, 1:, :].sum(axis=(0, 2))
157+
label_metadata[:, 1] = confusion_matrices[:, :, 1:].sum(axis=(0, 1))
187158
gt_counts = label_metadata[:, 0]
188159
pd_counts = label_metadata[:, 1]
189160

src/valor_lite/semantic_segmentation/evaluator.py

Lines changed: 48 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
1+
import json
2+
from collections import defaultdict
13
from dataclasses import asdict, dataclass
4+
from pathlib import Path
25

3-
import json
46
import numpy as np
57
from numpy.typing import NDArray
6-
from tqdm import tqdm
7-
from pathlib import Path
8-
from collections import defaultdict
9-
108
from pyarrow import pa
119
from pyarrow.compute import pc
1210
from pyarrow.dataset import ds
11+
from tqdm import tqdm
1312

13+
from valor_lite.cache import (
14+
CacheReader,
15+
DataType,
16+
convert_type_mapping_to_schema,
17+
)
1418
from valor_lite.exceptions import EmptyCacheError, EmptyFilterError
1519
from valor_lite.semantic_segmentation.annotation import Segmentation
1620
from valor_lite.semantic_segmentation.computation import (
@@ -23,15 +27,15 @@
2327
from valor_lite.semantic_segmentation.utilities import (
2428
unpack_precision_recall_iou_into_metric_lists,
2529
)
26-
from valor_lite.cache import CacheReader, DataType, convert_type_mapping_to_schema
2730

2831

2932
@dataclass
3033
class EvaluatorInfo:
3134
number_of_datums: int = 0
32-
number_of_groundtruth_annotations: int = 0
33-
number_of_prediction_annotations: int = 0
3435
number_of_labels: int = 0
36+
number_of_pixels: int = 0
37+
number_of_groundtruth_pixels: int = 0
38+
number_of_prediction_pixels: int = 0
3539
number_of_rows: int = 0
3640
datum_metadata_types: dict[str, DataType] | None = None
3741
groundtruth_metadata_types: dict[str, DataType] | None = None
@@ -120,10 +124,9 @@ def generate_meta(
120124
tbl = fragment.to_table()
121125
columns = (
122126
"datum_id",
123-
"gt_id",
124-
"pd_id",
125127
"gt_label_id",
126128
"pd_label_id",
129+
"counts",
127130
)
128131
ids = np.column_stack(
129132
[tbl[col].to_numpy() for col in columns]
@@ -136,18 +139,8 @@ def generate_meta(
136139
datum_ids = np.unique(ids[:, 0])
137140
info.number_of_datums += int(datum_ids.size)
138141

139-
# count unique groundtruths
140-
gt_ids = ids[:, 1]
141-
gt_ids = np.unique(gt_ids[gt_ids >= 0])
142-
info.number_of_groundtruth_annotations += int(gt_ids.shape[0])
143-
144-
# count unique predictions
145-
pd_ids = ids[:, 2]
146-
pd_ids = np.unique(pd_ids[pd_ids >= 0])
147-
info.number_of_prediction_annotations += int(pd_ids.shape[0])
148-
149142
# get gt labels
150-
gt_label_ids = ids[:, 3]
143+
gt_label_ids = ids[:, 1]
151144
gt_label_ids, gt_indices = np.unique(
152145
gt_label_ids, return_index=True
153146
)
@@ -157,17 +150,17 @@ def generate_meta(
157150
labels.update(gt_labels)
158151

159152
# get pd labels
160-
pd_label_ids = ids[:, 4]
161-
pd_label_ids, pd_indices = np.unique(
162-
pd_label_ids, return_index=True
153+
pd_label_ids = ids[:, 2]
154+
pd_label_ids, pd_indices, pd_counts = np.unique(
155+
pd_label_ids, return_index=True, return_counts=True
163156
)
164157
pd_labels = tbl["pd_label"].take(pd_indices).to_pylist()
165158
pd_labels = dict(zip(pd_label_ids.astype(int).tolist(), pd_labels))
166159
pd_labels.pop(-1, None)
167160
labels.update(pd_labels)
168161

169162
# count gts per label
170-
gts = ids[:, (1, 3)].astype(np.int64)
163+
gts = ids[:, 1].astype(np.int64)
171164
unique_ann = np.unique(gts[gts[:, 0] >= 0], axis=0)
172165
unique_labels, label_counts = np.unique(
173166
unique_ann[:, 1], return_counts=True
@@ -181,14 +174,35 @@ def generate_meta(
181174
# complete info object
182175
info.number_of_labels = len(labels)
183176

184-
# convert gt counts to numpy
185-
number_of_groundtruths_per_label = np.zeros(
186-
len(labels), dtype=np.uint64
187-
)
188-
for k, v in gt_counts_per_lbl.items():
189-
number_of_groundtruths_per_label[int(k)] = v
177+
# create confusion matrix
178+
n_labels = len(labels)
179+
matrix = np.zeros((n_labels + 1, n_labels + 1), dtype=np.uint64)
180+
for fragment in dataset.get_fragments():
181+
tbl = fragment.to_table()
182+
columns = (
183+
"datum_id",
184+
"gt_label_id",
185+
"pd_label_id",
186+
)
187+
ids = np.column_stack(
188+
[tbl[col].to_numpy() for col in columns]
189+
).astype(np.int64)
190+
counts = tbl["counts"].to_numpy()
191+
192+
for idx in range(n_labels):
193+
mask_gts = ids[:, 1] == idx
194+
for pidx in range(n_labels):
195+
mask_pds = ids[:, 2] == pidx
196+
matrix[idx + 1, pidx + 1] = counts[
197+
mask_gts & mask_pds
198+
].sum()
199+
200+
mask_unmatched_gts = mask_gts & (ids[:, 2] == -1)
201+
matrix[idx + 1, 0] = counts[mask_unmatched_gts].sum()
202+
mask_unmatched_pds = (ids[:, 1] == -1) & (ids[:, 2] == idx)
203+
matrix[0, idx + 1] = counts[mask_unmatched_pds]
190204

191-
return labels, number_of_groundtruths_per_label, info
205+
return labels, matrix, info
192206

193207
@staticmethod
194208
def iterate_pairs(
@@ -241,8 +255,8 @@ def filter(
241255
from valor_lite.semantic_segmentation.loader import Loader
242256

243257
return Loader.filter(
244-
directory=directory,
245258
name=name,
259+
directory=directory,
246260
evaluator=self,
247261
filter_expr=filter_expr,
248-
)
262+
)

src/valor_lite/semantic_segmentation/manager.py

Lines changed: 3 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from valor_lite.semantic_segmentation.annotation import Segmentation
99
from valor_lite.semantic_segmentation.computation import (
1010
compute_intermediates,
11-
compute_label_metadata,
1211
compute_metrics,
1312
filter_cache,
1413
)
@@ -103,28 +102,6 @@ def __init__(self):
103102
def metadata(self) -> Metadata:
104103
return self._metadata
105104

106-
@property
107-
def ignored_prediction_labels(self) -> list[str]:
108-
"""
109-
Prediction labels that are not present in the ground truth set.
110-
"""
111-
glabels = set(np.where(self._label_metadata[:, 0] > 0)[0])
112-
plabels = set(np.where(self._label_metadata[:, 1] > 0)[0])
113-
return [
114-
self.index_to_label[label_id] for label_id in (plabels - glabels)
115-
]
116-
117-
@property
118-
def missing_prediction_labels(self) -> list[str]:
119-
"""
120-
Ground truth labels that are not present in the prediction set.
121-
"""
122-
glabels = set(np.where(self._label_metadata[:, 0] > 0)[0])
123-
plabels = set(np.where(self._label_metadata[:, 1] > 0)[0])
124-
return [
125-
self.index_to_label[label_id] for label_id in (glabels - plabels)
126-
]
127-
128105
def create_filter(
129106
self,
130107
datums: list[str] | NDArray[np.int64] | None = None,
@@ -208,7 +185,7 @@ def create_filter(
208185
).any(axis=1)
209186
label_mask[~mask_valid_labels] = True
210187

211-
filtered_confusion_matrices, _ = filter_cache(
188+
filtered_confusion_matrices = filter_cache(
212189
confusion_matrices=self._confusion_matrices.copy(),
213190
datum_mask=datum_mask,
214191
label_mask=label_mask,
@@ -223,9 +200,7 @@ def create_filter(
223200
),
224201
)
225202

226-
def filter(
227-
self, filter_: Filter
228-
) -> tuple[NDArray[np.int64], NDArray[np.int64]]:
203+
def filter(self, filter_: Filter) -> tuple[NDArray[np.int64]]:
229204
"""
230205
Performs the filter operation over the internal cache.
231206
@@ -260,7 +235,7 @@ def compute_precision_recall_iou(
260235
A dictionary mapping MetricType enumerations to lists of computed metrics.
261236
"""
262237
if filter_ is not None:
263-
confusion_matrices, label_metadata = self.filter(filter_)
238+
confusion_matrices = self.filter(filter_)
264239
n_pixels = filter_.metadata.number_of_pixels
265240
else:
266241
confusion_matrices = self._confusion_matrices
@@ -269,12 +244,10 @@ def compute_precision_recall_iou(
269244

270245
results = compute_metrics(
271246
confusion_matrices=confusion_matrices,
272-
label_metadata=label_metadata,
273247
n_pixels=n_pixels,
274248
)
275249
return unpack_precision_recall_iou_into_metric_lists(
276250
results=results,
277-
label_metadata=label_metadata,
278251
index_to_label=self.index_to_label,
279252
)
280253

@@ -436,10 +409,6 @@ def finalize(self) -> Evaluator:
436409
for idx, matrix in enumerate(self.matrices):
437410
h, w = matrix.shape
438411
self._evaluator._confusion_matrices[idx, :h, :w] = matrix
439-
self._evaluator._label_metadata = compute_label_metadata(
440-
confusion_matrices=self._evaluator._confusion_matrices,
441-
n_labels=n_labels,
442-
)
443412
self._evaluator._metadata = Metadata.create(
444413
confusion_matrices=self._evaluator._confusion_matrices,
445414
)

src/valor_lite/semantic_segmentation/utilities.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
def unpack_precision_recall_iou_into_metric_lists(
1010
results: tuple,
11-
label_metadata: NDArray[np.int64],
1211
index_to_label: list[str],
1312
) -> dict[MetricType, list[Metric]]:
1413

@@ -39,24 +38,20 @@ def unpack_precision_recall_iou_into_metric_lists(
3938
"iou": float(ious[gt_label_idx, pd_label_idx])
4039
}
4140
for pd_label_idx in range(n_labels)
42-
if label_metadata[pd_label_idx, 0] > 0
4341
}
4442
for gt_label_idx in range(n_labels)
45-
if label_metadata[gt_label_idx, 0] > 0
4643
},
4744
unmatched_predictions={
4845
index_to_label[pd_label_idx]: {
4946
"ratio": float(unmatched_prediction_ratios[pd_label_idx])
5047
}
5148
for pd_label_idx in range(n_labels)
52-
if label_metadata[pd_label_idx, 0] > 0
5349
},
5450
unmatched_ground_truths={
5551
index_to_label[gt_label_idx]: {
5652
"ratio": float(unmatched_ground_truth_ratios[gt_label_idx])
5753
}
5854
for gt_label_idx in range(n_labels)
59-
if label_metadata[gt_label_idx, 0] > 0
6055
},
6156
)
6257
]
@@ -73,10 +68,6 @@ def unpack_precision_recall_iou_into_metric_lists(
7368
"label": label,
7469
}
7570

76-
# if no groundtruths exists for a label, skip it.
77-
if label_metadata[label_idx, 0] == 0:
78-
continue
79-
8071
metrics[MetricType.Precision].append(
8172
Metric.precision(
8273
value=float(precision[label_idx]),

0 commit comments

Comments
 (0)