Description
Describe the bug
When training a classification model that uses the Gemma3 token_embedding layer, the kernel dies.
To Reproduce
https://colab.research.google.com/drive/12BAorKsFy_1651K7LLKPbglG0Pe951pI?usp=sharing
Here is the relevant code:
class GemmaEncoder(keras.Layer):
def __init__(
self,
preprocessor: keras_hub.models.Gemma3CausalLMPreprocessor,
backbone: keras_hub.models.Gemma3Backbone,
pooling_layer: keras.layers.Layer,
**kwargs):
super().__init__(**kwargs)
self.preprocessor = preprocessor
self.backbone = backbone
self.pooling_layer = pooling_layer
@classmethod
def from_preset(
cls,
preset: str = "gemma3_1b",
pooling_layer: keras.layers.Layer = None,
name = "gemma_encoder",
**kwargs):
preprocessor = keras_hub.models.Gemma3CausalLMPreprocessor.from_preset(preset, sequence_length = 128)
backbone = keras_hub.models.Gemma3Backbone.from_preset(preset)
pooling_layer = keras.layers.GlobalAveragePooling1D(name = name + "_global_average_pooling1d") if pooling_layer is None else pooling_layer
return cls(preprocessor = preprocessor, backbone = backbone, pooling_layer = pooling_layer, name = name, **kwargs)
def call(self, inputs):
adapted = inputs if isinstance(inputs, dict) and "prompts" in inputs else \
{
"prompts": keras.ops.array(inputs),
"responses": keras.ops.array([""])
}
tokenized = self.preprocessor(adapted)
embedded = self.backbone.token_embedding(tokenized[0]["token_ids"])
pooled = self.pooling_layer(embedded)
return pooled
gse_layer = GemmaEncoder.from_preset(preset = "gemma3_1b");
gse_layer(inputs = ["oranges and lemons are sour", "lemons and oranges are tart"])
headline_input = keras.layers.Input(shape = (), dtype = "string", name = "headline")
headline_featurizer = gse_layer(headline_input)
dense_16 = keras.layers.Dense(16, activation = "relu", name = "dense_16")(headline_featurizer)
activation = keras.layers.Dense(1, activation = "sigmoid", name = "activation")(dense_16)
inputs = [headline_input]
outputs = [activation]
nn_model = keras.Model(inputs = inputs, outputs = outputs, name = "nn_model")
optimizer = keras.optimizers.Adam(learning_rate=0.001) # keras.optimizers.Nadam(learning_rate = 0.00007)
nn_model.compile(optimizer = optimizer, loss = "binary_crossentropy", metrics = ["accuracy"], run_eagerly = True)
x_train = {"headline" : keras.ops.array(["hello", "goodbye", "see you soon"])}
y_train = keras.ops.array([[1], [0], [0]])
nn_model_history = nn_model.fit(
x = x_train,
y = y_train,
# batch_size = 1,
epochs = 3,
verbose = 1)
Expected behavior
The kernel shouldn't die.
Additional context
This code is a variation on another open issue I have that uses a Gemma (not Gemma3) model. In that case, the Gemma-based model trains without crashing but has some concerning warnings and doesn't work when deployed to an endpoint. In this case, with the Gemma3-based model, it crashes immediately after training begins.
Would you like to help us fix it?
I'm happy to provide any information I can to assist with fixing the issue, but I suspect it's a bug in KerasHub Gemma3 code.