Skip to content

Commit 9b5b5e5

Browse files
committed
parameter name changes in utils.py, changed temp file handling in save_predictions
1 parent 5307a0d commit 9b5b5e5

File tree

2 files changed

+19
-19
lines changed

2 files changed

+19
-19
lines changed

speciesnet/utils.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
from io import BytesIO
2929
import json
3030
from pathlib import Path
31-
import tempfile
3231
from typing import Any, Optional, Union
32+
import uuid
3333

3434
from absl import logging
3535
from cloudpathlib import CloudPath
@@ -179,42 +179,45 @@ def load_json(filepath: StrPath) -> dict:
179179
return json.load(fp)
180180

181181

182-
def limit_float_precision(obj: Any, precision: int) -> Any:
182+
def limit_float_precision(obj: Any, num_decimals: int) -> Any:
183183
"""Recursively limits precision of floating-point numbers in nested data structures.
184184
185185
Args:
186186
obj: The object to process (can be dict, list, float, or other types).
187-
precision: Number of decimal places to which we should round floating-point
187+
num_decimals: Number of decimal places to which we should round floating-point
188188
numbers.
189189
190190
Returns:
191191
The processed object with limited floating-point precision.
192192
"""
193193
if isinstance(obj, (float, np.floating)):
194-
return round(float(obj), precision)
194+
return round(float(obj), num_decimals)
195195
elif isinstance(obj, dict):
196196
return {
197-
key: limit_float_precision(value, precision) for key, value in obj.items()
197+
key: limit_float_precision(value, num_decimals)
198+
for key, value in obj.items()
198199
}
199200
elif isinstance(obj, list):
200-
return [limit_float_precision(item, precision) for item in obj]
201+
return [limit_float_precision(item, num_decimals) for item in obj]
201202
elif isinstance(obj, tuple):
202-
return tuple(limit_float_precision(item, precision) for item in obj)
203+
return tuple(limit_float_precision(item, num_decimals) for item in obj)
203204
else:
204205
return obj
205206

206207

207-
def write_json(data: Any, filepath: StrPath, precision: Optional[int] = None) -> None:
208+
def write_json(
209+
data: Any, filepath: StrPath, num_decimals: Optional[int] = None
210+
) -> None:
208211
"""Writes JSON-serializable data to a file with UTF-8 encoding.
209212
210213
Args:
211214
data: The JSON-serializable data to write.
212215
filepath: Path where to write the JSON file.
213-
precision: Optional number of decimal places to which we should round
216+
num_decimals: Optional number of decimal places to which we should round
214217
floating-point numbers. If None, no precision limiting is applied.
215218
"""
216-
if precision is not None:
217-
data = limit_float_precision(data, precision)
219+
if num_decimals is not None:
220+
data = limit_float_precision(data, num_decimals)
218221

219222
with open(filepath, mode="w", encoding="utf-8") as fp:
220223
json.dump(data, fp, ensure_ascii=False, indent=1)
@@ -492,13 +495,10 @@ def save_predictions(predictions_dict: dict, output_json: StrPath) -> None:
492495
"""
493496

494497
output_json = Path(output_json)
495-
output_json_tmp = Path(
496-
tempfile.mktemp(
497-
dir=output_json.parent,
498-
prefix=f"{output_json.name}.tmp.",
499-
)
500-
)
498+
stem = output_json.stem
499+
suffix = output_json.suffix
500+
output_json_tmp = output_json.parent / f"{stem}.tmp.{uuid.uuid4()}{suffix}"
501501
logging.info("Saving predictions to `%s`.", output_json_tmp)
502-
write_json(predictions_dict, output_json_tmp, precision=4)
502+
write_json(predictions_dict, output_json_tmp, num_decimals=4)
503503
logging.info("Moving `%s` to `%s`.", output_json_tmp, output_json)
504504
output_json_tmp.replace(output_json) # Atomic operation.

speciesnet/utils_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -658,7 +658,7 @@ def test_write_json_with_precision(self, tmp_path) -> None:
658658
}
659659

660660
output_file = tmp_path / "test_precision.json"
661-
write_json(test_data, output_file, precision=3)
661+
write_json(test_data, output_file, num_decimals=3)
662662

663663
# Read the file back and verify precision was limited
664664
loaded_data = load_json(output_file)

0 commit comments

Comments
 (0)