Skip to content
This repository was archived by the owner on Mar 10, 2026. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 41 additions & 18 deletions keras_cv/models/object_detection/retinanet/retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,22 +524,45 @@ def compute_metrics(self, x, y, y_pred, sample_weight):
return metrics

def get_config(self):
return {
"num_classes": self.num_classes,
"bounding_box_format": self.bounding_box_format,
"backbone": keras.saving.serialize_keras_object(self.backbone),
"label_encoder": keras.saving.serialize_keras_object(
self.label_encoder
),
"prediction_decoder": self._prediction_decoder,
"classification_head": keras.saving.serialize_keras_object(
self.classification_head
),
"box_head": keras.saving.serialize_keras_object(self.box_head),
}
config = super().get_config()
config.update(
{
"num_classes": self.num_classes,
"bounding_box_format": self.bounding_box_format,
"backbone": keras.utils.serialize_keras_object(self.backbone),
"label_encoder": keras.utils.serialize_keras_object(
self.label_encoder
),
"prediction_decoder": keras.utils.serialize_keras_object(
self._prediction_decoder
),
"classification_head": keras.utils.serialize_keras_object(
self.classification_head
),
"box_head": keras.utils.serialize_keras_object(self.box_head),
"feature_pyramid": keras.utils.serialize_keras_object(
self.feature_pyramid
),
}
)
return config

