diff --git a/keras_hub/src/models/image_classifier.py b/keras_hub/src/models/image_classifier.py index e75e390899..6bcbe4728f 100644 --- a/keras_hub/src/models/image_classifier.py +++ b/keras_hub/src/models/image_classifier.py @@ -100,6 +100,7 @@ def __init__( **kwargs, ): head_dtype = head_dtype or backbone.dtype_policy + self.head_dtype = head_dtype data_format = getattr(backbone, "data_format", None) # === Layers === @@ -161,6 +162,7 @@ def get_config(self): "pooling": self.pooling, "activation": self.activation, "dropout": self.dropout, + "head_dtype": self.head_dtype, } ) return config