Skip to content

Commit 3edf658

Browse files
committed
cleaned up
1 parent 6652c4b commit 3edf658

2 files changed

Lines changed: 265 additions & 293 deletions

File tree

src/valor_lite/object_detection/evaluator.py

Lines changed: 104 additions & 198 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
import heapq
21
import json
3-
import tempfile
42
from collections import defaultdict
53
from dataclasses import dataclass
64
from pathlib import Path
@@ -9,18 +7,16 @@
97
import pyarrow as pa
108
import pyarrow.compute as pc
119
import pyarrow.dataset as ds
12-
import pyarrow.parquet as pq
1310
from numpy.typing import NDArray
1411

15-
from valor_lite.cache import CacheReader, CacheWriter, DataType
12+
from valor_lite.cache import CacheReader, DataType
1613
from valor_lite.object_detection.computation import (
1714
compute_average_precision,
1815
compute_average_recall,
1916
compute_confusion_matrix,
2017
compute_counts,
2118
compute_pair_classifications,
2219
compute_precision_recall_f1,
23-
rank_table,
2420
)
2521
from valor_lite.object_detection.format import PathFormatter
2622
from valor_lite.object_detection.metric import Metric, MetricType
@@ -72,24 +68,30 @@ def __init__(
7268
number_of_groundtruths_per_label
7369
)
7470

