Skip to content

Commit 0d34879

Browse files
committed
chore(train-yolo): replace definitions and funcs by labelr
implementations
1 parent 90627ad commit 0d34879

File tree

3 files changed

+11
-41
lines changed

3 files changed

+11
-41
lines changed

packages/train-yolo/main.py

Lines changed: 5 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,16 @@
1212
import typer
1313
import ultralytics
1414
import wandb
15-
from datasets import Dataset, Features
16-
from datasets import Image as HFImage
15+
from datasets import Dataset
1716
from huggingface_hub import HfApi, ModelCard, ModelCardData
1817
from PIL import Image
1918

19+
from labelr.dataset_features import OBJECT_DETECTION_DS_PREDICTION_FEATURES
2020
from labelr.export import (
2121
_pickle_sample_generator,
2222
export_from_hf_to_ultralytics_object_detection,
2323
)
24+
from labelr.utils import parse_hf_repo_id
2425

2526
CARD_TEMPLATE = """
2627
---
@@ -101,34 +102,6 @@
101102
"""
102103

103104

104-
DS_PREDICTION_FEATURES = Features(
105-
{
106-
"image": HFImage(),
107-
"image_with_prediction": HFImage(),
108-
"image_id": datasets.Value("string"),
109-
"detected": {
110-
"bbox": datasets.Sequence(datasets.Sequence(datasets.Value("float32"))),
111-
"category_id": datasets.Sequence(datasets.Value("int64")),
112-
"category_name": datasets.Sequence(datasets.Value("string")),
113-
"confidence": datasets.Sequence(datasets.Value("float32")),
114-
},
115-
"split": datasets.Value("string"),
116-
"width": datasets.Value("int64"),
117-
"height": datasets.Value("int64"),
118-
"meta": {
119-
"barcode": datasets.Value("string"),
120-
"off_image_id": datasets.Value("string"),
121-
"image_url": datasets.Value("string"),
122-
},
123-
"objects": {
124-
"bbox": datasets.Sequence(datasets.Sequence(datasets.Value("float32"))),
125-
"category_id": datasets.Sequence(datasets.Value("int64")),
126-
"category_name": datasets.Sequence(datasets.Value("string")),
127-
},
128-
}
129-
)
130-
131-
132105
def create_model_card(
133106
dataset_repo_id: str,
134107
dataset_revision: str,
@@ -239,7 +212,7 @@ def create_predict_dataset(
239212
# image
240213
ds = Dataset.from_generator(
241214
functools.partial(_pickle_sample_generator, tmp_dir),
242-
features=DS_PREDICTION_FEATURES,
215+
features=OBJECT_DETECTION_DS_PREDICTION_FEATURES,
243216
)
244217
ds.to_parquet(output_path)
245218
typer.echo(f"Saved Hugging Face dataset as Parquet file to: {output_path}")
@@ -334,11 +307,7 @@ def main(
334307
dataset_dir = root_dir / "datasets"
335308
run_dir = (Path(__file__).parent / project / run_name).absolute()
336309

337-
# We can specify a revision (branch, commit sha or tag) with '@' suffix
338-
if "@" in hf_repo_id:
339-
hf_repo_id, revision = hf_repo_id.split("@", 1)
340-
else:
341-
revision = "main"
310+
hf_repo_id, revision = parse_hf_repo_id(hf_repo_id)
342311

343312
# `skip_dataset_download` is an option to skip dataset download, useful
344313
# for debugging locally

packages/train-yolo/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ dependencies = [
1010
"Pillow",
1111
"ultralytics==8.3.223",
1212
"albumentations",
13-
"labelr==0.6.0",
13+
"labelr==0.7.0",
1414
"wandb==0.22.3",
1515
"torch==2.9.0",
1616
"Jinja2==3.1.6",

packages/train-yolo/uv.lock

Lines changed: 5 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)