Skip to content
Open
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 104 additions & 61 deletions keras_hub/src/models/t5/t5_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def __init__(
dtype=None,
**kwargs,
):
nnx_enabled = keras.config.is_nnx_enabled()

# Token embedding layer. This layer is shared by encoder and decoder.
self.token_embedding = ReversibleEmbedding(
input_dim=vocabulary_size,
Expand All @@ -84,11 +86,13 @@ def __init__(
dtype=dtype,
name="token_embedding",
)

self.encoder_embedding_dropout = keras.layers.Dropout(
dropout,
dtype=dtype,
name="encoder_embedding_dropout",
)

self.encoder_transformer_layers = []
for i in range(num_layers):
layer = T5TransformerLayer(
Expand All @@ -106,21 +110,25 @@ def __init__(
name=f"transformer_encoder_layer_{i}",
)
self.encoder_transformer_layers.append(layer)

self.encoder_layer_norm = T5LayerNorm(
epsilon=layer_norm_epsilon,
dtype=dtype,
name="encoder_output_layer_norm",
)

self.encoder_dropout = keras.layers.Dropout(
dropout,
dtype=dtype,
name="encoder_output_dropout",
)

self.decoder_embedding_dropout = keras.layers.Dropout(
dropout,
dtype=dtype,
name="decoder_embedding_dropout",
)

self.decoder_transformer_layers = []
for i in range(num_layers):
layer = T5TransformerLayer(
Expand All @@ -138,80 +146,112 @@ def __init__(
name=f"transformer_decoder_layer_{i}",
)
self.decoder_transformer_layers.append(layer)

self.decoder_layer_norm = T5LayerNorm(
epsilon=layer_norm_epsilon,
dtype=dtype,
name="decoder_output_layer_norm",
)

self.decoder_dropout = keras.layers.Dropout(
dropout,
dtype=dtype,
name="decoder_output_dropout",
)

# === Functional Model ===
encoder_token_id_input = keras.Input(
shape=(None,), dtype="int32", name="encoder_token_ids"
)
encoder_padding_mask_input = keras.Input(
shape=(None,), dtype="int32", name="encoder_padding_mask"
)
decoder_token_id_input = keras.Input(
shape=(None,), dtype="int32", name="decoder_token_ids"
)
decoder_padding_mask_input = keras.Input(
shape=(None,), dtype="int32", name="decoder_padding_mask"
)
# Encoder.
x = self.token_embedding(encoder_token_id_input)
x = self.encoder_embedding_dropout(x)
encoder_attention_mask = encoder_padding_mask_input[:, None, :]
position_bias = None
for transformer_layer in self.encoder_transformer_layers:
output = transformer_layer(
x,
attention_mask=encoder_attention_mask,
position_bias=position_bias,
use_causal_mask=False,
def _forward(inputs, training=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _forward method is a private helper method and should ideally be placed after all public methods and properties, or at least consistently grouped with other private methods. While not a functional bug, this is a maintainability improvement.

encoder_token_ids = inputs["encoder_token_ids"]
encoder_padding_mask = inputs["encoder_padding_mask"]
decoder_token_ids = inputs["decoder_token_ids"]
decoder_padding_mask = inputs["decoder_padding_mask"]
# Encoder
x = self.token_embedding(encoder_token_ids)
x = self.encoder_embedding_dropout(x, training=training)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The training argument should be passed to the token_embedding layer call as well, to ensure consistent behavior, especially if the embedding layer has any dropout or batch normalization components that behave differently during training and inference. While ReversibleEmbedding might not directly use training, it's good practice for consistency.

            # Encoder
            x = self.token_embedding(encoder_token_ids, training=training)
            x = self.encoder_embedding_dropout(x, training=training)

encoder_attention_mask = encoder_padding_mask[:, None, :]
position_bias = None

for transformer_layer in self.encoder_transformer_layers:
output = transformer_layer(
x,
attention_mask=encoder_attention_mask,
position_bias=position_bias,
use_causal_mask=False,
training=training,
)
if isinstance(output, tuple):
x, position_bias = output

x = self.encoder_layer_norm(x)
x = self.encoder_dropout(x, training=training)
encoder_output = x
# Decoder
x = self.token_embedding(decoder_token_ids)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the encoder, the training argument should be passed to the token_embedding layer call for the decoder path to ensure consistent behavior.

            encoder_output = x
            # Decoder
            x = self.token_embedding(decoder_token_ids, training=training)
            x = self.decoder_embedding_dropout(x, training=training)

x = self.decoder_embedding_dropout(x, training=training)
decoder_attention_mask = decoder_padding_mask[:, None, :]
position_bias = None

for transformer_layer in self.decoder_transformer_layers:
output = transformer_layer(
x,
attention_mask=decoder_attention_mask,
position_bias=position_bias,
encoder_hidden_states=encoder_output,
encoder_attention_mask=encoder_attention_mask,
use_causal_mask=True,
training=training,
)
if isinstance(output, tuple):
x, position_bias = output

x = self.decoder_layer_norm(x)
x = self.decoder_dropout(x, training=training)

return {
"encoder_sequence_output": encoder_output,
"decoder_sequence_output": x,
}

self._forward = _forward

if not nnx_enabled:
# === Functional Model ===
encoder_token_id_input = keras.Input(
shape=(None,), dtype="int32", name="encoder_token_ids"
)
if isinstance(output, tuple):
x, position_bias = output
x = self.encoder_layer_norm(x)
x = self.encoder_dropout(x)
encoder_output = x
# Decoder.
x = self.token_embedding(decoder_token_id_input)
x = self.decoder_embedding_dropout(x)
decoder_attention_mask = decoder_padding_mask_input[:, None, :]
position_bias = None
for transformer_layer in self.decoder_transformer_layers:
output = transformer_layer(
x,
attention_mask=decoder_attention_mask,
position_bias=position_bias,
encoder_hidden_states=encoder_output,
encoder_attention_mask=encoder_attention_mask,
use_causal_mask=True,
encoder_padding_mask_input = keras.Input(
shape=(None,), dtype="int32", name="encoder_padding_mask"
)
decoder_token_id_input = keras.Input(
shape=(None,), dtype="int32", name="decoder_token_ids"
)
decoder_padding_mask_input = keras.Input(
shape=(None,), dtype="int32", name="decoder_padding_mask"
)

outputs = self._forward(
{
"encoder_token_ids": encoder_token_id_input,
"encoder_padding_mask": encoder_padding_mask_input,
"decoder_token_ids": decoder_token_id_input,
"decoder_padding_mask": decoder_padding_mask_input,
},
training=False,
)
if isinstance(output, tuple):
x, position_bias = output
x = self.decoder_layer_norm(x)
x = self.decoder_dropout(x)
decoder_output = x
super().__init__(
{
"encoder_token_ids": encoder_token_id_input,
"encoder_padding_mask": encoder_padding_mask_input,
"decoder_token_ids": decoder_token_id_input,
"decoder_padding_mask": decoder_padding_mask_input,
},
outputs={
"encoder_sequence_output": encoder_output,
"decoder_sequence_output": decoder_output,
},
dtype=dtype,
**kwargs,
)

super().__init__(
{
"encoder_token_ids": encoder_token_id_input,
"encoder_padding_mask": encoder_padding_mask_input,
"decoder_token_ids": decoder_token_id_input,
"decoder_padding_mask": decoder_padding_mask_input,
},
outputs=outputs,
dtype=dtype,
**kwargs,
)
else:
# NNX-safe subclassed model path
super().__init__(dtype=dtype, **kwargs)
Comment on lines +202 to +203
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change is intended to fix an issue with JAX and NNX by avoiding the Functional API. However, by only calling super().__init__() for the NNX path, the model becomes a subclassed model without a call method. This will cause a NotImplementedError when the model is called.

To complete the fix, you need to implement the call method to define the forward pass. The logic for this method is the same as what's currently inside the if not nnx_enabled: block.

Here is a suggested implementation for the call method to be added to the T5Backbone class:

    def call(self, inputs, training=False):
        encoder_token_ids = inputs["encoder_token_ids"]
        encoder_padding_mask = inputs["encoder_padding_mask"]
        decoder_token_ids = inputs["decoder_token_ids"]
        decoder_padding_mask = inputs["decoder_padding_mask"]

        # Encoder.
        x = self.token_embedding(encoder_token_ids)
        x = self.encoder_embedding_dropout(x, training=training)
        encoder_attention_mask = encoder_padding_mask[:, None, :]
        position_bias = None
        for transformer_layer in self.encoder_transformer_layers:
            output = transformer_layer(
                x,
                attention_mask=encoder_attention_mask,
                position_bias=position_bias,
                use_causal_mask=False,
                training=training,
            )
            if isinstance(output, tuple):
                x, position_bias = output
        x = self.encoder_layer_norm(x)
        x = self.encoder_dropout(x, training=training)
        encoder_output = x

        # Decoder.
        x = self.token_embedding(decoder_token_ids)
        x = self.decoder_embedding_dropout(x, training=training)
        decoder_attention_mask = decoder_padding_mask[:, None, :]
        position_bias = None
        for transformer_layer in self.decoder_transformer_layers:
            output = transformer_layer(
                x,
                attention_mask=decoder_attention_mask,
                position_bias=position_bias,
                encoder_hidden_states=encoder_output,
                encoder_attention_mask=encoder_attention_mask,
                use_causal_mask=True,
                training=training,
            )
            if isinstance(output, tuple):
                x, position_bias = output
        x = self.decoder_layer_norm(x)
        x = self.decoder_dropout(x, training=training)
        decoder_output = x

        return {
            "encoder_sequence_output": encoder_output,
            "decoder_sequence_output": decoder_output,
        }


# === Config ===
self.vocabulary_size = vocabulary_size
Expand All @@ -226,6 +266,9 @@ def __init__(
self.layer_norm_epsilon = layer_norm_epsilon
self.tie_embedding_weights = tie_embedding_weights

def call(self, inputs, training=None):
return self._forward(inputs, training=training)

def get_config(self):
config = super().get_config()
config.update(
Expand Down
Loading