Skip to content

Commit 910b526

Browse files
authored
Merge pull request #48 from google/ignore-existing-predictions
added --ignore_existing_predictions flag to run_model.py to overwrite existing predictions
2 parents 89ba610 + f358ab6 commit 910b526

File tree

3 files changed

+63
-39
lines changed

3 files changed

+63
-39
lines changed

notebooks/run_speciesnet_in_jupyter.ipynb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
"from speciesnet import SpeciesNet\n",
4949
"from speciesnet import SUPPORTED_MODELS\n",
5050
"\n",
51+
"\n",
5152
"def print_predictions(predictions_dict: dict) -> None:\n",
5253
" print(\"Predictions:\")\n",
5354
" for prediction in predictions_dict[\"predictions\"]:\n",

notebooks/run_speciesnet_on_colab.ipynb

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
"from speciesnet import SpeciesNet\n",
7373
"from speciesnet import SUPPORTED_MODELS\n",
7474
"\n",
75+
"\n",
7576
"def print_predictions(predictions_dict: dict) -> None:\n",
7677
" print(\"Predictions:\")\n",
7778
" for prediction in predictions_dict[\"predictions\"]:\n",
@@ -100,14 +101,15 @@
100101
"import shutil\n",
101102
"\n",
102103
"# Choose the folder we're going to download to\n",
103-
"model_path = '/content/models'\n",
104+
"model_path = \"/content/models\"\n",
104105
"os.makedirs(model_path, exist_ok=True)\n",
105106
"\n",
106107
"# Download the model (it will go to a folder like /kaggle/input/...)\n",
107-
"download_path = kagglehub.model_download('google/speciesnet/PyTorch/v4.0.1a',\n",
108-
" force_download=True)\n",
108+
"download_path = kagglehub.model_download(\n",
109+
" \"google/speciesnet/PyTorch/v4.0.1a\", force_download=True\n",
110+
")\n",
109111
"\n",
110-
"print('Model downloaded to temporary folder: {}'.format(download_path))\n",
112+
"print(\"Model downloaded to temporary folder: {}\".format(download_path))\n",
111113
"\n",
112114
"# List the contents of the downloaded directory to identify the actual files/subdirectories\n",
113115
"model_files = os.listdir(download_path)\n",
@@ -121,7 +123,7 @@
121123
" elif os.path.isdir(source_path):\n",
122124
" shutil.copytree(source_path, destination_path, dirs_exist_ok=True)\n",
123125
"\n",
124-
"print('{} files copied to: {}'.format(len(model_files),model_path))"
126+
"print(\"{} files copied to: {}\".format(len(model_files), model_path))"
125127
]
126128
},
127129
{
@@ -141,7 +143,7 @@
141143
},
142144
"outputs": [],
143145
"source": [
144-
"os.makedirs('/content/images',exist_ok=True)\n",
146+
"os.makedirs(\"/content/images\", exist_ok=True)\n",
145147
"!wget \"https://github.com/google/cameratrapai/blob/main/test_data/african_elephants.jpg?raw=true\" -O \"/content/images/african_elephants.jpg\"\n",
146148
"!wget \"https://github.com/google/cameratrapai/blob/main/test_data/american_black_bear.jpg?raw=true\" -O \"/content/images/american_black_bear.jpg\""
147149
]
@@ -176,9 +178,10 @@
176178
"source": [
177179
"# print the contents of the output json\n",
178180
"import json\n",
179-
"with open('/content/predictions-ensemble.json','r') as f:\n",
180-
" d = json.load(f)\n",
181-
"print(str(d))\n"
181+
"\n",
182+
"with open(\"/content/predictions-ensemble.json\", \"r\") as f:\n",
183+
" d = json.load(f)\n",
184+
"print(str(d))"
182185
]
183186
},
184187
{

speciesnet/scripts/run_model.py

Lines changed: 50 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,13 @@
140140
"unexpected files are supplied. --bypass_prompts bypasses prompts, --nobypass_prompts "
141141
"(default) does not.",
142142
)
143+
_IGNORE_EXISTING_PREDICTIONS = flags.DEFINE_bool(
144+
"ignore_existing_predictions",
145+
False,
146+
"Whether to ignore existing predictions in the output JSON file and reprocess all "
147+
"instances. --ignore_existing_predictions bypasses loading partial results, "
148+
"--noignore_existing_predictions (default) resumes from existing predictions.",
149+
)
143150

