Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 79 additions & 61 deletions keras_hub/src/models/t5/t5_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ def __init__(
dtype=None,
**kwargs,
):
import keras

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The import keras statement is redundant as keras is already imported at the top of the file. This local import and the following blank line can be removed.

nnx_enabled = keras.config.is_nnx_enabled()

# Token embedding layer. This layer is shared by encoder and decoder.
self.token_embedding = ReversibleEmbedding(
input_dim=vocabulary_size,
Expand All @@ -89,6 +93,7 @@ def __init__(
dtype=dtype,
name="encoder_embedding_dropout",
)

self.encoder_transformer_layers = []
for i in range(num_layers):
layer = T5TransformerLayer(
Expand All @@ -106,21 +111,25 @@ def __init__(
name=f"transformer_encoder_layer_{i}",
)
self.encoder_transformer_layers.append(layer)

self.encoder_layer_norm = T5LayerNorm(
epsilon=layer_norm_epsilon,
dtype=dtype,
name="encoder_output_layer_norm",
)

self.encoder_dropout = keras.layers.Dropout(
dropout,
dtype=dtype,
name="encoder_output_dropout",
)

self.decoder_embedding_dropout = keras.layers.Dropout(
dropout,
dtype=dtype,
name="decoder_embedding_dropout",
)

self.decoder_transformer_layers = []
for i in range(num_layers):
layer = T5TransformerLayer(
Expand All @@ -138,80 +147,89 @@ def __init__(
name=f"transformer_decoder_layer_{i}",
)
self.decoder_transformer_layers.append(layer)

self.decoder_layer_norm = T5LayerNorm(
epsilon=layer_norm_epsilon,
dtype=dtype,
name="decoder_output_layer_norm",
)

self.decoder_dropout = keras.layers.Dropout(
dropout,
dtype=dtype,
name="decoder_output_dropout",
)

# === Functional Model ===
encoder_token_id_input = keras.Input(
shape=(None,), dtype="int32", name="encoder_token_ids"
)
encoder_padding_mask_input = keras.Input(
shape=(None,), dtype="int32", name="encoder_padding_mask"
)
decoder_token_id_input = keras.Input(
shape=(None,), dtype="int32", name="decoder_token_ids"
)
decoder_padding_mask_input = keras.Input(
shape=(None,), dtype="int32", name="decoder_padding_mask"
)
# Encoder.
x = self.token_embedding(encoder_token_id_input)
x = self.encoder_embedding_dropout(x)
encoder_attention_mask = encoder_padding_mask_input[:, None, :]
position_bias = None
for transformer_layer in self.encoder_transformer_layers:
output = transformer_layer(
x,
attention_mask=encoder_attention_mask,
position_bias=position_bias,
use_causal_mask=False,
if not nnx_enabled:
# === Functional Model ===
encoder_token_id_input = keras.Input(
shape=(None,), dtype="int32", name="encoder_token_ids"
)
if isinstance(output, tuple):
x, position_bias = output
x = self.encoder_layer_norm(x)
x = self.encoder_dropout(x)
encoder_output = x
# Decoder.
x = self.token_embedding(decoder_token_id_input)
x = self.decoder_embedding_dropout(x)
decoder_attention_mask = decoder_padding_mask_input[:, None, :]
position_bias = None
for transformer_layer in self.decoder_transformer_layers:
output = transformer_layer(
x,
attention_mask=decoder_attention_mask,
position_bias=position_bias,
encoder_hidden_states=encoder_output,
encoder_attention_mask=encoder_attention_mask,
use_causal_mask=True,
encoder_padding_mask_input = keras.Input(
shape=(None,), dtype="int32", name="encoder_padding_mask"
)
if isinstance(output, tuple):
x, position_bias = output
x = self.decoder_layer_norm(x)
x = self.decoder_dropout(x)
decoder_output = x
super().__init__(
{
"encoder_token_ids": encoder_token_id_input,
"encoder_padding_mask": encoder_padding_mask_input,
"decoder_token_ids": decoder_token_id_input,
"decoder_padding_mask": decoder_padding_mask_input,
},
outputs={
"encoder_sequence_output": encoder_output,
"decoder_sequence_output": decoder_output,
},
dtype=dtype,
**kwargs,
)
decoder_token_id_input = keras.Input(
shape=(None,), dtype="int32", name="decoder_token_ids"
)
decoder_padding_mask_input = keras.Input(
shape=(None,), dtype="int32", name="decoder_padding_mask"
)

# Encoder.
x = self.token_embedding(encoder_token_id_input)
x = self.encoder_embedding_dropout(x)
encoder_attention_mask = encoder_padding_mask_input[:, None, :]
position_bias = None
for transformer_layer in self.encoder_transformer_layers:
output = transformer_layer(
x,
attention_mask=encoder_attention_mask,
position_bias=position_bias,
use_causal_mask=False,
)
if isinstance(output, tuple):
x, position_bias = output
x = self.encoder_layer_norm(x)
x = self.encoder_dropout(x)
encoder_output = x
# Decoder.
x = self.token_embedding(decoder_token_id_input)
x = self.decoder_embedding_dropout(x)
decoder_attention_mask = decoder_padding_mask_input[:, None, :]
position_bias = None
for transformer_layer in self.decoder_transformer_layers:
output = transformer_layer(
x,
attention_mask=decoder_attention_mask,
position_bias=position_bias,
encoder_hidden_states=encoder_output,
encoder_attention_mask=encoder_attention_mask,
use_causal_mask=True,
)
if isinstance(output, tuple):
x, position_bias = output
x = self.decoder_layer_norm(x)
x = self.decoder_dropout(x)
decoder_output = x

super().__init__(
{
"encoder_token_ids": encoder_token_id_input,
"encoder_padding_mask": encoder_padding_mask_input,
"decoder_token_ids": decoder_token_id_input,
"decoder_padding_mask": decoder_padding_mask_input,
},
outputs={
"encoder_sequence_output": encoder_output,
"decoder_sequence_output": decoder_output,
},
dtype=dtype,
**kwargs,
)

else:
# NNX-safe subclassed model path
super().__init__(dtype=dtype, **kwargs)
Comment on lines +202 to +203
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change is intended to fix an issue with JAX and NNX by avoiding the Functional API. However, by only calling super().__init__() for the NNX path, the model becomes a subclassed model without a call method. This will cause a NotImplementedError when the model is called.

To complete the fix, you need to implement the call method to define the forward pass. The logic for this method is the same as what's currently inside the if not nnx_enabled: block.

Here is a suggested implementation for the call method to be added to the T5Backbone class:

    def call(self, inputs, training=False):
        encoder_token_ids = inputs["encoder_token_ids"]
        encoder_padding_mask = inputs["encoder_padding_mask"]
        decoder_token_ids = inputs["decoder_token_ids"]
        decoder_padding_mask = inputs["decoder_padding_mask"]

        # Encoder.
        x = self.token_embedding(encoder_token_ids)
        x = self.encoder_embedding_dropout(x, training=training)
        encoder_attention_mask = encoder_padding_mask[:, None, :]
        position_bias = None
        for transformer_layer in self.encoder_transformer_layers:
            output = transformer_layer(
                x,
                attention_mask=encoder_attention_mask,
                position_bias=position_bias,
                use_causal_mask=False,
                training=training,
            )
            if isinstance(output, tuple):
                x, position_bias = output
        x = self.encoder_layer_norm(x)
        x = self.encoder_dropout(x, training=training)
        encoder_output = x

        # Decoder.
        x = self.token_embedding(decoder_token_ids)
        x = self.decoder_embedding_dropout(x, training=training)
        decoder_attention_mask = decoder_padding_mask[:, None, :]
        position_bias = None
        for transformer_layer in self.decoder_transformer_layers:
            output = transformer_layer(
                x,
                attention_mask=decoder_attention_mask,
                position_bias=position_bias,
                encoder_hidden_states=encoder_output,
                encoder_attention_mask=encoder_attention_mask,
                use_causal_mask=True,
                training=training,
            )
            if isinstance(output, tuple):
                x, position_bias = output
        x = self.decoder_layer_norm(x)
        x = self.decoder_dropout(x, training=training)
        decoder_output = x

        return {
            "encoder_sequence_output": encoder_output,
            "decoder_sequence_output": decoder_output,
        }


# === Config ===
self.vocabulary_size = vocabulary_size
Expand Down
Loading