|
12 | 12 | import typer |
13 | 13 | import ultralytics |
14 | 14 | import wandb |
15 | | -from datasets import Dataset, Features |
16 | | -from datasets import Image as HFImage |
| 15 | +from datasets import Dataset |
17 | 16 | from huggingface_hub import HfApi, ModelCard, ModelCardData |
18 | 17 | from PIL import Image |
19 | 18 |
|
| 19 | +from labelr.dataset_features import OBJECT_DETECTION_DS_PREDICTION_FEATURES |
20 | 20 | from labelr.export import ( |
21 | 21 | _pickle_sample_generator, |
22 | 22 | export_from_hf_to_ultralytics_object_detection, |
23 | 23 | ) |
| 24 | +from labelr.utils import parse_hf_repo_id |
24 | 25 |
|
25 | 26 | CARD_TEMPLATE = """ |
26 | 27 | --- |
|
101 | 102 | """ |
102 | 103 |
|
103 | 104 |
|
104 | | -DS_PREDICTION_FEATURES = Features( |
105 | | - { |
106 | | - "image": HFImage(), |
107 | | - "image_with_prediction": HFImage(), |
108 | | - "image_id": datasets.Value("string"), |
109 | | - "detected": { |
110 | | - "bbox": datasets.Sequence(datasets.Sequence(datasets.Value("float32"))), |
111 | | - "category_id": datasets.Sequence(datasets.Value("int64")), |
112 | | - "category_name": datasets.Sequence(datasets.Value("string")), |
113 | | - "confidence": datasets.Sequence(datasets.Value("float32")), |
114 | | - }, |
115 | | - "split": datasets.Value("string"), |
116 | | - "width": datasets.Value("int64"), |
117 | | - "height": datasets.Value("int64"), |
118 | | - "meta": { |
119 | | - "barcode": datasets.Value("string"), |
120 | | - "off_image_id": datasets.Value("string"), |
121 | | - "image_url": datasets.Value("string"), |
122 | | - }, |
123 | | - "objects": { |
124 | | - "bbox": datasets.Sequence(datasets.Sequence(datasets.Value("float32"))), |
125 | | - "category_id": datasets.Sequence(datasets.Value("int64")), |
126 | | - "category_name": datasets.Sequence(datasets.Value("string")), |
127 | | - }, |
128 | | - } |
129 | | -) |
130 | | - |
131 | | - |
132 | 105 | def create_model_card( |
133 | 106 | dataset_repo_id: str, |
134 | 107 | dataset_revision: str, |
@@ -239,7 +212,7 @@ def create_predict_dataset( |
239 | 212 | # image |
240 | 213 | ds = Dataset.from_generator( |
241 | 214 | functools.partial(_pickle_sample_generator, tmp_dir), |
242 | | - features=DS_PREDICTION_FEATURES, |
| 215 | + features=OBJECT_DETECTION_DS_PREDICTION_FEATURES, |
243 | 216 | ) |
244 | 217 | ds.to_parquet(output_path) |
245 | 218 | typer.echo(f"Saved Hugging Face dataset as Parquet file to: {output_path}") |
@@ -334,11 +307,7 @@ def main( |
334 | 307 | dataset_dir = root_dir / "datasets" |
335 | 308 | run_dir = (Path(__file__).parent / project / run_name).absolute() |
336 | 309 |
|
337 | | - # We can specify a revision (branch, commit sha or tag) with '@' suffix |
338 | | - if "@" in hf_repo_id: |
339 | | - hf_repo_id, revision = hf_repo_id.split("@", 1) |
340 | | - else: |
341 | | - revision = "main" |
| 310 | + hf_repo_id, revision = parse_hf_repo_id(hf_repo_id) |
342 | 311 |
|
343 | 312 | # `skip_dataset_download` is an option to skip dataset download, useful |
344 | 313 | # for debugging locally |
|
0 commit comments