Skip to content

Commit a799026

Browse files
authored
Merge pull request #42 from google/fix-json-handling
Fix json encoding issue
2 parents 429f4d2 + 9b5b5e5 commit a799026

File tree

3 files changed

+261
-30
lines changed

3 files changed

+261
-30
lines changed

.gitignore

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,3 +269,13 @@ tags
269269

270270
/benchmark_data/
271271
/models/
272+
273+
################
274+
### AI stuff ###
275+
################
276+
277+
GEMINI.md
278+
claude.md
279+
**/*settings.json
280+
**/*settings.local.json
281+

speciesnet/utils.py

Lines changed: 82 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,16 @@
2020
"file_exists",
2121
"load_rgb_image",
2222
"prepare_instances_dict",
23+
"load_json",
24+
"write_json",
2325
]
2426

2527
from dataclasses import dataclass
2628
from io import BytesIO
2729
import json
2830
from pathlib import Path
29-
import tempfile
30-
from typing import Optional, Union
31+
from typing import Any, Optional, Union
32+
import uuid
3133

3234
from absl import logging
3335
from cloudpathlib import CloudPath
@@ -99,8 +101,7 @@ def __init__(self, model_name: str) -> None:
99101
base_dir = Path(base_dir)
100102

101103
# Load model info.
102-
with open(base_dir / "info.json", mode="r", encoding="utf-8") as fp:
103-
info = json.load(fp)
104+
info = load_json(base_dir / "info.json")
104105

105106
# Download detector weights if not provided with the other model files.
106107
filepath_or_url = info["detector"]
@@ -165,6 +166,63 @@ class BBox:
165166
height: float
166167

167168

169+
def load_json(filepath: StrPath) -> dict:
170+
"""Loads a JSON file with UTF-8 encoding.
171+
172+
Args:
173+
filepath: Path to the JSON file to load.
174+
175+
Returns:
176+
The loaded JSON data as a dictionary.
177+
"""
178+
with open(filepath, mode="r", encoding="utf-8") as fp:
179+
return json.load(fp)
180+
181+
182+
def limit_float_precision(obj: Any, num_decimals: int) -> Any:
183+
"""Recursively limits precision of floating-point numbers in nested data structures.
184+
185+
Args:
186+
obj: The object to process (can be dict, list, float, or other types).
187+
num_decimals: Number of decimal places to which we should round floating-point
188+
numbers.
189+
190+
Returns:
191+
The processed object with limited floating-point precision.
192+
"""
193+
if isinstance(obj, (float, np.floating)):
194+
return round(float(obj), num_decimals)
195+
elif isinstance(obj, dict):
196+
return {
197+
key: limit_float_precision(value, num_decimals)
198+
for key, value in obj.items()
199+
}
200+
elif isinstance(obj, list):
201+
return [limit_float_precision(item, num_decimals) for item in obj]
202+
elif isinstance(obj, tuple):
203+
return tuple(limit_float_precision(item, num_decimals) for item in obj)
204+
else:
205+
return obj
206+
207+
208+
def write_json(
209+
data: Any, filepath: StrPath, num_decimals: Optional[int] = None
210+
) -> None:
211+
"""Writes JSON-serializable data to a file with UTF-8 encoding.
212+
213+
Args:
214+
data: The JSON-serializable data to write.
215+
filepath: Path where to write the JSON file.
216+
num_decimals: Optional number of decimal places to which we should round
217+
floating-point numbers. If None, no precision limiting is applied.
218+
"""
219+
if num_decimals is not None:
220+
data = limit_float_precision(data, num_decimals)
221+
222+
with open(filepath, mode="w", encoding="utf-8") as fp:
223+
json.dump(data, fp, ensure_ascii=False, indent=1)
224+
225+
168226
def only_one_true(*args) -> bool:
169227
"""Checks that only one of the given arguments is `True`."""
170228

@@ -331,8 +389,7 @@ def _enforce_location(
331389
)
332390

333391
if instances_json is not None:
334-
with open(instances_json, mode="r", encoding="utf-8") as fp:
335-
instances_dict = json.load(fp)
392+
instances_dict = load_json(instances_json)
336393
if instances_dict is not None:
337394
return _enforce_location(instances_dict, country, admin1_region)
338395