@classmethod
def from_config(cls, config):
if "backbone" in config and isinstance(config["backbone"], dict):
config["backbone"] = keras.layers.deserialize(config["backbone"])
if "prediction_decoder" in config and isinstance(
config["prediction_decoder"], dict
):
config["prediction_decoder"] = keras.layers.deserialize(
config["prediction_decoder"]
)
if "label_encoder" in config and isinstance(
config["label_encoder"], dict
):
config["label_encoder"] = keras.layers.deserialize(
config["label_encoder"]
)
if "box_head" in config and isinstance(config["box_head"], dict):
config["box_head"] = keras.layers.deserialize(config["box_head"])
if "classification_head" in config and isinstance(
Expand All @@ -548,13 +571,13 @@ def from_config(cls, config):
config["classification_head"] = keras.layers.deserialize(
config["classification_head"]
)
if "label_encoder" in config and isinstance(
config["label_encoder"], dict
if "feature_pyramid" in config and isinstance(
config["feature_pyramid"], dict
):
config["label_encoder"] = keras.layers.deserialize(
config["label_encoder"]
config["feature_pyramid"] = keras.layers.deserialize(
config["feature_pyramid"]
)
return super().from_config(config)
return cls(**config)

@classproperty
def presets(cls):
Expand Down
43 changes: 43 additions & 0 deletions keras_cv/models/object_detection/retinanet/retinanet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,49 @@ def test_saved_model(self):
tf.nest.map_structure(ops.convert_to_numpy, restored_output),
)

@pytest.mark.large
def test_custom_saved_model(self, save_format, filename):
class CustomPredictionHead(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make sense simple convolutions or something with real variables?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, how does the forward pass work if these classes are ill-defined?

Copy link
Contributor Author

@gianlucasama gianlucasama Jul 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran the test and it worked, because the custom classes are inheriting from the classes that are used as default in the RetinaNet class, they just have different names. If you think we should also test for custom architecture, e.g. creating a totally new and independent class, I'll do it, but I don't think it would change anything.

Copy link
Contributor Author

@gianlucasama gianlucasama Jul 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also test for saving/loading with custom Label Encoder, Label Decoder, Anchor Generator, Backbone, and so on.

keras_cv.models.object_detection.retinanet.PredictionHead
):
pass

class CustomFeaturePyramid(
keras_cv.models.object_detection.retinanet.FeaturePyramid
):
pass

model = keras_cv.models.RetinaNet(
num_classes=20,
bounding_box_format="xywh",
backbone=keras_cv.models.ResNet18V2Backbone(),
box_head=CustomPredictionHead(9 * 4, tf.keras.initializers.Zeros()),
classification_head=CustomPredictionHead(
9 * 20, tf.keras.initializers.Zeros()
),
feature_pyramid=CustomFeaturePyramid(),
)

input_batch = tf.ones(shape=(2, 224, 224, 3))
model_output = model(input_batch)
save_path = os.path.join(self.get_temp_dir(), filename)
model.save(save_path, save_format=save_format)
restored_model = keras.models.load_model(
save_path,
{
"RetinaNet": keras_cv.models.RetinaNet,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to provide this as a custom object? If so, why isn't it needed in all tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This, my friend, I still don't understand, but from what I remember that's the only way to make the whole thing work using the different formats of saving/loading. I'll look into it and let you know why I did it that way.

Copy link
Contributor Author

@gianlucasama gianlucasama Jul 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is with the "h5" saving/loading format. Without the "RetinaNet": keras_cv.models.RetinaNet "h5" loading doesn't work, it gives you:

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/gianlucasama/.vscode/extensions/ms-python.python-2023.12.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in <module>
    cli.main()
  File "/home/gianlucasama/.vscode/extensions/ms-python.python-2023.12.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
    run()
  File "/home/gianlucasama/.vscode/extensions/ms-python.python-2023.12.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file
    runpy.run_path(target, run_name="__main__")
  File "/home/gianlucasama/.vscode/extensions/ms-python.python-2023.12.0/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
    return _run_module_code(code, init_globals, run_name,
  File "/home/gianlucasama/.vscode/extensions/ms-python.python-2023.12.0/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "/home/gianlucasama/.vscode/extensions/ms-python.python-2023.12.0/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
    exec(code, run_globals)
  File "/home/gianlucasama/Github/keras-cv/keras_cv/models/object_detection/retinanet/retinanet_test.py", line 374, in <module>
    t.test_custom_saved_model("h5", "model")
  File "/home/gianlucasama/Github/keras-cv/keras_cv/models/object_detection/retinanet/retinanet_test.py", line 274, in test_custom_saved_model
    restored_model = keras.models.load_model(
  File "/home/gianlucasama/.local/lib/python3.10/site-packages/keras/saving/saving_api.py", line 212, in load_model
    return legacy_sm_saving_lib.load_model(
  File "/home/gianlucasama/.local/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 61, in error_handler
    return fn(*args, **kwargs)
  File "/home/gianlucasama/.local/lib/python3.10/site-packages/keras/saving/legacy/save.py", line 245, in load_model
    return hdf5_format.load_model_from_hdf5(
  File "/home/gianlucasama/.local/lib/python3.10/site-packages/keras/saving/legacy/hdf5_format.py", line 192, in load_model_from_hdf5
    model = model_config_lib.model_from_config(
  File "/home/gianlucasama/.local/lib/python3.10/site-packages/keras/saving/legacy/model_config.py", line 55, in model_from_config
    return deserialize(config, custom_objects=custom_objects)
  File ["/home/gianlucasama/.local/lib/python3.10/site-packages/keras/layers/serialization.py",] line 265, in deserialize
    return legacy_serialization.deserialize_keras_object(
  File "/home/gianlucasama/.local/lib/python3.10/site-packages/keras/saving/legacy/serialization.py", line 486, in deserialize_keras_object
    (cls, cls_config) = class_and_config_for_serialized_keras_object(
  File "/home/gianlucasama/.local/lib/python3.10/site-packages/keras/saving/legacy/serialization.py", line 368, in class_and_config_for_serialized_keras_object
    raise ValueError(
ValueError: Unknown layer: 'RetinaNet'. Please ensure you are using a `keras.utils.custom_object_scope` and that this object is included in the scope. See https://www.tensorflow.org/guide/keras/save_and_serialize#registering_the_custom_object for details.

By the way, I noticed that in RetinaNetTest.test_saved_model, line 228, we are not testing the "h5" saving/loading format and that's probably why we are not adding the "RetinaNet": keras_cv.models.RetinaNet there.

"CustomFeaturePyramid": CustomFeaturePyramid,
"CustomPredictionHead": CustomPredictionHead,
},
)

# Check we got the real object back.
self.assertIsInstance(restored_model, keras_cv.models.RetinaNet)

# Check that output matches.
restored_output = restored_model(input_batch)
self.assertAllClose(model_output, restored_output)

def test_call_with_custom_label_encoder(self):
anchor_generator = keras_cv.models.RetinaNet.default_anchor_generator(
"xywh"
Expand Down