Skip to content

Commit 108719c

Browse files
committed
wip cache upgrade
1 parent bd84d8d commit 108719c

1 file changed

Lines changed: 157 additions & 126 deletions

File tree

src/valor_lite/semantic_segmentation/evaluator.py

Lines changed: 157 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33
from pathlib import Path
44

55
import numpy as np
6+
import pyarrow as pa
67
import pyarrow.compute as pc
78
import pyarrow.dataset as ds
89
from numpy.typing import NDArray
910

10-
from valor_lite.cache import DataType
11+
from valor_lite.cache import CacheReader, DataType
12+
from valor_lite.exceptions import EmptyCacheError
1113
from valor_lite.semantic_segmentation.computation import compute_metrics
14+
from valor_lite.semantic_segmentation.format import PathFormatter
1215
from valor_lite.semantic_segmentation.metric import Metric, MetricType
1316
from valor_lite.semantic_segmentation.utilities import (
1417
unpack_precision_recall_iou_into_metric_lists,
@@ -35,44 +38,171 @@ class Filter:
3538
predictions: pc.Expression | None = None
3639

3740

38-
class Evaluator:
41+
class Evaluator(PathFormatter):
3942
def __init__(
4043
self,
41-
name: str = "default",
42-
directory: str | Path = ".valor",
43-
labels_override: dict[int, str] | None = None,
44+
path: str | Path,
45+
cache: CacheReader,
46+
info: EvaluatorInfo,
47+
index_to_label: dict[int, str],
48+
confusion_matrix: NDArray[np.uint64],
4449
):
45-
self._directory = Path(directory)
46-
self._name = name
47-
self._path = self._directory / name
48-
self._cache_path = self._path / "counts"
49-
self._metadata_path = self._path / "metadata.json"
50+
self._path = Path(path)
51+
self._cache = cache
52+
self._info = info
53+
self._index_to_label = index_to_label
54+
self._confusion_matrix = confusion_matrix
55+
56+
@classmethod
57+
def load(
58+
cls,
59+
path: str | Path,
60+
index_to_label_override: dict[int, str] | None = None,
61+
):
62+
# validate path
63+
path = Path(path)
64+
if not path.exists():
65+
raise FileNotFoundError(f"Directory does not exist: {path}")
66+
elif not path.is_dir():
67+
raise NotADirectoryError(
68+
f"Path exists but is not a directory: {path}"
69+
)
5070

51-
# link cache
52-
self._dataset = ds.dataset(self._cache_path, format="parquet")
71+
# load cache
72+
cache = CacheReader.load(cls._generate_cache_path(path))
5373

5474
# build evaluator meta
5575
(
56-
self._index_to_label,
57-
self._confusion_matrix,
58-
self._info,
59-
) = self.generate_meta(self._dataset, labels_override)
76+
index_to_label,
77+
confusion_matrix,
78+
info,
79+
) = cls.generate_meta(cache.dataset, index_to_label_override)
6080