@@ -403,21 +460,20 @@ def load_partial_predictions(
403460

404461
partial_predictions = {}
405462
target_filepaths = {instance["filepath"] for instance in instances}
406-
with open(predictions_json, mode="r", encoding="utf-8") as fp:
407-
predictions_dict = json.load(fp)
408-
for prediction in predictions_dict["predictions"]:
409-
filepath = prediction["filepath"]
410-
if filepath not in target_filepaths:
411-
raise RuntimeError(
412-
f"Filepath from loaded predictions is missing from the set of "
413-
f"instances to process: `{filepath}`. Make sure you're resuming "
414-
f"the work using the same set of instances."
415-
)
416-
417-
if "failures" in prediction:
418-
continue
419-
420-
partial_predictions[prediction["filepath"]] = prediction
463+
predictions_dict = load_json(predictions_json)
464+
for prediction in predictions_dict["predictions"]:
465+
filepath = prediction["filepath"]
466+
if filepath not in target_filepaths:
467+
raise RuntimeError(
468+
f"Filepath from loaded predictions is missing from the set of "
469+
f"instances to process: `{filepath}`. Make sure you're resuming "
470+
f"the work using the same set of instances."
471+
)
472+
473+
if "failures" in prediction:
474+
continue
475+
476+
partial_predictions[prediction["filepath"]] = prediction
421477

422478
instances_to_process = [
423479
instance
@@ -439,14 +495,10 @@ def save_predictions(predictions_dict: dict, output_json: StrPath) -> None:
439495
"""
440496

441497
output_json = Path(output_json)
442-
with tempfile.NamedTemporaryFile(
443-
mode="w",
444-
dir=output_json.parent,
445-
prefix=f"{output_json.name}.tmp.",
446-
delete=False,
447-
) as fp:
448-
logging.info("Saving predictions to `%s`.", fp.name)
449-
output_json_tmp = Path(fp.name)
450-
json.dump(predictions_dict, fp, ensure_ascii=False, indent=4)
498+
stem = output_json.stem
499+
suffix = output_json.suffix
500+
output_json_tmp = output_json.parent / f"{stem}.tmp.{uuid.uuid4()}{suffix}"
501+
logging.info("Saving predictions to `%s`.", output_json_tmp)
502+
write_json(predictions_dict, output_json_tmp, num_decimals=4)
451503
logging.info("Moving `%s` to `%s`.", output_json_tmp, output_json)
452504
output_json_tmp.replace(output_json) # Atomic operation.

speciesnet/utils_test.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,18 @@
1818
from pathlib import Path
1919
from typing import Generator
2020

21+
import numpy as np
2122
import pytest
2223

2324
from speciesnet.utils import file_exists
25+
from speciesnet.utils import limit_float_precision
26+
from speciesnet.utils import load_json
2427
from speciesnet.utils import load_partial_predictions
2528
from speciesnet.utils import load_rgb_image
2629
from speciesnet.utils import ModelInfo
2730
from speciesnet.utils import prepare_instances_dict
2831
from speciesnet.utils import save_predictions
32+
from speciesnet.utils import write_json
2933

3034
# fmt: off
3135
# pylint: disable=line-too-long
@@ -514,3 +518,168 @@ def test_failed_saving(self, tmp_path) -> None:
514518
}
515519
with pytest.raises(TypeError):
516520
save_predictions(predictions, tmp_path)
521+
522+
523+
class TestPrecisionLimiting:
524+
"""Tests for precision limiting functionality in JSON operations."""
525+
526+
def test_limit_float_precision_simple_float(self) -> None:
527+
"""Test precision limiting for simple floats."""
528+
assert limit_float_precision(3.14159265359, 2) == 3.14
529+
assert limit_float_precision(3.14159265359, 4) == 3.1416
530+
assert limit_float_precision(2.0, 2) == 2.0
531+
532+
def test_limit_float_precision_numpy_float(self) -> None:
533+
"""Test precision limiting for numpy floats."""
534+
assert limit_float_precision(np.float32(3.14159265359), 2) == 3.14
535+
assert limit_float_precision(np.float64(3.14159265359), 4) == 3.1416
536+
assert limit_float_precision(np.float16(2.0), 2) == 2.0
537+
538+
def test_limit_float_precision_non_float_types(self) -> None:
539+
"""Test that non-float types are unchanged."""
540+
assert limit_float_precision("string", 2) == "string"
541+
assert limit_float_precision(42, 2) == 42
542+
assert limit_float_precision(True, 2) is True
543+
assert limit_float_precision(None, 2) is None
544+
545+
def test_limit_float_precision_simple_list(self) -> None:
546+
"""Test precision limiting for lists with floats."""
547+
input_list = [1.23456, 2.0, "string", 42, 9.87654]
548+
expected = [1.23, 2.0, "string", 42, 9.88]
549+
assert limit_float_precision(input_list, 2) == expected
550+
551+
def test_limit_float_precision_simple_dict(self) -> None:
552+
"""Test precision limiting for dictionaries with floats."""
553+
input_dict = {
554+
"float_val": 3.14159,
555+
"string_val": "test",
556+
"int_val": 42,
557+
"another_float": 2.718281828,
558+
}
559+
expected = {
560+
"float_val": 3.14,
561+
"string_val": "test",
562+
"int_val": 42,
563+
"another_float": 2.72,
564+
}
565+
assert limit_float_precision(input_dict, 2) == expected
566+
567+
def test_limit_float_precision_nested_dict_in_list(self) -> None:
568+
"""Test precision limiting for dictionaries nested in lists."""
569+
input_data = [
570+
{"score": 0.123456, "name": "item1"},
571+
{"score": 0.987654, "name": "item2"},
572+
"string_item",
573+
42,
574+
]
575+
expected = [
576+
{"score": 0.123, "name": "item1"},
577+
{"score": 0.988, "name": "item2"},
578+
"string_item",
579+
42,
580+
]
581+
assert limit_float_precision(input_data, 3) == expected
582+
583+
def test_limit_float_precision_nested_list_in_dict(self) -> None:
584+
"""Test precision limiting for lists nested in dictionaries."""
585+
input_data = {
586+
"scores": [0.12345, 0.67890, 0.99999],
587+
"names": ["A", "B", "C"],
588+
"metadata": {"threshold": 0.54321, "version": "1.0"},
589+
}
590+
expected = {
591+
"scores": [0.12, 0.68, 1.0],
592+
"names": ["A", "B", "C"],
593+
"metadata": {"threshold": 0.54, "version": "1.0"},
594+
}
595+
assert limit_float_precision(input_data, 2) == expected
596+
597+
def test_limit_float_precision_deeply_nested(self) -> None:
598+
"""Test precision limiting for deeply nested structures."""
599+
input_data = {
600+
"level1": {
601+
"level2": {
602+
"level3": [
603+
{"deep_float": 3.14159265359, "items": [1.23456, 7.89012]},
604+
{"another_deep": 2.71828, "values": [9.87654, 5.43210]},
605+
]
606+
}
607+
}
608+
}
609+
expected = {
610+
"level1": {
611+
"level2": {
612+
"level3": [
613+
{"deep_float": 3.1416, "items": [1.2346, 7.8901]},
614+
{"another_deep": 2.7183, "values": [9.8765, 5.4321]},
615+
]
616+
}
617+
}
618+
}
619+
assert limit_float_precision(input_data, 4) == expected
620+
621+
def test_limit_float_precision_tuples(self) -> None:
622+
"""Test precision limiting for tuples."""
623+
input_tuple = (1.23456, "string", 7.89012, 42)
624+
expected = (1.23, "string", 7.89, 42)
625+
assert limit_float_precision(input_tuple, 2) == expected
626+
627+
def test_limit_float_precision_mixed_numpy_types(self) -> None:
628+
"""Test precision limiting with mixed numpy and Python floats."""
629+
input_data = {
630+
"python_float": 3.14159,
631+
"numpy_float32": np.float32(2.71828),
632+
"numpy_float64": np.float64(1.41421),
633+
"list_mixed": [1.23456, np.float32(9.87654), "string", np.float64(5.55555)],
634+
}
635+
expected = {
636+
"python_float": 3.14,
637+
"numpy_float32": 2.72,
638+
"numpy_float64": 1.41,
639+
"list_mixed": [1.23, 9.88, "string", 5.56],
640+
}
641+
assert limit_float_precision(input_data, 2) == expected
642+
643+
def test_write_json_with_precision(self, tmp_path) -> None:
644+
"""Test write_json function with precision parameter."""
645+
test_data = {
646+
"predictions": [
647+
{
648+
"filepath": "test.jpg",
649+
"scores": [0.123456789, 0.987654321],
650+
"bbox": [0.111111, 0.222222, 0.333333, 0.444444],
651+
"confidence": 0.876543210,
652+
"nested": {
653+
"value": 1.414213562,
654+
"items": [2.718281828, 3.141592654],
655+
},
656+
}
657+
]
658+
}
659+
660+
output_file = tmp_path / "test_precision.json"
661+
write_json(test_data, output_file, num_decimals=3)
662+
663+
# Read the file back and verify precision was limited
664+
loaded_data = load_json(output_file)
665+
prediction = loaded_data["predictions"][0]
666+
667+
assert prediction["scores"] == [0.123, 0.988]
668+
assert prediction["bbox"] == [0.111, 0.222, 0.333, 0.444]
669+
assert prediction["confidence"] == 0.877
670+
assert prediction["nested"]["value"] == 1.414
671+
assert prediction["nested"]["items"] == [2.718, 3.142]
672+
673+
def test_write_json_without_precision(self, tmp_path) -> None:
674+
"""Test write_json function without precision parameter.
675+
676+
Should preserve original precision.
677+
"""
678+
test_data = {"value": 3.14159265359, "scores": [0.123456789, 0.987654321]}
679+
680+
output_file = tmp_path / "test_no_precision.json"
681+
write_json(test_data, output_file)
682+
683+
loaded_data = load_json(output_file)
684+
assert loaded_data["value"] == 3.14159265359
685+
assert loaded_data["scores"] == [0.123456789, 0.987654321]

0 commit comments

Comments
 (0)