71+
@property
72+
def path(self) -> Path:
73+
return self._path
74+
75+
@property
76+
def detailed(self) -> CacheReader:
77+
return self._detailed_cache
78+
79+
@property
80+
def ranked(self) -> CacheReader:
81+
return self._ranked_cache
82+
83+
@property
84+
def info(self) -> EvaluatorInfo:
85+
return self._info
86+
7587
@classmethod
76-
def create(
88+
def load(
7789
cls,
7890
path: str | Path,
79-
batch_size: int = 1_000,
8091
index_to_label_override: dict[int, str] | None = None,
8192
):
82-
"""
83-
Create a ranked pair cache.
84-
85-
Parameters
86-
----------
87-
path : str | Path
88-
Where to store the evaluator cache.
89-
batch_size : int, default=1_000
90-
Sets the batch size for reading. Defaults to 1_000.
91-
"""
9293
detailed_cache = CacheReader(cls._generate_detailed_cache_path(path))
94+
ranked_cache = CacheReader(cls._generate_ranked_cache_path(path))
9395

9496
# build evaluator meta
9597
(
@@ -108,119 +110,6 @@ def create(
108110
]
109111
info.prediction_metadata_types = types["prediction_metadata_types"]
110112

111-
# create ranked cache schema
112-
annotation_metadata_keys = {
113-
*(
114-
set(info.groundtruth_metadata_types.keys())
115-
if info.groundtruth_metadata_types
116-
else {}
117-
),
118-
*(
119-
set(info.prediction_metadata_types.keys())
120-
if info.prediction_metadata_types
121-
else {}
122-
),
123-
}
124-
pruned_schema = pa.schema(
125-
[
126-
field
127-
for field in detailed_cache.schema
128-
if field.name not in annotation_metadata_keys
129-
]
130-
)
131-
ranked_schema = pruned_schema.append(
132-
pa.field("iou_prev", pa.float64())
133-
)
134-
ranked_schema = ranked_schema.append(
135-
pa.field("high_score", pa.bool_())
136-
)
137-
138-
n_labels = len(index_to_label)
139-
140-
with CacheWriter.create(
141-
path=cls._generate_ranked_cache_path(path),
142-
schema=ranked_schema,
143-
batch_size=detailed_cache.batch_size,
144-
rows_per_file=detailed_cache.rows_per_file,
145-
compression=detailed_cache.compression,
146-
) as ranked_cache:
147-
if detailed_cache.num_dataset_files == 1:
148-
pf = pq.ParquetFile(detailed_cache.dataset_files[0])
149-
tbl = pf.read()
150-
ranked_tbl = rank_table(tbl, n_labels)
151-
ranked_cache.write_table(ranked_tbl)
152-
else:
153-
pruned_detailed_columns = [
154-
field.name for field in pruned_schema
155-
]
156-
with tempfile.TemporaryDirectory() as tmpdir:
157-
158-
# rank individual files
159-
tmpfiles = []
160-
for idx, fragment in enumerate(
161-
detailed_cache.dataset.get_fragments()
162-
):
163-
fragment_path = Path(tmpdir) / f"{idx:06d}.parquet"
164-
tbl = fragment.to_table(
165-
columns=pruned_detailed_columns
166-
)
167-
ranked_tbl = rank_table(tbl, n_labels)
168-
pq.write_table(ranked_tbl, fragment_path)
169-
tmpfiles.append(fragment_path)
170-
171-
def generate_heap_item(batches, batch_idx, row_idx):
172-
score = batches[batch_idx]["score"][row_idx].as_py()
173-
iou = batches[batch_idx]["iou"][row_idx].as_py()
174-
return (
175-
-score,
176-
-iou,
177-
batch_idx,
178-
row_idx,
179-
)
180-
181-
# merge sorted rows
182-
heap = []
183-
batch_iterators = []
184-
batches = []
185-
for batch_idx, batch_path in enumerate(tmpfiles):
186-
pf = pq.ParquetFile(batch_path)
187-
batch_iter = pf.iter_batches(batch_size=batch_size)
188-
batch_iterators.append(batch_iter)
189-
batches.append(next(batch_iterators[batch_idx], None))
190-
if (
191-
batches[batch_idx] is not None
192-
and len(batches[batch_idx]) > 0
193-
):
194-
heapq.heappush(
195-
heap, generate_heap_item(batches, batch_idx, 0)
196-
)
197-
198-
while heap:
199-
_, _, batch_idx, row_idx = heapq.heappop(heap)
200-
row_table = batches[batch_idx].slice(row_idx, 1)
201-
ranked_cache.write_batch(row_table)
202-
row_idx += 1
203-
if row_idx < len(batches[batch_idx]):
204-
heapq.heappush(
205-
heap,
206-
generate_heap_item(
207-
batches, batch_idx, row_idx
208-
),
209-
)
210-
else:
211-
batches[batch_idx] = next(
212-
batch_iterators[batch_idx], None
213-
)
214-
if (
215-
batches[batch_idx] is not None
216-
and len(batches[batch_idx]) > 0
217-
):
218-
heapq.heappush(
219-
heap,
220-
generate_heap_item(batches, batch_idx, 0),
221-
)
222-
223-
ranked_cache = CacheReader(cls._generate_ranked_cache_path(path))
224113
return cls(
225114
path=path,
226115
detailed_cache=detailed_cache,
@@ -230,56 +119,101 @@ def generate_heap_item(batches, batch_idx, row_idx):
230119
number_of_groundtruths_per_label=number_of_groundtruths_per_label,
231120
)
232121

233-
@classmethod
234-
def load(
235-
cls,
122+
def filter(
123+
self,
236124
path: str | Path,
237-
index_to_label_override: dict[int, str] | None = None,
238-
):
239-
detailed_cache = CacheReader(cls._generate_detailed_cache_path(path))
240-
ranked_cache = CacheReader(cls._generate_ranked_cache_path(path))
125+
filter_expr: Filter,
126+
batch_size: int = 1_000,
127+
) -> "Evaluator":
128+
"""
129+
Filter evaluator cache.
241130
242-
# build evaluator meta
243-
(
244-
index_to_label,
245-
number_of_groundtruths_per_label,
246-
info,
247-
) = cls.generate_meta(detailed_cache.dataset, index_to_label_override)
131+
Parameters
132+
----------
133+
path : str | Path
134+
Where to store the filtered cache.
135+
filter_expr : Filter
136+
An object containing filter expressions.
137+
batch_size : int
138+
The maximum number of rows read into memory per file.
248139
249-
# read config
250-
metadata_path = cls._generate_metadata_path(path)
251-
with open(metadata_path, "r") as f:
252-
types = json.load(f)
253-
info.datum_metadata_types = types["datum_metadata_types"]
254-
info.groundtruth_metadata_types = types[
255-
"groundtruth_metadata_types"
256-
]
257-
info.prediction_metadata_types = types["prediction_metadata_types"]
140+
Returns
141+
-------
142+
Evaluator
143+
A new evaluator object containing the filtered cache.
144+
"""
145+
from valor_lite.object_detection.loader import Loader
258146

259-
return cls(
147+
loader = Loader.create(
260148
path=path,
261-
detailed_cache=detailed_cache,
262-
ranked_cache=ranked_cache,
263-
info=info,
264-
index_to_label=index_to_label,
265-
number_of_groundtruths_per_label=number_of_groundtruths_per_label,
149+
batch_size=self.detailed.batch_size,
150+
rows_per_file=self.detailed.rows_per_file,
151+
compression=self.detailed.compression,
152+
datum_metadata_types=self.info.datum_metadata_types,
153+
groundtruth_metadata_types=self.info.groundtruth_metadata_types,
154+
prediction_metadata_types=self.info.prediction_metadata_types,
266155
)
156+
for fragment in self.detailed.dataset.get_fragments():
157+
tbl = fragment.to_table(filter=filter_expr.datums)
267158

268-
@property
269-
def path(self) -> Path:
270-
return self._path
159+
columns = (
160+
"datum_id",
161+
"gt_id",
162+
"pd_id",
163+
"gt_label_id",
164+
"pd_label_id",
165+
"iou",
166+
"score",
167+
)
168+
pairs = np.column_stack([tbl[col].to_numpy() for col in columns])
169+
170+
n_pairs = pairs.shape[0]
171+
gt_ids = pairs[:, (0, 1)].astype(np.int64)
172+
pd_ids = pairs[:, (0, 2)].astype(np.int64)
173+
174+
if filter_expr.groundtruths is not None:
175+
mask_valid_gt = np.zeros(n_pairs, dtype=np.bool_)
176+
gt_tbl = tbl.filter(filter_expr.groundtruths)
177+
gt_pairs = np.column_stack(
178+
[gt_tbl[col].to_numpy() for col in ("datum_id", "gt_id")]
179+
).astype(np.int64)
180+
for gt in np.unique(gt_pairs, axis=0):
181+
mask_valid_gt |= (gt_ids == gt).all(axis=1)
182+
else:
183+
mask_valid_gt = np.ones(n_pairs, dtype=np.bool_)
184+
185+
if filter_expr.predictions is not None:
186+
mask_valid_pd = np.zeros(n_pairs, dtype=np.bool_)
187+
pd_tbl = tbl.filter(filter_expr.predictions)
188+
pd_pairs = np.column_stack(
189+
[pd_tbl[col].to_numpy() for col in ("datum_id", "pd_id")]
190+
).astype(np.int64)
191+
for pd in np.unique(pd_pairs, axis=0):
192+
mask_valid_pd |= (pd_ids == pd).all(axis=1)
193+
else:
194+
mask_valid_pd = np.ones(n_pairs, dtype=np.bool_)
271195

272-
@property
273-
def detailed(self) -> CacheReader:
274-
return self._detailed_cache
196+
mask_valid = mask_valid_gt | mask_valid_pd
197+
mask_valid_gt &= mask_valid
198+
mask_valid_pd &= mask_valid
275199

276-
@property
277-
def ranked(self) -> CacheReader:
278-
return self._ranked_cache
200+
pairs[np.ix_(~mask_valid_gt, (1, 3))] = -1.0 # type: ignore - numpy ix_
201+
pairs[np.ix_(~mask_valid_pd, (2, 4, 6))] = -1.0 # type: ignore - numpy ix_
202+
pairs[~mask_valid_pd | ~mask_valid_gt, 5] = 0.0
279203

280-
@property
281-
def info(self) -> EvaluatorInfo:
282-
return self._info
204+
for idx, col in enumerate(columns):
205+
tbl = tbl.set_column(
206+
tbl.schema.names.index(col), col, pa.array(pairs[:, idx])
207+
)
208+
209+
mask_invalid = ~mask_valid | (pairs[:, (1, 2)] < 0).all(axis=1)
210+
filtered_tbl = tbl.filter(pa.array(~mask_invalid))
211+
loader._cache.write_table(filtered_tbl)
212+
213+
return loader.finalize(
214+
batch_size=batch_size,
215+
index_to_label_override=self._index_to_label,
216+
)
283217

284218
@staticmethod
285219
def generate_meta(
@@ -406,34 +340,6 @@ def iterate_pairs_with_table(
406340
[tbl[col].to_numpy() for col in columns]
407341
)
408342

409-
def filter(
410-
self,
411-
path: str | Path,
412-
filter_expr: Filter,
413-
) -> "Evaluator":
414-
"""
415-
Filter evaluator cache.
416-
417-
Parameters
418-
----------
419-
path : str | Path
420-
Where to store the filtered cache.
421-
filter_expr : Filter
422-
An object containing filter expressions.
423-
424-
Returns
425-
-------
426-
Evaluator
427-
A new evaluator object containing the filtered cache.
428-
"""
429-
from valor_lite.object_detection.loader import Loader
430-
431-
return Loader.filter(
432-
path=path,
433-
evaluator=self,
434-
filter_expr=filter_expr,
435-
)
436-
437343
def compute_precision_recall(
438344
self,
439345
iou_thresholds: list[float],

0 commit comments

Comments
 (0)