6181
# read config
62-
with open(self._metadata_path, "r") as f:
82+
metadata_path = cls._generate_metadata_path(path)
83+
with open(metadata_path, "r") as f:
6384
types = json.load(f)
64-
self._info.datum_metadata_types = types["datum"]
65-
self._info.groundtruth_metadata_types = types["groundtruth"]
66-
self._info.prediction_metadata_types = types["prediction"]
67-
with open(self._cache_path / ".cfg", "r") as f:
68-
cfg = json.load(f)
69-
self._detailed_batch_size = cfg["batch_size"]
70-
self._detailed_rows_per_file = cfg["rows_per_file"]
71-
self._detailed_compression = cfg["compression"]
85+
info.datum_metadata_types = types["datum_metadata_types"]
86+
info.groundtruth_metadata_types = types[
87+
"groundtruth_metadata_types"
88+
]
89+
info.prediction_metadata_types = types["prediction_metadata_types"]
90+
91+
return cls(
92+
path=path,
93+
cache=cache,
94+
info=info,
95+
index_to_label=index_to_label,
96+
confusion_matrix=confusion_matrix,
97+
)
98+
99+
def filter(
100+
self,
101+
path: str | Path,
102+
filter_expr: Filter,
103+
) -> "Evaluator":
104+
"""
105+
Filter evaluator cache.
106+
107+
Parameters
108+
----------
109+
path : str | Path
110+
Where to store the filtered cache.
111+
filter_expr : Filter
112+
An object containing filter expressions.
113+
114+
Returns
115+
-------
116+
Evaluator
117+
A new evaluator object containing the filtered cache.
118+
"""
119+
from valor_lite.semantic_segmentation.loader import Loader
120+
121+
loader = Loader.create(
122+
path=path,
123+
batch_size=self.cache.batch_size,
124+
rows_per_file=self.cache.rows_per_file,
125+
compression=self.cache.compression,
126+
datum_metadata_types=self.info.datum_metadata_types,
127+
groundtruth_metadata_types=self.info.groundtruth_metadata_types,
128+
prediction_metadata_types=self.info.prediction_metadata_types,
129+
)
130+
for fragment in self.cache.dataset.get_fragments():
131+
tbl = fragment.to_table(filter=filter_expr.datums)
132+
133+
columns = (
134+
"datum_id",
135+
"gt_label_id",
136+
"pd_label_id",
137+
)
138+
pairs = np.column_stack([tbl[col].to_numpy() for col in columns])
139+
140+
n_pairs = pairs.shape[0]
141+
gt_ids = pairs[:, (0, 1)].astype(np.int64)
142+
pd_ids = pairs[:, (0, 2)].astype(np.int64)
143+
144+
if filter_expr.groundtruths is not None:
145+
mask_valid_gt = np.zeros(n_pairs, dtype=np.bool_)
146+
gt_tbl = tbl.filter(filter_expr.groundtruths)
147+
gt_pairs = np.column_stack(
148+
[
149+
gt_tbl[col].to_numpy()
150+
for col in ("datum_id", "gt_label_id")
151+
]
152+
).astype(np.int64)
153+
for gt in np.unique(gt_pairs, axis=0):
154+
mask_valid_gt |= (gt_ids == gt).all(axis=1)
155+
else:
156+
mask_valid_gt = np.ones(n_pairs, dtype=np.bool_)
157+
158+
if filter_expr.predictions is not None:
159+
mask_valid_pd = np.zeros(n_pairs, dtype=np.bool_)
160+
pd_tbl = tbl.filter(filter_expr.predictions)
161+
pd_pairs = np.column_stack(
162+
[
163+
pd_tbl[col].to_numpy()
164+
for col in ("datum_id", "pd_label_id")
165+
]
166+
).astype(np.int64)
167+
for pd in np.unique(pd_pairs, axis=0):
168+
mask_valid_pd |= (pd_ids == pd).all(axis=1)
169+
else:
170+
mask_valid_pd = np.ones(n_pairs, dtype=np.bool_)
171+
172+
mask_valid = mask_valid_gt | mask_valid_pd
173+
mask_valid_gt &= mask_valid
174+
mask_valid_pd &= mask_valid
175+
176+
pairs[~mask_valid_gt, 1] = -1
177+
pairs[~mask_valid_pd, 2] = -1
178+
179+
for idx, col in enumerate(columns):
180+
tbl = tbl.set_column(
181+
tbl.schema.names.index(col), col, pa.array(pairs[:, idx])
182+
)
183+
loader._cache.write_table(tbl)
184+
185+
loader._cache.flush()
186+
if loader._cache.dataset.count_rows() == 0:
187+
raise EmptyCacheError()
188+
189+
return loader.finalize()
190+
191+
def delete(self):
192+
"""
193+
Delete evaluator cache.
194+
"""
195+
from valor_lite.semantic_segmentation.loader import Loader
196+
197+
Loader.delete(self.path)
198+
199+
@property
200+
def path(self) -> Path:
201+
return self._path
72202

73203
@property
74-
def dataset(self) -> ds.Dataset:
75-
return self._dataset
204+
def cache(self) -> CacheReader:
205+
return self._cache
76206