144151

145152
def guess_predictions_source(
@@ -313,37 +320,50 @@ def main(argv: list[str]) -> None:
313320

314321
# Check the compatibility of output predictions with existing partial predictions.
315322
if _PREDICTIONS_JSON.value:
316-
partial_predictions, _ = load_partial_predictions(
317-
_PREDICTIONS_JSON.value, instances_dict["instances"]
318-
)
319-
predictions_source = guess_predictions_source(partial_predictions)
320-
321-
if _CLASSIFIER_ONLY.value and predictions_source not in [
322-
"classifier",
323-
"unknown",
324-
]:
325-
raise RuntimeError(
326-
f"The classifier risks overwriting previous predictions from "
327-
f"`{_PREDICTIONS_JSON.value}` that were produced by different "
328-
f"components. Make sure to provide a different output location to "
329-
f"--{_PREDICTIONS_JSON.name}."
330-
)
331-
332-
if _DETECTOR_ONLY.value and predictions_source not in ["detector", "unknown"]:
333-
raise RuntimeError(
334-
f"The detector risks overwriting previous predictions from "
335-
f"`{_PREDICTIONS_JSON.value}` that were produced by different "
336-
f"components. Make sure to provide a different output location to "
337-
f"--{_PREDICTIONS_JSON.name}."
338-
)
339-
340-
if _ENSEMBLE_ONLY.value and predictions_source not in ["ensemble", "unknown"]:
341-
raise RuntimeError(
342-
f"The ensemble risks overwriting previous predictions from "
343-
f"`{_PREDICTIONS_JSON.value}` that were produced by different "
344-
f"components. Make sure to provide a different output location to "
345-
f"--{_PREDICTIONS_JSON.name}."
323+
if _IGNORE_EXISTING_PREDICTIONS.value:
324+
# When ignoring existing predictions, delete the file to ensure all instances
325+
# are reprocessed from scratch.
326+
if local_file_exists(_PREDICTIONS_JSON.value):
327+
print(f"Deleting existing predictions in `{_PREDICTIONS_JSON.value}`.")
328+
Path(_PREDICTIONS_JSON.value).unlink()
329+
else:
330+
partial_predictions, _ = load_partial_predictions(
331+
_PREDICTIONS_JSON.value, instances_dict["instances"]
346332
)
333+
predictions_source = guess_predictions_source(partial_predictions)
334+
335+
if _CLASSIFIER_ONLY.value and predictions_source not in [
336+
"classifier",
337+
"unknown",
338+
]:
339+
raise RuntimeError(
340+
f"The classifier risks overwriting previous predictions from "
341+
f"`{_PREDICTIONS_JSON.value}` that were produced by different "
342+
f"components. Make sure to provide a different output location to "
343+
f"--{_PREDICTIONS_JSON.name}."
344+
)
345+
346+
if _DETECTOR_ONLY.value and predictions_source not in [
347+
"detector",
348+
"unknown",
349+
]:
350+
raise RuntimeError(
351+
f"The detector risks overwriting previous predictions from "
352+
f"`{_PREDICTIONS_JSON.value}` that were produced by different "
353+
f"components. Make sure to provide a different output location to "
354+
f"--{_PREDICTIONS_JSON.name}."
355+
)
356+
357+
if _ENSEMBLE_ONLY.value and predictions_source not in [
358+
"ensemble",
359+
"unknown",
360+
]:
361+
raise RuntimeError(
362+
f"The ensemble risks overwriting previous predictions from "
363+
f"`{_PREDICTIONS_JSON.value}` that were produced by different "
364+
f"components. Make sure to provide a different output location to "
365+
f"--{_PREDICTIONS_JSON.name}."
366+
)
347367

348368
else:
349369
if not say_yes_to_continue(

0 commit comments

Comments
 (0)