Skip to content

Commit 51e478b

Browse files
committed
ingestion fix
1 parent 4034429 commit 51e478b

2 files changed

Lines changed: 29 additions & 15 deletions

File tree

src/valor_lite/object_detection/evaluator.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -179,37 +179,39 @@ def generate_meta(
179179
info.number_of_datums += int(datum_ids.size)
180180

181181
# count unique groundtruths
182-
gt_ids = ids[:, (0, 1)]
183-
gt_ids = gt_ids[gt_ids[:, 1] >= 0]
184-
gt_ids = np.unique(gt_ids, axis=0)
182+
gt_ids = ids[:, 1]
183+
gt_ids = np.unique(gt_ids[gt_ids >= 0])
185184
info.number_of_groundtruth_annotations += int(gt_ids.shape[0])
186185

187186
# count unique predictions
188-
pd_ids = ids[:, (0, 2)]
189-
pd_ids = pd_ids[pd_ids[:, 1] >= 0]
190-
pd_ids = np.unique(pd_ids, axis=0)
187+
pd_ids = ids[:, 2]
188+
pd_ids = np.unique(pd_ids[pd_ids >= 0])
191189
info.number_of_prediction_annotations += int(pd_ids.shape[0])
192190

193191
# get gt labels
194-
gt_label_ids, gt_indices = np.unique(ids[:, 3], return_index=True)
192+
gt_label_ids = ids[:, 3]
193+
gt_label_ids, gt_indices = np.unique(
194+
gt_label_ids[gt_label_ids >= 0], return_index=True
195+
)
195196
gt_labels = tbl["gt_label"].take(gt_indices).to_pylist()
196197
gt_labels = dict(zip(gt_label_ids.astype(int).tolist(), gt_labels))
197198
labels.update(gt_labels)
198199

199200
# get pd labels
200-
pd_label_ids, pd_indices = np.unique(ids[:, 4], return_index=True)
201+
pd_label_ids = ids[:, 4]
202+
pd_label_ids, pd_indices = np.unique(
203+
pd_label_ids[pd_label_ids >= 0], return_index=True
204+
)
201205
pd_labels = tbl["pd_label"].take(pd_indices).to_pylist()
202206
pd_labels = dict(zip(pd_label_ids.astype(int).tolist(), pd_labels))
203207
labels.update(pd_labels)
204208

205209
# count gts per label
206-
gts = ids[:, (0, 1, 3)].astype(np.int64)
207-
unique_ann = np.unique(gts, axis=0)
210+
gts = ids[:, (1, 3)].astype(np.int64)
211+
unique_ann = np.unique(gts[gts[:, 0] >= 0], axis=0)
208212
unique_labels, label_counts = np.unique(
209-
unique_ann[:, 2], return_counts=True
213+
unique_ann[:, 1], return_counts=True
210214
)
211-
label_counts = label_counts[unique_labels >= 0]
212-
unique_labels = unique_labels[unique_labels >= 0]
213215
for label_id, count in zip(unique_labels, label_counts):
214216
gt_counts_per_lbl[int(label_id)] += int(count)
215217

src/valor_lite/object_detection/loader.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,13 @@ def filter(
440440
evaluator.rank(where=loader._ranked_path)
441441
return evaluator
442442

443-
def finalize(self):
443+
def finalize(
444+
self,
445+
rows_per_file: int | None = None,
446+
compression: str | None = None,
447+
write_batch_size: int | None = None,
448+
read_batch_size: int = 1000,
449+
):
444450
"""
445451
Performs data finalization and some preprocessing steps.
446452
@@ -457,5 +463,11 @@ def finalize(self):
457463
directory=self._directory,
458464
name=self._name,
459465
)
460-
evaluator.rank(where=self._ranked_path)
466+
evaluator.rank(
467+
where=self._ranked_path,
468+
rows_per_file=rows_per_file,
469+
compression=compression,
470+
write_batch_size=write_batch_size,
471+
read_batch_size=read_batch_size,
472+
)
461473
return evaluator

0 commit comments

Comments
 (0)