|
28 | 28 | from io import BytesIO |
29 | 29 | import json |
30 | 30 | from pathlib import Path |
31 | | -import tempfile |
32 | 31 | from typing import Any, Optional, Union |
| 32 | +import uuid |
33 | 33 |
|
34 | 34 | from absl import logging |
35 | 35 | from cloudpathlib import CloudPath |
@@ -179,42 +179,45 @@ def load_json(filepath: StrPath) -> dict: |
179 | 179 | return json.load(fp) |
180 | 180 |
|
181 | 181 |
|
182 | | -def limit_float_precision(obj: Any, precision: int) -> Any: |
| 182 | +def limit_float_precision(obj: Any, num_decimals: int) -> Any: |
183 | 183 | """Recursively limits precision of floating-point numbers in nested data structures. |
184 | 184 |
|
185 | 185 | Args: |
186 | 186 | obj: The object to process (can be dict, list, float, or other types). |
187 | | - precision: Number of decimal places to which we should round floating-point |
| 187 | + num_decimals: Number of decimal places to which we should round floating-point |
188 | 188 | numbers. |
189 | 189 |
|
190 | 190 | Returns: |
191 | 191 | The processed object with limited floating-point precision. |
192 | 192 | """ |
193 | 193 | if isinstance(obj, (float, np.floating)): |
194 | | - return round(float(obj), precision) |
| 194 | + return round(float(obj), num_decimals) |
195 | 195 | elif isinstance(obj, dict): |
196 | 196 | return { |
197 | | - key: limit_float_precision(value, precision) for key, value in obj.items() |
| 197 | + key: limit_float_precision(value, num_decimals) |
| 198 | + for key, value in obj.items() |
198 | 199 | } |
199 | 200 | elif isinstance(obj, list): |
200 | | - return [limit_float_precision(item, precision) for item in obj] |
| 201 | + return [limit_float_precision(item, num_decimals) for item in obj] |
201 | 202 | elif isinstance(obj, tuple): |
202 | | - return tuple(limit_float_precision(item, precision) for item in obj) |
| 203 | + return tuple(limit_float_precision(item, num_decimals) for item in obj) |
203 | 204 | else: |
204 | 205 | return obj |
205 | 206 |
|
206 | 207 |
|
207 | | -def write_json(data: Any, filepath: StrPath, precision: Optional[int] = None) -> None: |
| 208 | +def write_json( |
| 209 | + data: Any, filepath: StrPath, num_decimals: Optional[int] = None |
| 210 | +) -> None: |
208 | 211 | """Writes JSON-serializable data to a file with UTF-8 encoding. |
209 | 212 |
|
210 | 213 | Args: |
211 | 214 | data: The JSON-serializable data to write. |
212 | 215 | filepath: Path where to write the JSON file. |
213 | | - precision: Optional number of decimal places to which we should round |
| 216 | + num_decimals: Optional number of decimal places to which we should round |
214 | 217 | floating-point numbers. If None, no precision limiting is applied. |
215 | 218 | """ |
216 | | - if precision is not None: |
217 | | - data = limit_float_precision(data, precision) |
| 219 | + if num_decimals is not None: |
| 220 | + data = limit_float_precision(data, num_decimals) |
218 | 221 |
|
219 | 222 | with open(filepath, mode="w", encoding="utf-8") as fp: |
220 | 223 | json.dump(data, fp, ensure_ascii=False, indent=1) |
@@ -492,13 +495,10 @@ def save_predictions(predictions_dict: dict, output_json: StrPath) -> None: |
492 | 495 | """ |
493 | 496 |
|
494 | 497 | output_json = Path(output_json) |
495 | | - output_json_tmp = Path( |
496 | | - tempfile.mktemp( |
497 | | - dir=output_json.parent, |
498 | | - prefix=f"{output_json.name}.tmp.", |
499 | | - ) |
500 | | - ) |
| 498 | + stem = output_json.stem |
| 499 | + suffix = output_json.suffix |
| 500 | + output_json_tmp = output_json.parent / f"{stem}.tmp.{uuid.uuid4()}{suffix}" |
501 | 501 | logging.info("Saving predictions to `%s`.", output_json_tmp) |
502 | | - write_json(predictions_dict, output_json_tmp, precision=4) |
| 502 | + write_json(predictions_dict, output_json_tmp, num_decimals=4) |
503 | 503 | logging.info("Moving `%s` to `%s`.", output_json_tmp, output_json) |
504 | 504 | output_json_tmp.replace(output_json) # Atomic operation. |
0 commit comments