-
Notifications
You must be signed in to change notification settings - Fork 330
fix: jax & NNX TraceContextError in T5Backbone #2602
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 1 commit
4bebc0b
fdc3c04
e4ad862
c3648e0
9d82a6b
e7c661e
033e446
d2cc9a6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -75,6 +75,10 @@ def __init__( | |
| dtype=None, | ||
| **kwargs, | ||
| ): | ||
| import keras | ||
|
|
||
| nnx_enabled = keras.config.is_nnx_enabled() | ||
|
|
||
| # Token embedding layer. This layer is shared by encoder and decoder. | ||
| self.token_embedding = ReversibleEmbedding( | ||
| input_dim=vocabulary_size, | ||
|
|
@@ -89,6 +93,7 @@ def __init__( | |
| dtype=dtype, | ||
| name="encoder_embedding_dropout", | ||
| ) | ||
|
|
||
| self.encoder_transformer_layers = [] | ||
| for i in range(num_layers): | ||
| layer = T5TransformerLayer( | ||
|
|
@@ -106,21 +111,25 @@ def __init__( | |
| name=f"transformer_encoder_layer_{i}", | ||
| ) | ||
| self.encoder_transformer_layers.append(layer) | ||
|
|
||
| self.encoder_layer_norm = T5LayerNorm( | ||
| epsilon=layer_norm_epsilon, | ||
| dtype=dtype, | ||
| name="encoder_output_layer_norm", | ||
| ) | ||
|
|
||
| self.encoder_dropout = keras.layers.Dropout( | ||
| dropout, | ||
| dtype=dtype, | ||
| name="encoder_output_dropout", | ||
| ) | ||
|
|
||
| self.decoder_embedding_dropout = keras.layers.Dropout( | ||
| dropout, | ||
| dtype=dtype, | ||
| name="decoder_embedding_dropout", | ||
| ) | ||
|
|
||
| self.decoder_transformer_layers = [] | ||
| for i in range(num_layers): | ||
| layer = T5TransformerLayer( | ||
|
|
@@ -138,80 +147,89 @@ def __init__( | |
| name=f"transformer_decoder_layer_{i}", | ||
| ) | ||
| self.decoder_transformer_layers.append(layer) | ||
|
|
||
| self.decoder_layer_norm = T5LayerNorm( | ||
| epsilon=layer_norm_epsilon, | ||
| dtype=dtype, | ||
| name="decoder_output_layer_norm", | ||
| ) | ||
|
|
||
| self.decoder_dropout = keras.layers.Dropout( | ||
| dropout, | ||
| dtype=dtype, | ||
| name="decoder_output_dropout", | ||
| ) | ||
|
|
||
| # === Functional Model === | ||
| encoder_token_id_input = keras.Input( | ||
| shape=(None,), dtype="int32", name="encoder_token_ids" | ||
| ) | ||
| encoder_padding_mask_input = keras.Input( | ||
| shape=(None,), dtype="int32", name="encoder_padding_mask" | ||
| ) | ||
| decoder_token_id_input = keras.Input( | ||
| shape=(None,), dtype="int32", name="decoder_token_ids" | ||
| ) | ||
| decoder_padding_mask_input = keras.Input( | ||
| shape=(None,), dtype="int32", name="decoder_padding_mask" | ||
| ) | ||
| # Encoder. | ||
| x = self.token_embedding(encoder_token_id_input) | ||
| x = self.encoder_embedding_dropout(x) | ||
| encoder_attention_mask = encoder_padding_mask_input[:, None, :] | ||
| position_bias = None | ||
| for transformer_layer in self.encoder_transformer_layers: | ||
| output = transformer_layer( | ||
| x, | ||
| attention_mask=encoder_attention_mask, | ||
| position_bias=position_bias, | ||
| use_causal_mask=False, | ||
| if not nnx_enabled: | ||
| # === Functional Model === | ||
| encoder_token_id_input = keras.Input( | ||
| shape=(None,), dtype="int32", name="encoder_token_ids" | ||
| ) | ||
| if isinstance(output, tuple): | ||
| x, position_bias = output | ||
| x = self.encoder_layer_norm(x) | ||
| x = self.encoder_dropout(x) | ||
| encoder_output = x | ||
| # Decoder. | ||
| x = self.token_embedding(decoder_token_id_input) | ||
| x = self.decoder_embedding_dropout(x) | ||
| decoder_attention_mask = decoder_padding_mask_input[:, None, :] | ||
| position_bias = None | ||
| for transformer_layer in self.decoder_transformer_layers: | ||
| output = transformer_layer( | ||
| x, | ||
| attention_mask=decoder_attention_mask, | ||
| position_bias=position_bias, | ||
| encoder_hidden_states=encoder_output, | ||
| encoder_attention_mask=encoder_attention_mask, | ||
| use_causal_mask=True, | ||
| encoder_padding_mask_input = keras.Input( | ||
| shape=(None,), dtype="int32", name="encoder_padding_mask" | ||
| ) | ||
| if isinstance(output, tuple): | ||
| x, position_bias = output | ||
| x = self.decoder_layer_norm(x) | ||
| x = self.decoder_dropout(x) | ||
| decoder_output = x | ||
| super().__init__( | ||
| { | ||
| "encoder_token_ids": encoder_token_id_input, | ||
| "encoder_padding_mask": encoder_padding_mask_input, | ||
| "decoder_token_ids": decoder_token_id_input, | ||
| "decoder_padding_mask": decoder_padding_mask_input, | ||
| }, | ||
| outputs={ | ||
| "encoder_sequence_output": encoder_output, | ||
| "decoder_sequence_output": decoder_output, | ||
| }, | ||
| dtype=dtype, | ||
| **kwargs, | ||
| ) | ||
| decoder_token_id_input = keras.Input( | ||
| shape=(None,), dtype="int32", name="decoder_token_ids" | ||
| ) | ||
| decoder_padding_mask_input = keras.Input( | ||
| shape=(None,), dtype="int32", name="decoder_padding_mask" | ||
| ) | ||
|
|
||
| # Encoder. | ||
| x = self.token_embedding(encoder_token_id_input) | ||
| x = self.encoder_embedding_dropout(x) | ||
| encoder_attention_mask = encoder_padding_mask_input[:, None, :] | ||
| position_bias = None | ||
| for transformer_layer in self.encoder_transformer_layers: | ||
| output = transformer_layer( | ||
| x, | ||
| attention_mask=encoder_attention_mask, | ||
| position_bias=position_bias, | ||
| use_causal_mask=False, | ||
| ) | ||
| if isinstance(output, tuple): | ||
| x, position_bias = output | ||
| x = self.encoder_layer_norm(x) | ||
| x = self.encoder_dropout(x) | ||
| encoder_output = x | ||
| # Decoder. | ||
| x = self.token_embedding(decoder_token_id_input) | ||
| x = self.decoder_embedding_dropout(x) | ||
| decoder_attention_mask = decoder_padding_mask_input[:, None, :] | ||
| position_bias = None | ||
| for transformer_layer in self.decoder_transformer_layers: | ||
| output = transformer_layer( | ||
| x, | ||
| attention_mask=decoder_attention_mask, | ||
| position_bias=position_bias, | ||
| encoder_hidden_states=encoder_output, | ||
| encoder_attention_mask=encoder_attention_mask, | ||
| use_causal_mask=True, | ||
| ) | ||
| if isinstance(output, tuple): | ||
| x, position_bias = output | ||
| x = self.decoder_layer_norm(x) | ||
| x = self.decoder_dropout(x) | ||
| decoder_output = x | ||
|
|
||
| super().__init__( | ||
| { | ||
| "encoder_token_ids": encoder_token_id_input, | ||
| "encoder_padding_mask": encoder_padding_mask_input, | ||
| "decoder_token_ids": decoder_token_id_input, | ||
| "decoder_padding_mask": decoder_padding_mask_input, | ||
| }, | ||
| outputs={ | ||
| "encoder_sequence_output": encoder_output, | ||
| "decoder_sequence_output": decoder_output, | ||
| }, | ||
| dtype=dtype, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| else: | ||
| # NNX-safe subclassed model path | ||
| super().__init__(dtype=dtype, **kwargs) | ||
|
Comment on lines
+202
to
+203
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change is intended to fix an issue with JAX and NNX by avoiding the Functional API. However, by only calling To complete the fix, you need to implement the Here is a suggested implementation for the def call(self, inputs, training=False):
encoder_token_ids = inputs["encoder_token_ids"]
encoder_padding_mask = inputs["encoder_padding_mask"]
decoder_token_ids = inputs["decoder_token_ids"]
decoder_padding_mask = inputs["decoder_padding_mask"]
# Encoder.
x = self.token_embedding(encoder_token_ids)
x = self.encoder_embedding_dropout(x, training=training)
encoder_attention_mask = encoder_padding_mask[:, None, :]
position_bias = None
for transformer_layer in self.encoder_transformer_layers:
output = transformer_layer(
x,
attention_mask=encoder_attention_mask,
position_bias=position_bias,
use_causal_mask=False,
training=training,
)
if isinstance(output, tuple):
x, position_bias = output
x = self.encoder_layer_norm(x)
x = self.encoder_dropout(x, training=training)
encoder_output = x
# Decoder.
x = self.token_embedding(decoder_token_ids)
x = self.decoder_embedding_dropout(x, training=training)
decoder_attention_mask = decoder_padding_mask[:, None, :]
position_bias = None
for transformer_layer in self.decoder_transformer_layers:
output = transformer_layer(
x,
attention_mask=decoder_attention_mask,
position_bias=position_bias,
encoder_hidden_states=encoder_output,
encoder_attention_mask=encoder_attention_mask,
use_causal_mask=True,
training=training,
)
if isinstance(output, tuple):
x, position_bias = output
x = self.decoder_layer_norm(x)
x = self.decoder_dropout(x, training=training)
decoder_output = x
return {
"encoder_sequence_output": encoder_output,
"decoder_sequence_output": decoder_output,
} |
||
|
|
||
| # === Config === | ||
| self.vocabulary_size = vocabulary_size | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
import kerasstatement is redundant askerasis already imported at the top of the file. This local import and the following blank line can be removed.