Skip to content

Crash on Gemma3 token_embedding Layer During Training #2205

Closed
@rlcauvin

Description

@rlcauvin

Describe the bug
When training a classification model that uses the Gemma3 token_embedding layer, the kernel dies.

To Reproduce
https://colab.research.google.com/drive/12BAorKsFy_1651K7LLKPbglG0Pe951pI?usp=sharing

Here is the relevant code:

class GemmaEncoder(keras.Layer):

  def __init__(
    self,
    preprocessor: keras_hub.models.Gemma3CausalLMPreprocessor,
    backbone: keras_hub.models.Gemma3Backbone,
    pooling_layer: keras.layers.Layer,
    **kwargs):

    super().__init__(**kwargs)

    self.preprocessor = preprocessor
    self.backbone = backbone
    self.pooling_layer = pooling_layer

  @classmethod
  def from_preset(
    cls,
    preset: str = "gemma3_1b",
    pooling_layer: keras.layers.Layer = None,
    name = "gemma_encoder",
    **kwargs):

    preprocessor = keras_hub.models.Gemma3CausalLMPreprocessor.from_preset(preset, sequence_length = 128)
    backbone = keras_hub.models.Gemma3Backbone.from_preset(preset)
    pooling_layer = keras.layers.GlobalAveragePooling1D(name = name + "_global_average_pooling1d") if pooling_layer is None else pooling_layer

    return cls(preprocessor = preprocessor, backbone = backbone, pooling_layer = pooling_layer, name = name, **kwargs)

  def call(self, inputs):

    adapted = inputs if isinstance(inputs, dict) and "prompts" in inputs else \
      {
      "prompts": keras.ops.array(inputs),
      "responses": keras.ops.array([""])
      }
    tokenized = self.preprocessor(adapted)
    embedded = self.backbone.token_embedding(tokenized[0]["token_ids"])
    pooled = self.pooling_layer(embedded)

    return pooled

gse_layer = GemmaEncoder.from_preset(preset = "gemma3_1b");

gse_layer(inputs = ["oranges and lemons are sour", "lemons and oranges are tart"])

headline_input = keras.layers.Input(shape = (), dtype = "string", name = "headline")
headline_featurizer = gse_layer(headline_input)
dense_16 = keras.layers.Dense(16, activation = "relu", name = "dense_16")(headline_featurizer)
activation = keras.layers.Dense(1, activation = "sigmoid", name = "activation")(dense_16)

inputs = [headline_input]
outputs = [activation]
nn_model = keras.Model(inputs = inputs, outputs = outputs, name = "nn_model")

optimizer = keras.optimizers.Adam(learning_rate=0.001) # keras.optimizers.Nadam(learning_rate = 0.00007)
nn_model.compile(optimizer = optimizer, loss = "binary_crossentropy", metrics = ["accuracy"], run_eagerly = True)

x_train = {"headline" : keras.ops.array(["hello", "goodbye", "see you soon"])}
y_train = keras.ops.array([[1], [0], [0]])

nn_model_history = nn_model.fit(
  x = x_train,
  y = y_train,
  # batch_size = 1,
  epochs = 3,
  verbose = 1)

Expected behavior
The kernel shouldn't die.

Additional context
This code is a variation on another open issue I have that uses a Gemma (not Gemma3) model. In that case, the Gemma-based model trains without crashing but has some concerning warnings and doesn't work when deployed to an endpoint. In this case, with the Gemma3-based model, it crashes immediately after training begins.

Would you like to help us fix it?
I'm happy to provide any information I can to assist with fixing the issue, but I suspect it's a bug in KerasHub Gemma3 code.

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions