Skip to content

Commit 0c7d0c7

Browse files
hertschuhtensorflower-gardener
authored andcommitted
Fix save_weights when saving weights in TF format with a remote path.
`SupportWriteToRemote` was incorrectly guessing the file format and was assuming an H5 file, but the TF format was used and the folder failed to be copied. PiperOrigin-RevId: 739001145
1 parent e19a3cb commit 0c7d0c7

File tree

2 files changed

+98
-14
lines changed

2 files changed

+98
-14
lines changed

tf_keras/saving/saving_api.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from tf_keras.saving import saving_lib
2626
from tf_keras.saving.legacy import save as legacy_sm_saving_lib
27+
from tf_keras.saving.legacy import saving_utils
2728
from tf_keras.utils import io_utils
2829

2930
try:
@@ -75,8 +76,7 @@ class SupportWriteToRemote:
7576
supports remoted saved model out of the box.
7677
"""
7778

78-
def __init__(self, filepath, overwrite=True, save_format=None):
79-
save_format = get_save_format(filepath, save_format=save_format)
79+
def __init__(self, filepath, overwrite, save_format):
8080
self.overwrite = overwrite
8181
if saving_lib.is_remote_path(filepath) and save_format != "tf":
8282
self.temp_directory = tempfile.TemporaryDirectory()
@@ -191,14 +191,14 @@ def save_model(model, filepath, overwrite=True, save_format=None, **kwargs):
191191
when loading the model. See the `custom_objects` argument in
192192
`tf.keras.saving.load_model`.
193193
"""
194+
save_format = get_save_format(filepath, save_format)
195+
194196
# Supports remote paths via a temporary file
195197
with SupportWriteToRemote(
196198
filepath,
197199
overwrite=overwrite,
198200
save_format=save_format,
199201
) as local_filepath:
200-
save_format = get_save_format(filepath, save_format)
201-
202202
# Deprecation warnings
203203
if save_format == "h5":
204204
warnings.warn(
@@ -307,8 +307,12 @@ def load_model(
307307

308308

309309
def save_weights(model, filepath, overwrite=True, **kwargs):
310+
save_format = get_save_weights_format(filepath)
311+
310312
# Supports remote paths via a temporary file
311-
with SupportWriteToRemote(filepath, overwrite=overwrite) as local_filepath:
313+
with SupportWriteToRemote(
314+
filepath, overwrite=overwrite, save_format=save_format
315+
) as local_filepath:
312316
if str(local_filepath).endswith(".weights.h5"):
313317
# If file exists and should not be overwritten.
314318
try:
@@ -385,3 +389,12 @@ def get_save_format(filepath, save_format):
385389
return "tf"
386390
else:
387391
return "h5"
392+
393+
394+
def get_save_weights_format(filepath):
395+
filepath = io_utils.path_to_string(filepath)
396+
filepath_is_h5 = saving_utils.is_hdf5_filepath(filepath)
397+
if filepath_is_h5:
398+
return "h5"
399+
else:
400+
return "tf"

tf_keras/saving/saving_lib_test.py

+80-9
Original file line numberDiff line numberDiff line change
@@ -533,26 +533,97 @@ def test_metadata(self):
533533
self.assertIn("keras_version", metadata)
534534
self.assertIn("date_saved", metadata)
535535

536-
def test_gfile_copy_called(self):
537-
temp_filepath = Path(
538-
os.path.join(self.get_temp_dir(), "my_model.keras")
536+
def test_save_keras_gfile_copy_called(self):
537+
path = Path(os.path.join(self.get_temp_dir(), "my_model.keras"))
538+
model = keras.Sequential(
539+
[
540+
keras.Input(shape=(1, 1)),
541+
keras.layers.Dense(4),
542+
]
539543
)
540-
model = CompileOverridingModel()
541544
with mock.patch(
542545
"re.match", autospec=True
543546
) as mock_re_match, mock.patch.object(
544547
tf.io.gfile, "copy"
545548
) as mock_gfile_copy:
546549
# Check regex matching
547550
mock_re_match.return_value = True
548-
model.save(temp_filepath, save_format="keras_v3")
551+
model.save(path, save_format="keras_v3")
549552
mock_re_match.assert_called()
550-
self.assertIn(str(temp_filepath), mock_re_match.call_args.args)
553+
self.assertIn(str(path), mock_re_match.call_args.args)
551554

552555
# Check gfile copied with filepath specified as destination
553-
self.assertEqual(
554-
str(temp_filepath), str(mock_gfile_copy.call_args.args[1])
555-
)
556+
mock_gfile_copy.assert_called()
557+
self.assertEqual(str(path), str(mock_gfile_copy.call_args.args[1]))
558+
559+
def test_save_tf_gfile_copy_not_called(self):
560+
path = Path(os.path.join(self.get_temp_dir(), "my_model.keras"))
561+
model = keras.Sequential(
562+
[
563+
keras.Input(shape=(1, 1)),
564+
keras.layers.Dense(4),
565+
]
566+
)
567+
with mock.patch(
568+
"re.match", autospec=True
569+
) as mock_re_match, mock.patch.object(
570+
tf.io.gfile, "copy"
571+
) as mock_gfile_copy:
572+
# Check regex matching
573+
mock_re_match.return_value = True
574+
model.save(path, save_format="tf")
575+
mock_re_match.assert_called()
576+
self.assertIn(str(path), mock_re_match.call_args.args)
577+
578+
# Check gfile.copy was not used.
579+
mock_gfile_copy.assert_not_called()
580+
581+
def test_save_weights_h5_gfile_copy_called(self):
582+
path = Path(os.path.join(self.get_temp_dir(), "my_model.weights.h5"))
583+
model = keras.Sequential(
584+
[
585+
keras.Input(shape=(1, 1)),
586+
keras.layers.Dense(4),
587+
]
588+
)
589+
model(tf.constant([[1.0]]))
590+
with mock.patch(
591+
"re.match", autospec=True
592+
) as mock_re_match, mock.patch.object(
593+
tf.io.gfile, "copy"
594+
) as mock_gfile_copy:
595+
# Check regex matching
596+
mock_re_match.return_value = True
597+
model.save_weights(path)
598+
mock_re_match.assert_called()
599+
self.assertIn(str(path), mock_re_match.call_args.args)
600+
601+
# Check gfile copied with filepath specified as destination
602+
mock_gfile_copy.assert_called()
603+
self.assertEqual(str(path), str(mock_gfile_copy.call_args.args[1]))
604+
605+
def test_save_weights_tf_gfile_copy_not_called(self):
606+
path = Path(os.path.join(self.get_temp_dir(), "my_model.ckpt"))
607+
model = keras.Sequential(
608+
[
609+
keras.Input(shape=(1, 1)),
610+
keras.layers.Dense(4),
611+
]
612+
)
613+
model(tf.constant([[1.0]]))
614+
with mock.patch(
615+
"re.match", autospec=True
616+
) as mock_re_match, mock.patch.object(
617+
tf.io.gfile, "copy"
618+
) as mock_gfile_copy:
619+
# Check regex matching
620+
mock_re_match.return_value = True
621+
model.save_weights(path)
622+
mock_re_match.assert_called()
623+
self.assertIn(str(path), mock_re_match.call_args.args)
624+
625+
# Check gfile.copy was not used.
626+
mock_gfile_copy.assert_not_called()
556627

557628
def test_load_model_api_endpoint(self):
558629
temp_filepath = Path(os.path.join(self.get_temp_dir(), "mymodel.keras"))

0 commit comments

Comments
 (0)