Skip to content

Commit f175f52

Browse files
committed
docstring
1 parent b5ac471 commit f175f52

1 file changed

Lines changed: 57 additions & 23 deletions

File tree

src/valor_lite/object_detection/evaluator.py

Lines changed: 57 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pyarrow.compute as pc
1111
import pyarrow.dataset as ds
1212
import pyarrow.parquet as pq
13+
from numpy.typing import NDArray
1314

1415
from valor_lite.cache import CacheReader, CacheWriter, DataType
1516
from valor_lite.object_detection.computation import (
@@ -131,8 +132,28 @@ def info(self) -> EvaluatorInfo:
131132

132133
@staticmethod
133134
def generate_meta(
134-
dataset: ds.Dataset, labels_override: dict[int, str] | None
135-
):
135+
dataset: ds.Dataset,
136+
labels_override: dict[int, str] | None,
137+
) -> tuple[dict[int, str], NDArray[np.uint64], EvaluatorInfo,]:
138+
"""
139+
Generate cache statistics.
140+
141+
Parameters
142+
----------
143+
dataset : Dataset
144+
Valor cache.
145+
labels_override : dict[int, str], optional
146+
Optional labels override. Use when operating over filtered data.
147+
148+
Returns
149+
-------
150+
labels : dict[int, str]
151+
Mapping of label ID's to label values.
152+
number_of_groundtruths_per_label : NDArray[np.uint64]
153+
Array of size (n_labels,) containing ground truth counts.
154+
info : EvaluatorInfo
155+
Evaluator cache details.
156+
"""
136157
gt_counts_per_lbl = defaultdict(int)
137158
labels = labels_override if labels_override else {}
138159
info = EvaluatorInfo()
@@ -236,6 +257,23 @@ def filter(
236257
name: str | None = None,
237258
directory: str | Path | None = None,
238259
) -> "Evaluator":
260+
"""
261+
Filter evaluator cache.
262+
263+
Parameters
264+
----------
265+
filter_expr : Filter
266+
An object containing filter expressions.
267+
name : str, optional
268+
Filtered cache name.
269+
directory : str | Path, optional
270+
The directory to store the filtered cache.
271+
272+
Returns
273+
-------
274+
Evaluator
275+
A new evaluator object containing the filtered cache.
276+
"""
239277
name = name if name else "filtered"
240278
directory = directory if directory else self._directory
241279
from valor_lite.object_detection.loader import Loader
@@ -255,6 +293,23 @@ def rank(
255293
write_batch_size: int | None = None,
256294
read_batch_size: int = 1_000,
257295
):
296+
"""
297+
Create a ranked pair cache.
298+
299+
Parameters
300+
----------
301+
where : str | Path
302+
Where to store the ranked cache.
303+
rows_per_file : int, optional
304+
Sets the maximum number of rows per file. Defaults to value from detailed cache.
305+
compression : str, optional
306+
Sets the compression method. Defaults to value from detailed cache.
307+
write_batch_size : int, optional
308+
Sets the batch size for writing. Defaults to value from detailed cache.
309+
read_batch_size : int, default=1_000
310+
Sets the batch size for reading. Defaults to 1_000.
311+
"""
312+
258313
n_labels = len(self._index_to_label)
259314
detailed_cache = CacheReader(self._detailed_path)
260315

@@ -456,8 +511,6 @@ def compute_confusion_matrix(
456511
A list of IOU thresholds to compute metrics over.
457512
score_thresholds : list[float]
458513
A list of score thresholds to compute metrics over.
459-
filter_ : Filter, optional
460-
A collection of filter parameters and masks.
461514
462515
Returns
463516
-------
@@ -548,8 +601,6 @@ def compute_examples(
548601
A list of IOU thresholds to compute metrics over.
549602
score_thresholds : list[float]
550603
A list of score thresholds to compute metrics over.
551-
filter_ : Filter, optional
552-
A collection of filter parameters and masks.
553604
554605
Returns
555606
-------
@@ -637,8 +688,6 @@ def compute_confusion_matrix_with_examples(
637688
A list of IOU thresholds to compute metrics over.
638689
score_thresholds : list[float]
639690
A list of score thresholds to compute metrics over.
640-
filter_ : Filter, optional
641-
A collection of filter parameters and masks.
642691
643692
Returns
644693
-------
@@ -716,18 +765,3 @@ def compute_confusion_matrix_with_examples(
716765
)
717766

718767
return [m for inner in metrics.values() for m in inner.values()]
719-
720-
def evaluate(
721-
self,
722-
iou_thresholds: list[float] = [0.1, 0.5, 0.75],
723-
score_thresholds: list[float] = [0.5, 0.75, 0.9],
724-
) -> dict[MetricType, list[Metric]]:
725-
metrics = self.compute_precision_recall(
726-
iou_thresholds=iou_thresholds,
727-
score_thresholds=score_thresholds,
728-
)
729-
metrics[MetricType.ConfusionMatrix] = self.compute_confusion_matrix(
730-
iou_thresholds=iou_thresholds,
731-
score_thresholds=score_thresholds,
732-
)
733-
return metrics

0 commit comments

Comments
 (0)