Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions keras_hub/src/models/deit/deit_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def __init__(

# === config ===
self.num_classes = num_classes
self.head_dtype = head_dtype
self.pooling = pooling
self.activation = activation
self.dropout = dropout
Expand All @@ -166,6 +167,7 @@ def get_config(self):
"pooling": self.pooling,
"activation": self.activation,
"dropout": self.dropout,
"head_dtype": keras.dtype_policies.serialize(self.head_dtype),
}
)
return config
2 changes: 2 additions & 0 deletions keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def __init__(
self.pooling = pooling
self.dropout = dropout
self.num_classes = num_classes
self.head_dtype = head_dtype

def get_config(self):
config = Task.get_config(self)
Expand All @@ -211,6 +212,7 @@ def get_config(self):
"activation": self.activation,
"dropout": self.dropout,
"head_filters": self.head_filters,
"head_dtype": keras.dtype_policies.serialize(self.head_dtype),
}
)
return config
2 changes: 2 additions & 0 deletions keras_hub/src/models/image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def __init__(
self.activation = activation
self.pooling = pooling
self.dropout = dropout
self.head_dtype = head_dtype

def get_config(self):
# Backbone serialized in `super`
Expand All @@ -161,6 +162,7 @@ def get_config(self):
"pooling": self.pooling,
"activation": self.activation,
"dropout": self.dropout,
"head_dtype": keras.dtype_policies.serialize(self.head_dtype),
}
)
return config
Expand Down
2 changes: 2 additions & 0 deletions keras_hub/src/models/mobilenet/mobilenet_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(

# === Config ===
self.num_classes = num_classes
self.head_dtype = head_dtype
self.num_features = num_features

def get_config(self):
Expand All @@ -79,6 +80,7 @@ def get_config(self):
{
"num_classes": self.num_classes,
"num_features": self.num_features,
"head_dtype": keras.dtype_policies.serialize(self.head_dtype),
}
)
return config
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def __init__(

# === Config ===
self.num_classes = num_classes
self.head_dtype = head_dtype
self.head_hidden_size = head_hidden_size
self.global_pool_type = global_pool
self.drop_rate = drop_rate
Expand All @@ -152,6 +153,7 @@ def get_config(self):
"head_hidden_size": self.head_hidden_size,
"global_pool": self.global_pool_type,
"drop_rate": self.drop_rate,
"head_dtype": keras.dtype_policies.serialize(self.head_dtype),
}
)
return config
4 changes: 4 additions & 0 deletions keras_hub/src/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ def from_config(cls, config):
config["preprocessor"] = keras.layers.deserialize(
config["preprocessor"]
)
if "head_dtype" in config and isinstance(config["head_dtype"], dict):
config["head_dtype"] = keras.dtype_policies.deserialize(
config["head_dtype"]
)
return cls(**config)

@classproperty
Expand Down
20 changes: 20 additions & 0 deletions keras_hub/src/models/task_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,26 @@ def test_save_to_preset_custom_backbone_and_preprocessor(self):
actual = restored_task.predict(batch)
self.assertAllClose(expected, actual)

def test_image_classifier_head_dtype_serialization(self):
inputs = keras.Input(shape=(None, None, 3))
outputs = keras.layers.Dense(8)(inputs)
backbone = keras.Model(inputs, outputs)
model = ImageClassifier(
backbone=backbone,
num_classes=10,
head_dtype="float32",
)
# Verify head_dtype is in config
config = model.get_config()
self.assertIn("head_dtype", config)

# Verify round-trip via from_config
restored = ImageClassifier.from_config(config)
self.assertEqual(
str(model.head_dtype),
str(restored.head_dtype),
)

def _create_gemma_for_export_tests(self):
proto = os.path.join(self.get_test_data_dir(), "gemma_export_vocab.spm")
tokenizer = GemmaTokenizer(proto=proto)
Expand Down
2 changes: 2 additions & 0 deletions keras_hub/src/models/vgg/vgg_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def __init__(

# === Config ===
self.num_classes = num_classes
self.head_dtype = head_dtype
self.activation = activation
self.pooling = pooling
self.pooling_hidden_dim = pooling_hidden_dim
Expand All @@ -211,6 +212,7 @@ def get_config(self):
"activation": self.activation,
"pooling_hidden_dim": self.pooling_hidden_dim,
"dropout": self.dropout,
"head_dtype": keras.dtype_policies.serialize(self.head_dtype),
}
)
return config
2 changes: 2 additions & 0 deletions keras_hub/src/models/vit/vit_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def __init__(

# === config ===
self.num_classes = num_classes
self.head_dtype = head_dtype
self.pooling = pooling
self.intermediate_dim = intermediate_dim
self.activation = activation
Expand All @@ -182,6 +183,7 @@ def get_config(self):
"intermediate_dim": self.intermediate_dim,
"activation": self.activation,
"dropout": self.dropout,
"head_dtype": keras.dtype_policies.serialize(self.head_dtype),
}
)
return config
Loading