Skip to content

Commit 2503ce8

Browse files
committed
feat: add module to visualize predictions using fiftyone
1 parent a01e591 commit 2503ce8

File tree

3 files changed

+143
-0
lines changed

3 files changed

+143
-0
lines changed

src/labelr/apps/evaluate.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from typing import Annotated
2+
3+
import typer
4+
5+
from labelr.evaluate import visualize as _visualize
6+
7+
app = typer.Typer()
8+
9+
10+
@app.command()
11+
def visualize(
12+
hf_repo_id: Annotated[
13+
str,
14+
typer.Option(
15+
...,
16+
help="Hugging Face repository ID of the trained model. "
17+
"A `predictions.parquet` file is expected in the repo. Revision can be specified "
18+
"by appending `@<revision>` to the repo ID.",
19+
),
20+
],
21+
dataset_name: Annotated[
22+
str, typer.Option(..., help="Name of the FiftyOne dataset to create.")
23+
],
24+
persistent: Annotated[
25+
bool,
26+
typer.Option(
27+
...,
28+
help="Whether to make the FiftyOne dataset persistent (i.e., saved to disk).",
29+
),
30+
] = False,
31+
):
32+
_visualize(
33+
hf_repo_id=hf_repo_id,
34+
dataset_name=dataset_name,
35+
persistent=persistent,
36+
)

src/labelr/evaluate.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import tempfile
2+
from pathlib import Path
3+
4+
import datasets
5+
import fiftyone as fo
6+
from huggingface_hub import hf_hub_download
7+
8+
from labelr.dataset_features import OBJECT_DETECTION_DS_PREDICTION_FEATURES
9+
from labelr.utils import parse_hf_repo_id
10+
11+
12+
def convert_bbox_to_fo_format(
13+
bbox: tuple[float, float, float, float],
14+
) -> tuple[float, float, float, float]:
15+
# Bounding box coordinates should be relative values
16+
# in [0, 1] in the following format:
17+
# [top-left-x, top-left-y, width, height]
18+
y_min, x_min, y_max, x_max = bbox
19+
return (
20+
x_min,
21+
y_min,
22+
(x_max - x_min),
23+
(y_max - y_min),
24+
)
25+
26+
27+
def visualize(
28+
hf_repo_id: str,
29+
dataset_name: str,
30+
persistent: bool,
31+
):
32+
hf_repo_id, hf_revision = parse_hf_repo_id(hf_repo_id)
33+
34+
file_path = hf_hub_download(
35+
hf_repo_id,
36+
filename="predictions.parquet",
37+
revision=hf_revision,
38+
repo_type="model",
39+
# local_dir="./predictions/",
40+
)
41+
file_path = Path(file_path).absolute()
42+
prediction_dataset = datasets.load_dataset(
43+
"parquet",
44+
data_files=str(file_path),
45+
split="train",
46+
features=OBJECT_DETECTION_DS_PREDICTION_FEATURES,
47+
)
48+
fo_dataset = fo.Dataset(name=dataset_name, persistent=persistent)
49+
50+
with tempfile.TemporaryDirectory() as tmpdir_str:
51+
tmp_dir = Path(tmpdir_str)
52+
for i, hf_sample in enumerate(prediction_dataset):
53+
image = hf_sample["image"]
54+
image_path = tmp_dir / f"{i}.jpg"
55+
image.save(image_path)
56+
split = hf_sample["split"]
57+
sample = fo.Sample(
58+
filepath=image_path,
59+
split=split,
60+
tags=[split],
61+
image=hf_sample["image_id"],
62+
)
63+
ground_truth_detections = [
64+
fo.Detection(
65+
label=hf_sample["objects"]["category_name"][i],
66+
bounding_box=convert_bbox_to_fo_format(
67+
bbox=hf_sample["objects"]["bbox"][i],
68+
),
69+
)
70+
for i in range(len(hf_sample["objects"]["bbox"]))
71+
]
72+
sample["ground_truth"] = fo.Detections(detections=ground_truth_detections)
73+
74+
if hf_sample["detected"] is not None and hf_sample["detected"]["bbox"]:
75+
model_detections = [
76+
fo.Detection(
77+
label=hf_sample["detected"]["category_name"][i],
78+
bounding_box=convert_bbox_to_fo_format(
79+
bbox=hf_sample["detected"]["bbox"][i]
80+
),
81+
confidence=hf_sample["detected"]["confidence"][i],
82+
)
83+
for i in range(len(hf_sample["detected"]["bbox"]))
84+
]
85+
sample["model"] = fo.Detections(detections=model_detections)
86+
87+
fo_dataset.add_sample(sample)
88+
89+
# View summary info about the dataset
90+
print(fo_dataset)
91+
92+
# Print the first few samples in the dataset
93+
print(fo_dataset.head())
94+
95+
# Visualize the dataset in the FiftyOne App
96+
session = fo.launch_app(fo_dataset)
97+
fo_dataset.evaluate_detections(
98+
"model", gt_field="ground_truth", eval_key="eval", compute_mAP=True
99+
)
100+
session.wait()

src/labelr/main.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from openfoodfacts.utils import get_logger
55

66
from labelr.apps import datasets as dataset_app
7+
from labelr.apps import evaluate as evaluate_app
78
from labelr.apps import projects as project_app
89
from labelr.apps import train as train_app
910
from labelr.apps import users as user_app
@@ -76,5 +77,11 @@ def predict(
7677
help="Train models",
7778
)
7879

80+
app.add_typer(
81+
evaluate_app.app,
82+
name="evaluate",
83+
help="Visualize and evaluate trained models",
84+
)
85+
7986
if __name__ == "__main__":
8087
app()

0 commit comments

Comments
 (0)