From c1b756b23ba7fb82ac96eaae912e0527d166c6b9 Mon Sep 17 00:00:00 2001 From: yashwanth510 Date: Mon, 2 Mar 2026 21:32:00 +0530 Subject: [PATCH 1/5] fix: serialize head_dtype in all ImageClassifier subclasses head_dtype was accepted by __init__() and used to set dtype policy for classifier head layers, but was never stored on self or included in get_config() in several ImageClassifier subclasses. Affected models: - VitImageClassifier - DeiTImageClassifier - VGGImageClassifier - MobileNetImageClassifier - MobileNetV5ImageClassifier - HGNetV2ImageClassifier This is a follow-up to #2614 which fixed the same issue in the base ImageClassifier class. --- keras_hub/src/models/deit/deit_image_classifier.py | 2 ++ keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py | 2 ++ keras_hub/src/models/mobilenet/mobilenet_image_classifier.py | 2 ++ .../src/models/mobilenetv5/mobilenetv5_image_classifier.py | 2 ++ keras_hub/src/models/vgg/vgg_image_classifier.py | 2 ++ keras_hub/src/models/vit/vit_image_classifier.py | 2 ++ 6 files changed, 12 insertions(+) diff --git a/keras_hub/src/models/deit/deit_image_classifier.py b/keras_hub/src/models/deit/deit_image_classifier.py index 4c5da009e4..a8429a7881 100644 --- a/keras_hub/src/models/deit/deit_image_classifier.py +++ b/keras_hub/src/models/deit/deit_image_classifier.py @@ -115,6 +115,7 @@ def __init__( **kwargs, ): head_dtype = head_dtype or backbone.dtype_policy + self.head_dtype = head_dtype # === Layers === self.backbone = backbone @@ -166,6 +167,7 @@ def get_config(self): "pooling": self.pooling, "activation": self.activation, "dropout": self.dropout, + "head_dtype": self.head_dtype, } ) return config diff --git a/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py b/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py index bed9831563..9500a1bca4 100644 --- a/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py +++ b/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py @@ -117,6 +117,7 @@ def __init__( ): name = kwargs.get("name", "hgnetv2_image_classifier") head_dtype = head_dtype or backbone.dtype_policy + self.head_dtype = head_dtype data_format = getattr(backbone, "data_format", "channels_last") channel_axis = -1 if data_format == "channels_last" else 1 self.head_filters = ( @@ -211,6 +212,7 @@ def get_config(self): "activation": self.activation, "dropout": self.dropout, "head_filters": self.head_filters, + "head_dtype": self.head_dtype, } ) return config diff --git a/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py b/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py index 1545c39615..af1b9b05ba 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +++ b/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py @@ -24,6 +24,7 @@ def __init__( **kwargs, ): head_dtype = head_dtype or backbone.dtype_policy + self.head_dtype = head_dtype data_format = getattr(backbone, "data_format", None) # === Layers === @@ -79,6 +80,7 @@ def get_config(self): { "num_classes": self.num_classes, "num_features": self.num_features, + "head_dtype": self.head_dtype, } ) return config diff --git a/keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py b/keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py index 0d1ab8d970..52c0456f03 100644 --- a/keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py +++ b/keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py @@ -88,6 +88,7 @@ def __init__( **kwargs, ): head_dtype = head_dtype or backbone.dtype_policy + self.head_dtype = head_dtype data_format = getattr(backbone, "data_format", "channels_last") # === Layers === @@ -152,6 +153,7 @@ def get_config(self): "head_hidden_size": self.head_hidden_size, "global_pool": self.global_pool_type, "drop_rate": self.drop_rate, + "head_dtype": self.head_dtype, } ) return config diff --git a/keras_hub/src/models/vgg/vgg_image_classifier.py b/keras_hub/src/models/vgg/vgg_image_classifier.py index a72b256288..27f5c36e8a 100644 --- a/keras_hub/src/models/vgg/vgg_image_classifier.py +++ b/keras_hub/src/models/vgg/vgg_image_classifier.py @@ -114,6 +114,7 @@ def __init__( **kwargs, ): head_dtype = head_dtype or backbone.dtype_policy + self.head_dtype = head_dtype data_format = getattr(backbone, "data_format", None) # === Layers === @@ -211,6 +212,7 @@ def get_config(self): "activation": self.activation, "pooling_hidden_dim": self.pooling_hidden_dim, "dropout": self.dropout, + "head_dtype": self.head_dtype, } ) return config diff --git a/keras_hub/src/models/vit/vit_image_classifier.py b/keras_hub/src/models/vit/vit_image_classifier.py index 6e8746d6b6..fde7791916 100644 --- a/keras_hub/src/models/vit/vit_image_classifier.py +++ b/keras_hub/src/models/vit/vit_image_classifier.py @@ -120,6 +120,7 @@ def __init__( **kwargs, ): head_dtype = head_dtype or backbone.dtype_policy + self.head_dtype = head_dtype # === Layers === self.backbone = backbone @@ -182,6 +183,7 @@ def get_config(self): "intermediate_dim": self.intermediate_dim, "activation": self.activation, "dropout": self.dropout, + "head_dtype": self.head_dtype, } ) return config From d063f4f1ff25e20bb370a185f53f2e0db1cefd5b Mon Sep 17 00:00:00 2001 From: yashwanth510 Date: Tue, 3 Mar 2026 08:17:18 +0530 Subject: [PATCH 2/5] fix: move self.head_dtype to config section and include base ImageClassifier fix --- keras_hub/src/models/deit/deit_image_classifier.py | 2 +- keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py | 2 +- keras_hub/src/models/image_classifier.py | 2 ++ keras_hub/src/models/mobilenet/mobilenet_image_classifier.py | 2 +- .../src/models/mobilenetv5/mobilenetv5_image_classifier.py | 2 +- keras_hub/src/models/vgg/vgg_image_classifier.py | 2 +- keras_hub/src/models/vit/vit_image_classifier.py | 2 +- 7 files changed, 8 insertions(+), 6 deletions(-) diff --git a/keras_hub/src/models/deit/deit_image_classifier.py b/keras_hub/src/models/deit/deit_image_classifier.py index a8429a7881..7e4aa6d762 100644 --- a/keras_hub/src/models/deit/deit_image_classifier.py +++ b/keras_hub/src/models/deit/deit_image_classifier.py @@ -115,7 +115,6 @@ def __init__( **kwargs, ): head_dtype = head_dtype or backbone.dtype_policy - self.head_dtype = head_dtype # === Layers === self.backbone = backbone @@ -154,6 +153,7 @@ def __init__( # === config === self.num_classes = num_classes + self.head_dtype = head_dtype self.pooling = pooling self.activation = activation self.dropout = dropout diff --git a/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py b/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py index 9500a1bca4..a406848137 100644 --- a/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py +++ b/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py @@ -117,7 +117,6 @@ def __init__( ): name = kwargs.get("name", "hgnetv2_image_classifier") head_dtype = head_dtype or backbone.dtype_policy - self.head_dtype = head_dtype data_format = getattr(backbone, "data_format", "channels_last") channel_axis = -1 if data_format == "channels_last" else 1 self.head_filters = ( @@ -202,6 +201,7 @@ def __init__( self.pooling = pooling self.dropout = dropout self.num_classes = num_classes + self.head_dtype = head_dtype def get_config(self): config = Task.get_config(self) diff --git a/keras_hub/src/models/image_classifier.py b/keras_hub/src/models/image_classifier.py index e75e390899..6bcbe4728f 100644 --- a/keras_hub/src/models/image_classifier.py +++ b/keras_hub/src/models/image_classifier.py @@ -100,6 +100,7 @@ def __init__( **kwargs, ): head_dtype = head_dtype or backbone.dtype_policy + self.head_dtype = head_dtype data_format = getattr(backbone, "data_format", None) # === Layers === @@ -161,6 +162,7 @@ def get_config(self): "pooling": self.pooling, "activation": self.activation, "dropout": self.dropout, + "head_dtype": self.head_dtype, } ) return config diff --git a/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py b/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py index af1b9b05ba..bb9b7cdb80 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +++ b/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py @@ -24,7 +24,6 @@ def __init__( **kwargs, ): head_dtype = head_dtype or backbone.dtype_policy - self.head_dtype = head_dtype data_format = getattr(backbone, "data_format", None) # === Layers === @@ -71,6 +70,7 @@ def __init__( # === Config === self.num_classes = num_classes + self.head_dtype = head_dtype self.num_features = num_features def get_config(self): diff --git a/keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py b/keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py index 52c0456f03..d713c030e4 100644 --- a/keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py +++ b/keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py @@ -88,7 +88,6 @@ def __init__( **kwargs, ): head_dtype = head_dtype or backbone.dtype_policy - self.head_dtype = head_dtype data_format = getattr(backbone, "data_format", "channels_last") # === Layers === @@ -141,6 +140,7 @@ def __init__( # === Config === self.num_classes = num_classes + self.head_dtype = head_dtype self.head_hidden_size = head_hidden_size self.global_pool_type = global_pool self.drop_rate = drop_rate diff --git a/keras_hub/src/models/vgg/vgg_image_classifier.py b/keras_hub/src/models/vgg/vgg_image_classifier.py index 27f5c36e8a..c7ccecdd19 100644 --- a/keras_hub/src/models/vgg/vgg_image_classifier.py +++ b/keras_hub/src/models/vgg/vgg_image_classifier.py @@ -114,7 +114,6 @@ def __init__( **kwargs, ): head_dtype = head_dtype or backbone.dtype_policy - self.head_dtype = head_dtype data_format = getattr(backbone, "data_format", None) # === Layers === @@ -196,6 +195,7 @@ def __init__( # === Config === self.num_classes = num_classes + self.head_dtype = head_dtype self.activation = activation self.pooling = pooling self.pooling_hidden_dim = pooling_hidden_dim diff --git a/keras_hub/src/models/vit/vit_image_classifier.py b/keras_hub/src/models/vit/vit_image_classifier.py index fde7791916..377253cec3 100644 --- a/keras_hub/src/models/vit/vit_image_classifier.py +++ b/keras_hub/src/models/vit/vit_image_classifier.py @@ -120,7 +120,6 @@ def __init__( **kwargs, ): head_dtype = head_dtype or backbone.dtype_policy - self.head_dtype = head_dtype # === Layers === self.backbone = backbone @@ -168,6 +167,7 @@ def __init__( # === config === self.num_classes = num_classes + self.head_dtype = head_dtype self.pooling = pooling self.intermediate_dim = intermediate_dim self.activation = activation From da89b6931bd36cd3af107dc3b7f9447e70459930 Mon Sep 17 00:00:00 2001 From: yashwanth510 Date: Tue, 3 Mar 2026 08:21:04 +0530 Subject: [PATCH 3/5] fix: move self.head_dtype to config section in base ImageClassifier --- keras_hub/src/models/image_classifier.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_hub/src/models/image_classifier.py b/keras_hub/src/models/image_classifier.py index 6bcbe4728f..d36c3be1cd 100644 --- a/keras_hub/src/models/image_classifier.py +++ b/keras_hub/src/models/image_classifier.py @@ -100,7 +100,6 @@ def __init__( **kwargs, ): head_dtype = head_dtype or backbone.dtype_policy - self.head_dtype = head_dtype data_format = getattr(backbone, "data_format", None) # === Layers === @@ -152,6 +151,7 @@ def __init__( self.activation = activation self.pooling = pooling self.dropout = dropout + self.head_dtype = head_dtype def get_config(self): # Backbone serialized in `super` From a1a4660bfb1271bc18b06470c8fb36d195f5007f Mon Sep 17 00:00:00 2001 From: yashwanth510 Date: Tue, 3 Mar 2026 11:01:27 +0530 Subject: [PATCH 4/5] fix: serialize head_dtype using keras.dtype_policies.serialize --- keras_hub/src/models/deit/deit_image_classifier.py | 2 +- keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py | 2 +- keras_hub/src/models/image_classifier.py | 2 +- keras_hub/src/models/mobilenet/mobilenet_image_classifier.py | 2 +- .../src/models/mobilenetv5/mobilenetv5_image_classifier.py | 2 +- keras_hub/src/models/vgg/vgg_image_classifier.py | 2 +- keras_hub/src/models/vit/vit_image_classifier.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/keras_hub/src/models/deit/deit_image_classifier.py b/keras_hub/src/models/deit/deit_image_classifier.py index 7e4aa6d762..90a55755a2 100644 --- a/keras_hub/src/models/deit/deit_image_classifier.py +++ b/keras_hub/src/models/deit/deit_image_classifier.py @@ -167,7 +167,7 @@ def get_config(self): "pooling": self.pooling, "activation": self.activation, "dropout": self.dropout, - "head_dtype": self.head_dtype, + "head_dtype": keras.dtype_policies.serialize(self.head_dtype), } ) return config diff --git a/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py b/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py index a406848137..ebc78cc47b 100644 --- a/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py +++ b/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py @@ -212,7 +212,7 @@ def get_config(self): "activation": self.activation, "dropout": self.dropout, "head_filters": self.head_filters, - "head_dtype": self.head_dtype, + "head_dtype": keras.dtype_policies.serialize(self.head_dtype), } ) return config diff --git a/keras_hub/src/models/image_classifier.py b/keras_hub/src/models/image_classifier.py index d36c3be1cd..cef7051d99 100644 --- a/keras_hub/src/models/image_classifier.py +++ b/keras_hub/src/models/image_classifier.py @@ -162,7 +162,7 @@ def get_config(self): "pooling": self.pooling, "activation": self.activation, "dropout": self.dropout, - "head_dtype": self.head_dtype, + "head_dtype": keras.dtype_policies.serialize(self.head_dtype), } ) return config diff --git a/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py b/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py index bb9b7cdb80..abcc0cbeb0 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +++ b/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py @@ -80,7 +80,7 @@ def get_config(self): { "num_classes": self.num_classes, "num_features": self.num_features, - "head_dtype": self.head_dtype, + "head_dtype": keras.dtype_policies.serialize(self.head_dtype), } ) return config diff --git a/keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py b/keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py index d713c030e4..f6e50e2407 100644 --- a/keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py +++ b/keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py @@ -153,7 +153,7 @@ def get_config(self): "head_hidden_size": self.head_hidden_size, "global_pool": self.global_pool_type, "drop_rate": self.drop_rate, - "head_dtype": self.head_dtype, + "head_dtype": keras.dtype_policies.serialize(self.head_dtype), } ) return config diff --git a/keras_hub/src/models/vgg/vgg_image_classifier.py b/keras_hub/src/models/vgg/vgg_image_classifier.py index c7ccecdd19..0dc7902c71 100644 --- a/keras_hub/src/models/vgg/vgg_image_classifier.py +++ b/keras_hub/src/models/vgg/vgg_image_classifier.py @@ -212,7 +212,7 @@ def get_config(self): "activation": self.activation, "pooling_hidden_dim": self.pooling_hidden_dim, "dropout": self.dropout, - "head_dtype": self.head_dtype, + "head_dtype": keras.dtype_policies.serialize(self.head_dtype), } ) return config diff --git a/keras_hub/src/models/vit/vit_image_classifier.py b/keras_hub/src/models/vit/vit_image_classifier.py index 377253cec3..16bbd5374b 100644 --- a/keras_hub/src/models/vit/vit_image_classifier.py +++ b/keras_hub/src/models/vit/vit_image_classifier.py @@ -183,7 +183,7 @@ def get_config(self): "intermediate_dim": self.intermediate_dim, "activation": self.activation, "dropout": self.dropout, - "head_dtype": self.head_dtype, + "head_dtype": keras.dtype_policies.serialize(self.head_dtype), } ) return config From 0d754d9bc37ce80cca510d5a624b4deec396302d Mon Sep 17 00:00:00 2001 From: yashwanth510 Date: Sun, 8 Mar 2026 11:59:16 +0530 Subject: [PATCH 5/5] fix: add head_dtype deserialization in task.py and add test case --- keras_hub/src/models/task.py | 4 ++++ keras_hub/src/models/task_test.py | 20 ++++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/keras_hub/src/models/task.py b/keras_hub/src/models/task.py index abf450093c..d8e80ae21e 100644 --- a/keras_hub/src/models/task.py +++ b/keras_hub/src/models/task.py @@ -119,6 +119,10 @@ def from_config(cls, config): config["preprocessor"] = keras.layers.deserialize( config["preprocessor"] ) + if "head_dtype" in config and isinstance(config["head_dtype"], dict): + config["head_dtype"] = keras.dtype_policies.deserialize( + config["head_dtype"] + ) return cls(**config) @classproperty diff --git a/keras_hub/src/models/task_test.py b/keras_hub/src/models/task_test.py index d1dc60b67e..f3c616756b 100644 --- a/keras_hub/src/models/task_test.py +++ b/keras_hub/src/models/task_test.py @@ -273,6 +273,26 @@ def test_save_to_preset_custom_backbone_and_preprocessor(self): actual = restored_task.predict(batch) self.assertAllClose(expected, actual) + def test_image_classifier_head_dtype_serialization(self): + inputs = keras.Input(shape=(None, None, 3)) + outputs = keras.layers.Dense(8)(inputs) + backbone = keras.Model(inputs, outputs) + model = ImageClassifier( + backbone=backbone, + num_classes=10, + head_dtype="float32", + ) + # Verify head_dtype is in config + config = model.get_config() + self.assertIn("head_dtype", config) + + # Verify round-trip via from_config + restored = ImageClassifier.from_config(config) + self.assertEqual( + str(model.head_dtype), + str(restored.head_dtype), + ) + def _create_gemma_for_export_tests(self): proto = os.path.join(self.get_test_data_dir(), "gemma_export_vocab.spm") tokenizer = GemmaTokenizer(proto=proto)