Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions keras_hub/src/models/deit/deit_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def __init__(

# === config ===
self.num_classes = num_classes
self.head_dtype = head_dtype
self.pooling = pooling
self.activation = activation
self.dropout = dropout
Expand All @@ -166,6 +167,7 @@ def get_config(self):
"pooling": self.pooling,
"activation": self.activation,
"dropout": self.dropout,
"head_dtype": keras.dtype_policies.serialize(self.head_dtype),
}
)
return config
2 changes: 2 additions & 0 deletions keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def __init__(
self.pooling = pooling
self.dropout = dropout
self.num_classes = num_classes
self.head_dtype = head_dtype

def get_config(self):
config = Task.get_config(self)
Expand All @@ -211,6 +212,7 @@ def get_config(self):
"activation": self.activation,
"dropout": self.dropout,
"head_filters": self.head_filters,
"head_dtype": keras.dtype_policies.serialize(self.head_dtype),
}
)
return config
2 changes: 2 additions & 0 deletions keras_hub/src/models/image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def __init__(
self.activation = activation
self.pooling = pooling
self.dropout = dropout
self.head_dtype = head_dtype

def get_config(self):
# Backbone serialized in `super`
Expand All @@ -161,6 +162,7 @@ def get_config(self):
"pooling": self.pooling,
"activation": self.activation,
"dropout": self.dropout,
"head_dtype": keras.dtype_policies.serialize(self.head_dtype),
}
)
return config
Expand Down
2 changes: 2 additions & 0 deletions keras_hub/src/models/mobilenet/mobilenet_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(

# === Config ===
self.num_classes = num_classes
self.head_dtype = head_dtype
self.num_features = num_features

def get_config(self):
Expand All @@ -79,6 +80,7 @@ def get_config(self):
{
"num_classes": self.num_classes,
"num_features": self.num_features,
"head_dtype": keras.dtype_policies.serialize(self.head_dtype),
}
)
return config
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def __init__(

# === Config ===
self.num_classes = num_classes
self.head_dtype = head_dtype
self.head_hidden_size = head_hidden_size
self.global_pool_type = global_pool
self.drop_rate = drop_rate
Expand All @@ -152,6 +153,7 @@ def get_config(self):
"head_hidden_size": self.head_hidden_size,
"global_pool": self.global_pool_type,
"drop_rate": self.drop_rate,
"head_dtype": keras.dtype_policies.serialize(self.head_dtype),
}
)
return config
2 changes: 2 additions & 0 deletions keras_hub/src/models/vgg/vgg_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def __init__(

# === Config ===
self.num_classes = num_classes
self.head_dtype = head_dtype
self.activation = activation
self.pooling = pooling
self.pooling_hidden_dim = pooling_hidden_dim
Expand All @@ -211,6 +212,7 @@ def get_config(self):
"activation": self.activation,
"pooling_hidden_dim": self.pooling_hidden_dim,
"dropout": self.dropout,
"head_dtype": keras.dtype_policies.serialize(self.head_dtype),
}
)
return config
2 changes: 2 additions & 0 deletions keras_hub/src/models/vit/vit_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def __init__(

# === config ===
self.num_classes = num_classes
self.head_dtype = head_dtype
self.pooling = pooling
self.intermediate_dim = intermediate_dim
self.activation = activation
Expand All @@ -182,6 +183,7 @@ def get_config(self):
"intermediate_dim": self.intermediate_dim,
"activation": self.activation,
"dropout": self.dropout,
"head_dtype": keras.dtype_policies.serialize(self.head_dtype),
}
)
return config
Loading