|
140 | 140 | "unexpected files are supplied. --bypass_prompts bypasses prompts, --nobypass_prompts " |
141 | 141 | "(default) does not.", |
142 | 142 | ) |
| 143 | +_IGNORE_EXISTING_PREDICTIONS = flags.DEFINE_bool( |
| 144 | + "ignore_existing_predictions", |
| 145 | + False, |
| 146 | + "Whether to ignore existing predictions in the output JSON file and reprocess all " |
| 147 | + "instances. --ignore_existing_predictions bypasses loading partial results, " |
| 148 | + "--noignore_existing_predictions (default) resumes from existing predictions.", |
| 149 | +) |
143 | 150 |
|
144 | 151 |
|
145 | 152 | def guess_predictions_source( |
@@ -313,37 +320,50 @@ def main(argv: list[str]) -> None: |
313 | 320 |
|
314 | 321 | # Check the compatibility of output predictions with existing partial predictions. |
315 | 322 | if _PREDICTIONS_JSON.value: |
316 | | - partial_predictions, _ = load_partial_predictions( |
317 | | - _PREDICTIONS_JSON.value, instances_dict["instances"] |
318 | | - ) |
319 | | - predictions_source = guess_predictions_source(partial_predictions) |
320 | | - |
321 | | - if _CLASSIFIER_ONLY.value and predictions_source not in [ |
322 | | - "classifier", |
323 | | - "unknown", |
324 | | - ]: |
325 | | - raise RuntimeError( |
326 | | - f"The classifier risks overwriting previous predictions from " |
327 | | - f"`{_PREDICTIONS_JSON.value}` that were produced by different " |
328 | | - f"components. Make sure to provide a different output location to " |
329 | | - f"--{_PREDICTIONS_JSON.name}." |
330 | | - ) |
331 | | - |
332 | | - if _DETECTOR_ONLY.value and predictions_source not in ["detector", "unknown"]: |
333 | | - raise RuntimeError( |
334 | | - f"The detector risks overwriting previous predictions from " |
335 | | - f"`{_PREDICTIONS_JSON.value}` that were produced by different " |
336 | | - f"components. Make sure to provide a different output location to " |
337 | | - f"--{_PREDICTIONS_JSON.name}." |
338 | | - ) |
339 | | - |
340 | | - if _ENSEMBLE_ONLY.value and predictions_source not in ["ensemble", "unknown"]: |
341 | | - raise RuntimeError( |
342 | | - f"The ensemble risks overwriting previous predictions from " |
343 | | - f"`{_PREDICTIONS_JSON.value}` that were produced by different " |
344 | | - f"components. Make sure to provide a different output location to " |
345 | | - f"--{_PREDICTIONS_JSON.name}." |
| 323 | + if _IGNORE_EXISTING_PREDICTIONS.value: |
| 324 | + # When ignoring existing predictions, delete the file to ensure all instances |
| 325 | + # are reprocessed from scratch. |
| 326 | + if local_file_exists(_PREDICTIONS_JSON.value): |
| 327 | + print(f"Deleting existing predictions in `{_PREDICTIONS_JSON.value}`.") |
| 328 | + Path(_PREDICTIONS_JSON.value).unlink() |
| 329 | + else: |
| 330 | + partial_predictions, _ = load_partial_predictions( |
| 331 | + _PREDICTIONS_JSON.value, instances_dict["instances"] |
346 | 332 | ) |
| 333 | + predictions_source = guess_predictions_source(partial_predictions) |
| 334 | + |
| 335 | + if _CLASSIFIER_ONLY.value and predictions_source not in [ |
| 336 | + "classifier", |
| 337 | + "unknown", |
| 338 | + ]: |
| 339 | + raise RuntimeError( |
| 340 | + f"The classifier risks overwriting previous predictions from " |
| 341 | + f"`{_PREDICTIONS_JSON.value}` that were produced by different " |
| 342 | + f"components. Make sure to provide a different output location to " |
| 343 | + f"--{_PREDICTIONS_JSON.name}." |
| 344 | + ) |
| 345 | + |
| 346 | + if _DETECTOR_ONLY.value and predictions_source not in [ |
| 347 | + "detector", |
| 348 | + "unknown", |
| 349 | + ]: |
| 350 | + raise RuntimeError( |
| 351 | + f"The detector risks overwriting previous predictions from " |
| 352 | + f"`{_PREDICTIONS_JSON.value}` that were produced by different " |
| 353 | + f"components. Make sure to provide a different output location to " |
| 354 | + f"--{_PREDICTIONS_JSON.name}." |
| 355 | + ) |
| 356 | + |
| 357 | + if _ENSEMBLE_ONLY.value and predictions_source not in [ |
| 358 | + "ensemble", |
| 359 | + "unknown", |
| 360 | + ]: |
| 361 | + raise RuntimeError( |
| 362 | + f"The ensemble risks overwriting previous predictions from " |
| 363 | + f"`{_PREDICTIONS_JSON.value}` that were produced by different " |
| 364 | + f"components. Make sure to provide a different output location to " |
| 365 | + f"--{_PREDICTIONS_JSON.name}." |
| 366 | + ) |
347 | 367 |
|
348 | 368 | else: |
349 | 369 | if not say_yes_to_continue( |
|
0 commit comments