77207
@property
78208
def info(self) -> EvaluatorInfo:
@@ -185,105 +315,6 @@ def generate_meta(
185315

186316
return labels, matrix, info
187317

188-
def filter(
189-
self,
190-
filter_expr: Filter,
191-
name: str | None = None,
192-
directory: str | Path | None = None,
193-
) -> "Evaluator":
194-
"""
195-
Filter evaluator cache.
196-
197-
Parameters
198-
----------
199-
filter_expr : Filter
200-
An object containing filter expressions.
201-
name : str, optional
202-
Filtered cache name.
203-
directory : str | Path, optional
204-
The directory to store the filtered cache.
205-
206-
Returns
207-
-------
208-
Evaluator
209-
A new evaluator object containing the filtered cache.
210-
"""
211-
loader = cls(
212-
directory=directory,
213-
name=name,
214-
batch_size=evaluator._detailed_batch_size,
215-
rows_per_file=evaluator._detailed_rows_per_file,
216-
compression=evaluator._detailed_compression,
217-
datum_metadata_types=evaluator.info.datum_metadata_types,
218-
groundtruth_metadata_types=evaluator.info.groundtruth_metadata_types,
219-
prediction_metadata_types=evaluator.info.prediction_metadata_types,
220-
)
221-
for fragment in evaluator.dataset.get_fragments():
222-
tbl = fragment.to_table(filter=filter_expr.datums)
223-
224-
columns = (
225-
"datum_id",
226-
"gt_label_id",
227-
"pd_label_id",
228-
)
229-
pairs = np.column_stack([tbl[col].to_numpy() for col in columns])
230-
231-
n_pairs = pairs.shape[0]
232-
gt_ids = pairs[:, (0, 1)].astype(np.int64)
233-
pd_ids = pairs[:, (0, 2)].astype(np.int64)
234-
235-
if filter_expr.groundtruths is not None:
236-
mask_valid_gt = np.zeros(n_pairs, dtype=np.bool_)
237-
gt_tbl = tbl.filter(filter_expr.groundtruths)
238-
gt_pairs = np.column_stack(
239-
[
240-
gt_tbl[col].to_numpy()
241-
for col in ("datum_id", "gt_label_id")
242-
]
243-
).astype(np.int64)
244-
for gt in np.unique(gt_pairs, axis=0):
245-
mask_valid_gt |= (gt_ids == gt).all(axis=1)
246-
else:
247-
mask_valid_gt = np.ones(n_pairs, dtype=np.bool_)
248-
249-
if filter_expr.predictions is not None:
250-
mask_valid_pd = np.zeros(n_pairs, dtype=np.bool_)
251-
pd_tbl = tbl.filter(filter_expr.predictions)
252-
pd_pairs = np.column_stack(
253-
[
254-
pd_tbl[col].to_numpy()
255-
for col in ("datum_id", "pd_label_id")
256-
]
257-
).astype(np.int64)
258-
for pd in np.unique(pd_pairs, axis=0):
259-
mask_valid_pd |= (pd_ids == pd).all(axis=1)
260-
else:
261-
mask_valid_pd = np.ones(n_pairs, dtype=np.bool_)
262-
263-
mask_valid = mask_valid_gt | mask_valid_pd
264-
mask_valid_gt &= mask_valid
265-
mask_valid_pd &= mask_valid
266-
267-
pairs[~mask_valid_gt, 1] = -1
268-
pairs[~mask_valid_pd, 2] = -1
269-
270-
for idx, col in enumerate(columns):
271-
tbl = tbl.set_column(
272-
tbl.schema.names.index(col), col, pa.array(pairs[:, idx])
273-
)
274-
loader._cache.write_table(tbl)
275-
276-
loader._cache.flush()
277-
if loader._cache.dataset.count_rows() == 0:
278-
raise EmptyCacheError()
279-
280-
evaluator = Evaluator(
281-
directory=loader._directory,
282-
name=loader._name,
283-
labels_override=evaluator._index_to_label,
284-
)
285-
return evaluator
286-
287318
def compute_precision_recall_iou(self) -> dict[MetricType, list]:
288319
"""
289320
Performs an evaluation and returns metrics.

0 commit comments

Comments
 (0)