Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions benchmarks/benchmark_objdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def run_benchmarking_analysis(
)
if eval_time > evaluation_timeout and evaluation_timeout != -1:
raise TimeoutError(
f"Base evaluation timed out with {evaluator.n_datums} datums."
f"Base evaluation timed out with {evaluator.metadata.number_of_datums} datums."
)

# evaluate - base metrics + detailed
Expand All @@ -337,16 +337,16 @@ def run_benchmarking_analysis(
and evaluation_timeout != -1
):
raise TimeoutError(
f"Detailed evaluation timed out with {evaluator.n_datums} datums."
f"Detailed evaluation timed out with {evaluator.metadata.number_of_datums} datums."
)

results.append(
Benchmark(
limit=limit,
n_datums=evaluator.n_datums,
n_groundtruths=evaluator.n_groundtruths,
n_predictions=evaluator.n_predictions,
n_labels=evaluator.n_labels,
n_datums=evaluator.metadata.number_of_datums,
n_groundtruths=evaluator.metadata.number_of_ground_truths,
n_predictions=evaluator.metadata.number_of_predictions,
n_labels=evaluator.metadata.number_of_labels,
gt_type=gt_type,
pd_type=pd_type,
chunk_size=chunk_size,
Expand Down
3 changes: 2 additions & 1 deletion src/valor_lite/object_detection/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .annotation import Bitmask, BoundingBox, Detection, Polygon
from .manager import DataLoader, Evaluator
from .manager import DataLoader, Evaluator, Filter
from .metric import Metric, MetricType

__all__ = [
Expand All @@ -11,4 +11,5 @@
"MetricType",
"DataLoader",
"Evaluator",
"Filter",
]
91 changes: 91 additions & 0 deletions src/valor_lite/object_detection/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,97 @@ def compute_label_metadata(
return label_metadata


def filter_cache(
detailed_pairs: NDArray[np.float64],
mask_datums: NDArray[np.bool_],
mask_predictions: NDArray[np.bool_],
mask_ground_truths: NDArray[np.bool_],
n_labels: int,
) -> tuple[NDArray[np.float64], NDArray[np.float64], NDArray[np.int32],]:
"""
Performs filtering on a detailed cache.

Parameters
----------
detailed_pairs : NDArray[float64]
A list of sorted detailed pairs with size (N, 7).
mask_datums : NDArray[bool]
A boolean mask with size (N,).
mask_ground_truths : NDArray[bool]
A boolean mask with size (N,).
mask_predictions : NDArray[bool]
A boolean mask with size (N,).
n_labels : int
The total number of unique labels.

Returns
-------
NDArray[float64]
Filtered detailed pairs.
NDArray[float64]
Filtered ranked pairs.
NDArray[int32]
Label metadata.
"""
# filter datums
detailed_pairs = detailed_pairs[mask_datums].copy()

# filter ground truths
if mask_ground_truths.any():
invalid_groundtruth_indices = np.where(mask_ground_truths)[0]
detailed_pairs[
invalid_groundtruth_indices[:, None], (1, 3, 5)
] = np.array([[-1, -1, 0]])

# filter predictions
if mask_predictions.any():
invalid_prediction_indices = np.where(mask_predictions)[0]
detailed_pairs[
invalid_prediction_indices[:, None], (2, 4, 5, 6)
] = np.array([[-1, -1, 0, -1]])

# filter null pairs
mask_null_pairs = np.all(
np.isclose(
detailed_pairs[:, 1:5],
np.array([-1.0, -1.0, -1.0, -1.0]),
),
axis=1,
)
detailed_pairs = detailed_pairs[~mask_null_pairs]

if detailed_pairs.size == 0:
warnings.warn("no valid filtered pairs")
return (
np.array([], dtype=np.float64),
np.array([], dtype=np.float64),
np.zeros((n_labels, 2), dtype=np.int32),
)

# sorts by score, iou with ground truth id as a tie-breaker
indices = np.lexsort(
(
detailed_pairs[:, 1], # ground truth id
-detailed_pairs[:, 5], # iou
-detailed_pairs[:, 6], # score
)
)
detailed_pairs = detailed_pairs[indices]
label_metadata = compute_label_metadata(
ids=detailed_pairs[:, :5].astype(np.int32),
n_labels=n_labels,
)
ranked_pairs = rank_pairs(
detailed_pairs=detailed_pairs,
label_metadata=label_metadata,
)
return (
detailed_pairs,
ranked_pairs,
label_metadata,
)


def rank_pairs(
detailed_pairs: NDArray[np.float64],
label_metadata: NDArray[np.int32],
Expand Down
Loading