Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 148 additions & 111 deletions keras_hub/src/models/t5/t5_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,19 @@ def __init__(
dtype=None,
**kwargs,
):
# Token embedding layer. This layer is shared by encoder and decoder.
self.vocabulary_size = vocabulary_size
self.hidden_dim = hidden_dim
self.intermediate_dim = intermediate_dim
self.num_layers = num_layers
self.num_heads = num_heads
self.activation = keras.activations.get(activation)
self.key_value_dim = key_value_dim
self.dropout = dropout
self.use_gated_activation = use_gated_activation
self.layer_norm_epsilon = layer_norm_epsilon
self.tie_embedding_weights = tie_embedding_weights

# Token embedding layer. Shared by encoder and decoder.
self.token_embedding = ReversibleEmbedding(
input_dim=vocabulary_size,
output_dim=hidden_dim,
Expand All @@ -84,106 +96,170 @@ def __init__(
dtype=dtype,
name="token_embedding",
)

self.encoder_embedding_dropout = keras.layers.Dropout(
dropout,
dtype=dtype,
name="encoder_embedding_dropout",
dropout, dtype=dtype, name="encoder_embedding_dropout"
)

self.encoder_transformer_layers = []
for i in range(num_layers):
layer = T5TransformerLayer(
is_decoder=False,
hidden_dim=hidden_dim,
intermediate_dim=intermediate_dim,
key_value_dim=key_value_dim or hidden_dim // num_heads,
dropout=dropout,
activation=activation,
layer_norm_epsilon=layer_norm_epsilon,
num_heads=num_heads,
use_gated_activation=use_gated_activation,
use_relative_attention_bias=bool(i == 0),
dtype=dtype,
name=f"transformer_encoder_layer_{i}",
self.encoder_transformer_layers.append(
T5TransformerLayer(
is_decoder=False,
hidden_dim=hidden_dim,
intermediate_dim=intermediate_dim,
key_value_dim=key_value_dim or hidden_dim // num_heads,
dropout=dropout,
activation=activation,
layer_norm_epsilon=layer_norm_epsilon,
num_heads=num_heads,
use_gated_activation=use_gated_activation,
use_relative_attention_bias=bool(i == 0),
dtype=dtype,
name=f"transformer_encoder_layer_{i}",
)
)
self.encoder_transformer_layers.append(layer)

self.encoder_layer_norm = T5LayerNorm(
epsilon=layer_norm_epsilon,
dtype=dtype,
name="encoder_output_layer_norm",
)
self.encoder_dropout = keras.layers.Dropout(
dropout,
dtype=dtype,
name="encoder_output_dropout",
dropout, dtype=dtype, name="encoder_output_dropout"
)

self.decoder_embedding_dropout = keras.layers.Dropout(
dropout,
dtype=dtype,
name="decoder_embedding_dropout",
dropout, dtype=dtype, name="decoder_embedding_dropout"
)

self.decoder_transformer_layers = []
for i in range(num_layers):
layer = T5TransformerLayer(
is_decoder=True,
hidden_dim=hidden_dim,
intermediate_dim=intermediate_dim,
key_value_dim=key_value_dim or hidden_dim // num_heads,
dropout=dropout,
activation=activation,
layer_norm_epsilon=layer_norm_epsilon,
num_heads=num_heads,
use_gated_activation=use_gated_activation,
use_relative_attention_bias=bool(i == 0),
dtype=dtype,
name=f"transformer_decoder_layer_{i}",
self.decoder_transformer_layers.append(
T5TransformerLayer(
is_decoder=True,
hidden_dim=hidden_dim,
intermediate_dim=intermediate_dim,
key_value_dim=key_value_dim or hidden_dim // num_heads,
dropout=dropout,
activation=activation,
layer_norm_epsilon=layer_norm_epsilon,
num_heads=num_heads,
use_gated_activation=use_gated_activation,
use_relative_attention_bias=bool(i == 0),
dtype=dtype,
name=f"transformer_decoder_layer_{i}",
)
)
self.decoder_transformer_layers.append(layer)

self.decoder_layer_norm = T5LayerNorm(
epsilon=layer_norm_epsilon,
dtype=dtype,
name="decoder_output_layer_norm",
)
self.decoder_dropout = keras.layers.Dropout(
dropout,
dtype=dtype,
name="decoder_output_dropout",
dropout, dtype=dtype, name="decoder_output_dropout"
)

# === Functional Model ===
encoder_token_id_input = keras.Input(
shape=(None,), dtype="int32", name="encoder_token_ids"
)
encoder_padding_mask_input = keras.Input(
shape=(None,), dtype="int32", name="encoder_padding_mask"
)
decoder_token_id_input = keras.Input(
shape=(None,), dtype="int32", name="decoder_token_ids"
)
decoder_padding_mask_input = keras.Input(
shape=(None,), dtype="int32", name="decoder_padding_mask"
# NNX Initialization
nnx_enabled = keras.config.is_nnx_enabled()
if not nnx_enabled:
# === Functional Model ===
encoder_token_id_input = keras.Input(
shape=(None,), dtype="int32", name="encoder_token_ids"
)
encoder_padding_mask_input = keras.Input(
shape=(None,), dtype="int32", name="encoder_padding_mask"
)
decoder_token_id_input = keras.Input(
shape=(None,), dtype="int32", name="decoder_token_ids"
)
decoder_padding_mask_input = keras.Input(
shape=(None,), dtype="int32", name="decoder_padding_mask"
)

outputs = self._forward(
{
"encoder_token_ids": encoder_token_id_input,
"encoder_padding_mask": encoder_padding_mask_input,
"decoder_token_ids": decoder_token_id_input,
"decoder_padding_mask": decoder_padding_mask_input,
},
training=None,
)

super().__init__(
inputs={
"encoder_token_ids": encoder_token_id_input,
"encoder_padding_mask": encoder_padding_mask_input,
"decoder_token_ids": decoder_token_id_input,
"decoder_padding_mask": decoder_padding_mask_input,
},
outputs=outputs,
dtype=dtype,
**kwargs,
)
else:
super().__init__(dtype=dtype, **kwargs)
Comment on lines +202 to +203
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change is intended to fix an issue with JAX and NNX by avoiding the Functional API. However, by only calling super().__init__() for the NNX path, the model becomes a subclassed model without a call method. This will cause a NotImplementedError when the model is called.

To complete the fix, you need to implement the call method to define the forward pass. The logic for this method is the same as what's currently inside the if not nnx_enabled: block.

Here is a suggested implementation for the call method to be added to the T5Backbone class:

    def call(self, inputs, training=False):
        encoder_token_ids = inputs["encoder_token_ids"]
        encoder_padding_mask = inputs["encoder_padding_mask"]
        decoder_token_ids = inputs["decoder_token_ids"]
        decoder_padding_mask = inputs["decoder_padding_mask"]

        # Encoder.
        x = self.token_embedding(encoder_token_ids)
        x = self.encoder_embedding_dropout(x, training=training)
        encoder_attention_mask = encoder_padding_mask[:, None, :]
        position_bias = None
        for transformer_layer in self.encoder_transformer_layers:
            output = transformer_layer(
                x,
                attention_mask=encoder_attention_mask,
                position_bias=position_bias,
                use_causal_mask=False,
                training=training,
            )
            if isinstance(output, tuple):
                x, position_bias = output
        x = self.encoder_layer_norm(x)
        x = self.encoder_dropout(x, training=training)
        encoder_output = x

        # Decoder.
        x = self.token_embedding(decoder_token_ids)
        x = self.decoder_embedding_dropout(x, training=training)
        decoder_attention_mask = decoder_padding_mask[:, None, :]
        position_bias = None
        for transformer_layer in self.decoder_transformer_layers:
            output = transformer_layer(
                x,
                attention_mask=decoder_attention_mask,
                position_bias=position_bias,
                encoder_hidden_states=encoder_output,
                encoder_attention_mask=encoder_attention_mask,
                use_causal_mask=True,
                training=training,
            )
            if isinstance(output, tuple):
                x, position_bias = output
        x = self.decoder_layer_norm(x)
        x = self.decoder_dropout(x, training=training)
        decoder_output = x

        return {
            "encoder_sequence_output": encoder_output,
            "decoder_sequence_output": decoder_output,
        }


def call(self, inputs, training=None):
return self._forward(inputs, training=training)

# Config
def get_config(self):
config = super().get_config()
config.update(
{
"vocabulary_size": self.vocabulary_size,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"activation": keras.activations.serialize(self.activation),
"key_value_dim": self.key_value_dim,
"dropout": self.dropout,
"use_gated_activation": self.use_gated_activation,
"layer_norm_epsilon": self.layer_norm_epsilon,
"tie_embedding_weights": self.tie_embedding_weights,
}
)
# Encoder.
x = self.token_embedding(encoder_token_id_input)
x = self.encoder_embedding_dropout(x)
encoder_attention_mask = encoder_padding_mask_input[:, None, :]
return config

def _forward(self, inputs, training=None):
encoder_token_ids = inputs["encoder_token_ids"]
encoder_padding_mask = inputs["encoder_padding_mask"]
decoder_token_ids = inputs["decoder_token_ids"]
decoder_padding_mask = inputs["decoder_padding_mask"]

# Encoder
x = self.token_embedding(encoder_token_ids)
x = self.encoder_embedding_dropout(x, training=training)

encoder_attention_mask = encoder_padding_mask[:, None, :]
position_bias = None

for transformer_layer in self.encoder_transformer_layers:
output = transformer_layer(
x,
attention_mask=encoder_attention_mask,
position_bias=position_bias,
use_causal_mask=False,
training=training,
)
if isinstance(output, tuple):
x, position_bias = output
x = self.encoder_layer_norm(x)
x = self.encoder_dropout(x)
encoder_output = x
# Decoder.
x = self.token_embedding(decoder_token_id_input)
x = self.decoder_embedding_dropout(x)
decoder_attention_mask = decoder_padding_mask_input[:, None, :]

encoder_output = self.encoder_dropout(
self.encoder_layer_norm(x), training=training
)

# Decoder
x = self.token_embedding(decoder_token_ids)
x = self.decoder_embedding_dropout(x, training=training)

decoder_attention_mask = decoder_padding_mask[:, None, :]
position_bias = None

for transformer_layer in self.decoder_transformer_layers:
output = transformer_layer(
x,
Expand All @@ -192,55 +268,16 @@ def __init__(
encoder_hidden_states=encoder_output,
encoder_attention_mask=encoder_attention_mask,
use_causal_mask=True,
training=training,
)
if isinstance(output, tuple):
x, position_bias = output
x = self.decoder_layer_norm(x)
x = self.decoder_dropout(x)
decoder_output = x
super().__init__(
{
"encoder_token_ids": encoder_token_id_input,
"encoder_padding_mask": encoder_padding_mask_input,
"decoder_token_ids": decoder_token_id_input,
"decoder_padding_mask": decoder_padding_mask_input,
},
outputs={
"encoder_sequence_output": encoder_output,
"decoder_sequence_output": decoder_output,
},
dtype=dtype,
**kwargs,
)

# === Config ===
self.vocabulary_size = vocabulary_size
self.hidden_dim = hidden_dim
self.intermediate_dim = intermediate_dim
self.num_layers = num_layers
self.num_heads = num_heads
self.activation = keras.activations.get(activation)
self.key_value_dim = key_value_dim
self.dropout = dropout
self.use_gated_activation = use_gated_activation
self.layer_norm_epsilon = layer_norm_epsilon
self.tie_embedding_weights = tie_embedding_weights

def get_config(self):
config = super().get_config()
config.update(
{
"vocabulary_size": self.vocabulary_size,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"activation": keras.activations.serialize(self.activation),
"key_value_dim": self.key_value_dim,
"dropout": self.dropout,
"use_gated_activation": self.use_gated_activation,
"layer_norm_epsilon": self.layer_norm_epsilon,
"tie_embedding_weights": self.tie_embedding_weights,
}
decoder_output = self.decoder_dropout(
self.decoder_layer_norm(x), training=training
)
return config

return {
"encoder_sequence_output": encoder_output,
"decoder_sequence_output": decoder_output,
}
Loading