|
| 1 | +import json |
| 2 | +from dataclasses import dataclass |
| 3 | +from pathlib import Path |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +import pyarrow.compute as pc |
| 7 | +import pyarrow.dataset as ds |
| 8 | +from numpy.typing import NDArray |
| 9 | + |
| 10 | +from valor_lite.cache import DataType |
| 11 | +from valor_lite.classification.computation import compute_metrics |
| 12 | +from valor_lite.classification.metric import Metric, MetricType |
| 13 | +from valor_lite.classification.utilities import ( |
| 14 | + unpack_precision_recall_iou_into_metric_lists, |
| 15 | +) |
| 16 | + |
| 17 | + |
| 18 | +@dataclass |
| 19 | +class EvaluatorInfo: |
| 20 | + number_of_rows: int = 0 |
| 21 | + number_of_datums: int = 0 |
| 22 | + number_of_labels: int = 0 |
| 23 | + number_of_pixels: int = 0 |
| 24 | + number_of_groundtruth_pixels: int = 0 |
| 25 | + number_of_prediction_pixels: int = 0 |
| 26 | + datum_metadata_types: dict[str, DataType] | None = None |
| 27 | + groundtruth_metadata_types: dict[str, DataType] | None = None |
| 28 | + prediction_metadata_types: dict[str, DataType] | None = None |
| 29 | + |
| 30 | + |
| 31 | +@dataclass |
| 32 | +class Filter: |
| 33 | + datums: pc.Expression | None = None |
| 34 | + groundtruths: pc.Expression | None = None |
| 35 | + predictions: pc.Expression | None = None |
| 36 | + |
| 37 | + |
| 38 | +class Evaluator: |
| 39 | + def __init__( |
| 40 | + self, |
| 41 | + name: str = "default", |
| 42 | + directory: str | Path = ".valor", |
| 43 | + labels_override: dict[int, str] | None = None, |
| 44 | + ): |
| 45 | + self._directory = Path(directory) |
| 46 | + self._name = name |
| 47 | + self._path = self._directory / name |
| 48 | + self._cache_path = self._path / "counts" |
| 49 | + self._metadata_path = self._path / "metadata.json" |
| 50 | + |
| 51 | + # link cache |
| 52 | + self._dataset = ds.dataset(self._cache_path, format="parquet") |
| 53 | + |
| 54 | + # build evaluator meta |
| 55 | + ( |
| 56 | + self._index_to_label, |
| 57 | + self._confusion_matrix, |
| 58 | + self._info, |
| 59 | + ) = self.generate_meta(self._dataset, labels_override) |
| 60 | + |
| 61 | + # read config |
| 62 | + with open(self._metadata_path, "r") as f: |
| 63 | + types = json.load(f) |
| 64 | + self._info.datum_metadata_types = types["datum"] |
| 65 | + self._info.groundtruth_metadata_types = types["groundtruth"] |
| 66 | + self._info.prediction_metadata_types = types["prediction"] |
| 67 | + with open(self._cache_path / ".cfg", "r") as f: |
| 68 | + cfg = json.load(f) |
| 69 | + self._detailed_batch_size = cfg["batch_size"] |
| 70 | + self._detailed_rows_per_file = cfg["rows_per_file"] |
| 71 | + self._detailed_compression = cfg["compression"] |
| 72 | + |
| 73 | + @property |
| 74 | + def dataset(self) -> ds.Dataset: |
| 75 | + return self._dataset |
| 76 | + |
| 77 | + @property |
| 78 | + def info(self) -> EvaluatorInfo: |
| 79 | + return self._info |
| 80 | + |
| 81 | + @staticmethod |
| 82 | + def generate_meta( |
| 83 | + dataset: ds.Dataset, |
| 84 | + labels_override: dict[int, str] | None, |
| 85 | + ) -> tuple[dict[int, str], NDArray[np.uint64], EvaluatorInfo]: |
| 86 | + """ |
| 87 | + Generate cache statistics. |
| 88 | +
|
| 89 | + Parameters |
| 90 | + ---------- |
| 91 | + dataset : Dataset |
| 92 | + Valor cache. |
| 93 | + labels_override : dict[int, str], optional |
| 94 | + Optional labels override. Use when operating over filtered data. |
| 95 | +
|
| 96 | + Returns |
| 97 | + ------- |
| 98 | + labels : dict[int, str] |
| 99 | + Mapping of label ID's to label values. |
| 100 | + confusion_matrix : NDArray[np.uint64] |
| 101 | + Array of size (n_labels + 1, n_labels + 1) containing pair counts. |
| 102 | + info : EvaluatorInfo |
| 103 | + Evaluator cache details. |
| 104 | + """ |
| 105 | + labels = labels_override if labels_override else {} |
| 106 | + info = EvaluatorInfo() |
| 107 | + |
| 108 | + for fragment in dataset.get_fragments(): |
| 109 | + tbl = fragment.to_table() |
| 110 | + columns = ( |
| 111 | + "datum_id", |
| 112 | + "gt_label_id", |
| 113 | + "pd_label_id", |
| 114 | + "count", |
| 115 | + ) |
| 116 | + ids = np.column_stack( |
| 117 | + [tbl[col].to_numpy() for col in columns] |
| 118 | + ).astype(np.int64) |
| 119 | + |
| 120 | + # count number of rows |
| 121 | + info.number_of_rows += int(tbl.shape[0]) |
| 122 | + |
| 123 | + # count unique datums |
| 124 | + datum_ids = np.unique(ids[:, 0]) |
| 125 | + info.number_of_datums += int(datum_ids.size) |
| 126 | + |
| 127 | + # get gt labels |
| 128 | + gt_label_ids = ids[:, 1] |
| 129 | + gt_label_ids, gt_indices = np.unique( |
| 130 | + gt_label_ids, return_index=True |
| 131 | + ) |
| 132 | + gt_labels = tbl["gt_label"].take(gt_indices).to_pylist() |
| 133 | + gt_labels = dict(zip(gt_label_ids.astype(int).tolist(), gt_labels)) |
| 134 | + gt_labels.pop(-1, None) |
| 135 | + labels.update(gt_labels) |
| 136 | + |
| 137 | + # get pd labels |
| 138 | + pd_label_ids = ids[:, 2] |
| 139 | + pd_label_ids, pd_indices, pd_counts = np.unique( |
| 140 | + pd_label_ids, return_index=True, return_counts=True |
| 141 | + ) |
| 142 | + pd_labels = tbl["pd_label"].take(pd_indices).to_pylist() |
| 143 | + pd_labels = dict(zip(pd_label_ids.astype(int).tolist(), pd_labels)) |
| 144 | + pd_labels.pop(-1, None) |
| 145 | + labels.update(pd_labels) |
| 146 | + |
| 147 | + # post-process |
| 148 | + labels.pop(-1, None) |
| 149 | + |
| 150 | + # create confusion matrix |
| 151 | + n_labels = len(labels) |
| 152 | + matrix = np.zeros((n_labels + 1, n_labels + 1), dtype=np.uint64) |
| 153 | + for fragment in dataset.get_fragments(): |
| 154 | + tbl = fragment.to_table() |
| 155 | + columns = ( |
| 156 | + "datum_id", |
| 157 | + "gt_label_id", |
| 158 | + "pd_label_id", |
| 159 | + ) |
| 160 | + ids = np.column_stack( |
| 161 | + [tbl[col].to_numpy() for col in columns] |
| 162 | + ).astype(np.int64) |
| 163 | + counts = tbl["count"].to_numpy() |
| 164 | + |
| 165 | + mask_null_gts = ids[:, 1] == -1 |
| 166 | + mask_null_pds = ids[:, 2] == -1 |
| 167 | + matrix[0, 0] = counts[mask_null_gts & mask_null_pds].sum() |
| 168 | + for idx in range(n_labels): |
| 169 | + mask_gts = ids[:, 1] == idx |
| 170 | + for pidx in range(n_labels): |
| 171 | + mask_pds = ids[:, 2] == pidx |
| 172 | + matrix[idx + 1, pidx + 1] = counts[ |
| 173 | + mask_gts & mask_pds |
| 174 | + ].sum() |
| 175 | + |
| 176 | + mask_unmatched_gts = mask_gts & mask_null_pds |
| 177 | + matrix[idx + 1, 0] = counts[mask_unmatched_gts].sum() |
| 178 | + mask_unmatched_pds = mask_null_gts & (ids[:, 2] == idx) |
| 179 | + matrix[0, idx + 1] = counts[mask_unmatched_pds].sum() |
| 180 | + |
| 181 | + # complete info object |
| 182 | + info.number_of_labels = len(labels) |
| 183 | + info.number_of_pixels = matrix.sum() |
| 184 | + info.number_of_groundtruth_pixels = matrix[1:, :].sum() |
| 185 | + info.number_of_prediction_pixels = matrix[:, 1:].sum() |
| 186 | + |
| 187 | + return labels, matrix, info |
| 188 | + |
| 189 | + def filter( |
| 190 | + self, |
| 191 | + filter_expr: Filter, |
| 192 | + name: str | None = None, |
| 193 | + directory: str | Path | None = None, |
| 194 | + ) -> "Evaluator": |
| 195 | + """ |
| 196 | + Filter evaluator cache. |
| 197 | +
|
| 198 | + Parameters |
| 199 | + ---------- |
| 200 | + filter_expr : Filter |
| 201 | + An object containing filter expressions. |
| 202 | + name : str, optional |
| 203 | + Filtered cache name. |
| 204 | + directory : str | Path, optional |
| 205 | + The directory to store the filtered cache. |
| 206 | +
|
| 207 | + Returns |
| 208 | + ------- |
| 209 | + Evaluator |
| 210 | + A new evaluator object containing the filtered cache. |
| 211 | + """ |
| 212 | + name = name if name else "filtered" |
| 213 | + directory = directory if directory else self._directory |
| 214 | + from valor_lite.classification.loader import Loader |
| 215 | + |
| 216 | + return Loader.filter( |
| 217 | + name=name, |
| 218 | + directory=directory, |
| 219 | + evaluator=self, |
| 220 | + filter_expr=filter_expr, |
| 221 | + ) |
| 222 | + |
| 223 | + def compute_precision_recall_iou(self) -> dict[MetricType, list]: |
| 224 | + """ |
| 225 | + Performs an evaluation and returns metrics. |
| 226 | +
|
| 227 | + Returns |
| 228 | + ------- |
| 229 | + dict[MetricType, list] |
| 230 | + A dictionary mapping MetricType enumerations to lists of computed metrics. |
| 231 | + """ |
| 232 | + results = compute_metrics(counts=self._confusion_matrix) |
| 233 | + return unpack_precision_recall_iou_into_metric_lists( |
| 234 | + results=results, |
| 235 | + index_to_label=self._index_to_label, |
| 236 | + ) |
| 237 | + |
| 238 | + def evaluate(self) -> dict[MetricType, list[Metric]]: |
| 239 | + """ |
| 240 | + Computes all available metrics. |
| 241 | +
|
| 242 | + Returns |
| 243 | + ------- |
| 244 | + dict[MetricType, list[Metric]] |
| 245 | + Lists of metrics organized by metric type. |
| 246 | + """ |
| 247 | + return self.compute_precision_recall_iou() |
0 commit comments