diff --git a/keras_hub/src/models/deit/deit_image_classifier.py b/keras_hub/src/models/deit/deit_image_classifier.py index 4c5da009e4..90a55755a2 100644 --- a/keras_hub/src/models/deit/deit_image_classifier.py +++ b/keras_hub/src/models/deit/deit_image_classifier.py @@ -153,6 +153,7 @@ def __init__( # === config === self.num_classes = num_classes + self.head_dtype = head_dtype self.pooling = pooling self.activation = activation self.dropout = dropout @@ -166,6 +167,7 @@ def get_config(self): "pooling": self.pooling, "activation": self.activation, "dropout": self.dropout, + "head_dtype": keras.dtype_policies.serialize(self.head_dtype), } ) return config diff --git a/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py b/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py index bed9831563..ebc78cc47b 100644 --- a/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py +++ b/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py @@ -201,6 +201,7 @@ def __init__( self.pooling = pooling self.dropout = dropout self.num_classes = num_classes + self.head_dtype = head_dtype def get_config(self): config = Task.get_config(self) @@ -211,6 +212,7 @@ def get_config(self): "activation": self.activation, "dropout": self.dropout, "head_filters": self.head_filters, + "head_dtype": keras.dtype_policies.serialize(self.head_dtype), } ) return config diff --git a/keras_hub/src/models/image_classifier.py b/keras_hub/src/models/image_classifier.py index e75e390899..cef7051d99 100644 --- a/keras_hub/src/models/image_classifier.py +++ b/keras_hub/src/models/image_classifier.py @@ -151,6 +151,7 @@ def __init__( self.activation = activation self.pooling = pooling self.dropout = dropout + self.head_dtype = head_dtype def get_config(self): # Backbone serialized in `super` @@ -161,6 +162,7 @@ def get_config(self): "pooling": self.pooling, "activation": self.activation, "dropout": self.dropout, + "head_dtype": keras.dtype_policies.serialize(self.head_dtype), } ) return config diff --git a/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py b/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py index 1545c39615..abcc0cbeb0 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +++ b/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py @@ -70,6 +70,7 @@ def __init__( # === Config === self.num_classes = num_classes + self.head_dtype = head_dtype self.num_features = num_features def get_config(self): @@ -79,6 +80,7 @@ def get_config(self): { "num_classes": self.num_classes, "num_features": self.num_features, + "head_dtype": keras.dtype_policies.serialize(self.head_dtype), } ) return config diff --git a/keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py b/keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py index 0d1ab8d970..f6e50e2407 100644 --- a/keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py +++ b/keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py @@ -140,6 +140,7 @@ def __init__( # === Config === self.num_classes = num_classes + self.head_dtype = head_dtype self.head_hidden_size = head_hidden_size self.global_pool_type = global_pool self.drop_rate = drop_rate @@ -152,6 +153,7 @@ def get_config(self): "head_hidden_size": self.head_hidden_size, "global_pool": self.global_pool_type, "drop_rate": self.drop_rate, + "head_dtype": keras.dtype_policies.serialize(self.head_dtype), } ) return config diff --git a/keras_hub/src/models/task.py b/keras_hub/src/models/task.py index abf450093c..d8e80ae21e 100644 --- a/keras_hub/src/models/task.py +++ b/keras_hub/src/models/task.py @@ -119,6 +119,10 @@ def from_config(cls, config): config["preprocessor"] = keras.layers.deserialize( config["preprocessor"] ) + if "head_dtype" in config and isinstance(config["head_dtype"], dict): + config["head_dtype"] = keras.dtype_policies.deserialize( + config["head_dtype"] + ) return cls(**config) @classproperty diff --git a/keras_hub/src/models/task_test.py b/keras_hub/src/models/task_test.py index d1dc60b67e..f3c616756b 100644 --- a/keras_hub/src/models/task_test.py +++ b/keras_hub/src/models/task_test.py @@ -273,6 +273,26 @@ def test_save_to_preset_custom_backbone_and_preprocessor(self): actual = restored_task.predict(batch) self.assertAllClose(expected, actual) + def test_image_classifier_head_dtype_serialization(self): + inputs = keras.Input(shape=(None, None, 3)) + outputs = keras.layers.Dense(8)(inputs) + backbone = keras.Model(inputs, outputs) + model = ImageClassifier( + backbone=backbone, + num_classes=10, + head_dtype="float32", + ) + # Verify head_dtype is in config + config = model.get_config() + self.assertIn("head_dtype", config) + + # Verify round-trip via from_config + restored = ImageClassifier.from_config(config) + self.assertEqual( + str(model.head_dtype), + str(restored.head_dtype), + ) + def _create_gemma_for_export_tests(self): proto = os.path.join(self.get_test_data_dir(), "gemma_export_vocab.spm") tokenizer = GemmaTokenizer(proto=proto) diff --git a/keras_hub/src/models/vgg/vgg_image_classifier.py b/keras_hub/src/models/vgg/vgg_image_classifier.py index a72b256288..0dc7902c71 100644 --- a/keras_hub/src/models/vgg/vgg_image_classifier.py +++ b/keras_hub/src/models/vgg/vgg_image_classifier.py @@ -195,6 +195,7 @@ def __init__( # === Config === self.num_classes = num_classes + self.head_dtype = head_dtype self.activation = activation self.pooling = pooling self.pooling_hidden_dim = pooling_hidden_dim @@ -211,6 +212,7 @@ def get_config(self): "activation": self.activation, "pooling_hidden_dim": self.pooling_hidden_dim, "dropout": self.dropout, + "head_dtype": keras.dtype_policies.serialize(self.head_dtype), } ) return config diff --git a/keras_hub/src/models/vit/vit_image_classifier.py b/keras_hub/src/models/vit/vit_image_classifier.py index 6e8746d6b6..16bbd5374b 100644 --- a/keras_hub/src/models/vit/vit_image_classifier.py +++ b/keras_hub/src/models/vit/vit_image_classifier.py @@ -167,6 +167,7 @@ def __init__( # === config === self.num_classes = num_classes + self.head_dtype = head_dtype self.pooling = pooling self.intermediate_dim = intermediate_dim self.activation = activation @@ -182,6 +183,7 @@ def get_config(self): "intermediate_dim": self.intermediate_dim, "activation": self.activation, "dropout": self.dropout, + "head_dtype": keras.dtype_policies.serialize(self.head_dtype), } ) return config