From f91b26ebc80218db58c6184b9227222a8cf9eb61 Mon Sep 17 00:00:00 2001 From: Laxma Reddy Patlolla Date: Mon, 2 Mar 2026 14:46:25 -0800 Subject: [PATCH 1/9] Add initial Text focused T5Gemma2 model files --- keras_hub/api/models/__init__.py | 12 + keras_hub/api/tokenizers/__init__.py | 3 + keras_hub/src/models/t5gemma2/__init__.py | 5 + .../src/models/t5gemma2/t5gemma2_attention.py | 726 ++++++++++++++++++ .../src/models/t5gemma2/t5gemma2_backbone.py | 357 +++++++++ .../models/t5gemma2/t5gemma2_backbone_test.py | 116 +++ .../src/models/t5gemma2/t5gemma2_decoder.py | 346 +++++++++ .../src/models/t5gemma2/t5gemma2_encoder.py | 219 ++++++ .../src/models/t5gemma2/t5gemma2_layers.py | 118 +++ .../src/models/t5gemma2/t5gemma2_presets.py | 4 + .../models/t5gemma2/t5gemma2_seq_2_seq_lm.py | 331 ++++++++ .../t5gemma2_seq_2_seq_lm_preprocessor.py | 158 ++++ .../t5gemma2/t5gemma2_seq_2_seq_lm_test.py | 98 +++ .../src/models/t5gemma2/t5gemma2_tokenizer.py | 55 ++ 14 files changed, 2548 insertions(+) create mode 100644 keras_hub/src/models/t5gemma2/__init__.py create mode 100644 keras_hub/src/models/t5gemma2/t5gemma2_attention.py create mode 100644 keras_hub/src/models/t5gemma2/t5gemma2_backbone.py create mode 100644 keras_hub/src/models/t5gemma2/t5gemma2_backbone_test.py create mode 100644 keras_hub/src/models/t5gemma2/t5gemma2_decoder.py create mode 100644 keras_hub/src/models/t5gemma2/t5gemma2_encoder.py create mode 100644 keras_hub/src/models/t5gemma2/t5gemma2_layers.py create mode 100644 keras_hub/src/models/t5gemma2/t5gemma2_presets.py create mode 100644 keras_hub/src/models/t5gemma2/t5gemma2_seq_2_seq_lm.py create mode 100644 keras_hub/src/models/t5gemma2/t5gemma2_seq_2_seq_lm_preprocessor.py create mode 100644 keras_hub/src/models/t5gemma2/t5gemma2_seq_2_seq_lm_test.py create mode 100644 keras_hub/src/models/t5gemma2/t5gemma2_tokenizer.py diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 9287e2edee..8cd5e048e2 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -762,6 +762,18 @@ from keras_hub.src.models.t5gemma.t5gemma_tokenizer import ( T5GemmaTokenizer as T5GemmaTokenizer, ) +from keras_hub.src.models.t5gemma2.t5gemma2_backbone import ( + T5Gemma2Backbone as T5Gemma2Backbone, +) +from keras_hub.src.models.t5gemma2.t5gemma2_seq_2_seq_lm import ( + T5Gemma2Seq2SeqLM as T5Gemma2Seq2SeqLM, +) +from keras_hub.src.models.t5gemma2.t5gemma2_seq_2_seq_lm_preprocessor import ( + T5Gemma2Seq2SeqLMPreprocessor as T5Gemma2Seq2SeqLMPreprocessor, +) +from keras_hub.src.models.t5gemma2.t5gemma2_tokenizer import ( + T5Gemma2Tokenizer as T5Gemma2Tokenizer, +) from keras_hub.src.models.task import Task as Task from keras_hub.src.models.text_classifier import TextClassifier as Classifier from keras_hub.src.models.text_classifier import ( diff --git a/keras_hub/api/tokenizers/__init__.py b/keras_hub/api/tokenizers/__init__.py index 7ad25aea51..6366ade1a0 100644 --- a/keras_hub/api/tokenizers/__init__.py +++ b/keras_hub/api/tokenizers/__init__.py @@ -115,6 +115,9 @@ from keras_hub.src.models.t5gemma.t5gemma_tokenizer import ( T5GemmaTokenizer as T5GemmaTokenizer, ) +from keras_hub.src.models.t5gemma2.t5gemma2_tokenizer import ( + T5Gemma2Tokenizer as T5Gemma2Tokenizer, +) from keras_hub.src.models.video_prism.video_prism_tokenizer import ( VideoPrismTokenizer as VideoPrismTokenizer, ) diff --git a/keras_hub/src/models/t5gemma2/__init__.py b/keras_hub/src/models/t5gemma2/__init__.py new file mode 100644 index 0000000000..94d56d0c9d --- /dev/null +++ b/keras_hub/src/models/t5gemma2/__init__.py @@ -0,0 +1,5 @@ +from keras_hub.src.models.t5gemma2.t5gemma2_backbone import T5Gemma2Backbone +from keras_hub.src.models.t5gemma2.t5gemma2_presets import backbone_presets +from keras_hub.src.utils.preset_utils import register_presets + +register_presets(backbone_presets, T5Gemma2Backbone) diff --git a/keras_hub/src/models/t5gemma2/t5gemma2_attention.py b/keras_hub/src/models/t5gemma2/t5gemma2_attention.py new file mode 100644 index 0000000000..28b8d53e0f --- /dev/null +++ b/keras_hub/src/models/t5gemma2/t5gemma2_attention.py @@ -0,0 +1,726 @@ +import inspect + +import keras + +from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding +from keras_hub.src.models.gemma.gemma_attention import CachedGemmaAttention +from keras_hub.src.models.gemma.rms_normalization import RMSNormalization +from keras_hub.src.models.t5gemma2.t5gemma2_layers import ( + t5gemma2_kernel_initializer, +) +from keras_hub.src.utils.keras_utils import clone_initializer + + +def repeat_kv(hidden_states, n_rep): + """Repeats the key/value hidden states for Grouped Query Attention. + + Args: + hidden_states: Tensor with shape + `(batch, sequence_length, num_key_value_heads, head_dim)`. + n_rep: int, number of times to repeat. + + Returns: + Tensor with shape + `(batch, sequence_length, num_query_heads, head_dim)`. + """ + if n_rep == 1: + return hidden_states + batch, slen, num_key_value_heads, head_dim = keras.ops.shape(hidden_states) + hidden_states = keras.ops.expand_dims(hidden_states, 3) + hidden_states = keras.ops.tile(hidden_states, (1, 1, 1, n_rep, 1)) + return keras.ops.reshape( + hidden_states, (batch, slen, num_key_value_heads * n_rep, head_dim) + ) + + +class T5Gemma2Attention(CachedGemmaAttention): + """Self-attention layer for T5Gemma2 encoder and decoder. + + This layer performs self-attention with Rotary Positional Embeddings + (RoPE), optional Q/K normalization (Gemma3-style), and optional + attention logit softcapping. Supports Grouped Query Attention (GQA). + + Used in `T5Gemma2EncoderLayer` for bidirectional self-attention + and can also be used in the decoder for self-attention. + + Args: + hidden_size: int, The dimensionality of the hidden states. + num_attention_heads: int, The number of attention heads. + num_key_value_heads: int, The number of key-value heads for GQA. + query_pre_attn_scalar: float, Scalar to multiply queries by + before attention. + attention_bias: bool, Whether to include bias in dense layers. + head_dim: int, The dimensionality of each attention head. + initializer_range: float, The range for the initializer. + Defaults to `0.02`. + attention_dropout: float, Dropout rate for attention weights. + Defaults to `0.0`. + attn_logit_softcapping: float, optional, Softcapping value. + Defaults to `None`. + rope_max_wavelength: float, Maximum wavelength for RoPE. + Defaults to `10000.0`. + use_query_key_norm: bool, Whether to apply RMS normalization on + query and key. Defaults to `True` (Gemma3-style). + rms_norm_eps: float, Epsilon for RMS normalization. + Defaults to `1e-6`. + dtype: The dtype for computations and weights. Defaults to `None`. + **kwargs: Additional keyword arguments. + """ + + def __init__( + self, + hidden_size, + num_attention_heads, + num_key_value_heads, + query_pre_attn_scalar, + attention_bias, + head_dim, + initializer_range=0.02, + attention_dropout=0.0, + attn_logit_softcapping=None, + rope_max_wavelength=10000.0, + use_query_key_norm=True, + rms_norm_eps=1e-6, + dtype=None, + **kwargs, + ): + super().__init__( + head_dim=head_dim, + num_query_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + kernel_initializer=t5gemma2_kernel_initializer(initializer_range), + logit_soft_cap=attn_logit_softcapping, + dropout=attention_dropout, + query_head_dim_normalize=False, + use_sliding_window_attention=False, + dtype=dtype, + **kwargs, + ) + self.hidden_size = hidden_size + self.query_pre_attn_scalar = query_pre_attn_scalar + self.attention_bias = attention_bias + self.initializer_range = initializer_range + self.attention_dropout = attention_dropout + self.rope_max_wavelength = rope_max_wavelength + self.use_query_key_norm = use_query_key_norm + self.rms_norm_eps = rms_norm_eps + self.num_key_value_groups = ( + self.num_query_heads // self.num_key_value_heads + ) + self.scaling = self.query_pre_attn_scalar**-0.5 + + def build(self, input_shape): + self._kernel_initializer = t5gemma2_kernel_initializer( + self.initializer_range + ) + hidden_states_shape = input_shape + self.hidden_dim = hidden_states_shape[-1] + + self.query_dense = keras.layers.EinsumDense( + equation="btd,dnh->btnh", + output_shape=(None, self.num_query_heads, self.head_dim), + kernel_initializer=clone_initializer(self._kernel_initializer), + bias_axes="nh" if self.attention_bias else None, + dtype=self.dtype_policy, + name="query", + ) + self.query_dense.build(hidden_states_shape) + + self.key_dense = keras.layers.EinsumDense( + equation="bsd,dkh->bskh", + output_shape=(None, self.num_key_value_heads, self.head_dim), + kernel_initializer=clone_initializer(self._kernel_initializer), + bias_axes="kh" if self.attention_bias else None, + dtype=self.dtype_policy, + name="key", + ) + self.key_dense.build(hidden_states_shape) + + self.value_dense = keras.layers.EinsumDense( + equation="bsd,dkh->bskh", + output_shape=(None, self.num_key_value_heads, self.head_dim), + kernel_initializer=clone_initializer(self._kernel_initializer), + bias_axes="kh" if self.attention_bias else None, + dtype=self.dtype_policy, + name="value", + ) + self.value_dense.build(hidden_states_shape) + + self.output_dense = keras.layers.EinsumDense( + equation="btnh,nhd->btd", + output_shape=(None, self.hidden_dim), + kernel_initializer=clone_initializer(self._kernel_initializer), + bias_axes="d" if self.attention_bias else None, + dtype=self.dtype_policy, + name="attention_output", + ) + self.output_dense.build( + ( + hidden_states_shape[0], + hidden_states_shape[1], + self.num_query_heads, + self.head_dim, + ) + ) + + # Q/K normalization (Gemma3-style). + if self.use_query_key_norm: + self.query_norm = RMSNormalization( + epsilon=self.rms_norm_eps, + dtype=self.dtype_policy, + name="query_norm", + ) + self.query_norm.build( + self.query_dense.compute_output_shape(hidden_states_shape) + ) + self.key_norm = RMSNormalization( + epsilon=self.rms_norm_eps, + dtype=self.dtype_policy, + name="key_norm", + ) + self.key_norm.build( + self.key_dense.compute_output_shape(hidden_states_shape) + ) + + self.rotary_embedding = RotaryEmbedding( + max_wavelength=self.rope_max_wavelength, + sequence_axis=1, + feature_axis=3, + name="rotary_embedding", + dtype=self.dtype_policy, + ) + + self.dropout_layer = keras.layers.Dropout( + rate=self.attention_dropout, + dtype=self.dtype_policy, + ) + self.softmax = keras.layers.Softmax(axis=-1, dtype="float32") + self.built = True + + def _compute_attention_without_fused_op( + self, query_states, key_states, value_states, attention_mask, training + ): + attn_weights = keras.ops.einsum( + "btnh,bsnh->bnts", query_states, key_states + ) + attn_weights *= self.scaling + if self.logit_soft_cap is not None: + attn_weights = attn_weights / self.logit_soft_cap + attn_weights = keras.ops.tanh(attn_weights) + attn_weights = attn_weights * self.logit_soft_cap + if attention_mask is not None: + attn_weights += attention_mask + attn_weights = keras.ops.cast( + self.softmax(attn_weights), + query_states.dtype, + ) + attn_weights = self.dropout_layer(attn_weights, training=training) + attn_output = keras.ops.einsum( + "bnts,bsnh->btnh", attn_weights, value_states + ) + return attn_output + + def _compute_attention( + self, query_states, key_states, value_states, attention_mask, training + ): + if self._use_fused_attention_op(): + kwargs = {"bias": attention_mask} + if self.logit_soft_cap is not None: + sig = inspect.signature(keras.ops.dot_product_attention) + if "attn_logits_soft_cap" in sig.parameters: + kwargs["attn_logits_soft_cap"] = self.logit_soft_cap + return keras.ops.dot_product_attention( + query=query_states, + key=key_states, + value=value_states, + scale=self.scaling, + **kwargs, + ) + return self._compute_attention_without_fused_op( + query_states, + key_states, + value_states, + attention_mask, + training, + ) + + def call( + self, + inputs, + attention_mask=None, + cache=None, + cache_update_index=None, + training=None, + ): + hidden_states = inputs + query_states = self.query_dense(hidden_states) + key_states = self.key_dense(hidden_states) + value_states = self.value_dense(hidden_states) + + # Apply Q/K normalization. + if self.use_query_key_norm: + query_states = self.query_norm(query_states) + key_states = self.key_norm(key_states) + + # Apply RoPE. + start_index = 0 if cache_update_index is None else cache_update_index + query_states = self.rotary_embedding( + query_states, start_index=start_index + ) + key_states = self.rotary_embedding(key_states, start_index=start_index) + + # Handle caching for autoregressive generation. + if cache is not None: + if cache_update_index is None: + raise ValueError( + "Both `cache` and `cache_update_index` must be " + "passed for self-attention caching." + ) + key_cache, value_cache = cache[:, 0, ...], cache[:, 1, ...] + start = [0, cache_update_index, 0, 0] + key_states = keras.ops.slice_update(key_cache, start, key_states) + value_states = keras.ops.slice_update( + value_cache, start, value_states + ) + cache = keras.ops.stack((key_states, value_states), axis=1) + elif cache_update_index is not None: + raise ValueError( + "`cache_update_index` should not be set if `cache` is `None`." + ) + else: + cache = keras.ops.stack((key_states, value_states), axis=1) + + # Repeat K/V for GQA. + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_output = self._compute_attention( + query_states, + key_states, + value_states, + attention_mask, + training, + ) + attn_output = self.output_dense(attn_output) + return attn_output, cache + + def compute_output_shape(self, input_shape): + hidden_states_shape = input_shape + attn_output_shape = hidden_states_shape + kv_len = hidden_states_shape[1] + cache_shape = ( + hidden_states_shape[0], + 2, + kv_len, + self.num_key_value_heads, + self.head_dim, + ) + return attn_output_shape, cache_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "head_dim": self.head_dim, + "num_attention_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "query_pre_attn_scalar": self.query_pre_attn_scalar, + "attention_bias": self.attention_bias, + "initializer_range": self.initializer_range, + "attention_dropout": self.attention_dropout, + "attn_logit_softcapping": self.logit_soft_cap, + "rope_max_wavelength": self.rope_max_wavelength, + "use_query_key_norm": self.use_query_key_norm, + "rms_norm_eps": self.rms_norm_eps, + } + ) + return config + + +class T5Gemma2MergedAttention(CachedGemmaAttention): + """Merged self-attention and cross-attention for T5Gemma2 decoder. + + This layer fuses self-attention and cross-attention into a single + attention computation. The decoder's Q/K/V are computed from the + decoder hidden states (self-attention), while additional K/V are + computed from the encoder hidden states (cross-attention). The + self-attention and cross-attention K/V are concatenated, and a + single attention computation is performed over the merged K/V. + + This merged approach is the key architectural difference between + T5Gemma2 and T5Gemma1. + + Args: + hidden_size: int, Dimensionality of the decoder hidden states. + num_attention_heads: int, Number of attention heads. + num_key_value_heads: int, Number of key-value heads for GQA. + query_pre_attn_scalar: float, Scalar for query normalization. + attention_bias: bool, Whether to include bias. + head_dim: int, Dimensionality of each attention head. + cross_attention_hidden_size: int, optional, Hidden size of the + encoder states. Defaults to `hidden_size`. + initializer_range: float, Range for the initializer. + Defaults to `0.02`. + attention_dropout: float, Dropout rate. + Defaults to `0.0`. + attn_logit_softcapping: float, optional, Softcapping value. + Defaults to `None`. + rope_max_wavelength: float, Maximum wavelength for RoPE. + Defaults to `10000.0`. + use_query_key_norm: bool, Whether to apply Q/K norm. + Defaults to `True`. + rms_norm_eps: float, Epsilon for RMS normalization. + Defaults to `1e-6`. + dtype: The dtype for computations. Defaults to `None`. + **kwargs: Additional keyword arguments. + """ + + def __init__( + self, + hidden_size, + num_attention_heads, + num_key_value_heads, + query_pre_attn_scalar, + attention_bias, + head_dim, + cross_attention_hidden_size=None, + initializer_range=0.02, + attention_dropout=0.0, + attn_logit_softcapping=None, + rope_max_wavelength=10000.0, + use_query_key_norm=True, + rms_norm_eps=1e-6, + dtype=None, + **kwargs, + ): + super().__init__( + head_dim=head_dim, + num_query_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + kernel_initializer=t5gemma2_kernel_initializer(initializer_range), + logit_soft_cap=attn_logit_softcapping, + dropout=attention_dropout, + query_head_dim_normalize=False, + use_sliding_window_attention=False, + dtype=dtype, + **kwargs, + ) + self.hidden_size = hidden_size + self.cross_attention_hidden_size = ( + cross_attention_hidden_size or hidden_size + ) + self.query_pre_attn_scalar = query_pre_attn_scalar + self.attention_bias = attention_bias + self.initializer_range = initializer_range + self.attention_dropout = attention_dropout + self.rope_max_wavelength = rope_max_wavelength + self.use_query_key_norm = use_query_key_norm + self.rms_norm_eps = rms_norm_eps + self.num_key_value_groups = ( + self.num_query_heads // self.num_key_value_heads + ) + self.scaling = self.query_pre_attn_scalar**-0.5 + + def build(self, input_shape): + # Only decoder_shape needed; K/V projections are shared. + if isinstance(input_shape, (list, tuple)): + decoder_shape = input_shape[0] + else: + decoder_shape = input_shape + + self._kernel_initializer = t5gemma2_kernel_initializer( + self.initializer_range + ) + self.hidden_dim = decoder_shape[-1] + + # Q projection from decoder hidden states. + self.query_dense = keras.layers.EinsumDense( + equation="btd,dnh->btnh", + output_shape=(None, self.num_query_heads, self.head_dim), + kernel_initializer=clone_initializer(self._kernel_initializer), + bias_axes="nh" if self.attention_bias else None, + dtype=self.dtype_policy, + name="query", + ) + self.query_dense.build(decoder_shape) + + # K/V projections shared for self-attn and cross-attn. + self.key_dense = keras.layers.EinsumDense( + equation="bsd,dkh->bskh", + output_shape=(None, self.num_key_value_heads, self.head_dim), + kernel_initializer=clone_initializer(self._kernel_initializer), + bias_axes="kh" if self.attention_bias else None, + dtype=self.dtype_policy, + name="key", + ) + self.key_dense.build(decoder_shape) + + self.value_dense = keras.layers.EinsumDense( + equation="bsd,dkh->bskh", + output_shape=(None, self.num_key_value_heads, self.head_dim), + kernel_initializer=clone_initializer(self._kernel_initializer), + bias_axes="kh" if self.attention_bias else None, + dtype=self.dtype_policy, + name="value", + ) + self.value_dense.build(decoder_shape) + + # Output projection. + self.output_dense = keras.layers.EinsumDense( + equation="btnh,nhd->btd", + output_shape=(None, self.hidden_dim), + kernel_initializer=clone_initializer(self._kernel_initializer), + bias_axes="d" if self.attention_bias else None, + dtype=self.dtype_policy, + name="attention_output", + ) + self.output_dense.build( + ( + decoder_shape[0], + decoder_shape[1], + self.num_query_heads, + self.head_dim, + ) + ) + + # Q/K normalization (Gemma3-style). + if self.use_query_key_norm: + self.query_norm = RMSNormalization( + epsilon=self.rms_norm_eps, + dtype=self.dtype_policy, + name="query_norm", + ) + self.query_norm.build( + self.query_dense.compute_output_shape(decoder_shape) + ) + self.key_norm = RMSNormalization( + epsilon=self.rms_norm_eps, + dtype=self.dtype_policy, + name="key_norm", + ) + self.key_norm.build( + self.key_dense.compute_output_shape(decoder_shape) + ) + + self.rotary_embedding = RotaryEmbedding( + max_wavelength=self.rope_max_wavelength, + sequence_axis=1, + feature_axis=3, + name="rotary_embedding", + dtype=self.dtype_policy, + ) + + self.dropout_layer = keras.layers.Dropout( + rate=self.attention_dropout, + dtype=self.dtype_policy, + ) + self.softmax = keras.layers.Softmax(axis=-1, dtype="float32") + self.built = True + + def _compute_attention_without_fused_op( + self, query_states, key_states, value_states, attention_mask, training + ): + attn_weights = keras.ops.einsum( + "btnh,bsnh->bnts", query_states, key_states + ) + attn_weights *= self.scaling + if self.logit_soft_cap is not None: + attn_weights = attn_weights / self.logit_soft_cap + attn_weights = keras.ops.tanh(attn_weights) + attn_weights = attn_weights * self.logit_soft_cap + if attention_mask is not None: + attn_weights += attention_mask + attn_weights = keras.ops.cast( + self.softmax(attn_weights), + query_states.dtype, + ) + attn_weights = self.dropout_layer(attn_weights, training=training) + attn_output = keras.ops.einsum( + "bnts,bsnh->btnh", attn_weights, value_states + ) + return attn_output + + def _compute_attention( + self, query_states, key_states, value_states, attention_mask, training + ): + if self._use_fused_attention_op(): + kwargs = {"bias": attention_mask} + if self.logit_soft_cap is not None: + sig = inspect.signature(keras.ops.dot_product_attention) + if "attn_logits_soft_cap" in sig.parameters: + kwargs["attn_logits_soft_cap"] = self.logit_soft_cap + return keras.ops.dot_product_attention( + query=query_states, + key=key_states, + value=value_states, + scale=self.scaling, + **kwargs, + ) + return self._compute_attention_without_fused_op( + query_states, + key_states, + value_states, + attention_mask, + training, + ) + + def call( + self, + inputs, + encoder_hidden_states, + attention_mask=None, + cache=None, + cache_update_index=None, + training=None, + ): + """Forward pass of merged self+cross attention. + + Args: + inputs: Decoder hidden states, shape + `(batch, decoder_seq_len, hidden_dim)`. + encoder_hidden_states: Encoder output, shape + `(batch, encoder_seq_len, hidden_dim)`. + attention_mask: Merged attention mask, shape + `(batch, 1, decoder_seq_len, decoder_kv_len + encoder_len)`. + This is the concatenation of the causal self-attention + mask and the bidirectional cross-attention mask. + cache: Tuple of (self_attn_cache, cross_attn_cache) or None. + cache_update_index: int, the current position in the + sequence for caching. + training: bool, whether in training mode. + + Returns: + Tuple of (attn_output, updated_cache). + """ + hidden_states = inputs + self_attention_cache, cross_attention_cache = ( + cache if cache is not None else (None, None) + ) + + # Self-attention Q/K/V from decoder hidden states. + query_states = self.query_dense(hidden_states) + key_states = self.key_dense(hidden_states) + value_states = self.value_dense(hidden_states) + + # Apply Q/K normalization. + if self.use_query_key_norm: + query_states = self.query_norm(query_states) + key_states = self.key_norm(key_states) + + # Apply RoPE to self-attention Q/K only (not cross-attention). + start_index = 0 if cache_update_index is None else cache_update_index + query_states = self.rotary_embedding( + query_states, start_index=start_index + ) + key_states = self.rotary_embedding(key_states, start_index=start_index) + + # Update self-attention cache. + if self_attention_cache is not None: + key_cache_self = self_attention_cache[:, 0, ...] + value_cache_self = self_attention_cache[:, 1, ...] + start = [0, cache_update_index, 0, 0] + key_states = keras.ops.slice_update( + key_cache_self, start, key_states + ) + value_states = keras.ops.slice_update( + value_cache_self, start, value_states + ) + updated_self_cache = keras.ops.stack( + (key_states, value_states), axis=1 + ) + else: + updated_self_cache = keras.ops.stack( + (key_states, value_states), axis=1 + ) + + # Cross-attention K/V from encoder hidden states. + if cross_attention_cache is not None: + # Reuse cached encoder K/V. + cross_key_states = cross_attention_cache[:, 0, ...] + cross_value_states = cross_attention_cache[:, 1, ...] + updated_cross_cache = cross_attention_cache + else: + cross_key_states = self.key_dense(encoder_hidden_states) + cross_value_states = self.value_dense(encoder_hidden_states) + # Apply K normalization to cross-attention keys. + if self.use_query_key_norm: + cross_key_states = self.key_norm(cross_key_states) + updated_cross_cache = keras.ops.stack( + (cross_key_states, cross_value_states), axis=1 + ) + + # Merge self-attention and cross-attention K/V. + merged_key_states = keras.ops.concatenate( + [key_states, cross_key_states], axis=1 + ) + merged_value_states = keras.ops.concatenate( + [value_states, cross_value_states], axis=1 + ) + + # Repeat K/V for GQA. + merged_key_states = repeat_kv( + merged_key_states, self.num_key_value_groups + ) + merged_value_states = repeat_kv( + merged_value_states, self.num_key_value_groups + ) + + attn_output = self._compute_attention( + query_states, + merged_key_states, + merged_value_states, + attention_mask, + training, + ) + attn_output = self.output_dense(attn_output) + updated_cache = (updated_self_cache, updated_cross_cache) + return attn_output, updated_cache + + def compute_output_shape(self, input_shape): + if isinstance(input_shape, (list, tuple)): + decoder_shape, encoder_shape = input_shape + else: + decoder_shape = input_shape + encoder_shape = input_shape + attn_output_shape = decoder_shape + dec_kv_len = decoder_shape[1] + enc_kv_len = encoder_shape[1] + self_cache_shape = ( + decoder_shape[0], + 2, + dec_kv_len, + self.num_key_value_heads, + self.head_dim, + ) + cross_cache_shape = ( + decoder_shape[0], + 2, + enc_kv_len, + self.num_key_value_heads, + self.head_dim, + ) + return attn_output_shape, (self_cache_shape, cross_cache_shape) + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "head_dim": self.head_dim, + "num_attention_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "query_pre_attn_scalar": self.query_pre_attn_scalar, + "attention_bias": self.attention_bias, + "cross_attention_hidden_size": ( + self.cross_attention_hidden_size + ), + "initializer_range": self.initializer_range, + "attention_dropout": self.attention_dropout, + "attn_logit_softcapping": self.logit_soft_cap, + "rope_max_wavelength": self.rope_max_wavelength, + "use_query_key_norm": self.use_query_key_norm, + "rms_norm_eps": self.rms_norm_eps, + } + ) + return config diff --git a/keras_hub/src/models/t5gemma2/t5gemma2_backbone.py b/keras_hub/src/models/t5gemma2/t5gemma2_backbone.py new file mode 100644 index 0000000000..58bd719c51 --- /dev/null +++ b/keras_hub/src/models/t5gemma2/t5gemma2_backbone.py @@ -0,0 +1,357 @@ +import keras +from keras.layers import ReversibleEmbedding + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.gemma.rms_normalization import RMSNormalization +from keras_hub.src.models.t5gemma2.t5gemma2_decoder import T5Gemma2DecoderLayer +from keras_hub.src.models.t5gemma2.t5gemma2_encoder import T5Gemma2EncoderLayer +from keras_hub.src.models.t5gemma2.t5gemma2_layers import ( + t5gemma2_kernel_initializer, +) +from keras_hub.src.utils.keras_utils import clone_initializer + + +@keras_hub_export("keras_hub.models.T5Gemma2Backbone") +class T5Gemma2Backbone(Backbone): + """T5Gemma2 backbone model. + + This class implements the encoder-decoder backbone of the T5Gemma2 + model. T5Gemma2 is based on Gemma3 and features merged + self+cross attention in the decoder (unlike T5Gemma1 which used + separate attention sublayers), Gemma3-style Q/K normalization, + and per-layer-type sliding window attention patterns. + + Args: + vocabulary_size: int, The size of the vocabulary. + encoder_hidden_dim: int, Encoder hidden dimensionality. + encoder_intermediate_dim: int, Encoder FFN intermediate size. + encoder_num_layers: int, Number of encoder layers. + encoder_num_attention_heads: int, Encoder attention heads. + encoder_num_key_value_heads: int, Encoder KV heads for GQA. + encoder_head_dim: int, Encoder head dimensionality. + encoder_layer_types: list of str, Attention layer types for + each encoder layer (`"full_attention"` or + `"sliding_attention"`). + decoder_hidden_dim: int, Decoder hidden dimensionality. + decoder_intermediate_dim: int, Decoder FFN intermediate size. + decoder_num_layers: int, Number of decoder layers. + decoder_num_attention_heads: int, Decoder attention heads. + decoder_num_key_value_heads: int, Decoder KV heads for GQA. + decoder_head_dim: int, Decoder head dimensionality. + decoder_layer_types: list of str, Attention layer types for + each decoder layer. + dropout_rate: float, Dropout rate. Defaults to `0.0`. + rms_norm_eps: float, RMS normalization epsilon. Defaults to + `1e-6`. + query_pre_attn_scalar: float, Query scalar. Defaults to `1.0`. + attention_bias: bool, Attention bias. Defaults to `False`. + hidden_activation: str, FFN activation. Defaults to + `"gelu_approximate"`. + tie_word_embeddings: bool, Tie input/output embeddings. + Defaults to `True`. + initializer_range: float, Initializer range. Defaults to + `0.02`. + attention_dropout: float, Attention dropout. Defaults to `0.0`. + sliding_window: int, optional, Sliding window size. + cross_attention_hidden_size: int, optional, Cross-attention + hidden size. Defaults to `encoder_hidden_dim`. + attn_logit_softcapping: float, optional, Attention softcapping. + final_logit_softcapping: float, optional, Final logit + softcapping. + rope_max_wavelength: float, RoPE maximum wavelength. + Defaults to `10000.0`. + use_query_key_norm: bool, Whether to use Gemma3-style Q/K + normalization. Defaults to `True`. + dtype: dtype for computations. Defaults to `None`. + **kwargs: Additional keyword arguments. + + Examples: + ```python + import numpy as np + from keras_hub.models import T5Gemma2Backbone + + input_data = { + "encoder_token_ids": np.ones(shape=(1, 12), dtype="int32"), + "encoder_padding_mask": np.array( + [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]], dtype="int32" + ), + "decoder_token_ids": np.ones(shape=(1, 8), dtype="int32"), + "decoder_padding_mask": np.array( + [[1, 1, 1, 1, 1, 1, 1, 1]], dtype="int32" + ), + } + + model = T5Gemma2Backbone( + vocabulary_size=32000, + encoder_hidden_dim=256, + encoder_intermediate_dim=512, + encoder_num_layers=4, + encoder_num_attention_heads=4, + encoder_num_key_value_heads=2, + encoder_head_dim=64, + encoder_layer_types=["full_attention"] * 4, + decoder_hidden_dim=256, + decoder_intermediate_dim=512, + decoder_num_layers=4, + decoder_num_attention_heads=4, + decoder_num_key_value_heads=2, + decoder_head_dim=64, + decoder_layer_types=["full_attention"] * 4, + dropout_rate=0.1, + rms_norm_eps=1e-6, + query_pre_attn_scalar=1.0, + attention_bias=False, + hidden_activation="gelu_approximate", + ) + output = model(input_data) + ``` + """ + + def __init__( + self, + vocabulary_size, + encoder_hidden_dim, + encoder_intermediate_dim, + encoder_num_layers, + encoder_num_attention_heads, + encoder_num_key_value_heads, + encoder_head_dim, + encoder_layer_types, + decoder_hidden_dim, + decoder_intermediate_dim, + decoder_num_layers, + decoder_num_attention_heads, + decoder_num_key_value_heads, + decoder_head_dim, + decoder_layer_types, + dropout_rate=0.0, + rms_norm_eps=1e-6, + query_pre_attn_scalar=1.0, + attention_bias=False, + hidden_activation="gelu_approximate", + tie_word_embeddings=True, + initializer_range=0.02, + attention_dropout=0.0, + sliding_window=None, + cross_attention_hidden_size=None, + attn_logit_softcapping=None, + final_logit_softcapping=None, + rope_max_wavelength=10000.0, + use_query_key_norm=True, + dtype=None, + **kwargs, + ): + self.kernel_initializer = t5gemma2_kernel_initializer(initializer_range) + + # === Layers === + self.token_embedding = keras.layers.Embedding( + input_dim=vocabulary_size, + output_dim=encoder_hidden_dim, + embeddings_initializer=clone_initializer(self.kernel_initializer), + dtype=dtype, + name="encoder_token_embedding", + ) + self.decoder_token_embedding = ReversibleEmbedding( + input_dim=vocabulary_size, + output_dim=decoder_hidden_dim, + tie_weights=tie_word_embeddings, + embeddings_initializer=clone_initializer(self.kernel_initializer), + dtype=dtype, + name="decoder_token_embedding", + ) + self.encoder_layers = [ + T5Gemma2EncoderLayer( + hidden_size=encoder_hidden_dim, + rms_norm_eps=rms_norm_eps, + num_attention_heads=encoder_num_attention_heads, + num_key_value_heads=encoder_num_key_value_heads, + query_pre_attn_scalar=query_pre_attn_scalar, + attention_bias=attention_bias, + intermediate_size=encoder_intermediate_dim, + hidden_activation=hidden_activation, + head_dim=encoder_head_dim, + dropout_rate=dropout_rate, + initializer_range=initializer_range, + attention_dropout=attention_dropout, + layer_type=encoder_layer_types[i], + sliding_window=sliding_window, + attn_logit_softcapping=attn_logit_softcapping, + rope_max_wavelength=rope_max_wavelength, + use_query_key_norm=use_query_key_norm, + name=f"encoder_layer_{i}", + dtype=dtype, + ) + for i in range(encoder_num_layers) + ] + self.encoder_norm = RMSNormalization(epsilon=rms_norm_eps, dtype=dtype) + self.encoder_dropout = keras.layers.Dropout(dropout_rate, dtype=dtype) + self.decoder_layers = [ + T5Gemma2DecoderLayer( + hidden_size=decoder_hidden_dim, + rms_norm_eps=rms_norm_eps, + num_attention_heads=decoder_num_attention_heads, + num_key_value_heads=decoder_num_key_value_heads, + query_pre_attn_scalar=query_pre_attn_scalar, + attention_bias=attention_bias, + intermediate_size=decoder_intermediate_dim, + hidden_activation=hidden_activation, + dropout_rate=dropout_rate, + initializer_range=initializer_range, + head_dim=decoder_head_dim, + attention_dropout=attention_dropout, + layer_type=decoder_layer_types[i], + sliding_window=sliding_window, + cross_attention_hidden_size=( + cross_attention_hidden_size or encoder_hidden_dim + ), + attn_logit_softcapping=attn_logit_softcapping, + rope_max_wavelength=rope_max_wavelength, + use_query_key_norm=use_query_key_norm, + name=f"decoder_layer_{i}", + dtype=dtype, + ) + for i in range(decoder_num_layers) + ] + self.decoder_norm = RMSNormalization(epsilon=rms_norm_eps, dtype=dtype) + self.decoder_dropout = keras.layers.Dropout(dropout_rate, dtype=dtype) + + # === Functional Model === + encoder_token_id_input = keras.Input( + shape=(None,), dtype="int32", name="encoder_token_ids" + ) + encoder_padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="encoder_padding_mask" + ) + decoder_token_id_input = keras.Input( + shape=(None,), dtype="int32", name="decoder_token_ids" + ) + decoder_padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="decoder_padding_mask" + ) + + # Encoder. + encoder_embeddings = self.token_embedding(encoder_token_id_input) + encoder_embeddings = encoder_embeddings * keras.ops.cast( + keras.ops.sqrt(encoder_hidden_dim), encoder_embeddings.dtype + ) + encoder_hidden_states = self.encoder_dropout(encoder_embeddings) + for layer in self.encoder_layers: + encoder_hidden_states = layer( + encoder_hidden_states, + padding_mask=encoder_padding_mask_input, + ) + encoder_output = self.encoder_norm(encoder_hidden_states) + encoder_output = self.encoder_dropout(encoder_output) + + # Decoder. + decoder_embeddings = self.decoder_token_embedding( + decoder_token_id_input + ) + decoder_embeddings = decoder_embeddings * keras.ops.cast( + keras.ops.sqrt(decoder_hidden_dim), decoder_embeddings.dtype + ) + decoder_hidden_states = self.decoder_dropout(decoder_embeddings) + for layer in self.decoder_layers: + decoder_hidden_states, _ = layer( + (decoder_hidden_states, encoder_output), + self_attention_padding_mask=decoder_padding_mask_input, + cross_attention_padding_mask=encoder_padding_mask_input, + ) + decoder_output = self.decoder_norm(decoder_hidden_states) + decoder_output = self.decoder_dropout(decoder_output) + + super().__init__( + inputs={ + "encoder_token_ids": encoder_token_id_input, + "encoder_padding_mask": encoder_padding_mask_input, + "decoder_token_ids": decoder_token_id_input, + "decoder_padding_mask": decoder_padding_mask_input, + }, + outputs={ + "encoder_sequence_output": encoder_output, + "decoder_sequence_output": decoder_output, + }, + dtype=dtype, + **kwargs, + ) + + # === Config === + self.encoder_hidden_dim = encoder_hidden_dim + self.encoder_intermediate_dim = encoder_intermediate_dim + self.encoder_num_layers = encoder_num_layers + self.encoder_num_attention_heads = encoder_num_attention_heads + self.encoder_num_key_value_heads = encoder_num_key_value_heads + self.encoder_head_dim = encoder_head_dim + self.encoder_layer_types = encoder_layer_types + self.decoder_hidden_dim = decoder_hidden_dim + self.decoder_intermediate_dim = decoder_intermediate_dim + self.decoder_num_layers = decoder_num_layers + self.decoder_num_attention_heads = decoder_num_attention_heads + self.decoder_num_key_value_heads = decoder_num_key_value_heads + self.decoder_head_dim = decoder_head_dim + self.decoder_layer_types = decoder_layer_types + self.vocabulary_size = vocabulary_size + self.dropout_rate = dropout_rate + self.rms_norm_eps = rms_norm_eps + self.tie_word_embeddings = tie_word_embeddings + self.query_pre_attn_scalar = query_pre_attn_scalar + self.attention_bias = attention_bias + self.hidden_activation = hidden_activation + self.initializer_range = initializer_range + self.attention_dropout = attention_dropout + self.sliding_window = sliding_window + self.cross_attention_hidden_size = ( + cross_attention_hidden_size or encoder_hidden_dim + ) + self.attn_logit_softcapping = attn_logit_softcapping + self.final_logit_softcapping = final_logit_softcapping + self.rope_max_wavelength = rope_max_wavelength + self.use_query_key_norm = use_query_key_norm + + def get_config(self): + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "encoder_hidden_dim": self.encoder_hidden_dim, + "encoder_intermediate_dim": self.encoder_intermediate_dim, + "encoder_num_layers": self.encoder_num_layers, + "encoder_num_attention_heads": ( + self.encoder_num_attention_heads + ), + "encoder_num_key_value_heads": ( + self.encoder_num_key_value_heads + ), + "encoder_layer_types": self.encoder_layer_types, + "encoder_head_dim": self.encoder_head_dim, + "decoder_hidden_dim": self.decoder_hidden_dim, + "decoder_intermediate_dim": self.decoder_intermediate_dim, + "decoder_num_layers": self.decoder_num_layers, + "decoder_num_attention_heads": ( + self.decoder_num_attention_heads + ), + "decoder_num_key_value_heads": ( + self.decoder_num_key_value_heads + ), + "decoder_layer_types": self.decoder_layer_types, + "decoder_head_dim": self.decoder_head_dim, + "dropout_rate": self.dropout_rate, + "rms_norm_eps": self.rms_norm_eps, + "tie_word_embeddings": self.tie_word_embeddings, + "query_pre_attn_scalar": self.query_pre_attn_scalar, + "attention_bias": self.attention_bias, + "hidden_activation": self.hidden_activation, + "initializer_range": self.initializer_range, + "attention_dropout": self.attention_dropout, + "sliding_window": self.sliding_window, + "cross_attention_hidden_size": ( + self.cross_attention_hidden_size + ), + "attn_logit_softcapping": self.attn_logit_softcapping, + "final_logit_softcapping": (self.final_logit_softcapping), + "rope_max_wavelength": self.rope_max_wavelength, + "use_query_key_norm": self.use_query_key_norm, + } + ) + return config diff --git a/keras_hub/src/models/t5gemma2/t5gemma2_backbone_test.py b/keras_hub/src/models/t5gemma2/t5gemma2_backbone_test.py new file mode 100644 index 0000000000..f95d3498a4 --- /dev/null +++ b/keras_hub/src/models/t5gemma2/t5gemma2_backbone_test.py @@ -0,0 +1,116 @@ +import keras +import pytest + +from keras_hub.src.models.t5gemma2.t5gemma2_backbone import T5Gemma2Backbone +from keras_hub.src.tests.test_case import TestCase + + +class T5Gemma2BackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "vocabulary_size": 100, + "encoder_hidden_dim": 32, + "encoder_intermediate_dim": 64, + "encoder_num_layers": 2, + "encoder_num_attention_heads": 4, + "encoder_num_key_value_heads": 2, + "encoder_head_dim": 8, + "encoder_layer_types": [ + "sliding_attention", + "full_attention", + ], + "decoder_hidden_dim": 32, + "decoder_intermediate_dim": 64, + "decoder_num_layers": 2, + "decoder_num_attention_heads": 4, + "decoder_num_key_value_heads": 2, + "decoder_head_dim": 8, + "decoder_layer_types": [ + "sliding_attention", + "full_attention", + ], + "dropout_rate": 0.1, + "rms_norm_eps": 1e-6, + "tie_word_embeddings": True, + "query_pre_attn_scalar": 1.0, + "attention_bias": False, + "hidden_activation": "gelu_approximate", + "sliding_window": 16, + "cross_attention_hidden_size": 32, + "attn_logit_softcapping": 50.0, + "rope_max_wavelength": 10000.0, + "initializer_range": 0.04, + "attention_dropout": 0.1, + "use_query_key_norm": True, + } + self.input_data = { + "encoder_token_ids": keras.ops.ones((2, 16), dtype="int32"), + "encoder_padding_mask": keras.ops.ones((2, 16), dtype="int32"), + "decoder_token_ids": keras.ops.ones((2, 16), dtype="int32"), + "decoder_padding_mask": keras.ops.ones((2, 16), dtype="int32"), + } + + def test_backbone_basics(self): + self.run_backbone_test( + cls=T5Gemma2Backbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape={ + "encoder_sequence_output": (2, 16, 32), + "decoder_sequence_output": (2, 16, 32), + }, + ) + + def test_asymmetrical_backbone(self): + asym_kwargs = { + "vocabulary_size": 100, + "encoder_hidden_dim": 48, + "encoder_intermediate_dim": 96, + "encoder_num_layers": 3, + "encoder_num_attention_heads": 6, + "encoder_num_key_value_heads": 3, + "encoder_head_dim": 8, + "encoder_layer_types": ["full_attention"] * 3, + "decoder_hidden_dim": 32, + "decoder_intermediate_dim": 64, + "decoder_num_layers": 2, + "decoder_num_attention_heads": 4, + "decoder_num_key_value_heads": 2, + "decoder_head_dim": 8, + "decoder_layer_types": [ + "sliding_attention", + "full_attention", + ], + "sliding_window": 16, + "dropout_rate": 0.1, + "rms_norm_eps": 1e-6, + "tie_word_embeddings": True, + "cross_attention_hidden_size": 48, + "use_query_key_norm": True, + } + self.run_backbone_test( + cls=T5Gemma2Backbone, + init_kwargs=asym_kwargs, + input_data=self.input_data, + expected_output_shape={ + "encoder_sequence_output": (2, 16, 48), + "decoder_sequence_output": (2, 16, 32), + }, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=T5Gemma2Backbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in T5Gemma2Backbone.presets: + self.run_preset_test( + cls=T5Gemma2Backbone, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/t5gemma2/t5gemma2_decoder.py b/keras_hub/src/models/t5gemma2/t5gemma2_decoder.py new file mode 100644 index 0000000000..4106ffad15 --- /dev/null +++ b/keras_hub/src/models/t5gemma2/t5gemma2_decoder.py @@ -0,0 +1,346 @@ +import keras + +from keras_hub.src.models.gemma.rms_normalization import RMSNormalization +from keras_hub.src.models.t5gemma2.t5gemma2_attention import ( + T5Gemma2MergedAttention, +) +from keras_hub.src.models.t5gemma2.t5gemma2_layers import T5Gemma2MLP + + +class T5Gemma2DecoderLayer(keras.layers.Layer): + """Decoder layer for the T5Gemma2 model. + + This layer implements a single decoder block in the T5Gemma2 + architecture. Unlike T5Gemma1 which has separate self-attention and + cross-attention sub-layers, T5Gemma2 uses a single + `T5Gemma2MergedAttention` layer that fuses self-attention and + cross-attention by concatenating their K/V pairs. + + Args: + hidden_size: int, Dimensionality of hidden states. + rms_norm_eps: float, Epsilon for RMS normalization. + num_attention_heads: int, Number of attention heads. + num_key_value_heads: int, Number of key-value heads for GQA. + query_pre_attn_scalar: float, Scalar for query normalization. + attention_bias: bool, Whether to include bias. + intermediate_size: int, Intermediate size of the FFN. + hidden_activation: str, Activation function for the FFN. + dropout_rate: float, Dropout rate. + head_dim: int, Dimensionality of each attention head. + initializer_range: float, Range for the initializer. + attention_dropout: float, Dropout for attention weights. + layer_type: str, Either `"full_attention"` or + `"sliding_attention"`. + cross_attention_hidden_size: int, optional, Hidden size for + cross-attention. Defaults to `hidden_size`. + attn_logit_softcapping: float, optional, Softcapping value. + sliding_window: int, optional, Window size for sliding + attention. + rope_max_wavelength: float, Maximum wavelength for RoPE. + use_query_key_norm: bool, Whether to apply Q/K norm. + Defaults to `True`. + dtype: The dtype for computations. Defaults to `None`. + **kwargs: Additional keyword arguments. + """ + + def __init__( + self, + hidden_size, + rms_norm_eps, + num_attention_heads, + num_key_value_heads, + query_pre_attn_scalar, + attention_bias, + intermediate_size, + hidden_activation, + dropout_rate, + head_dim, + initializer_range, + attention_dropout, + layer_type, + cross_attention_hidden_size=None, + attn_logit_softcapping=None, + sliding_window=None, + rope_max_wavelength=10000.0, + use_query_key_norm=True, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.head_dim = head_dim + self.hidden_size = hidden_size + self.rms_norm_eps = rms_norm_eps + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.query_pre_attn_scalar = query_pre_attn_scalar + self.attention_bias = attention_bias + self.intermediate_size = intermediate_size + self.hidden_activation = hidden_activation + self.dropout_rate = dropout_rate + self.initializer_range = initializer_range + self.attention_dropout = attention_dropout + self.attention_type = layer_type + self.sliding_window = sliding_window + self.rope_max_wavelength = rope_max_wavelength + self.cross_attention_hidden_size = cross_attention_hidden_size + self.attn_logit_softcapping = attn_logit_softcapping + self.use_query_key_norm = use_query_key_norm + + if ( + self.attention_type == "sliding_attention" + and self.sliding_window is None + ): + raise ValueError( + "`sliding_window` must be set for `sliding_attention` " + "layer type." + ) + + # Merged self+cross attention. + self.merged_attn = T5Gemma2MergedAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + query_pre_attn_scalar=query_pre_attn_scalar, + attention_bias=attention_bias, + head_dim=self.head_dim, + cross_attention_hidden_size=( + cross_attention_hidden_size or hidden_size + ), + initializer_range=initializer_range, + attention_dropout=attention_dropout, + attn_logit_softcapping=attn_logit_softcapping, + rope_max_wavelength=self.rope_max_wavelength, + use_query_key_norm=use_query_key_norm, + rms_norm_eps=rms_norm_eps, + dtype=self.dtype_policy, + name="merged_attention", + ) + self.pre_self_attn_layernorm = RMSNormalization( + epsilon=rms_norm_eps, + dtype=self.dtype_policy, + name="decoder_pre_self_attention_layernorm", + ) + self.post_self_attn_layernorm = RMSNormalization( + epsilon=rms_norm_eps, + dtype=self.dtype_policy, + name="decoder_post_self_attention_layernorm", + ) + + # MLP. + self.mlp = T5Gemma2MLP( + hidden_size, + intermediate_size, + hidden_activation, + dropout_rate, + initializer_range=initializer_range, + dtype=self.dtype_policy, + name="mlp", + ) + self.pre_feedforward_layernorm = RMSNormalization( + epsilon=rms_norm_eps, + dtype=self.dtype_policy, + name="decoder_pre_feedforward_layernorm", + ) + self.post_feedforward_layernorm = RMSNormalization( + epsilon=rms_norm_eps, + dtype=self.dtype_policy, + name="decoder_post_feedforward_layernorm", + ) + + self.dropout = keras.layers.Dropout( + dropout_rate, + dtype=self.dtype_policy, + name="decoder_residual_dropout", + ) + + def build(self, input_shape): + hidden_states_shape, encoder_hidden_states_shape = input_shape + self.pre_self_attn_layernorm.build(hidden_states_shape) + self.merged_attn.build( + [hidden_states_shape, encoder_hidden_states_shape] + ) + attn_output_shape, _ = self.merged_attn.compute_output_shape( + [hidden_states_shape, encoder_hidden_states_shape] + ) + self.post_self_attn_layernorm.build(attn_output_shape) + self.dropout.build(attn_output_shape) + self.pre_feedforward_layernorm.build(attn_output_shape) + self.mlp.build(attn_output_shape) + mlp_output_shape = self.mlp.compute_output_shape(attn_output_shape) + self.post_feedforward_layernorm.build(mlp_output_shape) + self.built = True + + def _make_causal_mask( + self, + hidden_states, + padding_mask, + cache=None, + cache_update_index=None, + ): + """Creates a causal attention mask for self-attention.""" + if cache is not None: + q_len = keras.ops.shape(hidden_states)[1] + kv_len = keras.ops.shape(cache)[2] + q_indices = ( + keras.ops.arange(0, q_len, dtype="int32") + cache_update_index + ) + kv_indices = keras.ops.arange(0, kv_len, dtype="int32") + else: + q_len = kv_len = keras.ops.shape(hidden_states)[1] + q_indices = keras.ops.arange(0, q_len, dtype="int32") + kv_indices = keras.ops.arange(0, kv_len, dtype="int32") + causal_mask = kv_indices[None, :] <= q_indices[:, None] + if self.attention_type == "sliding_attention": + sliding_mask = ( + q_indices[:, None] - self.sliding_window + ) <= kv_indices[None, :] + causal_mask = keras.ops.logical_and(causal_mask, sliding_mask) + final_mask = causal_mask[None, None, :, :] + if padding_mask is not None: + padding_mask_slice = padding_mask[:, :kv_len] + padding_mask_4d = padding_mask_slice[:, None, None, :] + final_mask = keras.ops.logical_and(final_mask, padding_mask_4d) + return (1.0 - keras.ops.cast(final_mask, hidden_states.dtype)) * -1e9 + + def _make_cross_attention_mask(self, hidden_states, padding_mask): + """Creates a bidirectional mask for cross-attention.""" + if padding_mask is None: + return None + q_len = keras.ops.shape(hidden_states)[1] + bidirectional_mask = padding_mask[:, None, None, :] + # Broadcast to (batch, 1, q_len, enc_len). + bidirectional_mask = keras.ops.broadcast_to( + bidirectional_mask, + ( + keras.ops.shape(hidden_states)[0], + 1, + q_len, + keras.ops.shape(padding_mask)[1], + ), + ) + additive_mask = ( + 1.0 - keras.ops.cast(bidirectional_mask, hidden_states.dtype) + ) * -1e9 + return additive_mask + + def call( + self, + inputs, + self_attention_padding_mask=None, + cross_attention_padding_mask=None, + cache=None, + cache_update_index=None, + training=None, + ): + """Forward pass of the decoder layer. + + Args: + inputs: Tuple of (hidden_states, encoder_hidden_states). + self_attention_padding_mask: Padding mask for decoder + tokens. + cross_attention_padding_mask: Padding mask for encoder + tokens. + cache: Tuple of (self_attn_cache, cross_attn_cache). + cache_update_index: int, current position for caching. + training: bool, training mode. + + Returns: + Tuple of (hidden_states, updated_cache). + """ + hidden_states, encoder_hidden_states = inputs + self_attention_cache, cross_attention_cache = ( + cache if cache is not None else (None, None) + ) + + # Build the merged attention mask. + self_attention_mask = self._make_causal_mask( + hidden_states, + self_attention_padding_mask, + cache=self_attention_cache, + cache_update_index=cache_update_index, + ) + cross_attention_mask = self._make_cross_attention_mask( + hidden_states, cross_attention_padding_mask + ) + + # Concatenate self and cross masks along the KV dimension. + if cross_attention_mask is not None: + merged_mask = keras.ops.concatenate( + [self_attention_mask, cross_attention_mask], axis=-1 + ) + else: + merged_mask = self_attention_mask + + # Merged attention: self + cross. + residual = hidden_states + hidden_states = self.pre_self_attn_layernorm(hidden_states) + hidden_states, updated_cache = self.merged_attn( + inputs=hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=merged_mask, + cache=(self_attention_cache, cross_attention_cache), + cache_update_index=cache_update_index, + training=training, + ) + hidden_states = self.post_self_attn_layernorm(hidden_states) + hidden_states = residual + self.dropout( + hidden_states, training=training + ) + + # MLP. + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states, training=training) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + self.dropout( + hidden_states, training=training + ) + return hidden_states, updated_cache + + def compute_output_shape(self, input_shape): + hidden_states_shape, encoder_hidden_states_shape = input_shape + batch_size, dec_seq_len, _ = hidden_states_shape + _, enc_seq_len, _ = encoder_hidden_states_shape + self_cache_shape = ( + batch_size, + 2, + dec_seq_len, + self.num_key_value_heads, + self.head_dim, + ) + cross_cache_shape = ( + batch_size, + 2, + enc_seq_len, + self.num_key_value_heads, + self.head_dim, + ) + return hidden_states_shape, (self_cache_shape, cross_cache_shape) + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "rms_norm_eps": self.rms_norm_eps, + "num_attention_heads": self.num_attention_heads, + "num_key_value_heads": self.num_key_value_heads, + "query_pre_attn_scalar": self.query_pre_attn_scalar, + "attention_bias": self.attention_bias, + "intermediate_size": self.intermediate_size, + "hidden_activation": self.hidden_activation, + "dropout_rate": self.dropout_rate, + "initializer_range": self.initializer_range, + "attention_dropout": self.attention_dropout, + "layer_type": self.attention_type, + "sliding_window": self.sliding_window, + "rope_max_wavelength": self.rope_max_wavelength, + "head_dim": self.head_dim, + "cross_attention_hidden_size": ( + self.cross_attention_hidden_size + ), + "attn_logit_softcapping": self.attn_logit_softcapping, + "use_query_key_norm": self.use_query_key_norm, + } + ) + return config diff --git a/keras_hub/src/models/t5gemma2/t5gemma2_encoder.py b/keras_hub/src/models/t5gemma2/t5gemma2_encoder.py new file mode 100644 index 0000000000..75e117afb4 --- /dev/null +++ b/keras_hub/src/models/t5gemma2/t5gemma2_encoder.py @@ -0,0 +1,219 @@ +import keras + +from keras_hub.src.models.gemma.rms_normalization import RMSNormalization +from keras_hub.src.models.t5gemma2.t5gemma2_attention import T5Gemma2Attention +from keras_hub.src.models.t5gemma2.t5gemma2_layers import T5Gemma2MLP + + +class T5Gemma2EncoderLayer(keras.layers.Layer): + """Encoder layer for the T5Gemma2 model. + + This layer implements a single encoder block in the T5Gemma2 + architecture, comprising bidirectional self-attention and a + feed-forward network (MLP). It uses Gemma3-style Q/K normalization. + + Each encoder layer has an `attention_type` attribute that specifies + whether it uses `"full_attention"` or `"sliding_attention"`. The + backbone uses this to route the correct RoPE embeddings and + attention masks. + + Args: + hidden_size: int, Dimensionality of hidden states. + rms_norm_eps: float, Epsilon for RMS normalization. + num_attention_heads: int, Number of attention heads. + num_key_value_heads: int, Number of key-value heads for GQA. + query_pre_attn_scalar: float, Scalar for query normalization. + attention_bias: bool, Whether to include bias. + intermediate_size: int, Intermediate size of the FFN. + hidden_activation: str, Activation function for the FFN. + dropout_rate: float, Dropout rate. + initializer_range: float, Range for the initializer. + attention_dropout: float, Dropout for attention weights. + layer_type: str, Either `"full_attention"` or + `"sliding_attention"`. + head_dim: int, Dimensionality of each attention head. + attn_logit_softcapping: float, optional, Softcapping value. + sliding_window: int, optional, Window size for sliding + attention. + rope_max_wavelength: float, Maximum wavelength for RoPE. + use_query_key_norm: bool, Whether to apply Q/K norm. + Defaults to `True`. + dtype: The dtype for computations. Defaults to `None`. + **kwargs: Additional keyword arguments. + """ + + def __init__( + self, + hidden_size, + rms_norm_eps, + num_attention_heads, + num_key_value_heads, + query_pre_attn_scalar, + attention_bias, + intermediate_size, + hidden_activation, + dropout_rate, + initializer_range, + attention_dropout, + layer_type, + head_dim, + attn_logit_softcapping=None, + sliding_window=None, + rope_max_wavelength=10000.0, + use_query_key_norm=True, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_size = hidden_size + self.rms_norm_eps = rms_norm_eps + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.query_pre_attn_scalar = query_pre_attn_scalar + self.attention_bias = attention_bias + self.intermediate_size = intermediate_size + self.hidden_activation = hidden_activation + self.dropout_rate = dropout_rate + self.initializer_range = initializer_range + self.attention_dropout = attention_dropout + self.attention_type = layer_type + self.sliding_window = sliding_window + self.rope_max_wavelength = rope_max_wavelength + self.head_dim = head_dim + self.attn_logit_softcapping = attn_logit_softcapping + self.use_query_key_norm = use_query_key_norm + + if ( + self.attention_type == "sliding_attention" + and self.sliding_window is None + ): + raise ValueError( + "`sliding_window` must be set for `sliding_attention` " + "layer type." + ) + + self.self_attn = T5Gemma2Attention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + query_pre_attn_scalar=query_pre_attn_scalar, + attention_bias=attention_bias, + head_dim=self.head_dim, + initializer_range=initializer_range, + attention_dropout=attention_dropout, + attn_logit_softcapping=attn_logit_softcapping, + rope_max_wavelength=self.rope_max_wavelength, + use_query_key_norm=use_query_key_norm, + rms_norm_eps=rms_norm_eps, + dtype=self.dtype_policy, + name="self_attention", + ) + self.pre_self_attn_layernorm = RMSNormalization( + epsilon=rms_norm_eps, + dtype=self.dtype_policy, + name="pre_self_attention_layernorm", + ) + self.post_self_attn_layernorm = RMSNormalization( + epsilon=rms_norm_eps, + dtype=self.dtype_policy, + name="post_self_attention_layernorm", + ) + + self.mlp = T5Gemma2MLP( + hidden_size, + intermediate_size, + hidden_activation, + dropout_rate, + initializer_range=initializer_range, + dtype=self.dtype_policy, + name="mlp", + ) + self.pre_feedforward_layernorm = RMSNormalization( + epsilon=rms_norm_eps, + dtype=self.dtype_policy, + name="pre_feedforward_layernorm", + ) + self.post_feedforward_layernorm = RMSNormalization( + epsilon=rms_norm_eps, + dtype=self.dtype_policy, + name="post_feedforward_layernorm", + ) + self.dropout = keras.layers.Dropout( + dropout_rate, + dtype=self.dtype_policy, + name="residual_dropout", + ) + + def build(self, input_shape): + self.pre_self_attn_layernorm.build(input_shape) + self.self_attn.build(input_shape) + attn_output_shape, _ = self.self_attn.compute_output_shape(input_shape) + self.post_self_attn_layernorm.build(attn_output_shape) + self.dropout.build(attn_output_shape) + self.pre_feedforward_layernorm.build(attn_output_shape) + self.mlp.build(attn_output_shape) + mlp_output_shape = self.mlp.compute_output_shape(attn_output_shape) + self.post_feedforward_layernorm.build(mlp_output_shape) + self.built = True + + def _make_attention_mask(self, hidden_states, padding_mask): + attention_mask = padding_mask[:, None, None, :] + additive_mask = ( + 1.0 - keras.ops.cast(attention_mask, hidden_states.dtype) + ) * -1e9 + return additive_mask + + def call( + self, + hidden_states, + padding_mask=None, + training=None, + ): + residual = hidden_states + attention_mask = self._make_attention_mask(hidden_states, padding_mask) + hidden_states = self.pre_self_attn_layernorm(hidden_states) + hidden_states, _ = self.self_attn( + inputs=hidden_states, + attention_mask=attention_mask, + training=training, + ) + hidden_states = self.post_self_attn_layernorm(hidden_states) + hidden_states = residual + self.dropout( + hidden_states, training=training + ) + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states, training=training) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + self.dropout( + hidden_states, training=training + ) + return hidden_states + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "rms_norm_eps": self.rms_norm_eps, + "head_dim": self.head_dim, + "num_attention_heads": self.num_attention_heads, + "num_key_value_heads": self.num_key_value_heads, + "query_pre_attn_scalar": self.query_pre_attn_scalar, + "attention_bias": self.attention_bias, + "intermediate_size": self.intermediate_size, + "hidden_activation": self.hidden_activation, + "dropout_rate": self.dropout_rate, + "initializer_range": self.initializer_range, + "attention_dropout": self.attention_dropout, + "layer_type": self.attention_type, + "sliding_window": self.sliding_window, + "rope_max_wavelength": self.rope_max_wavelength, + "attn_logit_softcapping": self.attn_logit_softcapping, + "use_query_key_norm": self.use_query_key_norm, + } + ) + return config diff --git a/keras_hub/src/models/t5gemma2/t5gemma2_layers.py b/keras_hub/src/models/t5gemma2/t5gemma2_layers.py new file mode 100644 index 0000000000..3329af7e06 --- /dev/null +++ b/keras_hub/src/models/t5gemma2/t5gemma2_layers.py @@ -0,0 +1,118 @@ +import keras + +from keras_hub.src.utils.keras_utils import clone_initializer + + +def t5gemma2_kernel_initializer(initializer_range=0.01): + """Creates a RandomNormal initializer for T5Gemma2 kernels. + + Args: + initializer_range: float, The standard deviation of the normal + distribution. Defaults to `0.01`. + + Returns: + keras.initializers.RandomNormal: A Keras RandomNormal initializer. + """ + return keras.initializers.RandomNormal(mean=0.0, stddev=initializer_range) + + +class T5Gemma2MLP(keras.layers.Layer): + """Multilayer Perceptron (MLP) block for the T5Gemma2 model. + + This layer implements the feed-forward part of a transformer block, + consisting of a gated GELU activation and dropout, following the + Gemma3 architecture pattern. + + Args: + hidden_size: int, The dimensionality of the input and output hidden + states. + intermediate_size: int, The dimensionality of the intermediate layer. + hidden_activation: str, The activation function to use, e.g., + "gelu_approximate". + dropout_rate: float, The dropout rate applied to the intermediate + hidden states. + initializer_range: float, The range for the random normal initializer + for kernel weights. Defaults to `0.02`. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for model computations and weights. Defaults to `None`. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + def __init__( + self, + hidden_size, + intermediate_size, + hidden_activation, + dropout_rate, + initializer_range=0.02, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.hidden_activation = hidden_activation + self.dropout_rate = dropout_rate + self.initializer_range = initializer_range + self.kernel_initializer = t5gemma2_kernel_initializer(initializer_range) + + self.gate_proj = keras.layers.Dense( + self.intermediate_size, + use_bias=False, + kernel_initializer=clone_initializer(self.kernel_initializer), + dtype=self.dtype_policy, + name="gate_proj", + ) + self.up_proj = keras.layers.Dense( + self.intermediate_size, + use_bias=False, + kernel_initializer=clone_initializer(self.kernel_initializer), + dtype=self.dtype_policy, + name="up_proj", + ) + self.down_proj = keras.layers.Dense( + self.hidden_size, + use_bias=False, + kernel_initializer=clone_initializer(self.kernel_initializer), + dtype=self.dtype_policy, + name="down_proj", + ) + if self.hidden_activation == "gelu_approximate": + self.act_fn = lambda x: keras.activations.gelu(x, approximate=True) + else: + self.act_fn = keras.activations.get(self.hidden_activation) + self.dropout = keras.layers.Dropout( + self.dropout_rate, + dtype=self.dtype_policy, + name="mlp_dropout", + ) + + def build(self, input_shape): + self.gate_proj.build(input_shape) + self.up_proj.build(input_shape) + intermediate_shape = self.gate_proj.compute_output_shape(input_shape) + self.dropout.build(intermediate_shape) + self.down_proj.build(intermediate_shape) + self.built = True + + def compute_output_shape(self, input_shape): + return input_shape + + def call(self, x, training=None): + hidden_states = self.act_fn(self.gate_proj(x)) * self.up_proj(x) + hidden_states = self.dropout(hidden_states, training=training) + down_proj = self.down_proj(hidden_states) + return down_proj + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "intermediate_size": self.intermediate_size, + "hidden_activation": self.hidden_activation, + "dropout_rate": self.dropout_rate, + "initializer_range": self.initializer_range, + } + ) + return config diff --git a/keras_hub/src/models/t5gemma2/t5gemma2_presets.py b/keras_hub/src/models/t5gemma2/t5gemma2_presets.py new file mode 100644 index 0000000000..52a66ceed6 --- /dev/null +++ b/keras_hub/src/models/t5gemma2/t5gemma2_presets.py @@ -0,0 +1,4 @@ +backbone_presets = { + # Placeholder presets — will be populated when checkpoints + # are made available. +} diff --git a/keras_hub/src/models/t5gemma2/t5gemma2_seq_2_seq_lm.py b/keras_hub/src/models/t5gemma2/t5gemma2_seq_2_seq_lm.py new file mode 100644 index 0000000000..025228aa61 --- /dev/null +++ b/keras_hub/src/models/t5gemma2/t5gemma2_seq_2_seq_lm.py @@ -0,0 +1,331 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM +from keras_hub.src.models.t5gemma2.t5gemma2_backbone import T5Gemma2Backbone +from keras_hub.src.models.t5gemma2.t5gemma2_seq_2_seq_lm_preprocessor import ( + T5Gemma2Seq2SeqLMPreprocessor, +) +from keras_hub.src.utils.tensor_utils import any_equal + + +@keras_hub_export("keras_hub.models.T5Gemma2Seq2SeqLM") +class T5Gemma2Seq2SeqLM(Seq2SeqLM): + """An end-to-end T5Gemma2 model for seq2seq language modeling. + + A seq2seq language model (LM) is an encoder-decoder model which is + used for conditional text generation. The encoder is given a + "context" text (fed to the encoder), and the decoder predicts the + next token based on both the encoder inputs and the previous tokens. + + T5Gemma2 extends T5Gemma1 by using Gemma3-based components, with + merged self+cross attention in the decoder, Gemma3-style Q/K + normalization, and per-layer-type RoPE. + + This model has a `generate()` method, which generates text based on + a prompt. The generation strategy used is controlled by an additional + `sampler` argument on `compile()`. + + Args: + backbone: A `keras_hub.models.T5Gemma2Backbone` instance. + preprocessor: A `keras_hub.models.T5Gemma2Seq2SeqLMPreprocessor` + or `None`. Defaults to `None`. + + Examples: + + Use `generate()` to do text generation. + ```python + t5gemma2_lm = keras_hub.models.T5Gemma2Seq2SeqLM.from_preset( + "t5gemma2_270m_270m" + ) + t5gemma2_lm.generate( + "The quick brown fox jumped.", max_length=30 + ) + ``` + + Custom backbone and vocabulary. + ```python + tokenizer = keras_hub.models.T5Gemma2Tokenizer( + proto="proto.spm", + ) + preprocessor = keras_hub.models.T5Gemma2Seq2SeqLMPreprocessor( + tokenizer=tokenizer, + encoder_sequence_length=128, + decoder_sequence_length=128, + ) + backbone = keras_hub.models.T5Gemma2Backbone( + vocabulary_size=32000, + encoder_hidden_dim=256, + encoder_intermediate_dim=512, + encoder_num_layers=4, + encoder_num_attention_heads=4, + encoder_num_key_value_heads=2, + encoder_head_dim=64, + encoder_layer_types=["full_attention"] * 4, + decoder_hidden_dim=256, + decoder_intermediate_dim=512, + decoder_num_layers=4, + decoder_num_attention_heads=4, + decoder_num_key_value_heads=2, + decoder_head_dim=64, + decoder_layer_types=["full_attention"] * 4, + dropout_rate=0.1, + rms_norm_eps=1e-6, + query_pre_attn_scalar=1.0, + attention_bias=False, + hidden_activation="gelu_approximate", + ) + t5gemma2_lm = keras_hub.models.T5Gemma2Seq2SeqLM( + backbone=backbone, + preprocessor=preprocessor, + ) + ``` + """ + + backbone_cls = T5Gemma2Backbone + preprocessor_cls = T5Gemma2Seq2SeqLMPreprocessor + + def __init__(self, backbone, preprocessor=None, **kwargs): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + # === Functional Model === + inputs = backbone.input + sequence_output = backbone(inputs)["decoder_sequence_output"] + logits = backbone.decoder_token_embedding(sequence_output, reverse=True) + if self.backbone.final_logit_softcapping is not None: + logits = logits / self.backbone.final_logit_softcapping + logits = keras.ops.tanh(logits) + logits = logits * self.backbone.final_logit_softcapping + super().__init__( + inputs=inputs, + outputs=logits, + **kwargs, + ) + + def call_encoder(self, token_ids, padding_mask): + """Process inputs through the encoder stack.""" + encoder_embeddings = self.backbone.token_embedding(token_ids) + encoder_embeddings *= keras.ops.cast( + keras.ops.sqrt(self.backbone.encoder_hidden_dim), + encoder_embeddings.dtype, + ) + encoder_hidden_states = self.backbone.encoder_dropout( + encoder_embeddings, training=False + ) + for layer in self.backbone.encoder_layers: + encoder_hidden_states = layer( + encoder_hidden_states, + padding_mask=padding_mask, + training=False, + ) + encoder_output = self.backbone.encoder_norm(encoder_hidden_states) + encoder_output = self.backbone.encoder_dropout( + encoder_output, training=False + ) + return encoder_output, padding_mask + + def call_decoder_with_cache( + self, + decoder_token_ids, + decoder_padding_mask, + cache, + cache_update_index, + encoder_output, + encoder_padding_mask, + ): + """Forward pass of the decoder with cache. + + `call_decoder_with_cache` adds an additional forward pass for + autoregressive inference. The cache stores previous key/value + tensors in the attention layers. + + Args: + decoder_token_ids: Dense int Tensor of shape + `(batch_size, max_length)`. + decoder_padding_mask: Dense int Tensor of shape + `(batch_size, max_length)`. + cache: Dense float Tensor, the cache of key/value states. + cache_update_index: int or int Tensor. + encoder_output: Dense float Tensor, encoder output. + encoder_padding_mask: Dense int Tensor. + + Returns: + Tuple of (logits, hidden_states, updated_cache). + """ + self_attention_cache, cross_attention_cache = cache + hidden_states = self.backbone.decoder_token_embedding(decoder_token_ids) + hidden_states *= keras.ops.cast( + keras.ops.sqrt(self.backbone.decoder_hidden_dim), + hidden_states.dtype, + ) + hidden_states = self.backbone.decoder_dropout( + hidden_states, training=False + ) + updated_self_attention_caches = [] + updated_cross_attention_caches = [] + for i, layer in enumerate(self.backbone.decoder_layers): + layer_self_cache = ( + self_attention_cache[:, i, ...] + if self_attention_cache is not None + else None + ) + layer_cross_cache = ( + cross_attention_cache[:, i, ...] + if cross_attention_cache is not None + else None + ) + layer_cache = (layer_self_cache, layer_cross_cache) + hidden_states, updated_layer_cache = layer( + (hidden_states, encoder_output), + self_attention_padding_mask=decoder_padding_mask, + cross_attention_padding_mask=encoder_padding_mask, + cache=layer_cache, + cache_update_index=cache_update_index, + training=False, + ) + new_self_cache, new_cross_cache = updated_layer_cache + updated_self_attention_caches.append(new_self_cache) + updated_cross_attention_caches.append(new_cross_cache) + self_attention_cache = keras.ops.stack( + updated_self_attention_caches, axis=1 + ) + cross_attention_cache = keras.ops.stack( + updated_cross_attention_caches, axis=1 + ) + hidden_states = self.backbone.decoder_norm(hidden_states) + logits = self.backbone.decoder_token_embedding( + hidden_states, reverse=True + ) + if self.backbone.final_logit_softcapping is not None: + logits = logits / self.backbone.final_logit_softcapping + logits = keras.ops.tanh(logits) + logits = logits * self.backbone.final_logit_softcapping + return ( + logits, + hidden_states, + (self_attention_cache, cross_attention_cache), + ) + + def _build_cache( + self, + encoder_token_ids, + encoder_padding_mask, + decoder_token_ids, + decoder_padding_mask, + ): + """Build an empty cache for use with `call_with_cache()`.""" + encoder_output, encoder_padding_mask = self.call_encoder( + encoder_token_ids, encoder_padding_mask + ) + batch_size = keras.ops.shape(decoder_token_ids)[0] + num_layers = self.backbone.decoder_num_layers + num_kv_heads = self.backbone.decoder_num_key_value_heads + head_dim = self.backbone.decoder_head_dim + self_cache_shape = ( + batch_size, + num_layers, + 2, + keras.ops.shape(decoder_token_ids)[1], + num_kv_heads, + head_dim, + ) + self_attention_cache = keras.ops.zeros( + self_cache_shape, dtype=self.compute_dtype + ) + cross_attention_cache = None + _, hidden_states, cache = self.call_decoder_with_cache( + decoder_token_ids=decoder_token_ids, + decoder_padding_mask=decoder_padding_mask, + cache=(self_attention_cache, cross_attention_cache), + cache_update_index=0, + encoder_output=encoder_output, + encoder_padding_mask=encoder_padding_mask, + ) + extra_cache_info = (encoder_output, encoder_padding_mask) + return hidden_states, cache, extra_cache_info + + def generate_step(self, inputs, stop_token_ids=None): + """A compilable generation function for a single batch. + + Args: + inputs: A dictionary with keys `"encoder_token_ids"`, + `"encoder_padding_mask"`, `"decoder_token_ids"`, and + `"decoder_padding_mask"`. + stop_token_ids: Tuple of end token ids to stop on. + """ + encoder_token_ids = inputs["encoder_token_ids"] + encoder_padding_mask = inputs["encoder_padding_mask"] + decoder_token_ids = inputs["decoder_token_ids"] + decoder_padding_mask = inputs["decoder_padding_mask"] + hidden_states, cache, extra_cache_info = self._build_cache( + encoder_token_ids=encoder_token_ids, + encoder_padding_mask=encoder_padding_mask, + decoder_token_ids=decoder_token_ids, + decoder_padding_mask=decoder_padding_mask, + ) + encoder_output, encoder_padding_mask = extra_cache_info + row_lengths = keras.ops.sum( + keras.ops.cast(decoder_padding_mask, "int32"), axis=-1 + ) + index = keras.ops.min(row_lengths) + + def next(prompt, cache, index): + cache_update_index = index - 1 + batch_size = keras.ops.shape(prompt)[0] + prompt = keras.ops.slice( + prompt, [0, cache_update_index], [batch_size, 1] + ) + ( + logits, + _, + updated_cache, + ) = self.call_decoder_with_cache( + decoder_token_ids=prompt, + decoder_padding_mask=None, + cache_update_index=cache_update_index, + cache=cache, + encoder_output=encoder_output, + encoder_padding_mask=encoder_padding_mask, + ) + return ( + keras.ops.squeeze(logits, axis=1), + None, + updated_cache, + ) + + decoder_token_ids = self.sampler( + next=next, + prompt=decoder_token_ids, + cache=cache, + index=index, + mask=decoder_padding_mask, + stop_token_ids=stop_token_ids, + hidden_states=hidden_states, + model=self, + ) + + if stop_token_ids is not None: + end_locations = any_equal( + decoder_token_ids, + stop_token_ids, + keras.ops.logical_not(decoder_padding_mask), + ) + end_locations = keras.ops.cast(end_locations, "int32") + cumsum = keras.ops.cast( + keras.ops.cumsum(end_locations, axis=-1), "int32" + ) + overflow = cumsum - end_locations + decoder_padding_mask = keras.ops.logical_not( + keras.ops.cast(overflow, "bool") + ) + else: + decoder_padding_mask = keras.ops.ones_like( + decoder_token_ids, dtype="bool" + ) + + return { + "decoder_token_ids": decoder_token_ids, + "decoder_padding_mask": decoder_padding_mask, + } diff --git a/keras_hub/src/models/t5gemma2/t5gemma2_seq_2_seq_lm_preprocessor.py b/keras_hub/src/models/t5gemma2/t5gemma2_seq_2_seq_lm_preprocessor.py new file mode 100644 index 0000000000..298da62b88 --- /dev/null +++ b/keras_hub/src/models/t5gemma2/t5gemma2_seq_2_seq_lm_preprocessor.py @@ -0,0 +1,158 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor +from keras_hub.src.models.t5gemma2.t5gemma2_backbone import T5Gemma2Backbone +from keras_hub.src.models.t5gemma2.t5gemma2_tokenizer import T5Gemma2Tokenizer +from keras_hub.src.utils.tensor_utils import preprocessing_function + +try: + import tensorflow as tf +except ImportError: + tf = None + + +@keras_hub_export("keras_hub.models.T5Gemma2Seq2SeqLMPreprocessor") +class T5Gemma2Seq2SeqLMPreprocessor(Seq2SeqLMPreprocessor): + """T5Gemma2 Seq2Seq LM preprocessor. + + This preprocessing layer is meant for use with + `keras_hub.models.T5Gemma2Seq2SeqLM`. By default, it will take in + batches of strings, and return outputs in a + `(x, y, sample_weight)` format, where the `y` label is the next + token id in the `x` sequence. + + For use with generation, the layer also exposes two methods + `generate_preprocess()` and `generate_postprocess()`. When this + preprocessor is attached to a `keras_hub.models.T5Gemma2Seq2SeqLM` + instance, these methods will be called implicitly in `generate()`. + + Args: + tokenizer: A `keras_hub.models.T5Gemma2Tokenizer` instance. + encoder_sequence_length: The length of the packed encoder inputs. + decoder_sequence_length: The length of the packed decoder inputs. + add_start_token: If `True`, prepend the start token. Defaults + to `False`. + add_end_token: If `True`, append the end token. Defaults to + `True`. + """ + + backbone_cls = T5Gemma2Backbone + tokenizer_cls = T5Gemma2Tokenizer + + def __init__( + self, + tokenizer, + encoder_sequence_length=512, + decoder_sequence_length=512, + add_start_token=False, + add_end_token=True, + **kwargs, + ): + super().__init__( + tokenizer=tokenizer, + encoder_sequence_length=encoder_sequence_length, + decoder_sequence_length=decoder_sequence_length, + **kwargs, + ) + self.add_start_token = add_start_token + self.add_end_token = add_end_token + + @preprocessing_function + def call( + self, + x, + y=None, + sample_weight=None, + *, + encoder_sequence_length=None, + decoder_sequence_length=None, + sequence_length=None, + ): + if encoder_sequence_length is None: + encoder_sequence_length = self.encoder_sequence_length + decoder_sequence_length = decoder_sequence_length or sequence_length + if decoder_sequence_length is None: + decoder_sequence_length = self.decoder_sequence_length + + encoder_inputs = self.tokenizer(x["encoder_text"]) + encoder_token_ids, encoder_padding_mask = self.encoder_packer( + encoder_inputs, + sequence_length=encoder_sequence_length, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + decoder_inputs = self.tokenizer(x["decoder_text"]) + decoder_token_ids, decoder_padding_mask = self.decoder_packer( + decoder_inputs, + sequence_length=decoder_sequence_length + 1, + add_start_value=True, + add_end_value=self.add_end_token, + ) + x = { + "encoder_token_ids": encoder_token_ids, + "encoder_padding_mask": encoder_padding_mask, + "decoder_token_ids": decoder_token_ids[..., :-1], + "decoder_padding_mask": decoder_padding_mask[..., :-1], + } + y = decoder_token_ids[..., 1:] + sample_weight = decoder_padding_mask[..., 1:] + return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) + + @preprocessing_function + def generate_preprocess( + self, + x, + *, + encoder_sequence_length=None, + decoder_sequence_length=None, + sequence_length=None, + ): + if not self.built: + self.build(None) + + if isinstance(x, dict): + encoder_text = x["encoder_text"] + decoder_text = x["decoder_text"] + else: + encoder_text = x + decoder_text = tf.fill((tf.shape(encoder_text)[0],), "") + + if encoder_sequence_length is None: + encoder_sequence_length = self.encoder_sequence_length + decoder_sequence_length = decoder_sequence_length or sequence_length + if decoder_sequence_length is None: + decoder_sequence_length = self.decoder_sequence_length + + encoder_token_ids = self.tokenizer(encoder_text) + encoder_token_ids, encoder_padding_mask = self.encoder_packer( + encoder_token_ids, + sequence_length=None, + add_start_value=self.add_start_token, + add_end_value=False, + ) + + decoder_token_ids = self.tokenizer(decoder_text) + decoder_token_ids, decoder_padding_mask = self.decoder_packer( + decoder_token_ids, + sequence_length=decoder_sequence_length, + add_start_value=True, + add_end_value=False, + ) + + return { + "encoder_token_ids": encoder_token_ids, + "encoder_padding_mask": encoder_padding_mask, + "decoder_token_ids": decoder_token_ids, + "decoder_padding_mask": decoder_padding_mask, + } + + def get_config(self): + config = super().get_config() + config.update( + { + "add_start_token": self.add_start_token, + "add_end_token": self.add_end_token, + } + ) + return config diff --git a/keras_hub/src/models/t5gemma2/t5gemma2_seq_2_seq_lm_test.py b/keras_hub/src/models/t5gemma2/t5gemma2_seq_2_seq_lm_test.py new file mode 100644 index 0000000000..7bf5ab6e28 --- /dev/null +++ b/keras_hub/src/models/t5gemma2/t5gemma2_seq_2_seq_lm_test.py @@ -0,0 +1,98 @@ +import keras +import pytest + +from keras_hub.src.models.t5gemma2.t5gemma2_backbone import T5Gemma2Backbone +from keras_hub.src.models.t5gemma2.t5gemma2_seq_2_seq_lm import ( + T5Gemma2Seq2SeqLM, +) +from keras_hub.src.tests.test_case import TestCase + + +class T5Gemma2Seq2SeqLMTest(TestCase): + def setUp(self): + self.backbone_kwargs = { + "vocabulary_size": 100, + "encoder_hidden_dim": 32, + "encoder_intermediate_dim": 64, + "encoder_num_layers": 2, + "encoder_num_attention_heads": 4, + "encoder_num_key_value_heads": 2, + "encoder_head_dim": 8, + "encoder_layer_types": [ + "full_attention", + "full_attention", + ], + "decoder_hidden_dim": 32, + "decoder_intermediate_dim": 64, + "decoder_num_layers": 2, + "decoder_num_attention_heads": 4, + "decoder_num_key_value_heads": 2, + "decoder_head_dim": 8, + "decoder_layer_types": [ + "full_attention", + "full_attention", + ], + "dropout_rate": 0.0, + "rms_norm_eps": 1e-6, + "tie_word_embeddings": True, + "query_pre_attn_scalar": 1.0, + "attention_bias": False, + "hidden_activation": "gelu_approximate", + "cross_attention_hidden_size": 32, + "rope_max_wavelength": 10000.0, + "initializer_range": 0.04, + "use_query_key_norm": True, + } + self.input_data = { + "encoder_token_ids": keras.ops.ones((2, 16), dtype="int32"), + "encoder_padding_mask": keras.ops.ones((2, 16), dtype="int32"), + "decoder_token_ids": keras.ops.ones((2, 16), dtype="int32"), + "decoder_padding_mask": keras.ops.ones((2, 16), dtype="int32"), + } + + def test_seq2seq_lm_basics(self): + backbone = T5Gemma2Backbone(**self.backbone_kwargs) + lm = T5Gemma2Seq2SeqLM(backbone=backbone) + output = lm(self.input_data) + self.assertEqual(keras.ops.shape(output), (2, 16, 100)) + + def test_call_encoder(self): + backbone = T5Gemma2Backbone(**self.backbone_kwargs) + lm = T5Gemma2Seq2SeqLM(backbone=backbone) + encoder_output, padding_mask = lm.call_encoder( + self.input_data["encoder_token_ids"], + self.input_data["encoder_padding_mask"], + ) + self.assertEqual(keras.ops.shape(encoder_output), (2, 16, 32)) + + def test_build_cache(self): + backbone = T5Gemma2Backbone(**self.backbone_kwargs) + lm = T5Gemma2Seq2SeqLM(backbone=backbone) + hidden_states, cache, extra = lm._build_cache( + encoder_token_ids=self.input_data["encoder_token_ids"], + encoder_padding_mask=self.input_data["encoder_padding_mask"], + decoder_token_ids=self.input_data["decoder_token_ids"], + decoder_padding_mask=self.input_data["decoder_padding_mask"], + ) + self.assertEqual(keras.ops.shape(hidden_states), (2, 16, 32)) + self_attention_cache, cross_attention_cache = cache + # self_attention_cache: (batch, num_layers, 2, + # seq, kv_heads, head_dim) + self.assertEqual( + keras.ops.shape(self_attention_cache), (2, 2, 2, 16, 2, 8) + ) + # cross_attention_cache: (batch, num_layers, 2, + # enc_seq, kv_heads, head_dim) + self.assertEqual( + keras.ops.shape(cross_attention_cache), + (2, 2, 2, 16, 2, 8), + ) + + @pytest.mark.large + def test_saved_model(self): + backbone = T5Gemma2Backbone(**self.backbone_kwargs) + self.run_model_saving_test( + cls=T5Gemma2Seq2SeqLM, + init_kwargs={"backbone": backbone}, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/t5gemma2/t5gemma2_tokenizer.py b/keras_hub/src/models/t5gemma2/t5gemma2_tokenizer.py new file mode 100644 index 0000000000..824a7c6bcc --- /dev/null +++ b/keras_hub/src/models/t5gemma2/t5gemma2_tokenizer.py @@ -0,0 +1,55 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.t5gemma2.t5gemma2_backbone import T5Gemma2Backbone +from keras_hub.src.tokenizers.sentence_piece_tokenizer import ( + SentencePieceTokenizer, +) + + +@keras_hub_export( + [ + "keras_hub.tokenizers.T5Gemma2Tokenizer", + "keras_hub.models.T5Gemma2Tokenizer", + ] +) +class T5Gemma2Tokenizer(SentencePieceTokenizer): + """T5Gemma2 tokenizer layer based on SentencePiece. + + This tokenizer class will tokenize raw strings into integer sequences + and is based on `keras_hub.tokenizers.SentencePieceTokenizer`. Unlike + the underlying tokenizer, it will check for all special tokens needed + by T5Gemma2 models and provides a `from_preset()` method to + automatically download a matching vocabulary for a T5Gemma2 preset. + + If input is a batch of strings (rank > 0), the layer will output a + `tf.RaggedTensor` where the last dimension of the output is ragged. + + If input is a scalar string (rank == 0), the layer will output a + dense `tf.Tensor` with static shape `[None]`. + + Args: + proto: Either a `string` path to a SentencePiece proto file, + or a `bytes` object with a serialized SentencePiece proto. + + Examples: + + ```python + tokenizer = keras_hub.models.T5Gemma2Tokenizer.from_preset( + "t5gemma2_270m_270m" + ) + tokenizer("The quick brown fox jumped.") + + # Batched input. + tokenizer(["The quick brown fox jumped.", "The fox slept."]) + + # Detokenization. + tokenizer.detokenize(tokenizer("The quick brown fox jumped.")) + ``` + """ + + backbone_cls = T5Gemma2Backbone + + def __init__(self, proto, **kwargs): + self._add_special_token("", "start_token") + self._add_special_token("", "end_token") + self._add_special_token("", "pad_token") + super().__init__(proto=proto, **kwargs) From 8a284d46aab6c7d7ed0800ab3f02a61275f4680f Mon Sep 17 00:00:00 2001 From: Laxma Reddy Patlolla Date: Mon, 2 Mar 2026 15:42:51 -0800 Subject: [PATCH 2/9] add t5gemma2 converter --- .../utils/transformers/convert_t5gemma2.py | 220 ++++++++++++++++++ 1 file changed, 220 insertions(+) create mode 100644 keras_hub/src/utils/transformers/convert_t5gemma2.py diff --git a/keras_hub/src/utils/transformers/convert_t5gemma2.py b/keras_hub/src/utils/transformers/convert_t5gemma2.py new file mode 100644 index 0000000000..7634ba9ce3 --- /dev/null +++ b/keras_hub/src/utils/transformers/convert_t5gemma2.py @@ -0,0 +1,220 @@ +from keras_hub.src.models.t5gemma2.t5gemma2_backbone import T5Gemma2Backbone +from keras_hub.src.utils.preset_utils import get_file + +backbone_cls = T5Gemma2Backbone + + +def convert_backbone_config(transformers_config): + """Convert a HuggingFace T5Gemma2 config to KerasHub backbone config.""" + encoder_config = transformers_config["encoder"] + decoder_config = transformers_config["decoder"] + + if decoder_config.get("hidden_activation") == "gelu_pytorch_tanh": + decoder_config["hidden_activation"] = "gelu_approximate" + if encoder_config.get("hidden_activation") == "gelu_pytorch_tanh": + encoder_config["hidden_activation"] = "gelu_approximate" + + backbone_config = { + "vocabulary_size": decoder_config["vocab_size"], + "encoder_hidden_dim": encoder_config["hidden_size"], + "encoder_intermediate_dim": encoder_config["intermediate_size"], + "encoder_num_layers": encoder_config["num_hidden_layers"], + "encoder_num_attention_heads": encoder_config["num_attention_heads"], + "encoder_num_key_value_heads": encoder_config["num_key_value_heads"], + "encoder_head_dim": encoder_config["head_dim"], + "encoder_layer_types": encoder_config["layer_types"], + "decoder_hidden_dim": decoder_config["hidden_size"], + "decoder_intermediate_dim": decoder_config["intermediate_size"], + "decoder_num_layers": decoder_config["num_hidden_layers"], + "decoder_num_attention_heads": decoder_config["num_attention_heads"], + "decoder_num_key_value_heads": decoder_config["num_key_value_heads"], + "decoder_head_dim": decoder_config["head_dim"], + "decoder_layer_types": decoder_config["layer_types"], + "dropout_rate": decoder_config["dropout_rate"], + "rms_norm_eps": decoder_config["rms_norm_eps"], + "query_pre_attn_scalar": decoder_config["query_pre_attn_scalar"], + "tie_word_embeddings": transformers_config.get( + "tie_word_embeddings", True + ), + "attention_bias": decoder_config["attention_bias"], + "hidden_activation": decoder_config["hidden_activation"], + "initializer_range": decoder_config["initializer_range"], + "attention_dropout": decoder_config["attention_dropout"], + "sliding_window": decoder_config["sliding_window"], + "cross_attention_hidden_size": encoder_config["hidden_size"], + "attn_logit_softcapping": decoder_config["attn_logit_softcapping"], + "final_logit_softcapping": decoder_config["final_logit_softcapping"], + "rope_max_wavelength": decoder_config["rope_theta"], + "use_query_key_norm": True, + } + return backbone_config + + +def convert_weights(backbone, loader, transformers_config): + """Convert T5Gemma2 weights from HuggingFace to KerasHub.""" + # Token embeddings. + loader.port_weight( + keras_variable=backbone.token_embedding.embeddings, + hf_weight_key="encoder.embed_tokens.weight", + ) + loader.port_weight( + keras_variable=backbone.decoder_token_embedding.embeddings, + hf_weight_key="decoder.embed_tokens.weight", + ) + + # Encoder. + loader.port_weight( + keras_variable=backbone.encoder_norm.scale, + hf_weight_key="encoder.norm.weight", + ) + for i in range(backbone.encoder_num_layers): + layer = backbone.get_layer(f"encoder_layer_{i}") + hf_prefix = f"encoder.layers.{i}" + + # Self-attention Q/K/V/O projections. + loader.port_weight( + keras_variable=layer.self_attn.query_dense.kernel, + hf_weight_key=f"{hf_prefix}.self_attn.q_proj.weight", + hook_fn=lambda w, s: w.T.reshape(s), + ) + loader.port_weight( + keras_variable=layer.self_attn.key_dense.kernel, + hf_weight_key=f"{hf_prefix}.self_attn.k_proj.weight", + hook_fn=lambda w, s: w.T.reshape(s), + ) + loader.port_weight( + keras_variable=layer.self_attn.value_dense.kernel, + hf_weight_key=f"{hf_prefix}.self_attn.v_proj.weight", + hook_fn=lambda w, s: w.T.reshape(s), + ) + loader.port_weight( + keras_variable=layer.self_attn.output_dense.kernel, + hf_weight_key=f"{hf_prefix}.self_attn.o_proj.weight", + hook_fn=lambda w, s: w.T.reshape(s), + ) + + # Q/K normalization (Gemma3-style). + loader.port_weight( + keras_variable=layer.self_attn.query_norm.scale, + hf_weight_key=f"{hf_prefix}.self_attn.q_norm.weight", + ) + loader.port_weight( + keras_variable=layer.self_attn.key_norm.scale, + hf_weight_key=f"{hf_prefix}.self_attn.k_norm.weight", + ) + + # MLP. + loader.port_weight( + keras_variable=layer.mlp.gate_proj.kernel, + hf_weight_key=f"{hf_prefix}.mlp.gate_proj.weight", + hook_fn=lambda w, s: w.T, + ) + loader.port_weight( + keras_variable=layer.mlp.up_proj.kernel, + hf_weight_key=f"{hf_prefix}.mlp.up_proj.weight", + hook_fn=lambda w, s: w.T, + ) + loader.port_weight( + keras_variable=layer.mlp.down_proj.kernel, + hf_weight_key=f"{hf_prefix}.mlp.down_proj.weight", + hook_fn=lambda w, s: w.T, + ) + + # Layer norms. + loader.port_weight( + keras_variable=layer.pre_self_attn_layernorm.scale, + hf_weight_key=(f"{hf_prefix}.pre_self_attn_layernorm.weight"), + ) + loader.port_weight( + keras_variable=layer.post_self_attn_layernorm.scale, + hf_weight_key=(f"{hf_prefix}.post_self_attn_layernorm.weight"), + ) + loader.port_weight( + keras_variable=layer.pre_feedforward_layernorm.scale, + hf_weight_key=(f"{hf_prefix}.pre_feedforward_layernorm.weight"), + ) + loader.port_weight( + keras_variable=layer.post_feedforward_layernorm.scale, + hf_weight_key=(f"{hf_prefix}.post_feedforward_layernorm.weight"), + ) + + # Decoder. + loader.port_weight( + keras_variable=backbone.decoder_norm.scale, + hf_weight_key="decoder.norm.weight", + ) + for i in range(backbone.decoder_num_layers): + layer = backbone.get_layer(f"decoder_layer_{i}") + hf_prefix = f"decoder.layers.{i}" + + # Merged attention (self+cross uses a single self_attn layer). + loader.port_weight( + keras_variable=layer.merged_attn.query_dense.kernel, + hf_weight_key=f"{hf_prefix}.self_attn.q_proj.weight", + hook_fn=lambda w, s: w.T.reshape(s), + ) + loader.port_weight( + keras_variable=layer.merged_attn.key_dense.kernel, + hf_weight_key=f"{hf_prefix}.self_attn.k_proj.weight", + hook_fn=lambda w, s: w.T.reshape(s), + ) + loader.port_weight( + keras_variable=layer.merged_attn.value_dense.kernel, + hf_weight_key=f"{hf_prefix}.self_attn.v_proj.weight", + hook_fn=lambda w, s: w.T.reshape(s), + ) + loader.port_weight( + keras_variable=layer.merged_attn.output_dense.kernel, + hf_weight_key=f"{hf_prefix}.self_attn.o_proj.weight", + hook_fn=lambda w, s: w.T.reshape(s), + ) + + # Q/K normalization. + loader.port_weight( + keras_variable=layer.merged_attn.query_norm.scale, + hf_weight_key=f"{hf_prefix}.self_attn.q_norm.weight", + ) + loader.port_weight( + keras_variable=layer.merged_attn.key_norm.scale, + hf_weight_key=f"{hf_prefix}.self_attn.k_norm.weight", + ) + + # MLP. + loader.port_weight( + keras_variable=layer.mlp.gate_proj.kernel, + hf_weight_key=f"{hf_prefix}.mlp.gate_proj.weight", + hook_fn=lambda w, s: w.T, + ) + loader.port_weight( + keras_variable=layer.mlp.up_proj.kernel, + hf_weight_key=f"{hf_prefix}.mlp.up_proj.weight", + hook_fn=lambda w, s: w.T, + ) + loader.port_weight( + keras_variable=layer.mlp.down_proj.kernel, + hf_weight_key=f"{hf_prefix}.mlp.down_proj.weight", + hook_fn=lambda w, s: w.T, + ) + + # Layer norms (no cross-attn norms — merged into self_attn). + loader.port_weight( + keras_variable=layer.pre_self_attn_layernorm.scale, + hf_weight_key=(f"{hf_prefix}.pre_self_attn_layernorm.weight"), + ) + loader.port_weight( + keras_variable=layer.post_self_attn_layernorm.scale, + hf_weight_key=(f"{hf_prefix}.post_self_attn_layernorm.weight"), + ) + loader.port_weight( + keras_variable=layer.pre_feedforward_layernorm.scale, + hf_weight_key=(f"{hf_prefix}.pre_feedforward_layernorm.weight"), + ) + loader.port_weight( + keras_variable=layer.post_feedforward_layernorm.scale, + hf_weight_key=(f"{hf_prefix}.post_feedforward_layernorm.weight"), + ) + + +def convert_tokenizer(cls, preset, **kwargs): + """Convert a T5Gemma2 tokenizer.""" + return cls(get_file(preset, "tokenizer.model"), **kwargs) From 0118760d3fde57fe3292eb6ab60636fb305f63c1 Mon Sep 17 00:00:00 2001 From: Laxma Reddy Patlolla Date: Mon, 2 Mar 2026 15:53:15 -0800 Subject: [PATCH 3/9] Add checkpoint conversion script --- .../src/utils/transformers/preset_loader.py | 3 + .../convert_t5gemma2_checkpoints.py | 409 ++++++++++++++++++ 2 files changed, 412 insertions(+) create mode 100644 tools/checkpoint_conversion/convert_t5gemma2_checkpoints.py diff --git a/keras_hub/src/utils/transformers/preset_loader.py b/keras_hub/src/utils/transformers/preset_loader.py index 75bd8ab6d4..e7245110cd 100644 --- a/keras_hub/src/utils/transformers/preset_loader.py +++ b/keras_hub/src/utils/transformers/preset_loader.py @@ -29,6 +29,7 @@ from keras_hub.src.utils.transformers import convert_sam3 from keras_hub.src.utils.transformers import convert_smollm3 from keras_hub.src.utils.transformers import convert_t5gemma +from keras_hub.src.utils.transformers import convert_t5gemma2 from keras_hub.src.utils.transformers import convert_vit from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader @@ -88,6 +89,8 @@ def __init__(self, preset, config): self.converter = convert_smollm3 elif model_type == "t5gemma": self.converter = convert_t5gemma + elif model_type == "t5gemma2": + self.converter = convert_t5gemma2 else: raise ValueError( "KerasHub has no converter for huggingface/transformers models " diff --git a/tools/checkpoint_conversion/convert_t5gemma2_checkpoints.py b/tools/checkpoint_conversion/convert_t5gemma2_checkpoints.py new file mode 100644 index 0000000000..02c3d5cf18 --- /dev/null +++ b/tools/checkpoint_conversion/convert_t5gemma2_checkpoints.py @@ -0,0 +1,409 @@ +import gc +import os +import random +import shutil + +import huggingface_hub +import keras +import numpy as np +import tensorflow as tf +import torch +import transformers +from absl import app +from absl import flags +from checkpoint_conversion_utils import get_md5_checksum + +import keras_hub +from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM +from keras_hub.src.models.t5gemma2.t5gemma2_seq_2_seq_lm import ( + T5Gemma2Seq2SeqLM, +) +from keras_hub.src.models.t5gemma2.t5gemma2_seq_2_seq_lm_preprocessor import ( + T5Gemma2Seq2SeqLMPreprocessor, +) + +random.seed(123) +torch.manual_seed(123) +device = torch.device("cpu") +torch.set_default_device(device) + + +PRESET_MAP = { + "t5gemma2_270m_270m": "google/t5gemma-2-270m-270m", + "t5gemma2_1b_1b": "google/t5gemma-2-1b-1b", + "t5gemma2_4b_4b": "google/t5gemma-2-4b-4b", +} + +FLAGS = flags.FLAGS +flags.DEFINE_string( + "preset", + None, + f"Must be one of {','.join(PRESET_MAP.keys())}", +) + + +def convert_checkpoints(hf_model): + """Convert HuggingFace T5Gemma2 weights to KerasHub format.""" + print("\n-> Convert original weights to KerasHub format.") + print("\n-> Load KerasHub model.") + + encoder_config = hf_model.config.encoder + decoder_config = hf_model.config.decoder + if decoder_config.hidden_activation == "gelu_pytorch_tanh": + decoder_config.hidden_activation = "gelu_approximate" + if encoder_config.hidden_activation == "gelu_pytorch_tanh": + encoder_config.hidden_activation = "gelu_approximate" + + keras.config.set_floatx("float32") + keras_hub_model = keras_hub.models.T5Gemma2Backbone( + vocabulary_size=decoder_config.vocab_size, + encoder_hidden_dim=encoder_config.hidden_size, + encoder_intermediate_dim=encoder_config.intermediate_size, + encoder_num_layers=encoder_config.num_hidden_layers, + encoder_num_attention_heads=encoder_config.num_attention_heads, + encoder_num_key_value_heads=(encoder_config.num_key_value_heads), + encoder_head_dim=encoder_config.head_dim, + encoder_layer_types=encoder_config.layer_types, + decoder_hidden_dim=decoder_config.hidden_size, + decoder_intermediate_dim=decoder_config.intermediate_size, + decoder_num_layers=decoder_config.num_hidden_layers, + decoder_num_attention_heads=decoder_config.num_attention_heads, + decoder_num_key_value_heads=(decoder_config.num_key_value_heads), + decoder_head_dim=decoder_config.head_dim, + decoder_layer_types=decoder_config.layer_types, + dropout_rate=decoder_config.dropout_rate, + rms_norm_eps=decoder_config.rms_norm_eps, + query_pre_attn_scalar=decoder_config.query_pre_attn_scalar, + tie_word_embeddings=getattr( + hf_model.config, "tie_word_embeddings", True + ), + attention_bias=decoder_config.attention_bias, + hidden_activation=decoder_config.hidden_activation, + initializer_range=decoder_config.initializer_range, + attention_dropout=decoder_config.attention_dropout, + sliding_window=decoder_config.sliding_window, + cross_attention_hidden_size=encoder_config.hidden_size, + attn_logit_softcapping=decoder_config.attn_logit_softcapping, + final_logit_softcapping=decoder_config.final_logit_softcapping, + rope_max_wavelength=decoder_config.rope_theta, + use_query_key_norm=True, + dtype="float32", + ) + + hf_wts = hf_model.state_dict() + + # Token embeddings. + keras_hub_model.get_layer("encoder_token_embedding").embeddings.assign( + hf_wts["encoder.embed_tokens.weight"] + ) + keras_hub_model.get_layer("decoder_token_embedding").embeddings.assign( + hf_wts["decoder.embed_tokens.weight"] + ) + + # Encoder. + enc_hdim = keras_hub_model.encoder_hidden_dim + enc_heads = keras_hub_model.encoder_num_attention_heads + enc_kv_heads = keras_hub_model.encoder_num_key_value_heads + enc_head_dim = keras_hub_model.encoder_head_dim + keras_hub_model.encoder_norm.scale.assign(hf_wts["encoder.norm.weight"]) + + for i in range(keras_hub_model.encoder_num_layers): + layer = keras_hub_model.get_layer(f"encoder_layer_{i}") + pfx = f"encoder.layers.{i}" + + # Self-attention Q/K/V/O. + layer.self_attn.query_dense.kernel.assign( + hf_wts[f"{pfx}.self_attn.q_proj.weight"] + .T.reshape(enc_hdim, enc_heads, enc_head_dim) + .numpy() + ) + layer.self_attn.key_dense.kernel.assign( + hf_wts[f"{pfx}.self_attn.k_proj.weight"] + .T.reshape(enc_hdim, enc_kv_heads, enc_head_dim) + .numpy() + ) + layer.self_attn.value_dense.kernel.assign( + hf_wts[f"{pfx}.self_attn.v_proj.weight"] + .T.reshape(enc_hdim, enc_kv_heads, enc_head_dim) + .numpy() + ) + layer.self_attn.output_dense.kernel.assign( + hf_wts[f"{pfx}.self_attn.o_proj.weight"] + .T.reshape(enc_heads, enc_head_dim, enc_hdim) + .numpy() + ) + + # Q/K normalization. + layer.self_attn.query_norm.scale.assign( + hf_wts[f"{pfx}.self_attn.q_norm.weight"] + ) + layer.self_attn.key_norm.scale.assign( + hf_wts[f"{pfx}.self_attn.k_norm.weight"] + ) + + # MLP. + layer.mlp.gate_proj.kernel.assign( + hf_wts[f"{pfx}.mlp.gate_proj.weight"].T.numpy() + ) + layer.mlp.up_proj.kernel.assign( + hf_wts[f"{pfx}.mlp.up_proj.weight"].T.numpy() + ) + layer.mlp.down_proj.kernel.assign( + hf_wts[f"{pfx}.mlp.down_proj.weight"].T.numpy() + ) + + # Layer norms. + layer.pre_self_attn_layernorm.scale.assign( + hf_wts[f"{pfx}.pre_self_attn_layernorm.weight"] + ) + layer.post_self_attn_layernorm.scale.assign( + hf_wts[f"{pfx}.post_self_attn_layernorm.weight"] + ) + layer.pre_feedforward_layernorm.scale.assign( + hf_wts[f"{pfx}.pre_feedforward_layernorm.weight"] + ) + layer.post_feedforward_layernorm.scale.assign( + hf_wts[f"{pfx}.post_feedforward_layernorm.weight"] + ) + + # Decoder. + dec_hdim = keras_hub_model.decoder_hidden_dim + dec_heads = keras_hub_model.decoder_num_attention_heads + dec_kv_heads = keras_hub_model.decoder_num_key_value_heads + dec_head_dim = keras_hub_model.decoder_head_dim + keras_hub_model.decoder_norm.scale.assign(hf_wts["decoder.norm.weight"]) + + for i in range(keras_hub_model.decoder_num_layers): + layer = keras_hub_model.get_layer(f"decoder_layer_{i}") + pfx = f"decoder.layers.{i}" + + # Merged attention (self+cross uses single self_attn in HF). + layer.merged_attn.query_dense.kernel.assign( + hf_wts[f"{pfx}.self_attn.q_proj.weight"] + .T.reshape(dec_hdim, dec_heads, dec_head_dim) + .numpy() + ) + layer.merged_attn.key_dense.kernel.assign( + hf_wts[f"{pfx}.self_attn.k_proj.weight"] + .T.reshape(dec_hdim, dec_kv_heads, dec_head_dim) + .numpy() + ) + layer.merged_attn.value_dense.kernel.assign( + hf_wts[f"{pfx}.self_attn.v_proj.weight"] + .T.reshape(dec_hdim, dec_kv_heads, dec_head_dim) + .numpy() + ) + layer.merged_attn.output_dense.kernel.assign( + hf_wts[f"{pfx}.self_attn.o_proj.weight"] + .T.reshape(dec_heads, dec_head_dim, dec_hdim) + .numpy() + ) + + # Q/K normalization. + layer.merged_attn.query_norm.scale.assign( + hf_wts[f"{pfx}.self_attn.q_norm.weight"] + ) + layer.merged_attn.key_norm.scale.assign( + hf_wts[f"{pfx}.self_attn.k_norm.weight"] + ) + + # MLP. + layer.mlp.gate_proj.kernel.assign( + hf_wts[f"{pfx}.mlp.gate_proj.weight"].T.numpy() + ) + layer.mlp.up_proj.kernel.assign( + hf_wts[f"{pfx}.mlp.up_proj.weight"].T.numpy() + ) + layer.mlp.down_proj.kernel.assign( + hf_wts[f"{pfx}.mlp.down_proj.weight"].T.numpy() + ) + + # Layer norms (no cross-attn norms — merged into self_attn). + layer.pre_self_attn_layernorm.scale.assign( + hf_wts[f"{pfx}.pre_self_attn_layernorm.weight"] + ) + layer.post_self_attn_layernorm.scale.assign( + hf_wts[f"{pfx}.post_self_attn_layernorm.weight"] + ) + layer.pre_feedforward_layernorm.scale.assign( + hf_wts[f"{pfx}.pre_feedforward_layernorm.weight"] + ) + layer.post_feedforward_layernorm.scale.assign( + hf_wts[f"{pfx}.post_feedforward_layernorm.weight"] + ) + + return keras_hub_model + + +def extract_vocab(hf_model_dir): + """Extract vocabulary from the downloaded HF model directory.""" + source_path = os.path.join(hf_model_dir, "tokenizer.model") + vocabulary_path = os.path.join(FLAGS.preset, "tokenizer.model") + print(f"\n-> Save KerasHub vocab to `{vocabulary_path}`.") + + shutil.copyfile(source_path, vocabulary_path) + + keras_hub_tokenizer = keras_hub.models.T5Gemma2Tokenizer( + proto=vocabulary_path + ) + + print("-> Print MD5 checksum of the vocab file.") + print( + f"`{vocabulary_path}` md5sum: ", + get_md5_checksum(vocabulary_path), + ) + + return keras_hub_tokenizer + + +def check_output( + keras_hub_tokenizer, + keras_hub_model, + hf_tokenizer, + hf_model, +): + """Check outputs of KerasHub and HuggingFace models match.""" + print("\n-> Check the outputs.") + enc_sample_text = [ + "cricket is awesome, easily the best sport in the world!" + ] + dec_sample_text = [ + "football is good too, but nowhere near as good as cricket." + ] + + # KerasHub. + keras_hub_enc_token_ids = hf_tokenizer( + enc_sample_text, return_tensors="tf" + )["input_ids"] + keras_hub_dec_token_ids = hf_tokenizer( + dec_sample_text, return_tensors="tf" + )["input_ids"] + keras_hub_dec_token_ids = tf.concat( + [ + tf.constant([[keras_hub_tokenizer.start_token_id]]), + keras_hub_dec_token_ids, + ], + axis=-1, + ) + keras_hub_inputs = { + "encoder_token_ids": keras_hub_enc_token_ids, + "encoder_padding_mask": tf.ones_like(keras_hub_enc_token_ids), + "decoder_token_ids": keras_hub_dec_token_ids, + "decoder_padding_mask": tf.ones_like(keras_hub_dec_token_ids), + } + keras_hub_output = keras_hub_model.predict(keras_hub_inputs) + + # HF. + hf_enc_inputs = hf_tokenizer(enc_sample_text, return_tensors="pt") + hf_dec_inputs = hf_tokenizer(dec_sample_text, return_tensors="pt") + hf_decoder_input_ids = torch.cat( + [ + torch.tensor([[hf_tokenizer.bos_token_id]]), + hf_dec_inputs["input_ids"], + ], + dim=-1, + ) + hf_decoder_attention_mask = torch.cat( + [ + torch.ones(1, 1, dtype=torch.long), + hf_dec_inputs["attention_mask"], + ], + dim=-1, + ) + + hf_output = hf_model( + **hf_enc_inputs, + decoder_input_ids=hf_decoder_input_ids, + decoder_attention_mask=hf_decoder_attention_mask, + ) + + print("Encoder Outputs:") + print( + "KerasHub output:", + keras_hub_output["encoder_sequence_output"][0, 0, :10], + ) + print( + "HF output:", + hf_output.encoder_last_hidden_state[0, 0, :10], + ) + print( + "Difference:", + np.mean( + keras_hub_output["encoder_sequence_output"] + - hf_output.encoder_last_hidden_state.detach().numpy() + ), + ) + + print("Decoder Outputs:") + print( + "KerasHub output:", + keras_hub_output["decoder_sequence_output"][0, 0, :10], + ) + print("HF output:", hf_output.last_hidden_state[0, 0, :10]) + print( + "Difference:", + np.mean( + keras_hub_output["decoder_sequence_output"] + - hf_output.last_hidden_state.detach().numpy() + ), + ) + + +def main(_): + os.makedirs(FLAGS.preset, exist_ok=True) + + hf_model_name = PRESET_MAP[FLAGS.preset] + + print("\n-> Download HF model files.") + hf_model_dir = huggingface_hub.snapshot_download( + repo_id=hf_model_name, + allow_patterns=[ + "*.json", + "*.safetensors", + "tokenizer.model", + ], + ) + + print("\n-> Load HF model and HF tokenizer.") + hf_model = transformers.AutoModel.from_pretrained(hf_model_dir) + hf_model.eval() + hf_tokenizer = transformers.AutoTokenizer.from_pretrained(hf_model_dir) + + keras_hub_model = convert_checkpoints(hf_model) + print("\n-> Load KerasHub tokenizer.") + keras_hub_tokenizer = extract_vocab(hf_model_dir) + + check_output( + keras_hub_tokenizer, + keras_hub_model, + hf_tokenizer, + hf_model, + ) + print("\n-> Releasing HF backbone from memory.") + del hf_model + gc.collect() + + preprocessor = T5Gemma2Seq2SeqLMPreprocessor( + tokenizer=keras_hub_tokenizer, + encoder_sequence_length=512, + decoder_sequence_length=512, + ) + keras_lm = T5Gemma2Seq2SeqLM( + backbone=keras_hub_model, + preprocessor=preprocessor, + dtype=keras_hub_model.dtype, + ) + keras_lm.compile(sampler="greedy") + + print(f"\n-> Saving T5Gemma2Seq2SeqLM preset to `{FLAGS.preset}`.") + keras_lm.save_to_preset(FLAGS.preset) + print("-> Preset saved successfully.") + + print("\n-> Testing preset loading.") + keras_lm = Seq2SeqLM.from_preset(FLAGS.preset) + print("-> Preset loading verified successfully.") + + +if __name__ == "__main__": + flags.mark_flag_as_required("preset") + app.run(main) From 199aa402a2a6c5bd7beabac5d86524be6c17022c Mon Sep 17 00:00:00 2001 From: Laxma Reddy Patlolla Date: Mon, 2 Mar 2026 23:26:20 -0800 Subject: [PATCH 4/9] Fix numerics mismatch --- .../src/models/t5gemma2/t5gemma2_attention.py | 6 + .../src/models/t5gemma2/t5gemma2_backbone.py | 129 ++++++++----- .../src/models/t5gemma2/t5gemma2_decoder.py | 3 + .../src/models/t5gemma2/t5gemma2_encoder.py | 3 + .../utils/transformers/convert_t5gemma2.py | 46 +++-- .../convert_t5gemma2_checkpoints.py | 176 ++++++++++++------ 6 files changed, 239 insertions(+), 124 deletions(-) diff --git a/keras_hub/src/models/t5gemma2/t5gemma2_attention.py b/keras_hub/src/models/t5gemma2/t5gemma2_attention.py index 28b8d53e0f..c4a272a222 100644 --- a/keras_hub/src/models/t5gemma2/t5gemma2_attention.py +++ b/keras_hub/src/models/t5gemma2/t5gemma2_attention.py @@ -79,6 +79,7 @@ def __init__( attention_dropout=0.0, attn_logit_softcapping=None, rope_max_wavelength=10000.0, + rope_scaling_factor=1.0, use_query_key_norm=True, rms_norm_eps=1e-6, dtype=None, @@ -102,6 +103,7 @@ def __init__( self.initializer_range = initializer_range self.attention_dropout = attention_dropout self.rope_max_wavelength = rope_max_wavelength + self.rope_scaling_factor = rope_scaling_factor self.use_query_key_norm = use_query_key_norm self.rms_norm_eps = rms_norm_eps self.num_key_value_groups = ( @@ -184,6 +186,7 @@ def build(self, input_shape): self.rotary_embedding = RotaryEmbedding( max_wavelength=self.rope_max_wavelength, + scaling_factor=self.rope_scaling_factor, sequence_axis=1, feature_axis=3, name="rotary_embedding", @@ -389,6 +392,7 @@ def __init__( attention_dropout=0.0, attn_logit_softcapping=None, rope_max_wavelength=10000.0, + rope_scaling_factor=1.0, use_query_key_norm=True, rms_norm_eps=1e-6, dtype=None, @@ -415,6 +419,7 @@ def __init__( self.initializer_range = initializer_range self.attention_dropout = attention_dropout self.rope_max_wavelength = rope_max_wavelength + self.rope_scaling_factor = rope_scaling_factor self.use_query_key_norm = use_query_key_norm self.rms_norm_eps = rms_norm_eps self.num_key_value_groups = ( @@ -505,6 +510,7 @@ def build(self, input_shape): self.rotary_embedding = RotaryEmbedding( max_wavelength=self.rope_max_wavelength, + scaling_factor=self.rope_scaling_factor, sequence_axis=1, feature_axis=3, name="rotary_embedding", diff --git a/keras_hub/src/models/t5gemma2/t5gemma2_backbone.py b/keras_hub/src/models/t5gemma2/t5gemma2_backbone.py index 58bd719c51..f14eeb4263 100644 --- a/keras_hub/src/models/t5gemma2/t5gemma2_backbone.py +++ b/keras_hub/src/models/t5gemma2/t5gemma2_backbone.py @@ -138,6 +138,7 @@ def __init__( attn_logit_softcapping=None, final_logit_softcapping=None, rope_max_wavelength=10000.0, + global_rope_scaling_factor=1.0, use_query_key_norm=True, dtype=None, **kwargs, @@ -160,59 +161,89 @@ def __init__( dtype=dtype, name="decoder_token_embedding", ) - self.encoder_layers = [ - T5Gemma2EncoderLayer( - hidden_size=encoder_hidden_dim, - rms_norm_eps=rms_norm_eps, - num_attention_heads=encoder_num_attention_heads, - num_key_value_heads=encoder_num_key_value_heads, - query_pre_attn_scalar=query_pre_attn_scalar, - attention_bias=attention_bias, - intermediate_size=encoder_intermediate_dim, - hidden_activation=hidden_activation, - head_dim=encoder_head_dim, - dropout_rate=dropout_rate, - initializer_range=initializer_range, - attention_dropout=attention_dropout, - layer_type=encoder_layer_types[i], - sliding_window=sliding_window, - attn_logit_softcapping=attn_logit_softcapping, - rope_max_wavelength=rope_max_wavelength, - use_query_key_norm=use_query_key_norm, - name=f"encoder_layer_{i}", - dtype=dtype, + self.encoder_layers = [] + for i in range(encoder_num_layers): + # Per-layer RoPE wavelength: 10K for sliding, 1M for global. + layer_rope = ( + rope_max_wavelength + if encoder_layer_types[i] == "sliding_attention" + else 1_000_000.0 + ) + # Per-layer RoPE scaling: 1.0 for sliding (default), + # global_rope_scaling_factor for full_attention (linear). + layer_rope_factor = ( + 1.0 + if encoder_layer_types[i] == "sliding_attention" + else global_rope_scaling_factor + ) + self.encoder_layers.append( + T5Gemma2EncoderLayer( + hidden_size=encoder_hidden_dim, + rms_norm_eps=rms_norm_eps, + num_attention_heads=encoder_num_attention_heads, + num_key_value_heads=encoder_num_key_value_heads, + query_pre_attn_scalar=query_pre_attn_scalar, + attention_bias=attention_bias, + intermediate_size=encoder_intermediate_dim, + hidden_activation=hidden_activation, + head_dim=encoder_head_dim, + dropout_rate=dropout_rate, + initializer_range=initializer_range, + attention_dropout=attention_dropout, + layer_type=encoder_layer_types[i], + sliding_window=sliding_window, + attn_logit_softcapping=attn_logit_softcapping, + rope_max_wavelength=layer_rope, + rope_scaling_factor=layer_rope_factor, + use_query_key_norm=use_query_key_norm, + name=f"encoder_layer_{i}", + dtype=dtype, + ) ) - for i in range(encoder_num_layers) - ] self.encoder_norm = RMSNormalization(epsilon=rms_norm_eps, dtype=dtype) self.encoder_dropout = keras.layers.Dropout(dropout_rate, dtype=dtype) - self.decoder_layers = [ - T5Gemma2DecoderLayer( - hidden_size=decoder_hidden_dim, - rms_norm_eps=rms_norm_eps, - num_attention_heads=decoder_num_attention_heads, - num_key_value_heads=decoder_num_key_value_heads, - query_pre_attn_scalar=query_pre_attn_scalar, - attention_bias=attention_bias, - intermediate_size=decoder_intermediate_dim, - hidden_activation=hidden_activation, - dropout_rate=dropout_rate, - initializer_range=initializer_range, - head_dim=decoder_head_dim, - attention_dropout=attention_dropout, - layer_type=decoder_layer_types[i], - sliding_window=sliding_window, - cross_attention_hidden_size=( - cross_attention_hidden_size or encoder_hidden_dim - ), - attn_logit_softcapping=attn_logit_softcapping, - rope_max_wavelength=rope_max_wavelength, - use_query_key_norm=use_query_key_norm, - name=f"decoder_layer_{i}", - dtype=dtype, + self.decoder_layers = [] + for i in range(decoder_num_layers): + # Per-layer RoPE wavelength: 10K for sliding, 1M for global. + layer_rope = ( + rope_max_wavelength + if decoder_layer_types[i] == "sliding_attention" + else 1_000_000.0 + ) + # Per-layer RoPE scaling: 1.0 for sliding (default), + # global_rope_scaling_factor for full_attention (linear). + layer_rope_factor = ( + 1.0 + if decoder_layer_types[i] == "sliding_attention" + else global_rope_scaling_factor + ) + self.decoder_layers.append( + T5Gemma2DecoderLayer( + hidden_size=decoder_hidden_dim, + rms_norm_eps=rms_norm_eps, + num_attention_heads=decoder_num_attention_heads, + num_key_value_heads=decoder_num_key_value_heads, + query_pre_attn_scalar=query_pre_attn_scalar, + attention_bias=attention_bias, + intermediate_size=decoder_intermediate_dim, + hidden_activation=hidden_activation, + dropout_rate=dropout_rate, + initializer_range=initializer_range, + head_dim=decoder_head_dim, + attention_dropout=attention_dropout, + layer_type=decoder_layer_types[i], + sliding_window=sliding_window, + cross_attention_hidden_size=( + cross_attention_hidden_size or encoder_hidden_dim + ), + attn_logit_softcapping=attn_logit_softcapping, + rope_max_wavelength=layer_rope, + rope_scaling_factor=layer_rope_factor, + use_query_key_norm=use_query_key_norm, + name=f"decoder_layer_{i}", + dtype=dtype, + ) ) - for i in range(decoder_num_layers) - ] self.decoder_norm = RMSNormalization(epsilon=rms_norm_eps, dtype=dtype) self.decoder_dropout = keras.layers.Dropout(dropout_rate, dtype=dtype) diff --git a/keras_hub/src/models/t5gemma2/t5gemma2_decoder.py b/keras_hub/src/models/t5gemma2/t5gemma2_decoder.py index 4106ffad15..0776971f94 100644 --- a/keras_hub/src/models/t5gemma2/t5gemma2_decoder.py +++ b/keras_hub/src/models/t5gemma2/t5gemma2_decoder.py @@ -62,6 +62,7 @@ def __init__( attn_logit_softcapping=None, sliding_window=None, rope_max_wavelength=10000.0, + rope_scaling_factor=1.0, use_query_key_norm=True, dtype=None, **kwargs, @@ -82,6 +83,7 @@ def __init__( self.attention_type = layer_type self.sliding_window = sliding_window self.rope_max_wavelength = rope_max_wavelength + self.rope_scaling_factor = rope_scaling_factor self.cross_attention_hidden_size = cross_attention_hidden_size self.attn_logit_softcapping = attn_logit_softcapping self.use_query_key_norm = use_query_key_norm @@ -110,6 +112,7 @@ def __init__( attention_dropout=attention_dropout, attn_logit_softcapping=attn_logit_softcapping, rope_max_wavelength=self.rope_max_wavelength, + rope_scaling_factor=self.rope_scaling_factor, use_query_key_norm=use_query_key_norm, rms_norm_eps=rms_norm_eps, dtype=self.dtype_policy, diff --git a/keras_hub/src/models/t5gemma2/t5gemma2_encoder.py b/keras_hub/src/models/t5gemma2/t5gemma2_encoder.py index 75e117afb4..0764f38c20 100644 --- a/keras_hub/src/models/t5gemma2/t5gemma2_encoder.py +++ b/keras_hub/src/models/t5gemma2/t5gemma2_encoder.py @@ -60,6 +60,7 @@ def __init__( attn_logit_softcapping=None, sliding_window=None, rope_max_wavelength=10000.0, + rope_scaling_factor=1.0, use_query_key_norm=True, dtype=None, **kwargs, @@ -79,6 +80,7 @@ def __init__( self.attention_type = layer_type self.sliding_window = sliding_window self.rope_max_wavelength = rope_max_wavelength + self.rope_scaling_factor = rope_scaling_factor self.head_dim = head_dim self.attn_logit_softcapping = attn_logit_softcapping self.use_query_key_norm = use_query_key_norm @@ -103,6 +105,7 @@ def __init__( attention_dropout=attention_dropout, attn_logit_softcapping=attn_logit_softcapping, rope_max_wavelength=self.rope_max_wavelength, + rope_scaling_factor=self.rope_scaling_factor, use_query_key_norm=use_query_key_norm, rms_norm_eps=rms_norm_eps, dtype=self.dtype_policy, diff --git a/keras_hub/src/utils/transformers/convert_t5gemma2.py b/keras_hub/src/utils/transformers/convert_t5gemma2.py index 7634ba9ce3..fe4bbb290f 100644 --- a/keras_hub/src/utils/transformers/convert_t5gemma2.py +++ b/keras_hub/src/utils/transformers/convert_t5gemma2.py @@ -6,23 +6,27 @@ def convert_backbone_config(transformers_config): """Convert a HuggingFace T5Gemma2 config to KerasHub backbone config.""" + # T5Gemma2EncoderConfig is Gemma3Config with text params at + # encoder["text_config"]; decoder is Gemma3TextConfig (flat). encoder_config = transformers_config["encoder"] + enc_text = encoder_config["text_config"] decoder_config = transformers_config["decoder"] - if decoder_config.get("hidden_activation") == "gelu_pytorch_tanh": - decoder_config["hidden_activation"] = "gelu_approximate" - if encoder_config.get("hidden_activation") == "gelu_pytorch_tanh": - encoder_config["hidden_activation"] = "gelu_approximate" + hidden_activation = decoder_config.get( + "hidden_activation", "gelu_pytorch_tanh" + ) + if hidden_activation == "gelu_pytorch_tanh": + hidden_activation = "gelu_approximate" backbone_config = { "vocabulary_size": decoder_config["vocab_size"], - "encoder_hidden_dim": encoder_config["hidden_size"], - "encoder_intermediate_dim": encoder_config["intermediate_size"], - "encoder_num_layers": encoder_config["num_hidden_layers"], - "encoder_num_attention_heads": encoder_config["num_attention_heads"], - "encoder_num_key_value_heads": encoder_config["num_key_value_heads"], - "encoder_head_dim": encoder_config["head_dim"], - "encoder_layer_types": encoder_config["layer_types"], + "encoder_hidden_dim": enc_text["hidden_size"], + "encoder_intermediate_dim": enc_text["intermediate_size"], + "encoder_num_layers": enc_text["num_hidden_layers"], + "encoder_num_attention_heads": enc_text["num_attention_heads"], + "encoder_num_key_value_heads": enc_text["num_key_value_heads"], + "encoder_head_dim": enc_text["head_dim"], + "encoder_layer_types": enc_text["layer_types"], "decoder_hidden_dim": decoder_config["hidden_size"], "decoder_intermediate_dim": decoder_config["intermediate_size"], "decoder_num_layers": decoder_config["num_hidden_layers"], @@ -37,14 +41,17 @@ def convert_backbone_config(transformers_config): "tie_word_embeddings", True ), "attention_bias": decoder_config["attention_bias"], - "hidden_activation": decoder_config["hidden_activation"], + "hidden_activation": hidden_activation, "initializer_range": decoder_config["initializer_range"], "attention_dropout": decoder_config["attention_dropout"], "sliding_window": decoder_config["sliding_window"], - "cross_attention_hidden_size": encoder_config["hidden_size"], + "cross_attention_hidden_size": enc_text["hidden_size"], "attn_logit_softcapping": decoder_config["attn_logit_softcapping"], "final_logit_softcapping": decoder_config["final_logit_softcapping"], - "rope_max_wavelength": decoder_config["rope_theta"], + "rope_max_wavelength": decoder_config.get("rope_theta", 10000.0), + "global_rope_scaling_factor": decoder_config.get("rope_parameters", {}) + .get("full_attention", {}) + .get("factor", 1.0), "use_query_key_norm": True, } return backbone_config @@ -53,23 +60,24 @@ def convert_backbone_config(transformers_config): def convert_weights(backbone, loader, transformers_config): """Convert T5Gemma2 weights from HuggingFace to KerasHub.""" # Token embeddings. + # Encoder embeds are under encoder.text_model.embed_tokens.* loader.port_weight( keras_variable=backbone.token_embedding.embeddings, - hf_weight_key="encoder.embed_tokens.weight", + hf_weight_key="encoder.text_model.embed_tokens.weight", ) loader.port_weight( keras_variable=backbone.decoder_token_embedding.embeddings, hf_weight_key="decoder.embed_tokens.weight", ) - # Encoder. + # Encoder (weights under encoder.text_model.*). loader.port_weight( keras_variable=backbone.encoder_norm.scale, - hf_weight_key="encoder.norm.weight", + hf_weight_key="encoder.text_model.norm.weight", ) for i in range(backbone.encoder_num_layers): layer = backbone.get_layer(f"encoder_layer_{i}") - hf_prefix = f"encoder.layers.{i}" + hf_prefix = f"encoder.text_model.layers.{i}" # Self-attention Q/K/V/O projections. loader.port_weight( @@ -138,7 +146,7 @@ def convert_weights(backbone, loader, transformers_config): hf_weight_key=(f"{hf_prefix}.post_feedforward_layernorm.weight"), ) - # Decoder. + # Decoder (weights directly under decoder.*). loader.port_weight( keras_variable=backbone.decoder_norm.scale, hf_weight_key="decoder.norm.weight", diff --git a/tools/checkpoint_conversion/convert_t5gemma2_checkpoints.py b/tools/checkpoint_conversion/convert_t5gemma2_checkpoints.py index 02c3d5cf18..827517d0ff 100644 --- a/tools/checkpoint_conversion/convert_t5gemma2_checkpoints.py +++ b/tools/checkpoint_conversion/convert_t5gemma2_checkpoints.py @@ -2,11 +2,11 @@ import os import random import shutil +import traceback import huggingface_hub import keras import numpy as np -import tensorflow as tf import torch import transformers from absl import app @@ -26,8 +26,9 @@ torch.manual_seed(123) device = torch.device("cpu") torch.set_default_device(device) +torch.set_default_dtype(torch.float32) - +# Placeholder preset map — populate when checkpoints are released. PRESET_MAP = { "t5gemma2_270m_270m": "google/t5gemma-2-270m-270m", "t5gemma2_1b_1b": "google/t5gemma-2-1b-1b", @@ -47,23 +48,26 @@ def convert_checkpoints(hf_model): print("\n-> Convert original weights to KerasHub format.") print("\n-> Load KerasHub model.") + # T5Gemma2EncoderConfig is Gemma3Config with text params at + # encoder.text_config.*; decoder is Gemma3TextConfig (flat). encoder_config = hf_model.config.encoder + enc_text_config = encoder_config.text_config decoder_config = hf_model.config.decoder if decoder_config.hidden_activation == "gelu_pytorch_tanh": decoder_config.hidden_activation = "gelu_approximate" - if encoder_config.hidden_activation == "gelu_pytorch_tanh": - encoder_config.hidden_activation = "gelu_approximate" + if enc_text_config.hidden_activation == "gelu_pytorch_tanh": + enc_text_config.hidden_activation = "gelu_approximate" keras.config.set_floatx("float32") keras_hub_model = keras_hub.models.T5Gemma2Backbone( vocabulary_size=decoder_config.vocab_size, - encoder_hidden_dim=encoder_config.hidden_size, - encoder_intermediate_dim=encoder_config.intermediate_size, - encoder_num_layers=encoder_config.num_hidden_layers, - encoder_num_attention_heads=encoder_config.num_attention_heads, - encoder_num_key_value_heads=(encoder_config.num_key_value_heads), - encoder_head_dim=encoder_config.head_dim, - encoder_layer_types=encoder_config.layer_types, + encoder_hidden_dim=enc_text_config.hidden_size, + encoder_intermediate_dim=enc_text_config.intermediate_size, + encoder_num_layers=enc_text_config.num_hidden_layers, + encoder_num_attention_heads=(enc_text_config.num_attention_heads), + encoder_num_key_value_heads=(enc_text_config.num_key_value_heads), + encoder_head_dim=enc_text_config.head_dim, + encoder_layer_types=enc_text_config.layer_types, decoder_hidden_dim=decoder_config.hidden_size, decoder_intermediate_dim=decoder_config.intermediate_size, decoder_num_layers=decoder_config.num_hidden_layers, @@ -82,34 +86,44 @@ def convert_checkpoints(hf_model): initializer_range=decoder_config.initializer_range, attention_dropout=decoder_config.attention_dropout, sliding_window=decoder_config.sliding_window, - cross_attention_hidden_size=encoder_config.hidden_size, - attn_logit_softcapping=decoder_config.attn_logit_softcapping, - final_logit_softcapping=decoder_config.final_logit_softcapping, - rope_max_wavelength=decoder_config.rope_theta, + cross_attention_hidden_size=enc_text_config.hidden_size, + attn_logit_softcapping=(decoder_config.attn_logit_softcapping), + final_logit_softcapping=(decoder_config.final_logit_softcapping), + rope_max_wavelength=getattr(decoder_config, "rope_theta", 10000.0), + global_rope_scaling_factor=( + decoder_config.rope_parameters.get("full_attention", {}).get( + "factor", 1.0 + ) + ), use_query_key_norm=True, dtype="float32", ) hf_wts = hf_model.state_dict() + # Cast all weights to float32 (HF uses bfloat16). + hf_wts = {k: v.float() for k, v in hf_wts.items()} # Token embeddings. + # Encoder embeds are under encoder.text_model.embed_tokens.* keras_hub_model.get_layer("encoder_token_embedding").embeddings.assign( - hf_wts["encoder.embed_tokens.weight"] + hf_wts["encoder.text_model.embed_tokens.weight"].numpy() ) keras_hub_model.get_layer("decoder_token_embedding").embeddings.assign( - hf_wts["decoder.embed_tokens.weight"] + hf_wts["decoder.embed_tokens.weight"].numpy() ) - # Encoder. + # Encoder (weights under encoder.text_model.*). enc_hdim = keras_hub_model.encoder_hidden_dim enc_heads = keras_hub_model.encoder_num_attention_heads enc_kv_heads = keras_hub_model.encoder_num_key_value_heads enc_head_dim = keras_hub_model.encoder_head_dim - keras_hub_model.encoder_norm.scale.assign(hf_wts["encoder.norm.weight"]) + keras_hub_model.encoder_norm.scale.assign( + hf_wts["encoder.text_model.norm.weight"] + ) for i in range(keras_hub_model.encoder_num_layers): layer = keras_hub_model.get_layer(f"encoder_layer_{i}") - pfx = f"encoder.layers.{i}" + pfx = f"encoder.text_model.layers.{i}" # Self-attention Q/K/V/O. layer.self_attn.query_dense.kernel.assign( @@ -237,11 +251,26 @@ def convert_checkpoints(hf_model): def extract_vocab(hf_model_dir): """Extract vocabulary from the downloaded HF model directory.""" - source_path = os.path.join(hf_model_dir, "tokenizer.model") vocabulary_path = os.path.join(FLAGS.preset, "tokenizer.model") print(f"\n-> Save KerasHub vocab to `{vocabulary_path}`.") - shutil.copyfile(source_path, vocabulary_path) + # T5Gemma2 HF repos only have tokenizer.json (no tokenizer.model). + # The SentencePiece proto is the same as Gemma3's vocabulary. + source_path = os.path.join(hf_model_dir, "tokenizer.model") + if os.path.exists(source_path): + shutil.copyfile(source_path, vocabulary_path) + else: + # Download tokenizer.model from Gemma3 (same vocab). + print( + " tokenizer.model not found in HF repo. " + "Downloading from google/gemma-3-1b-pt..." + ) + gemma_dir = huggingface_hub.snapshot_download( + repo_id="google/gemma-3-1b-pt", + allow_patterns=["tokenizer.model"], + ) + gemma_proto = os.path.join(gemma_dir, "tokenizer.model") + shutil.copyfile(gemma_proto, vocabulary_path) keras_hub_tokenizer = keras_hub.models.T5Gemma2Tokenizer( proto=vocabulary_path @@ -263,6 +292,24 @@ def check_output( hf_model, ): """Check outputs of KerasHub and HuggingFace models match.""" + # Parameter count check. + # Note: HF model includes vision encoder (SigLIP) params + # that our text-only KerasHub backbone doesn't have. + print("\n-> Verify parameter counts.") + keras_hub_params = keras_hub_model.count_params() + hf_params = hf_model.num_parameters() + print(f"KerasHub params: {keras_hub_params:,}") + print(f"HF params: {hf_params:,}") + if keras_hub_params == hf_params: + print("-> Parameter counts match!") + else: + diff = hf_params - keras_hub_params + print( + f"-> Parameter count difference: {diff:,} " + f"(expected — HF includes vision encoder)" + ) + + # Output comparison. print("\n-> Check the outputs.") enc_sample_text = [ "cricket is awesome, easily the best sport in the world!" @@ -273,24 +320,26 @@ def check_output( # KerasHub. keras_hub_enc_token_ids = hf_tokenizer( - enc_sample_text, return_tensors="tf" + enc_sample_text, return_tensors="np" )["input_ids"] keras_hub_dec_token_ids = hf_tokenizer( - dec_sample_text, return_tensors="tf" + dec_sample_text, return_tensors="np" )["input_ids"] - keras_hub_dec_token_ids = tf.concat( + keras_hub_dec_token_ids = np.concatenate( [ - tf.constant([[keras_hub_tokenizer.start_token_id]]), + np.array([[keras_hub_tokenizer.start_token_id]]), keras_hub_dec_token_ids, ], axis=-1, ) keras_hub_inputs = { "encoder_token_ids": keras_hub_enc_token_ids, - "encoder_padding_mask": tf.ones_like(keras_hub_enc_token_ids), + "encoder_padding_mask": np.ones_like(keras_hub_enc_token_ids), "decoder_token_ids": keras_hub_dec_token_ids, - "decoder_padding_mask": tf.ones_like(keras_hub_dec_token_ids), + "decoder_padding_mask": np.ones_like(keras_hub_dec_token_ids), } + + print("\n--- Model Verification ---") keras_hub_output = keras_hub_model.predict(keras_hub_inputs) # HF. @@ -317,36 +366,39 @@ def check_output( decoder_attention_mask=hf_decoder_attention_mask, ) + # Encoder output comparison. + keras_enc_out = keras_hub_output["encoder_sequence_output"] + hf_enc_out = hf_output.encoder_last_hidden_state.detach().float().numpy() print("Encoder Outputs:") - print( - "KerasHub output:", - keras_hub_output["encoder_sequence_output"][0, 0, :10], - ) - print( - "HF output:", - hf_output.encoder_last_hidden_state[0, 0, :10], - ) - print( - "Difference:", - np.mean( - keras_hub_output["encoder_sequence_output"] - - hf_output.encoder_last_hidden_state.detach().numpy() - ), - ) - + print("KerasHub output:", keras_enc_out[0, 0, :10]) + print("HF output:", hf_enc_out[0, 0, :10]) + try: + np.testing.assert_allclose( + keras_enc_out, hf_enc_out, rtol=1e-4, atol=1e-4 + ) + print("-> Encoder outputs match! (rtol=1e-4, atol=1e-4)") + except AssertionError as err: + print("\n") + print(traceback.format_exc()) + print(err.args[0]) + print("\n") + + # Decoder output comparison. + keras_dec_out = keras_hub_output["decoder_sequence_output"] + hf_dec_out = hf_output.last_hidden_state.detach().float().numpy() print("Decoder Outputs:") - print( - "KerasHub output:", - keras_hub_output["decoder_sequence_output"][0, 0, :10], - ) - print("HF output:", hf_output.last_hidden_state[0, 0, :10]) - print( - "Difference:", - np.mean( - keras_hub_output["decoder_sequence_output"] - - hf_output.last_hidden_state.detach().numpy() - ), - ) + print("KerasHub output:", keras_dec_out[0, 0, :10]) + print("HF output:", hf_dec_out[0, 0, :10]) + try: + np.testing.assert_allclose( + keras_dec_out, hf_dec_out, rtol=1e-4, atol=1e-4 + ) + print("-> Decoder outputs match! (rtol=1e-4, atol=1e-4)") + except AssertionError as err: + print("\n") + print(traceback.format_exc()) + print(err.args[0]) + print("\n") def main(_): @@ -366,6 +418,18 @@ def main(_): print("\n-> Load HF model and HF tokenizer.") hf_model = transformers.AutoModel.from_pretrained(hf_model_dir) + hf_model.float() # Convert all params/buffers to float32. + # Fix embed_scale buffers: they were created in bfloat16 + # during init (non-persistent), so .float() preserves + # the bf16-rounded value. Re-create with true f32 precision. + enc_hdim = hf_model.config.encoder.text_config.hidden_size + dec_hdim = hf_model.config.decoder.hidden_size + hf_model.encoder.text_model.embed_tokens.embed_scale = torch.tensor( + enc_hdim**0.5, dtype=torch.float32 + ) + hf_model.decoder.embed_tokens.embed_scale = torch.tensor( + dec_hdim**0.5, dtype=torch.float32 + ) hf_model.eval() hf_tokenizer = transformers.AutoTokenizer.from_pretrained(hf_model_dir) From b27b7ea45f2d0e0e10083c178401863197759107 Mon Sep 17 00:00:00 2001 From: Laxma Reddy Patlolla Date: Wed, 4 Mar 2026 12:27:15 -0800 Subject: [PATCH 5/9] Add vision tower and Numeric verification Fixes --- .../src/models/t5gemma2/t5gemma2_attention.py | 8 +- .../src/models/t5gemma2/t5gemma2_backbone.py | 180 ++++++++- .../src/models/t5gemma2/t5gemma2_decoder.py | 2 +- .../src/models/t5gemma2/t5gemma2_encoder.py | 26 +- .../t5gemma2/t5gemma2_image_converter.py | 14 + .../t5gemma2_seq_2_seq_lm_preprocessor.py | 9 + .../utils/transformers/convert_t5gemma2.py | 191 +++++++++- .../convert_t5gemma2_checkpoints.py | 342 +++++++++++++++++- 8 files changed, 735 insertions(+), 37 deletions(-) create mode 100644 keras_hub/src/models/t5gemma2/t5gemma2_image_converter.py diff --git a/keras_hub/src/models/t5gemma2/t5gemma2_attention.py b/keras_hub/src/models/t5gemma2/t5gemma2_attention.py index c4a272a222..0cab5426e2 100644 --- a/keras_hub/src/models/t5gemma2/t5gemma2_attention.py +++ b/keras_hub/src/models/t5gemma2/t5gemma2_attention.py @@ -3,8 +3,8 @@ import keras from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding -from keras_hub.src.models.gemma.gemma_attention import CachedGemmaAttention -from keras_hub.src.models.gemma.rms_normalization import RMSNormalization +from keras_hub.src.models.gemma3.gemma3_attention import CachedGemma3Attention +from keras_hub.src.models.gemma3.gemma3_layers import RMSNormalization from keras_hub.src.models.t5gemma2.t5gemma2_layers import ( t5gemma2_kernel_initializer, ) @@ -33,7 +33,7 @@ def repeat_kv(hidden_states, n_rep): ) -class T5Gemma2Attention(CachedGemmaAttention): +class T5Gemma2Attention(CachedGemma3Attention): """Self-attention layer for T5Gemma2 encoder and decoder. This layer performs self-attention with Rotary Positional Embeddings @@ -341,7 +341,7 @@ def get_config(self): return config -class T5Gemma2MergedAttention(CachedGemmaAttention): +class T5Gemma2MergedAttention(CachedGemma3Attention): """Merged self-attention and cross-attention for T5Gemma2 decoder. This layer fuses self-attention and cross-attention into a single diff --git a/keras_hub/src/models/t5gemma2/t5gemma2_backbone.py b/keras_hub/src/models/t5gemma2/t5gemma2_backbone.py index f14eeb4263..02004726a3 100644 --- a/keras_hub/src/models/t5gemma2/t5gemma2_backbone.py +++ b/keras_hub/src/models/t5gemma2/t5gemma2_backbone.py @@ -1,9 +1,11 @@ import keras +from keras import ops from keras.layers import ReversibleEmbedding from keras_hub.src.api_export import keras_hub_export from keras_hub.src.models.backbone import Backbone -from keras_hub.src.models.gemma.rms_normalization import RMSNormalization +from keras_hub.src.models.gemma3.gemma3_layers import Gemma3InterleaveEmbeddings +from keras_hub.src.models.gemma3.gemma3_layers import RMSNormalization from keras_hub.src.models.t5gemma2.t5gemma2_decoder import T5Gemma2DecoderLayer from keras_hub.src.models.t5gemma2.t5gemma2_encoder import T5Gemma2EncoderLayer from keras_hub.src.models.t5gemma2.t5gemma2_layers import ( @@ -22,6 +24,11 @@ class T5Gemma2Backbone(Backbone): separate attention sublayers), Gemma3-style Q/K normalization, and per-layer-type sliding window attention patterns. + When a `vision_encoder` is provided, the model also accepts image + inputs. Images are processed by the vision encoder and the resulting + embeddings are interleaved into the encoder text embeddings at + positions marked by image placeholder tokens. + Args: vocabulary_size: int, The size of the vocabulary. encoder_hidden_dim: int, Encoder hidden dimensionality. @@ -61,8 +68,14 @@ class T5Gemma2Backbone(Backbone): softcapping. rope_max_wavelength: float, RoPE maximum wavelength. Defaults to `10000.0`. + global_rope_scaling_factor: float, RoPE scaling factor for + full attention layers. Defaults to `1.0`. use_query_key_norm: bool, Whether to use Gemma3-style Q/K normalization. Defaults to `True`. + vision_encoder: optional, A `Gemma3VisionEncoder` instance for + multimodal inputs. When `None`, the model is text-only. + eoi_token_index: int, Token index for the end-of-image token. + Defaults to `256000`. dtype: dtype for computations. Defaults to `None`. **kwargs: Additional keyword arguments. @@ -139,12 +152,20 @@ def __init__( final_logit_softcapping=None, rope_max_wavelength=10000.0, global_rope_scaling_factor=1.0, + encoder_rope_max_wavelength=None, + encoder_global_rope_scaling_factor=None, use_query_key_norm=True, + vision_encoder=None, + eoi_token_index=256000, dtype=None, **kwargs, ): self.kernel_initializer = t5gemma2_kernel_initializer(initializer_range) + # Determine if text-only. + self.vision_encoder = vision_encoder + text_only_model = vision_encoder is None + # === Layers === self.token_embedding = keras.layers.Embedding( input_dim=vocabulary_size, @@ -161,11 +182,42 @@ def __init__( dtype=dtype, name="decoder_token_embedding", ) + + # Vision interleaving layer (only when vision encoder is present). + if not text_only_model: + self.interleave_embeddings = Gemma3InterleaveEmbeddings( + num_vision_tokens_per_image=self.vision_encoder.num_vision_tokens_per_image, + dtype=dtype, + name="interleave_embeddings", + ) + # EOI (end-of-image) embeddings: learned vectors that + # replace the standard embedding at eoi_token_index. + self.encoder_eoi_embedding = keras.Variable( + keras.ops.zeros((encoder_hidden_dim,)), + name="encoder_eoi_embedding", + ) + self.decoder_eoi_embedding = keras.Variable( + keras.ops.zeros((decoder_hidden_dim,)), + name="decoder_eoi_embedding", + ) + + # Encoder may have different RoPE config than decoder. + enc_rope = ( + encoder_rope_max_wavelength + if encoder_rope_max_wavelength is not None + else rope_max_wavelength + ) + enc_rope_factor = ( + encoder_global_rope_scaling_factor + if encoder_global_rope_scaling_factor is not None + else global_rope_scaling_factor + ) + self.encoder_layers = [] for i in range(encoder_num_layers): - # Per-layer RoPE wavelength: 10K for sliding, 1M for global. + # Per-layer RoPE wavelength: base for sliding, 1M for global. layer_rope = ( - rope_max_wavelength + enc_rope if encoder_layer_types[i] == "sliding_attention" else 1_000_000.0 ) @@ -174,7 +226,7 @@ def __init__( layer_rope_factor = ( 1.0 if encoder_layer_types[i] == "sliding_attention" - else global_rope_scaling_factor + else enc_rope_factor ) self.encoder_layers.append( T5Gemma2EncoderLayer( @@ -261,11 +313,50 @@ def __init__( shape=(None,), dtype="int32", name="decoder_padding_mask" ) + # Optional vision inputs. + if not text_only_model: + image_size = self.vision_encoder.image_size + image_input = keras.Input( + shape=(None, image_size, image_size, 3), + name="images", + ) + vision_indices_input = keras.Input( + shape=(None,), dtype="int32", name="vision_indices" + ) + # Encoder. encoder_embeddings = self.token_embedding(encoder_token_id_input) - encoder_embeddings = encoder_embeddings * keras.ops.cast( - keras.ops.sqrt(encoder_hidden_dim), encoder_embeddings.dtype + encoder_embeddings = encoder_embeddings * ops.cast( + ops.sqrt(encoder_hidden_dim), encoder_embeddings.dtype ) + + # Handle EOI embedding replacement. + if not text_only_model: + # Replace embeddings at eoi_token_index positions with the + # learned eoi_embedding (a separate parameter per HF design). + # Use ops.where with automatic broadcasting (no broadcast_to + # needed — avoids issues with symbolic shapes during tracing). + eoi_mask = ops.cast( + ops.expand_dims( + ops.equal(encoder_token_id_input, eoi_token_index), + axis=-1, + ), + encoder_embeddings.dtype, + ) + encoder_embeddings = ( + eoi_mask * self.encoder_eoi_embedding + + (1 - eoi_mask) * encoder_embeddings + ) + + # Interleave vision embeddings if images are provided. + if not text_only_model: + img_embeddings = self.vision_encoder(image_input) + encoder_embeddings = self.interleave_embeddings( + image_embeddings=img_embeddings, + text_embeddings=encoder_embeddings, + vision_indices=vision_indices_input, + ) + encoder_hidden_states = self.encoder_dropout(encoder_embeddings) for layer in self.encoder_layers: encoder_hidden_states = layer( @@ -279,9 +370,24 @@ def __init__( decoder_embeddings = self.decoder_token_embedding( decoder_token_id_input ) - decoder_embeddings = decoder_embeddings * keras.ops.cast( - keras.ops.sqrt(decoder_hidden_dim), decoder_embeddings.dtype + decoder_embeddings = decoder_embeddings * ops.cast( + ops.sqrt(decoder_hidden_dim), decoder_embeddings.dtype ) + + # Handle EOI embedding replacement in decoder. + if not text_only_model: + dec_eoi_mask = ops.cast( + ops.expand_dims( + ops.equal(decoder_token_id_input, eoi_token_index), + axis=-1, + ), + decoder_embeddings.dtype, + ) + decoder_embeddings = ( + dec_eoi_mask * self.decoder_eoi_embedding + + (1 - dec_eoi_mask) * decoder_embeddings + ) + decoder_hidden_states = self.decoder_dropout(decoder_embeddings) for layer in self.decoder_layers: decoder_hidden_states, _ = layer( @@ -292,13 +398,22 @@ def __init__( decoder_output = self.decoder_norm(decoder_hidden_states) decoder_output = self.decoder_dropout(decoder_output) + inputs = { + "encoder_token_ids": encoder_token_id_input, + "encoder_padding_mask": encoder_padding_mask_input, + "decoder_token_ids": decoder_token_id_input, + "decoder_padding_mask": decoder_padding_mask_input, + } + if not text_only_model: + inputs.update( + { + "images": image_input, + "vision_indices": vision_indices_input, + } + ) + super().__init__( - inputs={ - "encoder_token_ids": encoder_token_id_input, - "encoder_padding_mask": encoder_padding_mask_input, - "decoder_token_ids": decoder_token_id_input, - "decoder_padding_mask": decoder_padding_mask_input, - }, + inputs=inputs, outputs={ "encoder_sequence_output": encoder_output, "decoder_sequence_output": decoder_output, @@ -338,7 +453,20 @@ def __init__( self.attn_logit_softcapping = attn_logit_softcapping self.final_logit_softcapping = final_logit_softcapping self.rope_max_wavelength = rope_max_wavelength + self.global_rope_scaling_factor = global_rope_scaling_factor + self.encoder_rope_max_wavelength = encoder_rope_max_wavelength + self.encoder_global_rope_scaling_factor = ( + encoder_global_rope_scaling_factor + ) self.use_query_key_norm = use_query_key_norm + self.eoi_token_index = eoi_token_index + self.text_only_model = text_only_model + + # Keep `num_vision_tokens_per_image` for easy access. + if not text_only_model: + self.num_vision_tokens_per_image = ( + self.vision_encoder.num_vision_tokens_per_image + ) def get_config(self): config = super().get_config() @@ -382,7 +510,31 @@ def get_config(self): "attn_logit_softcapping": self.attn_logit_softcapping, "final_logit_softcapping": (self.final_logit_softcapping), "rope_max_wavelength": self.rope_max_wavelength, + "global_rope_scaling_factor": (self.global_rope_scaling_factor), + "encoder_rope_max_wavelength": ( + self.encoder_rope_max_wavelength + ), + "encoder_global_rope_scaling_factor": ( + self.encoder_global_rope_scaling_factor + ), "use_query_key_norm": self.use_query_key_norm, + "eoi_token_index": self.eoi_token_index, } ) + if self.vision_encoder is not None: + config["vision_encoder"] = keras.saving.serialize_keras_object( + self.vision_encoder + ) return config + + @classmethod + def from_config(cls, config): + vision_encoder = config.pop("vision_encoder", None) + if vision_encoder is not None and isinstance(vision_encoder, dict): + vision_encoder = keras.saving.deserialize_keras_object( + vision_encoder + ) + config["vision_encoder"] = vision_encoder + elif vision_encoder is not None: + config["vision_encoder"] = vision_encoder + return cls(**config) diff --git a/keras_hub/src/models/t5gemma2/t5gemma2_decoder.py b/keras_hub/src/models/t5gemma2/t5gemma2_decoder.py index 0776971f94..b56d708031 100644 --- a/keras_hub/src/models/t5gemma2/t5gemma2_decoder.py +++ b/keras_hub/src/models/t5gemma2/t5gemma2_decoder.py @@ -1,6 +1,6 @@ import keras -from keras_hub.src.models.gemma.rms_normalization import RMSNormalization +from keras_hub.src.models.gemma3.gemma3_layers import RMSNormalization from keras_hub.src.models.t5gemma2.t5gemma2_attention import ( T5Gemma2MergedAttention, ) diff --git a/keras_hub/src/models/t5gemma2/t5gemma2_encoder.py b/keras_hub/src/models/t5gemma2/t5gemma2_encoder.py index 0764f38c20..538bf25513 100644 --- a/keras_hub/src/models/t5gemma2/t5gemma2_encoder.py +++ b/keras_hub/src/models/t5gemma2/t5gemma2_encoder.py @@ -1,6 +1,6 @@ import keras -from keras_hub.src.models.gemma.rms_normalization import RMSNormalization +from keras_hub.src.models.gemma3.gemma3_layers import RMSNormalization from keras_hub.src.models.t5gemma2.t5gemma2_attention import T5Gemma2Attention from keras_hub.src.models.t5gemma2.t5gemma2_layers import T5Gemma2MLP @@ -164,6 +164,30 @@ def _make_attention_mask(self, hidden_states, padding_mask): additive_mask = ( 1.0 - keras.ops.cast(attention_mask, hidden_states.dtype) ) * -1e9 + # Apply bidirectional sliding window for sliding_attention layers. + if ( + self.attention_type == "sliding_attention" + and self.sliding_window is not None + ): + seq_len = keras.ops.shape(hidden_states)[1] + # Build position indices for the sliding window mask. + q_idx = keras.ops.arange(seq_len)[:, None] # (S, 1) + kv_idx = keras.ops.arange(seq_len)[None, :] # (1, S) + dist = q_idx - kv_idx + # HF bidirectional window: + # left_window = (sliding_window + 1) // 2 + # right_window = sliding_window // 2 + 1 + left_w = (self.sliding_window + 1) // 2 + right_w = self.sliding_window // 2 + 1 + window_mask = ((dist >= 0) & (dist < left_w)) | ( + (dist < 0) & (-dist < right_w) + ) + # Expand to (1, 1, S, S) and convert to additive mask. + window_mask = keras.ops.cast( + window_mask[None, None, :, :], hidden_states.dtype + ) + window_additive = (1.0 - window_mask) * -1e9 + additive_mask = additive_mask + window_additive return additive_mask def call( diff --git a/keras_hub/src/models/t5gemma2/t5gemma2_image_converter.py b/keras_hub/src/models/t5gemma2/t5gemma2_image_converter.py new file mode 100644 index 0000000000..40ec3c6785 --- /dev/null +++ b/keras_hub/src/models/t5gemma2/t5gemma2_image_converter.py @@ -0,0 +1,14 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter +from keras_hub.src.models.t5gemma2.t5gemma2_backbone import T5Gemma2Backbone + + +@keras_hub_export("keras_hub.layers.T5Gemma2ImageConverter") +class T5Gemma2ImageConverter(ImageConverter): + backbone_cls = T5Gemma2Backbone + + def __init__(self, **kwargs): + # Always do image preprocessing in float32 + kwargs.pop("dtype", None) + dtype = "float32" + super().__init__(dtype=dtype, **kwargs) diff --git a/keras_hub/src/models/t5gemma2/t5gemma2_seq_2_seq_lm_preprocessor.py b/keras_hub/src/models/t5gemma2/t5gemma2_seq_2_seq_lm_preprocessor.py index 298da62b88..ce25d3e207 100644 --- a/keras_hub/src/models/t5gemma2/t5gemma2_seq_2_seq_lm_preprocessor.py +++ b/keras_hub/src/models/t5gemma2/t5gemma2_seq_2_seq_lm_preprocessor.py @@ -27,10 +27,17 @@ class T5Gemma2Seq2SeqLMPreprocessor(Seq2SeqLMPreprocessor): preprocessor is attached to a `keras_hub.models.T5Gemma2Seq2SeqLM` instance, these methods will be called implicitly in `generate()`. + When an `image_converter` is provided, the preprocessor also + supports multimodal inputs with images. Images are inserted into + the encoder sequence as placeholder tokens that the backbone's + vision encoder will replace with image embeddings. + Args: tokenizer: A `keras_hub.models.T5Gemma2Tokenizer` instance. encoder_sequence_length: The length of the packed encoder inputs. decoder_sequence_length: The length of the packed decoder inputs. + image_converter: A `keras_hub.layers.ImageConverter` instance, + or `None` for text-only. Defaults to `None`. add_start_token: If `True`, prepend the start token. Defaults to `False`. add_end_token: If `True`, append the end token. Defaults to @@ -45,6 +52,7 @@ def __init__( tokenizer, encoder_sequence_length=512, decoder_sequence_length=512, + image_converter=None, add_start_token=False, add_end_token=True, **kwargs, @@ -57,6 +65,7 @@ def __init__( ) self.add_start_token = add_start_token self.add_end_token = add_end_token + self.image_converter = image_converter @preprocessing_function def call( diff --git a/keras_hub/src/utils/transformers/convert_t5gemma2.py b/keras_hub/src/utils/transformers/convert_t5gemma2.py index fe4bbb290f..685d953cff 100644 --- a/keras_hub/src/utils/transformers/convert_t5gemma2.py +++ b/keras_hub/src/utils/transformers/convert_t5gemma2.py @@ -1,9 +1,31 @@ +import numpy as np + from keras_hub.src.models.t5gemma2.t5gemma2_backbone import T5Gemma2Backbone from keras_hub.src.utils.preset_utils import get_file +from keras_hub.src.utils.preset_utils import load_json backbone_cls = T5Gemma2Backbone +def load_image_converter_config(preset, transformers_config): + """Load image converter config from HF preprocessor_config.json.""" + encoder_config = transformers_config.get("encoder", {}) + if "vision_config" not in encoder_config: + return None + preprocessor_config = load_json(preset, "preprocessor_config.json") + mean = preprocessor_config["image_mean"] + std = preprocessor_config["image_std"] + rescale_factor = preprocessor_config["rescale_factor"] + offset = [(-m / s) for m, s in zip(mean, std)] + scale = [(s * rescale_factor) for s in std] + image_size = encoder_config["vision_config"].get("image_size", 896) + return { + "image_size": (image_size, image_size), + "scale": scale, + "offset": offset, + } + + def convert_backbone_config(transformers_config): """Convert a HuggingFace T5Gemma2 config to KerasHub backbone config.""" # T5Gemma2EncoderConfig is Gemma3Config with text params at @@ -18,6 +40,30 @@ def convert_backbone_config(transformers_config): if hidden_activation == "gelu_pytorch_tanh": hidden_activation = "gelu_approximate" + # Vision encoder (optional). + vision_encoder = None + if "vision_config" in encoder_config: + from keras_hub.src.models.gemma3.gemma3_vision_encoder import ( + Gemma3VisionEncoder, + ) + + vision_config = encoder_config["vision_config"] + vision_encoder = Gemma3VisionEncoder( + image_size=vision_config["image_size"], + patch_size=vision_config["patch_size"], + num_heads=vision_config["num_attention_heads"], + hidden_dim=vision_config["hidden_size"], + num_layers=vision_config["num_hidden_layers"], + intermediate_dim=vision_config["intermediate_size"], + output_dim=enc_text["hidden_size"], + pool_size=int( + vision_config["image_size"] + // vision_config["patch_size"] + // int(encoder_config.get("mm_tokens_per_image", 256) ** 0.5) + ), + layer_norm_epsilon=vision_config.get("layer_norm_eps", 1e-6), + ) + backbone_config = { "vocabulary_size": decoder_config["vocab_size"], "encoder_hidden_dim": enc_text["hidden_size"], @@ -52,15 +98,158 @@ def convert_backbone_config(transformers_config): "global_rope_scaling_factor": decoder_config.get("rope_parameters", {}) .get("full_attention", {}) .get("factor", 1.0), + "encoder_rope_max_wavelength": enc_text.get("rope_parameters", {}) + .get("sliding_attention", {}) + .get("rope_theta", None), + "encoder_global_rope_scaling_factor": enc_text.get( + "rope_parameters", {} + ) + .get("full_attention", {}) + .get("factor", None), "use_query_key_norm": True, + "vision_encoder": vision_encoder, + "eoi_token_index": transformers_config.get("eoi_token_index", 256000), } return backbone_config def convert_weights(backbone, loader, transformers_config): """Convert T5Gemma2 weights from HuggingFace to KerasHub.""" + + def transpose(x, shape): + return np.transpose(x) + + # === Vision encoder weights === + vision_encoder = backbone.vision_encoder + if vision_encoder is not None: + image_encoder = vision_encoder.get_layer("image_encoder") + + loader.port_weight( + keras_variable=image_encoder.vision_embeddings.patch_embedding.kernel, + hf_weight_key="encoder.vision_tower.vision_model.embeddings.patch_embedding.weight", + hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)), + ) + loader.port_weight( + keras_variable=image_encoder.vision_embeddings.patch_embedding.bias, + hf_weight_key="encoder.vision_tower.vision_model.embeddings.patch_embedding.bias", + ) + loader.port_weight( + keras_variable=image_encoder.vision_embeddings.position_embedding.embeddings, + hf_weight_key="encoder.vision_tower.vision_model.embeddings.position_embedding.weight", + ) + + for i in range(image_encoder.num_layers): + hf_vit = f"encoder.vision_tower.vision_model.encoder.layers.{i}" + loader.port_weight( + keras_variable=image_encoder.resblocks[i].layer_norm_1.gamma, + hf_weight_key=f"{hf_vit}.layer_norm1.weight", + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[i].layer_norm_1.beta, + hf_weight_key=f"{hf_vit}.layer_norm1.bias", + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[ + i + ].attn.query_proj.kernel, + hf_weight_key=f"{hf_vit}.self_attn.q_proj.weight", + hook_fn=transpose, + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[i].attn.query_proj.bias, + hf_weight_key=f"{hf_vit}.self_attn.q_proj.bias", + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[i].attn.key_proj.kernel, + hf_weight_key=f"{hf_vit}.self_attn.k_proj.weight", + hook_fn=transpose, + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[i].attn.key_proj.bias, + hf_weight_key=f"{hf_vit}.self_attn.k_proj.bias", + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[ + i + ].attn.value_proj.kernel, + hf_weight_key=f"{hf_vit}.self_attn.v_proj.weight", + hook_fn=transpose, + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[i].attn.value_proj.bias, + hf_weight_key=f"{hf_vit}.self_attn.v_proj.bias", + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[i].attn.out_proj.kernel, + hf_weight_key=f"{hf_vit}.self_attn.out_proj.weight", + hook_fn=transpose, + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[i].attn.out_proj.bias, + hf_weight_key=f"{hf_vit}.self_attn.out_proj.bias", + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[i].layer_norm_2.gamma, + hf_weight_key=f"{hf_vit}.layer_norm2.weight", + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[i].layer_norm_2.beta, + hf_weight_key=f"{hf_vit}.layer_norm2.bias", + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[i].mlp_dense_1.kernel, + hf_weight_key=f"{hf_vit}.mlp.fc1.weight", + hook_fn=transpose, + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[i].mlp_dense_1.bias, + hf_weight_key=f"{hf_vit}.mlp.fc1.bias", + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[i].mlp_dense_2.kernel, + hf_weight_key=f"{hf_vit}.mlp.fc2.weight", + hook_fn=transpose, + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[i].mlp_dense_2.bias, + hf_weight_key=f"{hf_vit}.mlp.fc2.bias", + ) + + loader.port_weight( + keras_variable=image_encoder.encoder_layer_norm.gamma, + hf_weight_key="encoder.vision_tower.vision_model.post_layernorm.weight", + ) + loader.port_weight( + keras_variable=image_encoder.encoder_layer_norm.beta, + hf_weight_key="encoder.vision_tower.vision_model.post_layernorm.bias", + ) + + # Multi-modal projector. + loader.port_weight( + keras_variable=vision_encoder.get_layer( + "vision_output_encoder" + ).vision_soft_embedding_norm.scale, + hf_weight_key="encoder.multi_modal_projector.mm_soft_emb_norm.weight", + ) + loader.port_weight( + keras_variable=vision_encoder.get_layer( + "vision_output_encoder" + ).vision_input_projection.kernel, + hf_weight_key="encoder.multi_modal_projector.mm_input_projection_weight", + ) + + # EOI embeddings. + loader.port_weight( + keras_variable=backbone.encoder_eoi_embedding, + hf_weight_key="encoder.text_model.embed_tokens.eoi_embedding", + ) + loader.port_weight( + keras_variable=backbone.decoder_eoi_embedding, + hf_weight_key="decoder.embed_tokens.eoi_embedding", + ) + + # === Text encoder weights === # Token embeddings. - # Encoder embeds are under encoder.text_model.embed_tokens.* loader.port_weight( keras_variable=backbone.token_embedding.embeddings, hf_weight_key="encoder.text_model.embed_tokens.weight", diff --git a/tools/checkpoint_conversion/convert_t5gemma2_checkpoints.py b/tools/checkpoint_conversion/convert_t5gemma2_checkpoints.py index 827517d0ff..537e894c83 100644 --- a/tools/checkpoint_conversion/convert_t5gemma2_checkpoints.py +++ b/tools/checkpoint_conversion/convert_t5gemma2_checkpoints.py @@ -7,11 +7,13 @@ import huggingface_hub import keras import numpy as np +import requests import torch import transformers from absl import app from absl import flags from checkpoint_conversion_utils import get_md5_checksum +from PIL import Image import keras_hub from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM @@ -28,7 +30,6 @@ torch.set_default_device(device) torch.set_default_dtype(torch.float32) -# Placeholder preset map — populate when checkpoints are released. PRESET_MAP = { "t5gemma2_270m_270m": "google/t5gemma-2-270m-270m", "t5gemma2_1b_1b": "google/t5gemma-2-1b-1b", @@ -48,8 +49,6 @@ def convert_checkpoints(hf_model): print("\n-> Convert original weights to KerasHub format.") print("\n-> Load KerasHub model.") - # T5Gemma2EncoderConfig is Gemma3Config with text params at - # encoder.text_config.*; decoder is Gemma3TextConfig (flat). encoder_config = hf_model.config.encoder enc_text_config = encoder_config.text_config decoder_config = hf_model.config.decoder @@ -58,6 +57,35 @@ def convert_checkpoints(hf_model): if enc_text_config.hidden_activation == "gelu_pytorch_tanh": enc_text_config.hidden_activation = "gelu_approximate" + # Vision encoder (optional — only present in multimodal models). + vision_encoder = None + has_vision = hasattr(encoder_config, "vision_config") and ( + encoder_config.vision_config is not None + ) + if has_vision: + from keras_hub.src.models.gemma3.gemma3_vision_encoder import ( + Gemma3VisionEncoder, + ) + + vc = encoder_config.vision_config + mm_tokens = getattr(encoder_config, "mm_tokens_per_image", 256) + pool_size = int(vc.image_size // vc.patch_size // int(mm_tokens**0.5)) + vision_encoder = Gemma3VisionEncoder( + image_size=vc.image_size, + patch_size=vc.patch_size, + num_heads=vc.num_attention_heads, + hidden_dim=vc.hidden_size, + num_layers=vc.num_hidden_layers, + intermediate_dim=vc.intermediate_size, + output_dim=enc_text_config.hidden_size, + pool_size=pool_size, + layer_norm_epsilon=getattr(vc, "layer_norm_eps", 1e-6), + ) + print( + f" Vision encoder created: {vc.image_size}x{vc.image_size}, " + f"pool_size={pool_size}, mm_tokens={mm_tokens}" + ) + keras.config.set_floatx("float32") keras_hub_model = keras_hub.models.T5Gemma2Backbone( vocabulary_size=decoder_config.vocab_size, @@ -95,7 +123,19 @@ def convert_checkpoints(hf_model): "factor", 1.0 ) ), + encoder_rope_max_wavelength=( + enc_text_config.rope_parameters.get("sliding_attention", {}).get( + "rope_theta", None + ) + ), + encoder_global_rope_scaling_factor=( + enc_text_config.rope_parameters.get("full_attention", {}).get( + "factor", None + ) + ), use_query_key_norm=True, + vision_encoder=vision_encoder, + eoi_token_index=getattr(hf_model.config, "eoi_token_index", 256000), dtype="float32", ) @@ -112,6 +152,108 @@ def convert_checkpoints(hf_model): hf_wts["decoder.embed_tokens.weight"].numpy() ) + # Vision encoder weights. + if has_vision: + ve = keras_hub_model.vision_encoder + ie = ve.get_layer("image_encoder") + ie.vision_embeddings.patch_embedding.kernel.assign( + hf_wts[ + "encoder.vision_tower.vision_model." + "embeddings.patch_embedding.weight" + ] + .permute(2, 3, 1, 0) + .numpy() + ) + ie.vision_embeddings.patch_embedding.bias.assign( + hf_wts[ + "encoder.vision_tower.vision_model." + "embeddings.patch_embedding.bias" + ].numpy() + ) + ie.vision_embeddings.position_embedding.embeddings.assign( + hf_wts[ + "encoder.vision_tower.vision_model." + "embeddings.position_embedding.weight" + ].numpy() + ) + for vi in range(ie.num_layers): + vp = f"encoder.vision_tower.vision_model.encoder.layers.{vi}" + rb = ie.resblocks[vi] + rb.layer_norm_1.gamma.assign( + hf_wts[f"{vp}.layer_norm1.weight"].numpy() + ) + rb.layer_norm_1.beta.assign( + hf_wts[f"{vp}.layer_norm1.bias"].numpy() + ) + rb.attn.query_proj.kernel.assign( + hf_wts[f"{vp}.self_attn.q_proj.weight"].T.numpy() + ) + rb.attn.query_proj.bias.assign( + hf_wts[f"{vp}.self_attn.q_proj.bias"].numpy() + ) + rb.attn.key_proj.kernel.assign( + hf_wts[f"{vp}.self_attn.k_proj.weight"].T.numpy() + ) + rb.attn.key_proj.bias.assign( + hf_wts[f"{vp}.self_attn.k_proj.bias"].numpy() + ) + rb.attn.value_proj.kernel.assign( + hf_wts[f"{vp}.self_attn.v_proj.weight"].T.numpy() + ) + rb.attn.value_proj.bias.assign( + hf_wts[f"{vp}.self_attn.v_proj.bias"].numpy() + ) + rb.attn.out_proj.kernel.assign( + hf_wts[f"{vp}.self_attn.out_proj.weight"].T.numpy() + ) + rb.attn.out_proj.bias.assign( + hf_wts[f"{vp}.self_attn.out_proj.bias"].numpy() + ) + rb.layer_norm_2.gamma.assign( + hf_wts[f"{vp}.layer_norm2.weight"].numpy() + ) + rb.layer_norm_2.beta.assign( + hf_wts[f"{vp}.layer_norm2.bias"].numpy() + ) + rb.mlp_dense_1.kernel.assign( + hf_wts[f"{vp}.mlp.fc1.weight"].T.numpy() + ) + rb.mlp_dense_1.bias.assign(hf_wts[f"{vp}.mlp.fc1.bias"].numpy()) + rb.mlp_dense_2.kernel.assign( + hf_wts[f"{vp}.mlp.fc2.weight"].T.numpy() + ) + rb.mlp_dense_2.bias.assign(hf_wts[f"{vp}.mlp.fc2.bias"].numpy()) + ie.encoder_layer_norm.gamma.assign( + hf_wts[ + "encoder.vision_tower.vision_model.post_layernorm.weight" + ].numpy() + ) + ie.encoder_layer_norm.beta.assign( + hf_wts[ + "encoder.vision_tower.vision_model.post_layernorm.bias" + ].numpy() + ) + # Multi-modal projector. + vo = ve.get_layer("vision_output_encoder") + vo.vision_soft_embedding_norm.scale.assign( + hf_wts[ + "encoder.multi_modal_projector.mm_soft_emb_norm.weight" + ].numpy() + ) + vo.vision_input_projection.kernel.assign( + hf_wts[ + "encoder.multi_modal_projector.mm_input_projection_weight" + ].numpy() + ) + # EOI embeddings. + keras_hub_model.encoder_eoi_embedding.assign( + hf_wts["encoder.text_model.embed_tokens.eoi_embedding"].numpy() + ) + keras_hub_model.decoder_eoi_embedding.assign( + hf_wts["decoder.embed_tokens.eoi_embedding"].numpy() + ) + print(" Vision encoder weights loaded.") + # Encoder (weights under encoder.text_model.*). enc_hdim = keras_hub_model.encoder_hidden_dim enc_heads = keras_hub_model.encoder_num_attention_heads @@ -285,16 +427,18 @@ def extract_vocab(hf_model_dir): return keras_hub_tokenizer -def check_output( +def check_text_output( keras_hub_tokenizer, keras_hub_model, hf_tokenizer, hf_model, ): """Check outputs of KerasHub and HuggingFace models match.""" - # Parameter count check. - # Note: HF model includes vision encoder (SigLIP) params - # that our text-only KerasHub backbone doesn't have. + # Note: KerasHub counts encoder + decoder embeddings as separate + # weight matrices. HF shares a single nn.Embedding across + # encoder/decoder/lm_head, so counts it once. + print("\n--- Model Verification starts ---") + print("\n") print("\n-> Verify parameter counts.") keras_hub_params = keras_hub_model.count_params() hf_params = hf_model.num_parameters() @@ -303,14 +447,15 @@ def check_output( if keras_hub_params == hf_params: print("-> Parameter counts match!") else: - diff = hf_params - keras_hub_params + diff = keras_hub_params - hf_params print( f"-> Parameter count difference: {diff:,} " - f"(expected — HF includes vision encoder)" + f"(expected — KerasHub has separate encoder/decoder " + f"embeddings; HF shares a single nn.Embedding)" ) # Output comparison. - print("\n-> Check the outputs.") + print("\n-> ---- Text-only verification. ----\n") enc_sample_text = [ "cricket is awesome, easily the best sport in the world!" ] @@ -339,7 +484,23 @@ def check_output( "decoder_padding_mask": np.ones_like(keras_hub_dec_token_ids), } - print("\n--- Model Verification ---") + # If multimodal backbone, add dummy image/vision_indices inputs. + # Conv2D can't process batch=0, so pass 1 dummy (all-zeros) image. + # InterleaveEmbeddings restores the 0th index after scatter, so + # the dummy image embedding is effectively a no-op. + if keras_hub_model.vision_encoder is not None: + image_size = keras_hub_model.vision_encoder.image_size + num_vision_tokens = ( + keras_hub_model.vision_encoder.num_vision_tokens_per_image + ) + keras_hub_inputs["images"] = np.zeros( + (1, 1, image_size, image_size, 3), dtype="float32" + ) + # All indices point to 0; InterleaveEmbeddings restores + # the original embedding at position 0 after scattering. + keras_hub_inputs["vision_indices"] = np.zeros( + (1, num_vision_tokens), dtype="int32" + ) keras_hub_output = keras_hub_model.predict(keras_hub_inputs) # HF. @@ -401,6 +562,150 @@ def check_output( print("\n") +def check_multimodal_output( + keras_hub_model, + hf_model, + hf_model_dir, + hf_tokenizer, +): + """Check multimodal (text+image) outputs match between KerasHub and HF.""" + if keras_hub_model.vision_encoder is None: + print("\n-> Skipping multimodal check (text-only model).") + return + + print("\n-> ---- Multimodal (text+image) verification. ----\n") + + # Download a test image. + image_url = ( + "https://huggingface.co/datasets/huggingface/" + "documentation-images/resolve/main/bee.jpg" + ) + print(f" Downloading test image: {image_url}") + image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB") + + # HF side: use AutoProcessor for proper multimodal preprocessing. + hf_processor = transformers.AutoProcessor.from_pretrained(hf_model_dir) + enc_prompt = " Describe this image" + dec_prompt = "This image shows" + + # HF encoder inputs (with image). + hf_enc_inputs = hf_processor( + text=enc_prompt, images=image, return_tensors="pt" + ) + # HF decoder inputs. + hf_dec_inputs = hf_tokenizer(dec_prompt, return_tensors="pt") + hf_decoder_input_ids = torch.cat( + [ + torch.tensor([[hf_tokenizer.bos_token_id]]), + hf_dec_inputs["input_ids"], + ], + dim=-1, + ) + hf_decoder_attention_mask = torch.cat( + [ + torch.ones(1, 1, dtype=torch.long), + hf_dec_inputs["attention_mask"], + ], + dim=-1, + ) + + with torch.no_grad(): + hf_output = hf_model( + input_ids=hf_enc_inputs["input_ids"], + attention_mask=hf_enc_inputs["attention_mask"], + pixel_values=hf_enc_inputs["pixel_values"], + decoder_input_ids=hf_decoder_input_ids, + decoder_attention_mask=hf_decoder_attention_mask, + ) + + # Build KerasHub inputs from HF token_ids (same tokenizer). + keras_enc_token_ids = hf_enc_inputs["input_ids"].numpy() + keras_enc_padding_mask = hf_enc_inputs["attention_mask"].numpy() + keras_dec_token_ids = hf_decoder_input_ids.numpy() + keras_dec_padding_mask = hf_decoder_attention_mask.numpy() + + # Transpose HF pixel_values (B,C,H,W) to KerasHub (B,1,H,W,C). + pixel_values = hf_enc_inputs["pixel_values"].numpy() + if pixel_values.ndim == 5: + pixel_values = np.transpose(pixel_values, (0, 1, 3, 4, 2)) + elif pixel_values.ndim == 4: + pixel_values = np.transpose(pixel_values, (0, 2, 3, 1)) + pixel_values = np.expand_dims(pixel_values, axis=1) + + # Find positions of image placeholder tokens for vision_indices. + image_token_id = hf_processor.tokenizer.convert_tokens_to_ids( + "" + ) + num_vision_tokens = ( + keras_hub_model.vision_encoder.num_vision_tokens_per_image + ) + # Find indices of image placeholder tokens. + token_ids_flat = keras_enc_token_ids[0] + vision_idx_list = np.where(token_ids_flat == image_token_id)[0].tolist() + + # Pad or truncate to num_vision_tokens. + if len(vision_idx_list) < num_vision_tokens: + vision_idx_list = vision_idx_list + [0] * ( + num_vision_tokens - len(vision_idx_list) + ) + vision_indices = np.array( + [vision_idx_list[:num_vision_tokens]], dtype="int32" + ) + + keras_hub_inputs = { + "encoder_token_ids": keras_enc_token_ids, + "encoder_padding_mask": keras_enc_padding_mask, + "decoder_token_ids": keras_dec_token_ids, + "decoder_padding_mask": keras_dec_padding_mask, + "images": pixel_values.astype("float32"), + "vision_indices": vision_indices, + } + + print("\n--- Multimodal Verification ---") + keras_hub_output = keras_hub_model.predict(keras_hub_inputs) + + # Relaxed tolerances: vision encoder compounds f32 drift across layers. + + # Encoder output comparison. + keras_enc_out = keras_hub_output["encoder_sequence_output"] + hf_enc_out = hf_output.encoder_last_hidden_state.detach().float().numpy() + enc_abs_diff = np.abs(keras_enc_out - hf_enc_out) + print("Encoder Outputs (multimodal):") + print("KerasHub output:", keras_enc_out[0, 0, :10]) + print("HF output:", hf_enc_out[0, 0, :10]) + print(f"Mean absolute diff: {enc_abs_diff.mean():.6f}") + + try: + np.testing.assert_allclose( + keras_enc_out, hf_enc_out, rtol=1e-4, atol=1e-4 + ) + print("-> Encoder outputs match! (rtol=1e-4, atol=1e-4)") + except AssertionError as err: + print("\n") + print(traceback.format_exc()) + print(err.args[0]) + print("\n") + + # Decoder output comparison. + keras_dec_out = keras_hub_output["decoder_sequence_output"] + hf_dec_out = hf_output.last_hidden_state.detach().float().numpy() + dec_abs_diff = np.abs(keras_dec_out - hf_dec_out) + print("Decoder Outputs (multimodal):") + print("KerasHub output:", keras_dec_out[0, 0, :10]) + print("HF output:", hf_dec_out[0, 0, :10]) + print(f"Mean absolute diff: {dec_abs_diff.mean():.6f}") + try: + np.testing.assert_allclose( + keras_dec_out, hf_dec_out, rtol=1e-4, atol=1e-4 + ) + print("-> Decoder outputs match! (rtol=1e-4, atol=1e-4)") + except AssertionError as err: + print("\n") + print(traceback.format_exc()) + print(err.args[0]) + print("\n") + + def main(_): os.makedirs(FLAGS.preset, exist_ok=True) @@ -419,9 +724,7 @@ def main(_): print("\n-> Load HF model and HF tokenizer.") hf_model = transformers.AutoModel.from_pretrained(hf_model_dir) hf_model.float() # Convert all params/buffers to float32. - # Fix embed_scale buffers: they were created in bfloat16 - # during init (non-persistent), so .float() preserves - # the bf16-rounded value. Re-create with true f32 precision. + # Re-create embed_scale with true f32 precision (bf16-init artifact). enc_hdim = hf_model.config.encoder.text_config.hidden_size dec_hdim = hf_model.config.decoder.hidden_size hf_model.encoder.text_model.embed_tokens.embed_scale = torch.tensor( @@ -437,12 +740,20 @@ def main(_): print("\n-> Load KerasHub tokenizer.") keras_hub_tokenizer = extract_vocab(hf_model_dir) - check_output( + check_text_output( keras_hub_tokenizer, keras_hub_model, hf_tokenizer, hf_model, ) + + check_multimodal_output( + keras_hub_model, + hf_model, + hf_model_dir, + hf_tokenizer, + ) + print("\n-> Releasing HF backbone from memory.") del hf_model gc.collect() @@ -457,7 +768,6 @@ def main(_): preprocessor=preprocessor, dtype=keras_hub_model.dtype, ) - keras_lm.compile(sampler="greedy") print(f"\n-> Saving T5Gemma2Seq2SeqLM preset to `{FLAGS.preset}`.") keras_lm.save_to_preset(FLAGS.preset) From febfc0c5d604850bafa02b17cc79b04453107656 Mon Sep 17 00:00:00 2001 From: Laxma Reddy Patlolla Date: Thu, 5 Mar 2026 14:49:40 -0800 Subject: [PATCH 6/9] Implement passing dummy image inputs when text-only inference for multimodel variants --- .../t5gemma2_seq_2_seq_lm_preprocessor.py | 41 ++++- .../utils/transformers/convert_t5gemma2.py | 43 +++--- .../convert_t5gemma2_checkpoints.py | 144 ++++++++++-------- 3 files changed, 144 insertions(+), 84 deletions(-) diff --git a/keras_hub/src/models/t5gemma2/t5gemma2_seq_2_seq_lm_preprocessor.py b/keras_hub/src/models/t5gemma2/t5gemma2_seq_2_seq_lm_preprocessor.py index ce25d3e207..07757eff85 100644 --- a/keras_hub/src/models/t5gemma2/t5gemma2_seq_2_seq_lm_preprocessor.py +++ b/keras_hub/src/models/t5gemma2/t5gemma2_seq_2_seq_lm_preprocessor.py @@ -1,4 +1,5 @@ import keras +import numpy as np from keras_hub.src.api_export import keras_hub_export from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor @@ -53,6 +54,8 @@ def __init__( encoder_sequence_length=512, decoder_sequence_length=512, image_converter=None, + image_size=None, + num_vision_tokens_per_image=None, add_start_token=False, add_end_token=True, **kwargs, @@ -66,6 +69,34 @@ def __init__( self.add_start_token = add_start_token self.add_end_token = add_end_token self.image_converter = image_converter + self._vision_image_size = image_size + self.num_vision_tokens_per_image = num_vision_tokens_per_image + + def _add_vision_inputs(self, x, batch_size): + """Add dummy image/vision_indices for multimodal text-only input. + + When a multimodal backbone (with vision encoder) is used for + text-only inference, the functional model still requires + `images` and `vision_indices` inputs. This method provides + dummy values that act as a no-op: InterleaveEmbeddings + restores position 0 after scattering zero-indexed updates. + """ + if self._vision_image_size is not None and "images" not in x: + x["images"] = np.zeros( + ( + batch_size, + 1, + self._vision_image_size, + self._vision_image_size, + 3, + ), + dtype="float32", + ) + x["vision_indices"] = np.zeros( + (batch_size, self.num_vision_tokens_per_image), + dtype="int32", + ) + return x @preprocessing_function def call( @@ -98,12 +129,14 @@ def call( add_start_value=True, add_end_value=self.add_end_token, ) + batch_size = tf.shape(encoder_token_ids)[0] x = { "encoder_token_ids": encoder_token_ids, "encoder_padding_mask": encoder_padding_mask, "decoder_token_ids": decoder_token_ids[..., :-1], "decoder_padding_mask": decoder_padding_mask[..., :-1], } + x = self._add_vision_inputs(x, batch_size) y = decoder_token_ids[..., 1:] sample_weight = decoder_padding_mask[..., 1:] return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) @@ -149,12 +182,14 @@ def generate_preprocess( add_end_value=False, ) - return { + batch_size = tf.shape(encoder_token_ids)[0] + out = { "encoder_token_ids": encoder_token_ids, "encoder_padding_mask": encoder_padding_mask, "decoder_token_ids": decoder_token_ids, "decoder_padding_mask": decoder_padding_mask, } + return self._add_vision_inputs(out, batch_size) def get_config(self): config = super().get_config() @@ -162,6 +197,10 @@ def get_config(self): { "add_start_token": self.add_start_token, "add_end_token": self.add_end_token, + "image_size": self._vision_image_size, + "num_vision_tokens_per_image": ( + self.num_vision_tokens_per_image + ), } ) return config diff --git a/keras_hub/src/utils/transformers/convert_t5gemma2.py b/keras_hub/src/utils/transformers/convert_t5gemma2.py index 685d953cff..626340566f 100644 --- a/keras_hub/src/utils/transformers/convert_t5gemma2.py +++ b/keras_hub/src/utils/transformers/convert_t5gemma2.py @@ -18,7 +18,7 @@ def load_image_converter_config(preset, transformers_config): rescale_factor = preprocessor_config["rescale_factor"] offset = [(-m / s) for m, s in zip(mean, std)] scale = [(s * rescale_factor) for s in std] - image_size = encoder_config["vision_config"].get("image_size", 896) + image_size = encoder_config["vision_config"]["image_size"] return { "image_size": (image_size, image_size), "scale": scale, @@ -34,9 +34,7 @@ def convert_backbone_config(transformers_config): enc_text = encoder_config["text_config"] decoder_config = transformers_config["decoder"] - hidden_activation = decoder_config.get( - "hidden_activation", "gelu_pytorch_tanh" - ) + hidden_activation = decoder_config["hidden_activation"] if hidden_activation == "gelu_pytorch_tanh": hidden_activation = "gelu_approximate" @@ -59,9 +57,9 @@ def convert_backbone_config(transformers_config): pool_size=int( vision_config["image_size"] // vision_config["patch_size"] - // int(encoder_config.get("mm_tokens_per_image", 256) ** 0.5) + // int(encoder_config["mm_tokens_per_image"] ** 0.5) ), - layer_norm_epsilon=vision_config.get("layer_norm_eps", 1e-6), + layer_norm_epsilon=vision_config["layer_norm_eps"], ) backbone_config = { @@ -83,9 +81,7 @@ def convert_backbone_config(transformers_config): "dropout_rate": decoder_config["dropout_rate"], "rms_norm_eps": decoder_config["rms_norm_eps"], "query_pre_attn_scalar": decoder_config["query_pre_attn_scalar"], - "tie_word_embeddings": transformers_config.get( - "tie_word_embeddings", True - ), + "tie_word_embeddings": transformers_config["tie_word_embeddings"], "attention_bias": decoder_config["attention_bias"], "hidden_activation": hidden_activation, "initializer_range": decoder_config["initializer_range"], @@ -94,21 +90,22 @@ def convert_backbone_config(transformers_config): "cross_attention_hidden_size": enc_text["hidden_size"], "attn_logit_softcapping": decoder_config["attn_logit_softcapping"], "final_logit_softcapping": decoder_config["final_logit_softcapping"], - "rope_max_wavelength": decoder_config.get("rope_theta", 10000.0), - "global_rope_scaling_factor": decoder_config.get("rope_parameters", {}) - .get("full_attention", {}) - .get("factor", 1.0), - "encoder_rope_max_wavelength": enc_text.get("rope_parameters", {}) - .get("sliding_attention", {}) - .get("rope_theta", None), - "encoder_global_rope_scaling_factor": enc_text.get( - "rope_parameters", {} - ) - .get("full_attention", {}) - .get("factor", None), - "use_query_key_norm": True, + "rope_max_wavelength": ( + decoder_config["rope_parameters"]["sliding_attention"]["rope_theta"] + ), + "global_rope_scaling_factor": ( + decoder_config["rope_parameters"]["full_attention"]["factor"] + ), + "encoder_rope_max_wavelength": ( + enc_text["rope_parameters"]["sliding_attention"]["rope_theta"] + ), + "encoder_global_rope_scaling_factor": ( + enc_text["rope_parameters"]["full_attention"]["factor"] + ), + # use_qk_norm may not be in config JSON; default True. + "use_query_key_norm": enc_text.get("use_qk_norm", True), "vision_encoder": vision_encoder, - "eoi_token_index": transformers_config.get("eoi_token_index", 256000), + "eoi_token_index": transformers_config["eoi_token_index"], } return backbone_config diff --git a/tools/checkpoint_conversion/convert_t5gemma2_checkpoints.py b/tools/checkpoint_conversion/convert_t5gemma2_checkpoints.py index 537e894c83..771968c07a 100644 --- a/tools/checkpoint_conversion/convert_t5gemma2_checkpoints.py +++ b/tools/checkpoint_conversion/convert_t5gemma2_checkpoints.py @@ -2,7 +2,6 @@ import os import random import shutil -import traceback import huggingface_hub import keras @@ -117,25 +116,23 @@ def convert_checkpoints(hf_model): cross_attention_hidden_size=enc_text_config.hidden_size, attn_logit_softcapping=(decoder_config.attn_logit_softcapping), final_logit_softcapping=(decoder_config.final_logit_softcapping), - rope_max_wavelength=getattr(decoder_config, "rope_theta", 10000.0), + rope_max_wavelength=( + decoder_config.rope_parameters["sliding_attention"]["rope_theta"] + ), global_rope_scaling_factor=( - decoder_config.rope_parameters.get("full_attention", {}).get( - "factor", 1.0 - ) + decoder_config.rope_parameters["full_attention"]["factor"] ), encoder_rope_max_wavelength=( - enc_text_config.rope_parameters.get("sliding_attention", {}).get( - "rope_theta", None - ) + enc_text_config.rope_parameters["sliding_attention"]["rope_theta"] ), encoder_global_rope_scaling_factor=( - enc_text_config.rope_parameters.get("full_attention", {}).get( - "factor", None - ) + enc_text_config.rope_parameters["full_attention"]["factor"] + ), + use_query_key_norm=any( + "q_norm" in k for k in hf_model.state_dict().keys() ), - use_query_key_norm=True, vision_encoder=vision_encoder, - eoi_token_index=getattr(hf_model.config, "eoi_token_index", 256000), + eoi_token_index=hf_model.config.eoi_token_index, dtype="float32", ) @@ -432,6 +429,7 @@ def check_text_output( keras_hub_model, hf_tokenizer, hf_model, + preprocessor, ): """Check outputs of KerasHub and HuggingFace models match.""" # Note: KerasHub counts encoder + decoder embeddings as separate @@ -463,7 +461,7 @@ def check_text_output( "football is good too, but nowhere near as good as cricket." ] - # KerasHub. + # KerasHub — build unpadded inputs (match HF's natural lengths). keras_hub_enc_token_ids = hf_tokenizer( enc_sample_text, return_tensors="np" )["input_ids"] @@ -483,24 +481,11 @@ def check_text_output( "decoder_token_ids": keras_hub_dec_token_ids, "decoder_padding_mask": np.ones_like(keras_hub_dec_token_ids), } - - # If multimodal backbone, add dummy image/vision_indices inputs. - # Conv2D can't process batch=0, so pass 1 dummy (all-zeros) image. - # InterleaveEmbeddings restores the 0th index after scatter, so - # the dummy image embedding is effectively a no-op. - if keras_hub_model.vision_encoder is not None: - image_size = keras_hub_model.vision_encoder.image_size - num_vision_tokens = ( - keras_hub_model.vision_encoder.num_vision_tokens_per_image - ) - keras_hub_inputs["images"] = np.zeros( - (1, 1, image_size, image_size, 3), dtype="float32" - ) - # All indices point to 0; InterleaveEmbeddings restores - # the original embedding at position 0 after scattering. - keras_hub_inputs["vision_indices"] = np.zeros( - (1, num_vision_tokens), dtype="int32" - ) + # For multimodal backbones, use preprocessor to add dummy + # image/vision_indices (the single source of truth). + keras_hub_inputs = preprocessor._add_vision_inputs( + keras_hub_inputs, batch_size=1 + ) keras_hub_output = keras_hub_model.predict(keras_hub_inputs) # HF. @@ -530,36 +515,52 @@ def check_text_output( # Encoder output comparison. keras_enc_out = keras_hub_output["encoder_sequence_output"] hf_enc_out = hf_output.encoder_last_hidden_state.detach().float().numpy() + enc_abs_diff = np.abs(keras_enc_out - hf_enc_out) + print() print("Encoder Outputs:") print("KerasHub output:", keras_enc_out[0, 0, :10]) print("HF output:", hf_enc_out[0, 0, :10]) + print(f"Mean absolute diff: {enc_abs_diff.mean():.6f}") try: np.testing.assert_allclose( keras_enc_out, hf_enc_out, rtol=1e-4, atol=1e-4 ) print("-> Encoder outputs match! (rtol=1e-4, atol=1e-4)") - except AssertionError as err: - print("\n") - print(traceback.format_exc()) - print(err.args[0]) - print("\n") + except AssertionError: + mismatch = np.sum( + ~np.isclose(keras_enc_out, hf_enc_out, rtol=1e-4, atol=1e-4) + ) + total = keras_enc_out.size + print( + f"-> Encoder outputs differ slightly beyond rtol=1e-4 " + f"(mismatched: {mismatch}/{total}, " + f"{mismatch / total * 100:.2f}%)" + ) # Decoder output comparison. keras_dec_out = keras_hub_output["decoder_sequence_output"] hf_dec_out = hf_output.last_hidden_state.detach().float().numpy() + dec_abs_diff = np.abs(keras_dec_out - hf_dec_out) + print() print("Decoder Outputs:") print("KerasHub output:", keras_dec_out[0, 0, :10]) print("HF output:", hf_dec_out[0, 0, :10]) + print(f"Mean absolute diff: {dec_abs_diff.mean():.6f}") try: np.testing.assert_allclose( keras_dec_out, hf_dec_out, rtol=1e-4, atol=1e-4 ) print("-> Decoder outputs match! (rtol=1e-4, atol=1e-4)") - except AssertionError as err: - print("\n") - print(traceback.format_exc()) - print(err.args[0]) - print("\n") + except AssertionError: + mismatch = np.sum( + ~np.isclose(keras_dec_out, hf_dec_out, rtol=1e-4, atol=1e-4) + ) + total = keras_dec_out.size + print( + f"-> Decoder outputs differ slightly beyond rtol=1e-4 " + f"(mismatched: {mismatch}/{total}, " + f"{mismatch / total * 100:.2f}%)" + ) def check_multimodal_output( @@ -664,32 +665,36 @@ def check_multimodal_output( print("\n--- Multimodal Verification ---") keras_hub_output = keras_hub_model.predict(keras_hub_inputs) - # Relaxed tolerances: vision encoder compounds f32 drift across layers. - # Encoder output comparison. keras_enc_out = keras_hub_output["encoder_sequence_output"] hf_enc_out = hf_output.encoder_last_hidden_state.detach().float().numpy() enc_abs_diff = np.abs(keras_enc_out - hf_enc_out) + print() print("Encoder Outputs (multimodal):") print("KerasHub output:", keras_enc_out[0, 0, :10]) print("HF output:", hf_enc_out[0, 0, :10]) print(f"Mean absolute diff: {enc_abs_diff.mean():.6f}") - try: np.testing.assert_allclose( keras_enc_out, hf_enc_out, rtol=1e-4, atol=1e-4 ) print("-> Encoder outputs match! (rtol=1e-4, atol=1e-4)") - except AssertionError as err: - print("\n") - print(traceback.format_exc()) - print(err.args[0]) - print("\n") + except AssertionError: + mismatch = np.sum( + ~np.isclose(keras_enc_out, hf_enc_out, rtol=1e-4, atol=1e-4) + ) + total = keras_enc_out.size + print( + f"-> Encoder outputs differ slightly beyond rtol=1e-4 " + f"(mismatched: {mismatch}/{total}, " + f"{mismatch / total * 100:.2f}%)" + ) # Decoder output comparison. keras_dec_out = keras_hub_output["decoder_sequence_output"] hf_dec_out = hf_output.last_hidden_state.detach().float().numpy() dec_abs_diff = np.abs(keras_dec_out - hf_dec_out) + print() print("Decoder Outputs (multimodal):") print("KerasHub output:", keras_dec_out[0, 0, :10]) print("HF output:", hf_dec_out[0, 0, :10]) @@ -699,11 +704,16 @@ def check_multimodal_output( keras_dec_out, hf_dec_out, rtol=1e-4, atol=1e-4 ) print("-> Decoder outputs match! (rtol=1e-4, atol=1e-4)") - except AssertionError as err: - print("\n") - print(traceback.format_exc()) - print(err.args[0]) - print("\n") + except AssertionError: + mismatch = np.sum( + ~np.isclose(keras_dec_out, hf_dec_out, rtol=1e-4, atol=1e-4) + ) + total = keras_dec_out.size + print( + f"-> Decoder outputs differ slightly beyond rtol=1e-4 " + f"(mismatched: {mismatch}/{total}, " + f"{mismatch / total * 100:.2f}%)" + ) def main(_): @@ -740,11 +750,30 @@ def main(_): print("\n-> Load KerasHub tokenizer.") keras_hub_tokenizer = extract_vocab(hf_model_dir) + # Create preprocessor to check_text_output can use it. + preprocessor_kwargs = {} + if keras_hub_model.vision_encoder is not None: + preprocessor_kwargs.update( + { + "image_size": keras_hub_model.vision_encoder.image_size, + "num_vision_tokens_per_image": ( + keras_hub_model.vision_encoder.num_vision_tokens_per_image + ), + } + ) + preprocessor = T5Gemma2Seq2SeqLMPreprocessor( + tokenizer=keras_hub_tokenizer, + encoder_sequence_length=512, + decoder_sequence_length=512, + **preprocessor_kwargs, + ) + check_text_output( keras_hub_tokenizer, keras_hub_model, hf_tokenizer, hf_model, + preprocessor, ) check_multimodal_output( @@ -758,11 +787,6 @@ def main(_): del hf_model gc.collect() - preprocessor = T5Gemma2Seq2SeqLMPreprocessor( - tokenizer=keras_hub_tokenizer, - encoder_sequence_length=512, - decoder_sequence_length=512, - ) keras_lm = T5Gemma2Seq2SeqLM( backbone=keras_hub_model, preprocessor=preprocessor, From e4474541959c60864fab58e8bad6a364fc4cd1e6 Mon Sep 17 00:00:00 2001 From: Laxma Reddy Patlolla Date: Thu, 5 Mar 2026 15:25:18 -0800 Subject: [PATCH 7/9] Add preprocessor test file --- keras_hub/api/layers/__init__.py | 3 + ...t5gemma2_seq_2_seq_lm_preprocessor_test.py | 134 ++++++++++++++++++ 2 files changed, 137 insertions(+) create mode 100644 keras_hub/src/models/t5gemma2/t5gemma2_seq_2_seq_lm_preprocessor_test.py diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index 31d645784c..1ec7a7d009 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -169,6 +169,9 @@ from keras_hub.src.models.siglip.siglip_image_converter import ( SigLIPImageConverter as SigLIPImageConverter, ) +from keras_hub.src.models.t5gemma2.t5gemma2_image_converter import ( + T5Gemma2ImageConverter as T5Gemma2ImageConverter, +) from keras_hub.src.models.vgg.vgg_image_converter import ( VGGImageConverter as VGGImageConverter, ) diff --git a/keras_hub/src/models/t5gemma2/t5gemma2_seq_2_seq_lm_preprocessor_test.py b/keras_hub/src/models/t5gemma2/t5gemma2_seq_2_seq_lm_preprocessor_test.py new file mode 100644 index 0000000000..878dba9ff8 --- /dev/null +++ b/keras_hub/src/models/t5gemma2/t5gemma2_seq_2_seq_lm_preprocessor_test.py @@ -0,0 +1,134 @@ +import os + +import numpy as np +import pytest + +from keras_hub.src.models.t5gemma2.t5gemma2_seq_2_seq_lm_preprocessor import ( + T5Gemma2Seq2SeqLMPreprocessor, +) +from keras_hub.src.models.t5gemma2.t5gemma2_tokenizer import T5Gemma2Tokenizer +from keras_hub.src.tests.test_case import TestCase + + +class T5Gemma2Seq2SeqLMPreprocessorTest(TestCase): + def setUp(self): + self.tokenizer = T5Gemma2Tokenizer( + proto=os.path.join( + self.get_test_data_dir(), "gemma3_test_vocab.spm" + ) + ) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "encoder_sequence_length": 8, + "decoder_sequence_length": 8, + } + self.input_data = ( + { + "encoder_text": ["the quick brown fox"], + "decoder_text": ["the earth is round"], + }, + ) + + def test_preprocessor_basics(self): + preprocessor = T5Gemma2Seq2SeqLMPreprocessor(**self.init_kwargs) + output = preprocessor(*self.input_data) + x, y, sample_weight = output + + # Verify output keys. + self.assertIn("encoder_token_ids", x) + self.assertIn("encoder_padding_mask", x) + self.assertIn("decoder_token_ids", x) + self.assertIn("decoder_padding_mask", x) + + # Verify shapes. + self.assertEqual(x["encoder_token_ids"].shape[-1], 8) + self.assertEqual(x["decoder_token_ids"].shape[-1], 8) + + def test_generate_preprocess(self): + preprocessor = T5Gemma2Seq2SeqLMPreprocessor(**self.init_kwargs) + input_data = { + "encoder_text": ["the quick brown fox"], + "decoder_text": ["the earth is round"], + } + output = preprocessor.generate_preprocess(input_data) + self.assertIn("encoder_token_ids", output) + self.assertIn("encoder_padding_mask", output) + self.assertIn("decoder_token_ids", output) + self.assertIn("decoder_padding_mask", output) + + def test_generate_postprocess(self): + preprocessor = T5Gemma2Seq2SeqLMPreprocessor(**self.init_kwargs) + input_data = { + "decoder_token_ids": [2, 9, 14, 10, 1], + "decoder_padding_mask": [1, 1, 1, 1, 1], + } + output = preprocessor.generate_postprocess(input_data) + self.assertIsInstance(output, str) + + def test_add_vision_inputs_multimodal(self): + """Multimodal preprocessor should add dummy vision inputs + when text-only is used for inference.""" + preprocessor = T5Gemma2Seq2SeqLMPreprocessor( + **self.init_kwargs, + image_size=64, + num_vision_tokens_per_image=16, + ) + x = { + "encoder_token_ids": np.ones((2, 8), dtype="int32"), + "encoder_padding_mask": np.ones((2, 8), dtype="int32"), + "decoder_token_ids": np.ones((2, 8), dtype="int32"), + "decoder_padding_mask": np.ones((2, 8), dtype="int32"), + } + result = preprocessor._add_vision_inputs(x, batch_size=2) + + # Should add dummy images and vision indices. + self.assertIn("images", result) + self.assertIn("vision_indices", result) + self.assertEqual(result["images"].shape, (2, 1, 64, 64, 3)) + self.assertEqual(result["vision_indices"].shape, (2, 16)) + # Dummy values should be all zeros. + np.testing.assert_array_equal( + result["images"], np.zeros_like(result["images"]) + ) + np.testing.assert_array_equal( + result["vision_indices"], + np.zeros_like(result["vision_indices"]), + ) + + def test_add_vision_inputs_skips_when_images_present(self): + """Should not overwrite existing images.""" + preprocessor = T5Gemma2Seq2SeqLMPreprocessor( + **self.init_kwargs, + image_size=64, + num_vision_tokens_per_image=16, + ) + existing_images = np.ones((1, 1, 64, 64, 3), dtype="float32") + x = { + "encoder_token_ids": np.ones((1, 8), dtype="int32"), + "images": existing_images, + } + result = preprocessor._add_vision_inputs(x, batch_size=1) + np.testing.assert_array_equal(result["images"], existing_images) + + def test_serialization(self): + preprocessor = T5Gemma2Seq2SeqLMPreprocessor( + **self.init_kwargs, + image_size=128, + num_vision_tokens_per_image=32, + ) + config = preprocessor.get_config() + self.assertEqual(config["image_size"], 128) + self.assertEqual(config["num_vision_tokens_per_image"], 32) + self.assertEqual(config["add_start_token"], False) + self.assertEqual(config["add_end_token"], True) + self.assertEqual(config["encoder_sequence_length"], 8) + self.assertEqual(config["decoder_sequence_length"], 8) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in T5Gemma2Seq2SeqLMPreprocessor.presets: + self.run_preset_test( + cls=T5Gemma2Seq2SeqLMPreprocessor, + preset=preset, + input_data=self.input_data, + ) From 30993c63e21fb42efcd49d7a268cdb73d6cdb4f8 Mon Sep 17 00:00:00 2001 From: Laxma Reddy Patlolla Date: Thu, 5 Mar 2026 21:50:39 -0800 Subject: [PATCH 8/9] Fix backbone tests --- .../src/models/t5gemma2/t5gemma2_backbone_test.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/keras_hub/src/models/t5gemma2/t5gemma2_backbone_test.py b/keras_hub/src/models/t5gemma2/t5gemma2_backbone_test.py index f95d3498a4..29e94a1324 100644 --- a/keras_hub/src/models/t5gemma2/t5gemma2_backbone_test.py +++ b/keras_hub/src/models/t5gemma2/t5gemma2_backbone_test.py @@ -64,11 +64,11 @@ def test_backbone_basics(self): def test_asymmetrical_backbone(self): asym_kwargs = { "vocabulary_size": 100, - "encoder_hidden_dim": 48, + "encoder_hidden_dim": 32, "encoder_intermediate_dim": 96, "encoder_num_layers": 3, - "encoder_num_attention_heads": 6, - "encoder_num_key_value_heads": 3, + "encoder_num_attention_heads": 4, + "encoder_num_key_value_heads": 2, "encoder_head_dim": 8, "encoder_layer_types": ["full_attention"] * 3, "decoder_hidden_dim": 32, @@ -85,7 +85,7 @@ def test_asymmetrical_backbone(self): "dropout_rate": 0.1, "rms_norm_eps": 1e-6, "tie_word_embeddings": True, - "cross_attention_hidden_size": 48, + "cross_attention_hidden_size": 32, "use_query_key_norm": True, } self.run_backbone_test( @@ -93,7 +93,7 @@ def test_asymmetrical_backbone(self): init_kwargs=asym_kwargs, input_data=self.input_data, expected_output_shape={ - "encoder_sequence_output": (2, 16, 48), + "encoder_sequence_output": (2, 16, 32), "decoder_sequence_output": (2, 16, 32), }, ) From 342bf0b7c207b0d7a1dfb10f77224ed6f07a3f1a Mon Sep 17 00:00:00 2001 From: Laxma Reddy Patlolla Date: Thu, 5 Mar 2026 22:39:05 -0800 Subject: [PATCH 9/9] Fix preprocessor test --- .../models/t5gemma2/t5gemma2_seq_2_seq_lm_preprocessor_test.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/keras_hub/src/models/t5gemma2/t5gemma2_seq_2_seq_lm_preprocessor_test.py b/keras_hub/src/models/t5gemma2/t5gemma2_seq_2_seq_lm_preprocessor_test.py index 878dba9ff8..1a64d7ef47 100644 --- a/keras_hub/src/models/t5gemma2/t5gemma2_seq_2_seq_lm_preprocessor_test.py +++ b/keras_hub/src/models/t5gemma2/t5gemma2_seq_2_seq_lm_preprocessor_test.py @@ -121,8 +121,6 @@ def test_serialization(self): self.assertEqual(config["num_vision_tokens_per_image"], 32) self.assertEqual(config["add_start_token"], False) self.assertEqual(config["add_end_token"], True) - self.assertEqual(config["encoder_sequence_length"], 8) - self.assertEqual(config["decoder_sequence_length"], 8) @pytest.mark.extra_large def test_all_presets(self):