diff --git a/keras_cv/models/object_detection/retinanet/retinanet.py b/keras_cv/models/object_detection/retinanet/retinanet.py index 3154f819d5..154b0839fb 100644 --- a/keras_cv/models/object_detection/retinanet/retinanet.py +++ b/keras_cv/models/object_detection/retinanet/retinanet.py @@ -524,22 +524,45 @@ def compute_metrics(self, x, y, y_pred, sample_weight): return metrics def get_config(self): - return { - "num_classes": self.num_classes, - "bounding_box_format": self.bounding_box_format, - "backbone": keras.saving.serialize_keras_object(self.backbone), - "label_encoder": keras.saving.serialize_keras_object( - self.label_encoder - ), - "prediction_decoder": self._prediction_decoder, - "classification_head": keras.saving.serialize_keras_object( - self.classification_head - ), - "box_head": keras.saving.serialize_keras_object(self.box_head), - } + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "bounding_box_format": self.bounding_box_format, + "backbone": keras.utils.serialize_keras_object(self.backbone), + "label_encoder": keras.utils.serialize_keras_object( + self.label_encoder + ), + "prediction_decoder": keras.utils.serialize_keras_object( + self._prediction_decoder + ), + "classification_head": keras.utils.serialize_keras_object( + self.classification_head + ), + "box_head": keras.utils.serialize_keras_object(self.box_head), + "feature_pyramid": keras.utils.serialize_keras_object( + self.feature_pyramid + ), + } + ) + return config @classmethod def from_config(cls, config): + if "backbone" in config and isinstance(config["backbone"], dict): + config["backbone"] = keras.layers.deserialize(config["backbone"]) + if "prediction_decoder" in config and isinstance( + config["prediction_decoder"], dict + ): + config["prediction_decoder"] = keras.layers.deserialize( + config["prediction_decoder"] + ) + if "label_encoder" in config and isinstance( + config["label_encoder"], dict + ): + config["label_encoder"] = keras.layers.deserialize( + config["label_encoder"] + ) if "box_head" in config and isinstance(config["box_head"], dict): config["box_head"] = keras.layers.deserialize(config["box_head"]) if "classification_head" in config and isinstance( @@ -548,13 +571,13 @@ def from_config(cls, config): config["classification_head"] = keras.layers.deserialize( config["classification_head"] ) - if "label_encoder" in config and isinstance( - config["label_encoder"], dict + if "feature_pyramid" in config and isinstance( + config["feature_pyramid"], dict ): - config["label_encoder"] = keras.layers.deserialize( - config["label_encoder"] + config["feature_pyramid"] = keras.layers.deserialize( + config["feature_pyramid"] ) - return super().from_config(config) + return cls(**config) @classproperty def presets(cls): diff --git a/keras_cv/models/object_detection/retinanet/retinanet_test.py b/keras_cv/models/object_detection/retinanet/retinanet_test.py index cc0f7c9131..5aa764fbf2 100644 --- a/keras_cv/models/object_detection/retinanet/retinanet_test.py +++ b/keras_cv/models/object_detection/retinanet/retinanet_test.py @@ -233,6 +233,49 @@ def test_saved_model(self): tf.nest.map_structure(ops.convert_to_numpy, restored_output), ) + @pytest.mark.large + def test_custom_saved_model(self, save_format, filename): + class CustomPredictionHead( + keras_cv.models.object_detection.retinanet.PredictionHead + ): + pass + + class CustomFeaturePyramid( + keras_cv.models.object_detection.retinanet.FeaturePyramid + ): + pass + + model = keras_cv.models.RetinaNet( + num_classes=20, + bounding_box_format="xywh", + backbone=keras_cv.models.ResNet18V2Backbone(), + box_head=CustomPredictionHead(9 * 4, tf.keras.initializers.Zeros()), + classification_head=CustomPredictionHead( + 9 * 20, tf.keras.initializers.Zeros() + ), + feature_pyramid=CustomFeaturePyramid(), + ) + + input_batch = tf.ones(shape=(2, 224, 224, 3)) + model_output = model(input_batch) + save_path = os.path.join(self.get_temp_dir(), filename) + model.save(save_path, save_format=save_format) + restored_model = keras.models.load_model( + save_path, + { + "RetinaNet": keras_cv.models.RetinaNet, + "CustomFeaturePyramid": CustomFeaturePyramid, + "CustomPredictionHead": CustomPredictionHead, + }, + ) + + # Check we got the real object back. + self.assertIsInstance(restored_model, keras_cv.models.RetinaNet) + + # Check that output matches. + restored_output = restored_model(input_batch) + self.assertAllClose(model_output, restored_output) + def test_call_with_custom_label_encoder(self): anchor_generator = keras_cv.models.RetinaNet.default_anchor_generator( "xywh"