Skip to content

Commit 0529ccd

Browse files
authored
Fixes Saving Format related test failures for Keras 3 (#2180)
* Deprecate keras_v3 save format for Keras3 * Deprecate keras_v3 save format for Keras3 * Deprecate keras_v3 save format for Keras3
1 parent 91d8a58 commit 0529ccd

File tree

3 files changed

+17
-9
lines changed

3 files changed

+17
-9
lines changed

keras_cv/models/segmentation/deeplab_v3_plus/deeplab_v3_plus_test.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from keras_cv.backend import keras
2323
from keras_cv.backend import ops
24+
from keras_cv.backend.config import keras_3
2425
from keras_cv.models import DeepLabV3Plus
2526
from keras_cv.models import ResNet18V2Backbone
2627
from keras_cv.models.backbones.test_backbone_presets import (
@@ -86,12 +87,8 @@ def test_with_model_preset_forward_pass(self):
8687
expected_output = np.zeros((1, 512, 512, 1))
8788
self.assertAllClose(output, expected_output)
8889

89-
@parameterized.named_parameters(
90-
("tf_format", "tf", "model"),
91-
("keras_format", "keras_v3", "model.keras"),
92-
)
9390
@pytest.mark.large # Saving is slow, so mark these large.
94-
def test_saved_model(self, save_format, filename):
91+
def test_saved_model(self):
9592
target_size = [512, 512, 3]
9693

9794
backbone = ResNet18V2Backbone(input_shape=target_size)
@@ -100,8 +97,11 @@ def test_saved_model(self, save_format, filename):
10097
input_batch = np.ones(shape=[2] + target_size)
10198
model_output = model(input_batch)
10299

103-
save_path = os.path.join(self.get_temp_dir(), filename)
104-
model.save(save_path, save_format=save_format)
100+
save_path = os.path.join(self.get_temp_dir(), "model.keras")
101+
if keras_3():
102+
model.save(save_path)
103+
else:
104+
model.save(save_path, save_format="keras_v3")
105105
restored_model = keras.models.load_model(save_path)
106106

107107
# Check we got the real object back.

keras_cv/models/segmentation/segformer/segformer_test.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from keras_cv.backend import keras
2222
from keras_cv.backend import ops
23+
from keras_cv.backend.config import keras_3
2324
from keras_cv.models import MiTBackbone
2425
from keras_cv.models import SegFormer
2526
from keras_cv.tests.test_case import TestCase
@@ -81,7 +82,10 @@ def test_saved_model(self):
8182
model_output = model(input_batch)
8283

8384
save_path = os.path.join(self.get_temp_dir(), "model.keras")
84-
model.save(save_path, save_format="keras_v3")
85+
if keras_3():
86+
model.save(save_path)
87+
else:
88+
model.save(save_path, save_format="keras_v3")
8589
restored_model = keras.models.load_model(save_path)
8690

8791
# Check we got the real object back.

keras_cv/models/segmentation/segment_anything/sam_test.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from keras_cv.backend import keras
2424
from keras_cv.backend import ops
25+
from keras_cv.backend.config import keras_3
2526
from keras_cv.models.backbones.vit_det.vit_det_aliases import ViTDetBBackbone
2627
from keras_cv.models.segmentation.segment_anything.sam import (
2728
SegmentAnythingModel,
@@ -282,7 +283,10 @@ def test_end_to_end_model_save(self):
282283

283284
# Save the model
284285
save_path = os.path.join(self.get_temp_dir(), "model.keras")
285-
model.save(save_path, save_format="keras_v3")
286+
if keras_3():
287+
model.save(save_path)
288+
else:
289+
model.save(save_path, save_format="keras_v3")
286290
restored_model = keras.models.load_model(save_path)
287291

288292
# Check we got the real object back.

0 commit comments

Comments
 (0)