Skip to content

Commit 740598c

Browse files
Added retry to save_predictions and save_dataset
1 parent cc1674d commit 740598c

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

src/fmcore/framework/_dataset.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,6 @@ def save_dataset(
377377
dataset_destination.format,
378378
**kwargs,
379379
)
380-
json_writer: JsonWriter = JsonWriter()
381380

382381
if dataset_destination.is_path_valid_dir():
383382
dataset_params_file: FileMetadata = dataset_destination.file_in_dir(
@@ -403,8 +402,14 @@ def save_dataset(
403402
+ DATASET_PARAMS_SAVE_FILE_ENDING,
404403
format=FileFormat.JSON,
405404
)
405+
writer.write_metadata(
406+
file=dataset_destination,
407+
data=dataset.data,
408+
overwrite=overwrite,
409+
**kwargs,
410+
)
406411

407-
writer.write_metadata(file=dataset_destination, data=dataset.data, overwrite=overwrite, **kwargs)
412+
json_writer: JsonWriter = JsonWriter(**kwargs)
408413
json_writer.write_metadata(
409414
file=dataset_params_file,
410415
data=dataset.dict(exclude={"data", "data_idx", "data_position", "validated"}),

src/fmcore/framework/_predictions.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,6 @@ def save_predictions(
552552
predictions_destination.format,
553553
**kwargs,
554554
)
555-
json_writer: JsonWriter = JsonWriter()
556555

557556
if predictions_destination.is_path_valid_dir():
558557
predictions_params_file: FileMetadata = predictions_destination.file_in_dir(
@@ -579,7 +578,14 @@ def save_predictions(
579578
format=FileFormat.JSON,
580579
)
581580

582-
writer.write_metadata(file=predictions_destination, data=predictions.data, overwrite=overwrite, **kwargs)
581+
writer.write_metadata(
582+
file=predictions_destination,
583+
data=predictions.data,
584+
overwrite=overwrite,
585+
**kwargs,
586+
)
587+
588+
json_writer: JsonWriter = JsonWriter(**kwargs)
583589
json_writer.write_metadata(
584590
file=predictions_params_file,
585591
data=predictions.dict(exclude={"data", "data_idx", "data_position", "validated"}),

0 commit comments

Comments
 (0)