Skip to content

Commit c1b756b

Browse files
committed
fix: serialize head_dtype in all ImageClassifier subclasses
head_dtype was accepted by __init__() and used to set dtype policy for classifier head layers, but was never stored on self or included in get_config() in several ImageClassifier subclasses. Affected models: - VitImageClassifier - DeiTImageClassifier - VGGImageClassifier - MobileNetImageClassifier - MobileNetV5ImageClassifier - HGNetV2ImageClassifier This is a follow-up to keras-team#2614 which fixed the same issue in the base ImageClassifier class.
1 parent 9569150 commit c1b756b

File tree

6 files changed

+12
-0
lines changed

6 files changed

+12
-0
lines changed

keras_hub/src/models/deit/deit_image_classifier.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def __init__(
115115
**kwargs,
116116
):
117117
head_dtype = head_dtype or backbone.dtype_policy
118+
self.head_dtype = head_dtype
118119

119120
# === Layers ===
120121
self.backbone = backbone
@@ -166,6 +167,7 @@ def get_config(self):
166167
"pooling": self.pooling,
167168
"activation": self.activation,
168169
"dropout": self.dropout,
170+
"head_dtype": self.head_dtype,
169171
}
170172
)
171173
return config

keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def __init__(
117117
):
118118
name = kwargs.get("name", "hgnetv2_image_classifier")
119119
head_dtype = head_dtype or backbone.dtype_policy
120+
self.head_dtype = head_dtype
120121
data_format = getattr(backbone, "data_format", "channels_last")
121122
channel_axis = -1 if data_format == "channels_last" else 1
122123
self.head_filters = (
@@ -211,6 +212,7 @@ def get_config(self):
211212
"activation": self.activation,
212213
"dropout": self.dropout,
213214
"head_filters": self.head_filters,
215+
"head_dtype": self.head_dtype,
214216
}
215217
)
216218
return config

keras_hub/src/models/mobilenet/mobilenet_image_classifier.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def __init__(
2424
**kwargs,
2525
):
2626
head_dtype = head_dtype or backbone.dtype_policy
27+
self.head_dtype = head_dtype
2728
data_format = getattr(backbone, "data_format", None)
2829

2930
# === Layers ===
@@ -79,6 +80,7 @@ def get_config(self):
7980
{
8081
"num_classes": self.num_classes,
8182
"num_features": self.num_features,
83+
"head_dtype": self.head_dtype,
8284
}
8385
)
8486
return config

keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def __init__(
8888
**kwargs,
8989
):
9090
head_dtype = head_dtype or backbone.dtype_policy
91+
self.head_dtype = head_dtype
9192
data_format = getattr(backbone, "data_format", "channels_last")
9293

9394
# === Layers ===
@@ -152,6 +153,7 @@ def get_config(self):
152153
"head_hidden_size": self.head_hidden_size,
153154
"global_pool": self.global_pool_type,
154155
"drop_rate": self.drop_rate,
156+
"head_dtype": self.head_dtype,
155157
}
156158
)
157159
return config

keras_hub/src/models/vgg/vgg_image_classifier.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def __init__(
114114
**kwargs,
115115
):
116116
head_dtype = head_dtype or backbone.dtype_policy
117+
self.head_dtype = head_dtype
117118
data_format = getattr(backbone, "data_format", None)
118119

119120
# === Layers ===
@@ -211,6 +212,7 @@ def get_config(self):
211212
"activation": self.activation,
212213
"pooling_hidden_dim": self.pooling_hidden_dim,
213214
"dropout": self.dropout,
215+
"head_dtype": self.head_dtype,
214216
}
215217
)
216218
return config

keras_hub/src/models/vit/vit_image_classifier.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def __init__(
120120
**kwargs,
121121
):
122122
head_dtype = head_dtype or backbone.dtype_policy
123+
self.head_dtype = head_dtype
123124

124125
# === Layers ===
125126
self.backbone = backbone
@@ -182,6 +183,7 @@ def get_config(self):
182183
"intermediate_dim": self.intermediate_dim,
183184
"activation": self.activation,
184185
"dropout": self.dropout,
186+
"head_dtype": self.head_dtype,
185187
}
186188
)
187189
return config

0 commit comments

Comments
 (0)