-
Notifications
You must be signed in to change notification settings - Fork 327
Custom RetinaNet save and load, added test case #1928
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -233,6 +233,49 @@ def test_saved_model(self): | |
| tf.nest.map_structure(ops.convert_to_numpy, restored_output), | ||
| ) | ||
|
|
||
| @pytest.mark.large | ||
| def test_custom_saved_model(self, save_format, filename): | ||
| class CustomPredictionHead( | ||
| keras_cv.models.object_detection.retinanet.PredictionHead | ||
| ): | ||
| pass | ||
|
|
||
| class CustomFeaturePyramid( | ||
| keras_cv.models.object_detection.retinanet.FeaturePyramid | ||
| ): | ||
| pass | ||
|
|
||
| model = keras_cv.models.RetinaNet( | ||
| num_classes=20, | ||
| bounding_box_format="xywh", | ||
| backbone=keras_cv.models.ResNet18V2Backbone(), | ||
| box_head=CustomPredictionHead(9 * 4, tf.keras.initializers.Zeros()), | ||
| classification_head=CustomPredictionHead( | ||
| 9 * 20, tf.keras.initializers.Zeros() | ||
| ), | ||
| feature_pyramid=CustomFeaturePyramid(), | ||
| ) | ||
|
|
||
| input_batch = tf.ones(shape=(2, 224, 224, 3)) | ||
| model_output = model(input_batch) | ||
| save_path = os.path.join(self.get_temp_dir(), filename) | ||
| model.save(save_path, save_format=save_format) | ||
| restored_model = keras.models.load_model( | ||
| save_path, | ||
| { | ||
| "RetinaNet": keras_cv.models.RetinaNet, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to provide this as a custom object? If so, why isn't it needed in all tests?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This, my friend, I still don't understand, but from what I remember that's the only way to make the whole thing work using the different formats of saving/loading. I'll look into it and let you know why I did it that way.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The problem is with the "h5" saving/loading format. Without the By the way, I noticed that in |
||
| "CustomFeaturePyramid": CustomFeaturePyramid, | ||
| "CustomPredictionHead": CustomPredictionHead, | ||
| }, | ||
| ) | ||
|
|
||
| # Check we got the real object back. | ||
| self.assertIsInstance(restored_model, keras_cv.models.RetinaNet) | ||
|
|
||
| # Check that output matches. | ||
| restored_output = restored_model(input_batch) | ||
| self.assertAllClose(model_output, restored_output) | ||
|
|
||
| def test_call_with_custom_label_encoder(self): | ||
| anchor_generator = keras_cv.models.RetinaNet.default_anchor_generator( | ||
| "xywh" | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we make sense simple convolutions or something with real variables?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, how does the forward pass work if these classes are ill-defined?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ran the test and it worked, because the custom classes are inheriting from the classes that are used as default in the RetinaNet class, they just have different names. If you think we should also test for custom architecture, e.g. creating a totally new and independent class, I'll do it, but I don't think it would change anything.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could also test for saving/loading with custom Label Encoder, Label Decoder, Anchor Generator, Backbone, and so on.