diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index 31d645784c..1ec7a7d009 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -169,6 +169,9 @@ from keras_hub.src.models.siglip.siglip_image_converter import ( SigLIPImageConverter as SigLIPImageConverter, ) +from keras_hub.src.models.t5gemma2.t5gemma2_image_converter import ( + T5Gemma2ImageConverter as T5Gemma2ImageConverter, +) from keras_hub.src.models.vgg.vgg_image_converter import ( VGGImageConverter as VGGImageConverter, ) diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index e2b5909706..7fa1616306 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -763,6 +763,18 @@ from keras_hub.src.models.t5gemma.t5gemma_tokenizer import ( T5GemmaTokenizer as T5GemmaTokenizer, ) +from keras_hub.src.models.t5gemma2.t5gemma2_backbone import ( + T5Gemma2Backbone as T5Gemma2Backbone, +) +from keras_hub.src.models.t5gemma2.t5gemma2_seq_2_seq_lm import ( + T5Gemma2Seq2SeqLM as T5Gemma2Seq2SeqLM, +) +from keras_hub.src.models.t5gemma2.t5gemma2_seq_2_seq_lm_preprocessor import ( + T5Gemma2Seq2SeqLMPreprocessor as T5Gemma2Seq2SeqLMPreprocessor, +) +from keras_hub.src.models.t5gemma2.t5gemma2_tokenizer import ( + T5Gemma2Tokenizer as T5Gemma2Tokenizer, +) from keras_hub.src.models.task import Task as Task from keras_hub.src.models.text_classifier import TextClassifier as Classifier from keras_hub.src.models.text_classifier import ( diff --git a/keras_hub/api/tokenizers/__init__.py b/keras_hub/api/tokenizers/__init__.py index c558b5e4f5..6d8442ca52 100644 --- a/keras_hub/api/tokenizers/__init__.py +++ b/keras_hub/api/tokenizers/__init__.py @@ -115,6 +115,9 @@ from keras_hub.src.models.t5gemma.t5gemma_tokenizer import ( T5GemmaTokenizer as T5GemmaTokenizer, ) +from keras_hub.src.models.t5gemma2.t5gemma2_tokenizer import ( + T5Gemma2Tokenizer as T5Gemma2Tokenizer, +) from keras_hub.src.models.video_prism.video_prism_tokenizer import ( VideoPrismTokenizer as VideoPrismTokenizer, ) diff --git a/keras_hub/src/models/t5gemma2/__init__.py b/keras_hub/src/models/t5gemma2/__init__.py new file mode 100644 index 0000000000..94d56d0c9d --- /dev/null +++ b/keras_hub/src/models/t5gemma2/__init__.py @@ -0,0 +1,5 @@ +from keras_hub.src.models.t5gemma2.t5gemma2_backbone import T5Gemma2Backbone +from keras_hub.src.models.t5gemma2.t5gemma2_presets import backbone_presets +from keras_hub.src.utils.preset_utils import register_presets + +register_presets(backbone_presets, T5Gemma2Backbone) diff --git a/keras_hub/src/models/t5gemma2/t5gemma2_attention.py b/keras_hub/src/models/t5gemma2/t5gemma2_attention.py new file mode 100644 index 0000000000..0cab5426e2 --- /dev/null +++ b/keras_hub/src/models/t5gemma2/t5gemma2_attention.py @@ -0,0 +1,732 @@ +import inspect + +import keras + +from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding +from keras_hub.src.models.gemma3.gemma3_attention import CachedGemma3Attention +from keras_hub.src.models.gemma3.gemma3_layers import RMSNormalization +from keras_hub.src.models.t5gemma2.t5gemma2_layers import ( + t5gemma2_kernel_initializer, +) +from keras_hub.src.utils.keras_utils import clone_initializer + + +def repeat_kv(hidden_states, n_rep): + """Repeats the key/value hidden states for Grouped Query Attention. + + Args: + hidden_states: Tensor with shape + `(batch, sequence_length, num_key_value_heads, head_dim)`. + n_rep: int, number of times to repeat. + + Returns: + Tensor with shape + `(batch, sequence_length, num_query_heads, head_dim)`. + """ + if n_rep == 1: + return hidden_states + batch, slen, num_key_value_heads, head_dim = keras.ops.shape(hidden_states) + hidden_states = keras.ops.expand_dims(hidden_states, 3) + hidden_states = keras.ops.tile(hidden_states, (1, 1, 1, n_rep, 1)) + return keras.ops.reshape( + hidden_states, (batch, slen, num_key_value_heads * n_rep, head_dim) + ) + + +class T5Gemma2Attention(CachedGemma3Attention): + """Self-attention layer for T5Gemma2 encoder and decoder. + + This layer performs self-attention with Rotary Positional Embeddings + (RoPE), optional Q/K normalization (Gemma3-style), and optional + attention logit softcapping. Supports Grouped Query Attention (GQA). + + Used in `T5Gemma2EncoderLayer` for bidirectional self-attention + and can also be used in the decoder for self-attention. + + Args: + hidden_size: int, The dimensionality of the hidden states. + num_attention_heads: int, The number of attention heads. + num_key_value_heads: int, The number of key-value heads for GQA. + query_pre_attn_scalar: float, Scalar to multiply queries by + before attention. + attention_bias: bool, Whether to include bias in dense layers. + head_dim: int, The dimensionality of each attention head. + initializer_range: float, The range for the initializer. + Defaults to `0.02`. + attention_dropout: float, Dropout rate for attention weights. + Defaults to `0.0`. + attn_logit_softcapping: float, optional, Softcapping value. + Defaults to `None`. + rope_max_wavelength: float, Maximum wavelength for RoPE. + Defaults to `10000.0`. + use_query_key_norm: bool, Whether to apply RMS normalization on + query and key. Defaults to `True` (Gemma3-style). + rms_norm_eps: float, Epsilon for RMS normalization. + Defaults to `1e-6`. + dtype: The dtype for computations and weights. Defaults to `None`. + **kwargs: Additional keyword arguments. + """ + + def __init__( + self, + hidden_size, + num_attention_heads, + num_key_value_heads, + query_pre_attn_scalar, + attention_bias, + head_dim, + initializer_range=0.02, + attention_dropout=0.0, + attn_logit_softcapping=None, + rope_max_wavelength=10000.0, + rope_scaling_factor=1.0, + use_query_key_norm=True, + rms_norm_eps=1e-6, + dtype=None, + **kwargs, + ): + super().__init__( + head_dim=head_dim, + num_query_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + kernel_initializer=t5gemma2_kernel_initializer(initializer_range), + logit_soft_cap=attn_logit_softcapping, + dropout=attention_dropout, + query_head_dim_normalize=False, + use_sliding_window_attention=False, + dtype=dtype, + **kwargs, + ) + self.hidden_size = hidden_size + self.query_pre_attn_scalar = query_pre_attn_scalar + self.attention_bias = attention_bias + self.initializer_range = initializer_range + self.attention_dropout = attention_dropout + self.rope_max_wavelength = rope_max_wavelength + self.rope_scaling_factor = rope_scaling_factor + self.use_query_key_norm = use_query_key_norm + self.rms_norm_eps = rms_norm_eps + self.num_key_value_groups = ( + self.num_query_heads // self.num_key_value_heads + ) + self.scaling = self.query_pre_attn_scalar**-0.5 + + def build(self, input_shape): + self._kernel_initializer = t5gemma2_kernel_initializer( + self.initializer_range + ) + hidden_states_shape = input_shape + self.hidden_dim = hidden_states_shape[-1] + + self.query_dense = keras.layers.EinsumDense( + equation="btd,dnh->btnh", + output_shape=(None, self.num_query_heads, self.head_dim), + kernel_initializer=clone_initializer(self._kernel_initializer), + bias_axes="nh" if self.attention_bias else None, + dtype=self.dtype_policy, + name="query", + ) + self.query_dense.build(hidden_states_shape) + + self.key_dense = keras.layers.EinsumDense( + equation="bsd,dkh->bskh", + output_shape=(None, self.num_key_value_heads, self.head_dim), + kernel_initializer=clone_initializer(self._kernel_initializer), + bias_axes="kh" if self.attention_bias else None, + dtype=self.dtype_policy, + name="key", + ) + self.key_dense.build(hidden_states_shape) + + self.value_dense = keras.layers.EinsumDense( + equation="bsd,dkh->bskh", + output_shape=(None, self.num_key_value_heads, self.head_dim), + kernel_initializer=clone_initializer(self._kernel_initializer), + bias_axes="kh" if self.attention_bias else None, + dtype=self.dtype_policy, + name="value", + ) + self.value_dense.build(hidden_states_shape) + + self.output_dense = keras.layers.EinsumDense( + equation="btnh,nhd->btd", + output_shape=(None, self.hidden_dim), + kernel_initializer=clone_initializer(self._kernel_initializer), + bias_axes="d" if self.attention_bias else None, + dtype=self.dtype_policy, + name="attention_output", + ) + self.output_dense.build( + ( + hidden_states_shape[0], + hidden_states_shape[1], + self.num_query_heads, + self.head_dim, + ) + ) + + # Q/K normalization (Gemma3-style). + if self.use_query_key_norm: + self.query_norm = RMSNormalization( + epsilon=self.rms_norm_eps, + dtype=self.dtype_policy, + name="query_norm", + ) + self.query_norm.build( + self.query_dense.compute_output_shape(hidden_states_shape) + ) + self.key_norm = RMSNormalization( + epsilon=self.rms_norm_eps, + dtype=self.dtype_policy, + name="key_norm", + ) + self.key_norm.build( + self.key_dense.compute_output_shape(hidden_states_shape) + ) + + self.rotary_embedding = RotaryEmbedding( + max_wavelength=self.rope_max_wavelength, + scaling_factor=self.rope_scaling_factor, + sequence_axis=1, + feature_axis=3, + name="rotary_embedding", + dtype=self.dtype_policy, + ) + + self.dropout_layer = keras.layers.Dropout( + rate=self.attention_dropout, + dtype=self.dtype_policy, + ) + self.softmax = keras.layers.Softmax(axis=-1, dtype="float32") + self.built = True + + def _compute_attention_without_fused_op( + self, query_states, key_states, value_states, attention_mask, training + ): + attn_weights = keras.ops.einsum( + "btnh,bsnh->bnts", query_states, key_states + ) + attn_weights *= self.scaling + if self.logit_soft_cap is not None: + attn_weights = attn_weights / self.logit_soft_cap + attn_weights = keras.ops.tanh(attn_weights) + attn_weights = attn_weights * self.logit_soft_cap + if attention_mask is not None: + attn_weights += attention_mask + attn_weights = keras.ops.cast( + self.softmax(attn_weights), + query_states.dtype, + ) + attn_weights = self.dropout_layer(attn_weights, training=training) + attn_output = keras.ops.einsum( + "bnts,bsnh->btnh", attn_weights, value_states + ) + return attn_output + + def _compute_attention( + self, query_states, key_states, value_states, attention_mask, training + ): + if self._use_fused_attention_op(): + kwargs = {"bias": attention_mask} + if self.logit_soft_cap is not None: + sig = inspect.signature(keras.ops.dot_product_attention) + if "attn_logits_soft_cap" in sig.parameters: + kwargs["attn_logits_soft_cap"] = self.logit_soft_cap + return keras.ops.dot_product_attention( + query=query_states, + key=key_states, + value=value_states, + scale=self.scaling, + **kwargs, + ) + return self._compute_attention_without_fused_op( + query_states, + key_states, + value_states, + attention_mask, + training, + ) + + def call( + self, + inputs, + attention_mask=None, + cache=None, + cache_update_index=None, + training=None, + ): + hidden_states = inputs + query_states = self.query_dense(hidden_states) + key_states = self.key_dense(hidden_states) + value_states = self.value_dense(hidden_states) + + # Apply Q/K normalization. + if self.use_query_key_norm: + query_states = self.query_norm(query_states) + key_states = self.key_norm(key_states) + + # Apply RoPE. + start_index = 0 if cache_update_index is None else cache_update_index + query_states = self.rotary_embedding( + query_states, start_index=start_index + ) + key_states = self.rotary_embedding(key_states, start_index=start_index) + + # Handle caching for autoregressive generation. + if cache is not None: + if cache_update_index is None: + raise ValueError( + "Both `cache` and `cache_update_index` must be " + "passed for self-attention caching." + ) + key_cache, value_cache = cache[:, 0, ...], cache[:, 1, ...] + start = [0, cache_update_index, 0, 0] + key_states = keras.ops.slice_update(key_cache, start, key_states) + value_states = keras.ops.slice_update( + value_cache, start, value_states + ) + cache = keras.ops.stack((key_states, value_states), axis=1) + elif cache_update_index is not None: + raise ValueError( + "`cache_update_index` should not be set if `cache` is `None`." + ) + else: + cache = keras.ops.stack((key_states, value_states), axis=1) + + # Repeat K/V for GQA. + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_output = self._compute_attention( + query_states, + key_states, + value_states, + attention_mask, + training, + ) + attn_output = self.output_dense(attn_output) + return attn_output, cache + + def compute_output_shape(self, input_shape): + hidden_states_shape = input_shape + attn_output_shape = hidden_states_shape + kv_len = hidden_states_shape[1] + cache_shape = ( + hidden_states_shape[0], + 2, + kv_len, + self.num_key_value_heads, + self.head_dim, + ) + return attn_output_shape, cache_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "head_dim": self.head_dim, + "num_attention_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "query_pre_attn_scalar": self.query_pre_attn_scalar, + "attention_bias": self.attention_bias, + "initializer_range": self.initializer_range, + "attention_dropout": self.attention_dropout, + "attn_logit_softcapping": self.logit_soft_cap, + "rope_max_wavelength": self.rope_max_wavelength, + "use_query_key_norm": self.use_query_key_norm, + "rms_norm_eps": self.rms_norm_eps, + } + ) + return config + + +class T5Gemma2MergedAttention(CachedGemma3Attention): + """Merged self-attention and cross-attention for T5Gemma2 decoder. + + This layer fuses self-attention and cross-attention into a single + attention computation. The decoder's Q/K/V are computed from the + decoder hidden states (self-attention), while additional K/V are + computed from the encoder hidden states (cross-attention). The + self-attention and cross-attention K/V are concatenated, and a + single attention computation is performed over the merged K/V. + + This merged approach is the key architectural difference between + T5Gemma2 and T5Gemma1. + + Args: + hidden_size: int, Dimensionality of the decoder hidden states. + num_attention_heads: int, Number of attention heads. + num_key_value_heads: int, Number of key-value heads for GQA. + query_pre_attn_scalar: float, Scalar for query normalization. + attention_bias: bool, Whether to include bias. + head_dim: int, Dimensionality of each attention head. + cross_attention_hidden_size: int, optional, Hidden size of the + encoder states. Defaults to `hidden_size`. + initializer_range: float, Range for the initializer. + Defaults to `0.02`. + attention_dropout: float, Dropout rate. + Defaults to `0.0`. + attn_logit_softcapping: float, optional, Softcapping value. + Defaults to `None`. + rope_max_wavelength: float, Maximum wavelength for RoPE. + Defaults to `10000.0`. + use_query_key_norm: bool, Whether to apply Q/K norm. + Defaults to `True`. + rms_norm_eps: float, Epsilon for RMS normalization. + Defaults to `1e-6`. + dtype: The dtype for computations. Defaults to `None`. + **kwargs: Additional keyword arguments. + """ + + def __init__( + self, + hidden_size, + num_attention_heads, + num_key_value_heads, + query_pre_attn_scalar, + attention_bias, + head_dim, + cross_attention_hidden_size=None, + initializer_range=0.02, + attention_dropout=0.0, + attn_logit_softcapping=None, + rope_max_wavelength=10000.0, + rope_scaling_factor=1.0, + use_query_key_norm=True, + rms_norm_eps=1e-6, + dtype=None, + **kwargs, + ): + super().__init__( + head_dim=head_dim, + num_query_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + kernel_initializer=t5gemma2_kernel_initializer(initializer_range), + logit_soft_cap=attn_logit_softcapping, + dropout=attention_dropout, + query_head_dim_normalize=False, + use_sliding_window_attention=False, + dtype=dtype, + **kwargs, + ) + self.hidden_size = hidden_size + self.cross_attention_hidden_size = ( + cross_attention_hidden_size or hidden_size + ) + self.query_pre_attn_scalar = query_pre_attn_scalar + self.attention_bias = attention_bias + self.initializer_range = initializer_range + self.attention_dropout = attention_dropout + self.rope_max_wavelength = rope_max_wavelength + self.rope_scaling_factor = rope_scaling_factor + self.use_query_key_norm = use_query_key_norm + self.rms_norm_eps = rms_norm_eps + self.num_key_value_groups = ( + self.num_query_heads // self.num_key_value_heads + ) + self.scaling = self.query_pre_attn_scalar**-0.5 + + def build(self, input_shape): + # Only decoder_shape needed; K/V projections are shared. + if isinstance(input_shape, (list, tuple)): + decoder_shape = input_shape[0] + else: + decoder_shape = input_shape + + self._kernel_initializer = t5gemma2_kernel_initializer( + self.initializer_range + ) + self.hidden_dim = decoder_shape[-1] + + # Q projection from decoder hidden states. + self.query_dense = keras.layers.EinsumDense( + equation="btd,dnh->btnh", + output_shape=(None, self.num_query_heads, self.head_dim), + kernel_initializer=clone_initializer(self._kernel_initializer), + bias_axes="nh" if self.attention_bias else None, + dtype=self.dtype_policy, + name="query", + ) + self.query_dense.build(decoder_shape) + + # K/V projections shared for self-attn and cross-attn. + self.key_dense = keras.layers.EinsumDense( + equation="bsd,dkh->bskh", + output_shape=(None, self.num_key_value_heads, self.head_dim), + kernel_initializer=clone_initializer(self._kernel_initializer), + bias_axes="kh" if self.attention_bias else None, + dtype=self.dtype_policy, + name="key", + ) + self.key_dense.build(decoder_shape) + + self.value_dense = keras.layers.EinsumDense( + equation="bsd,dkh->bskh", + output_shape=(None, self.num_key_value_heads, self.head_dim), + kernel_initializer=clone_initializer(self._kernel_initializer), + bias_axes="kh" if self.attention_bias else None, + dtype=self.dtype_policy, + name="value", + ) + self.value_dense.build(decoder_shape) + + # Output projection. + self.output_dense = keras.layers.EinsumDense( + equation="btnh,nhd->btd", + output_shape=(None, self.hidden_dim), + kernel_initializer=clone_initializer(self._kernel_initializer), + bias_axes="d" if self.attention_bias else None, + dtype=self.dtype_policy, + name="attention_output", + ) + self.output_dense.build( + ( + decoder_shape[0], + decoder_shape[1], + self.num_query_heads, + self.head_dim, + ) + ) + + # Q/K normalization (Gemma3-style). + if self.use_query_key_norm: + self.query_norm = RMSNormalization( + epsilon=self.rms_norm_eps, + dtype=self.dtype_policy, + name="query_norm", + ) + self.query_norm.build( + self.query_dense.compute_output_shape(decoder_shape) + ) + self.key_norm = RMSNormalization( + epsilon=self.rms_norm_eps, + dtype=self.dtype_policy, + name="key_norm", + ) + self.key_norm.build( + self.key_dense.compute_output_shape(decoder_shape) + ) + + self.rotary_embedding = RotaryEmbedding( + max_wavelength=self.rope_max_wavelength, + scaling_factor=self.rope_scaling_factor, + sequence_axis=1, + feature_axis=3, + name="rotary_embedding", + dtype=self.dtype_policy, + ) + + self.dropout_layer = keras.layers.Dropout( + rate=self.attention_dropout, + dtype=self.dtype_policy, + ) + self.softmax = keras.layers.Softmax(axis=-1, dtype="float32") + self.built = True + + def _compute_attention_without_fused_op( + self, query_states, key_states, value_states, attention_mask, training + ): + attn_weights = keras.ops.einsum( + "btnh,bsnh->bnts", query_states, key_states + ) + attn_weights *= self.scaling + if self.logit_soft_cap is not None: + attn_weights = attn_weights / self.logit_soft_cap + attn_weights = keras.ops.tanh(attn_weights) + attn_weights = attn_weights * self.logit_soft_cap + if attention_mask is not None: + attn_weights += attention_mask + attn_weights = keras.ops.cast( + self.softmax(attn_weights), + query_states.dtype, + ) + attn_weights = self.dropout_layer(attn_weights, training=training) + attn_output = keras.ops.einsum( + "bnts,bsnh->btnh", attn_weights, value_states + ) + return attn_output + + def _compute_attention( + self, query_states, key_states, value_states, attention_mask, training + ): + if self._use_fused_attention_op(): + kwargs = {"bias": attention_mask} + if self.logit_soft_cap is not None: + sig = inspect.signature(keras.ops.dot_product_attention) + if "attn_logits_soft_cap" in sig.parameters: + kwargs["attn_logits_soft_cap"] = self.logit_soft_cap + return keras.ops.dot_product_attention( + query=query_states, + key=key_states, + value=value_states, + scale=self.scaling, + **kwargs, + ) + return self._compute_attention_without_fused_op( + query_states, + key_states, + value_states, + attention_mask, + training, + ) + + def call( + self, + inputs, + encoder_hidden_states, + attention_mask=None, + cache=None, + cache_update_index=None, + training=None, + ): + """Forward pass of merged self+cross attention. + + Args: + inputs: Decoder hidden states, shape + `(batch, decoder_seq_len, hidden_dim)`. + encoder_hidden_states: Encoder output, shape + `(batch, encoder_seq_len, hidden_dim)`. + attention_mask: Merged attention mask, shape + `(batch, 1, decoder_seq_len, decoder_kv_len + encoder_len)`. + This is the concatenation of the causal self-attention + mask and the bidirectional cross-attention mask. + cache: Tuple of (self_attn_cache, cross_attn_cache) or None. + cache_update_index: int, the current position in the + sequence for caching. + training: bool, whether in training mode. + + Returns: + Tuple of (attn_output, updated_cache). + """ + hidden_states = inputs + self_attention_cache, cross_attention_cache = ( + cache if cache is not None else (None, None) + ) + + # Self-attention Q/K/V from decoder hidden states. + query_states = self.query_dense(hidden_states) + key_states = self.key_dense(hidden_states) + value_states = self.value_dense(hidden_states) + + # Apply Q/K normalization. + if self.use_query_key_norm: + query_states = self.query_norm(query_states) + key_states = self.key_norm(key_states) + + # Apply RoPE to self-attention Q/K only (not cross-attention). + start_index = 0 if cache_update_index is None else cache_update_index + query_states = self.rotary_embedding( + query_states, start_index=start_index + ) + key_states = self.rotary_embedding(key_states, start_index=start_index) + + # Update self-attention cache. + if self_attention_cache is not None: + key_cache_self = self_attention_cache[:, 0, ...] + value_cache_self = self_attention_cache[:, 1, ...] + start = [0, cache_update_index, 0, 0] + key_states = keras.ops.slice_update( + key_cache_self, start, key_states + ) + value_states = keras.ops.slice_update( + value_cache_self, start, value_states + ) + updated_self_cache = keras.ops.stack( + (key_states, value_states), axis=1 + ) + else: + updated_self_cache = keras.ops.stack( + (key_states, value_states), axis=1 + ) + + # Cross-attention K/V from encoder hidden states. + if cross_attention_cache is not None: + # Reuse cached encoder K/V. + cross_key_states = cross_attention_cache[:, 0, ...] + cross_value_states = cross_attention_cache[:, 1, ...] + updated_cross_cache = cross_attention_cache + else: + cross_key_states = self.key_dense(encoder_hidden_states) + cross_value_states = self.value_dense(encoder_hidden_states) + # Apply K normalization to cross-attention keys. + if self.use_query_key_norm: + cross_key_states = self.key_norm(cross_key_states) + updated_cross_cache = keras.ops.stack( + (cross_key_states, cross_value_states), axis=1 + ) + + # Merge self-attention and cross-attention K/V. + merged_key_states = keras.ops.concatenate( + [key_states, cross_key_states], axis=1 + ) + merged_value_states = keras.ops.concatenate( + [value_states, cross_value_states], axis=1 + ) + + # Repeat K/V for GQA. + merged_key_states = repeat_kv( + merged_key_states, self.num_key_value_groups + ) + merged_value_states = repeat_kv( + merged_value_states, self.num_key_value_groups + ) + + attn_output = self._compute_attention( + query_states, + merged_key_states, + merged_value_states, + attention_mask, + training, + ) + attn_output = self.output_dense(attn_output) + updated_cache = (updated_self_cache, updated_cross_cache) + return attn_output, updated_cache + + def compute_output_shape(self, input_shape): + if isinstance(input_shape, (list, tuple)): + decoder_shape, encoder_shape = input_shape + else: + decoder_shape = input_shape + encoder_shape = input_shape + attn_output_shape = decoder_shape + dec_kv_len = decoder_shape[1] + enc_kv_len = encoder_shape[1] + self_cache_shape = ( + decoder_shape[0], + 2, + dec_kv_len, + self.num_key_value_heads, + self.head_dim, + ) + cross_cache_shape = ( + decoder_shape[0], + 2, + enc_kv_len, + self.num_key_value_heads, + self.head_dim, + ) + return attn_output_shape, (self_cache_shape, cross_cache_shape) + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "head_dim": self.head_dim, + "num_attention_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "query_pre_attn_scalar": self.query_pre_attn_scalar, + "attention_bias": self.attention_bias, + "cross_attention_hidden_size": ( + self.cross_attention_hidden_size + ), + "initializer_range": self.initializer_range, + "attention_dropout": self.attention_dropout, + "attn_logit_softcapping": self.logit_soft_cap, + "rope_max_wavelength": self.rope_max_wavelength, + "use_query_key_norm": self.use_query_key_norm, + "rms_norm_eps": self.rms_norm_eps, + } + ) + return config diff --git a/keras_hub/src/models/t5gemma2/t5gemma2_backbone.py b/keras_hub/src/models/t5gemma2/t5gemma2_backbone.py new file mode 100644 index 0000000000..02004726a3 --- /dev/null +++ b/keras_hub/src/models/t5gemma2/t5gemma2_backbone.py @@ -0,0 +1,540 @@ +import keras +from keras import ops +from keras.layers import ReversibleEmbedding + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.gemma3.gemma3_layers import Gemma3InterleaveEmbeddings +from keras_hub.src.models.gemma3.gemma3_layers import RMSNormalization +from keras_hub.src.models.t5gemma2.t5gemma2_decoder import T5Gemma2DecoderLayer +from keras_hub.src.models.t5gemma2.t5gemma2_encoder import T5Gemma2EncoderLayer +from keras_hub.src.models.t5gemma2.t5gemma2_layers import ( + t5gemma2_kernel_initializer, +) +from keras_hub.src.utils.keras_utils import clone_initializer + + +@keras_hub_export("keras_hub.models.T5Gemma2Backbone") +class T5Gemma2Backbone(Backbone): + """T5Gemma2 backbone model. + + This class implements the encoder-decoder backbone of the T5Gemma2 + model. T5Gemma2 is based on Gemma3 and features merged + self+cross attention in the decoder (unlike T5Gemma1 which used + separate attention sublayers), Gemma3-style Q/K normalization, + and per-layer-type sliding window attention patterns. + + When a `vision_encoder` is provided, the model also accepts image + inputs. Images are processed by the vision encoder and the resulting + embeddings are interleaved into the encoder text embeddings at + positions marked by image placeholder tokens. + + Args: + vocabulary_size: int, The size of the vocabulary. + encoder_hidden_dim: int, Encoder hidden dimensionality. + encoder_intermediate_dim: int, Encoder FFN intermediate size. + encoder_num_layers: int, Number of encoder layers. + encoder_num_attention_heads: int, Encoder attention heads. + encoder_num_key_value_heads: int, Encoder KV heads for GQA. + encoder_head_dim: int, Encoder head dimensionality. + encoder_layer_types: list of str, Attention layer types for + each encoder layer (`"full_attention"` or + `"sliding_attention"`). + decoder_hidden_dim: int, Decoder hidden dimensionality. + decoder_intermediate_dim: int, Decoder FFN intermediate size. + decoder_num_layers: int, Number of decoder layers. + decoder_num_attention_heads: int, Decoder attention heads. + decoder_num_key_value_heads: int, Decoder KV heads for GQA. + decoder_head_dim: int, Decoder head dimensionality. + decoder_layer_types: list of str, Attention layer types for + each decoder layer. + dropout_rate: float, Dropout rate. Defaults to `0.0`. + rms_norm_eps: float, RMS normalization epsilon. Defaults to + `1e-6`. + query_pre_attn_scalar: float, Query scalar. Defaults to `1.0`. + attention_bias: bool, Attention bias. Defaults to `False`. + hidden_activation: str, FFN activation. Defaults to + `"gelu_approximate"`. + tie_word_embeddings: bool, Tie input/output embeddings. + Defaults to `True`. + initializer_range: float, Initializer range. Defaults to + `0.02`. + attention_dropout: float, Attention dropout. Defaults to `0.0`. + sliding_window: int, optional, Sliding window size. + cross_attention_hidden_size: int, optional, Cross-attention + hidden size. Defaults to `encoder_hidden_dim`. + attn_logit_softcapping: float, optional, Attention softcapping. + final_logit_softcapping: float, optional, Final logit + softcapping. + rope_max_wavelength: float, RoPE maximum wavelength. + Defaults to `10000.0`. + global_rope_scaling_factor: float, RoPE scaling factor for + full attention layers. Defaults to `1.0`. + use_query_key_norm: bool, Whether to use Gemma3-style Q/K + normalization. Defaults to `True`. + vision_encoder: optional, A `Gemma3VisionEncoder` instance for + multimodal inputs. When `None`, the model is text-only. + eoi_token_index: int, Token index for the end-of-image token. + Defaults to `256000`. + dtype: dtype for computations. Defaults to `None`. + **kwargs: Additional keyword arguments. + + Examples: + ```python + import numpy as np + from keras_hub.models import T5Gemma2Backbone + + input_data = { + "encoder_token_ids": np.ones(shape=(1, 12), dtype="int32"), + "encoder_padding_mask": np.array( + [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]], dtype="int32" + ), + "decoder_token_ids": np.ones(shape=(1, 8), dtype="int32"), + "decoder_padding_mask": np.array( + [[1, 1, 1, 1, 1, 1, 1, 1]], dtype="int32" + ), + } + + model = T5Gemma2Backbone( + vocabulary_size=32000, + encoder_hidden_dim=256, + encoder_intermediate_dim=512, + encoder_num_layers=4, + encoder_num_attention_heads=4, + encoder_num_key_value_heads=2, + encoder_head_dim=64, + encoder_layer_types=["full_attention"] * 4, + decoder_hidden_dim=256, + decoder_intermediate_dim=512, + decoder_num_layers=4, + decoder_num_attention_heads=4, + decoder_num_key_value_heads=2, + decoder_head_dim=64, + decoder_layer_types=["full_attention"] * 4, + dropout_rate=0.1, + rms_norm_eps=1e-6, + query_pre_attn_scalar=1.0, + attention_bias=False, + hidden_activation="gelu_approximate", + ) + output = model(input_data) + ``` + """ + + def __init__( + self, + vocabulary_size, + encoder_hidden_dim, + encoder_intermediate_dim, + encoder_num_layers, + encoder_num_attention_heads, + encoder_num_key_value_heads, + encoder_head_dim, + encoder_layer_types, + decoder_hidden_dim, + decoder_intermediate_dim, + decoder_num_layers, + decoder_num_attention_heads, + decoder_num_key_value_heads, + decoder_head_dim, + decoder_layer_types, + dropout_rate=0.0, + rms_norm_eps=1e-6, + query_pre_attn_scalar=1.0, + attention_bias=False, + hidden_activation="gelu_approximate", + tie_word_embeddings=True, + initializer_range=0.02, + attention_dropout=0.0, + sliding_window=None, + cross_attention_hidden_size=None, + attn_logit_softcapping=None, + final_logit_softcapping=None, + rope_max_wavelength=10000.0, + global_rope_scaling_factor=1.0, + encoder_rope_max_wavelength=None, + encoder_global_rope_scaling_factor=None, + use_query_key_norm=True, + vision_encoder=None, + eoi_token_index=256000, + dtype=None, + **kwargs, + ): + self.kernel_initializer = t5gemma2_kernel_initializer(initializer_range) + + # Determine if text-only. + self.vision_encoder = vision_encoder + text_only_model = vision_encoder is None + + # === Layers === + self.token_embedding = keras.layers.Embedding( + input_dim=vocabulary_size, + output_dim=encoder_hidden_dim, + embeddings_initializer=clone_initializer(self.kernel_initializer), + dtype=dtype, + name="encoder_token_embedding", + ) + self.decoder_token_embedding = ReversibleEmbedding( + input_dim=vocabulary_size, + output_dim=decoder_hidden_dim, + tie_weights=tie_word_embeddings, + embeddings_initializer=clone_initializer(self.kernel_initializer), + dtype=dtype, + name="decoder_token_embedding", + ) + + # Vision interleaving layer (only when vision encoder is present). + if not text_only_model: + self.interleave_embeddings = Gemma3InterleaveEmbeddings( + num_vision_tokens_per_image=self.vision_encoder.num_vision_tokens_per_image, + dtype=dtype, + name="interleave_embeddings", + ) + # EOI (end-of-image) embeddings: learned vectors that + # replace the standard embedding at eoi_token_index. + self.encoder_eoi_embedding = keras.Variable( + keras.ops.zeros((encoder_hidden_dim,)), + name="encoder_eoi_embedding", + ) + self.decoder_eoi_embedding = keras.Variable( + keras.ops.zeros((decoder_hidden_dim,)), + name="decoder_eoi_embedding", + ) + + # Encoder may have different RoPE config than decoder. + enc_rope = ( + encoder_rope_max_wavelength + if encoder_rope_max_wavelength is not None + else rope_max_wavelength + ) + enc_rope_factor = ( + encoder_global_rope_scaling_factor + if encoder_global_rope_scaling_factor is not None + else global_rope_scaling_factor + ) + + self.encoder_layers = [] + for i in range(encoder_num_layers): + # Per-layer RoPE wavelength: base for sliding, 1M for global. + layer_rope = ( + enc_rope + if encoder_layer_types[i] == "sliding_attention" + else 1_000_000.0 + ) + # Per-layer RoPE scaling: 1.0 for sliding (default), + # global_rope_scaling_factor for full_attention (linear). + layer_rope_factor = ( + 1.0 + if encoder_layer_types[i] == "sliding_attention" + else enc_rope_factor + ) + self.encoder_layers.append( + T5Gemma2EncoderLayer( + hidden_size=encoder_hidden_dim, + rms_norm_eps=rms_norm_eps, + num_attention_heads=encoder_num_attention_heads, + num_key_value_heads=encoder_num_key_value_heads, + query_pre_attn_scalar=query_pre_attn_scalar, + attention_bias=attention_bias, + intermediate_size=encoder_intermediate_dim, + hidden_activation=hidden_activation, + head_dim=encoder_head_dim, + dropout_rate=dropout_rate, + initializer_range=initializer_range, + attention_dropout=attention_dropout, + layer_type=encoder_layer_types[i], + sliding_window=sliding_window, + attn_logit_softcapping=attn_logit_softcapping, + rope_max_wavelength=layer_rope, + rope_scaling_factor=layer_rope_factor, + use_query_key_norm=use_query_key_norm, + name=f"encoder_layer_{i}", + dtype=dtype, + ) + ) + self.encoder_norm = RMSNormalization(epsilon=rms_norm_eps, dtype=dtype) + self.encoder_dropout = keras.layers.Dropout(dropout_rate, dtype=dtype) + self.decoder_layers = [] + for i in range(decoder_num_layers): + # Per-layer RoPE wavelength: 10K for sliding, 1M for global. + layer_rope = ( + rope_max_wavelength + if decoder_layer_types[i] == "sliding_attention" + else 1_000_000.0 + ) + # Per-layer RoPE scaling: 1.0 for sliding (default), + # global_rope_scaling_factor for full_attention (linear). + layer_rope_factor = ( + 1.0 + if decoder_layer_types[i] == "sliding_attention" + else global_rope_scaling_factor + ) + self.decoder_layers.append( + T5Gemma2DecoderLayer( + hidden_size=decoder_hidden_dim, + rms_norm_eps=rms_norm_eps, + num_attention_heads=decoder_num_attention_heads, + num_key_value_heads=decoder_num_key_value_heads, + query_pre_attn_scalar=query_pre_attn_scalar, + attention_bias=attention_bias, + intermediate_size=decoder_intermediate_dim, + hidden_activation=hidden_activation, + dropout_rate=dropout_rate, + initializer_range=initializer_range, + head_dim=decoder_head_dim, + attention_dropout=attention_dropout, + layer_type=decoder_layer_types[i], + sliding_window=sliding_window, + cross_attention_hidden_size=( + cross_attention_hidden_size or encoder_hidden_dim + ), + attn_logit_softcapping=attn_logit_softcapping, + rope_max_wavelength=layer_rope, + rope_scaling_factor=layer_rope_factor, + use_query_key_norm=use_query_key_norm, + name=f"decoder_layer_{i}", + dtype=dtype, + ) + ) + self.decoder_norm = RMSNormalization(epsilon=rms_norm_eps, dtype=dtype) + self.decoder_dropout = keras.layers.Dropout(dropout_rate, dtype=dtype) + + # === Functional Model === + encoder_token_id_input = keras.Input( + shape=(None,), dtype="int32", name="encoder_token_ids" + ) + encoder_padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="encoder_padding_mask" + ) + decoder_token_id_input = keras.Input( + shape=(None,), dtype="int32", name="decoder_token_ids" + ) + decoder_padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="decoder_padding_mask" + ) + + # Optional vision inputs. + if not text_only_model: + image_size = self.vision_encoder.image_size + image_input = keras.Input( + shape=(None, image_size, image_size, 3), + name="images", + ) + vision_indices_input = keras.Input( + shape=(None,), dtype="int32", name="vision_indices" + ) + + # Encoder. + encoder_embeddings = self.token_embedding(encoder_token_id_input) + encoder_embeddings = encoder_embeddings * ops.cast( + ops.sqrt(encoder_hidden_dim), encoder_embeddings.dtype + ) + + # Handle EOI embedding replacement. + if not text_only_model: + # Replace embeddings at eoi_token_index positions with the + # learned eoi_embedding (a separate parameter per HF design). + # Use ops.where with automatic broadcasting (no broadcast_to + # needed — avoids issues with symbolic shapes during tracing). + eoi_mask = ops.cast( + ops.expand_dims( + ops.equal(encoder_token_id_input, eoi_token_index), + axis=-1, + ), + encoder_embeddings.dtype, + ) + encoder_embeddings = ( + eoi_mask * self.encoder_eoi_embedding + + (1 - eoi_mask) * encoder_embeddings + ) + + # Interleave vision embeddings if images are provided. + if not text_only_model: + img_embeddings = self.vision_encoder(image_input) + encoder_embeddings = self.interleave_embeddings( + image_embeddings=img_embeddings, + text_embeddings=encoder_embeddings, + vision_indices=vision_indices_input, + ) + + encoder_hidden_states = self.encoder_dropout(encoder_embeddings) + for layer in self.encoder_layers: + encoder_hidden_states = layer( + encoder_hidden_states, + padding_mask=encoder_padding_mask_input, + ) + encoder_output = self.encoder_norm(encoder_hidden_states) + encoder_output = self.encoder_dropout(encoder_output) + + # Decoder. + decoder_embeddings = self.decoder_token_embedding( + decoder_token_id_input + ) + decoder_embeddings = decoder_embeddings * ops.cast( + ops.sqrt(decoder_hidden_dim), decoder_embeddings.dtype + ) + + # Handle EOI embedding replacement in decoder. + if not text_only_model: + dec_eoi_mask = ops.cast( + ops.expand_dims( + ops.equal(decoder_token_id_input, eoi_token_index), + axis=-1, + ), + decoder_embeddings.dtype, + ) + decoder_embeddings = ( + dec_eoi_mask * self.decoder_eoi_embedding + + (1 - dec_eoi_mask) * decoder_embeddings + ) + + decoder_hidden_states = self.decoder_dropout(decoder_embeddings) + for layer in self.decoder_layers: + decoder_hidden_states, _ = layer( + (decoder_hidden_states, encoder_output), + self_attention_padding_mask=decoder_padding_mask_input, + cross_attention_padding_mask=encoder_padding_mask_input, + ) + decoder_output = self.decoder_norm(decoder_hidden_states) + decoder_output = self.decoder_dropout(decoder_output) + + inputs = { + "encoder_token_ids": encoder_token_id_input, + "encoder_padding_mask": encoder_padding_mask_input, + "decoder_token_ids": decoder_token_id_input, + "decoder_padding_mask": decoder_padding_mask_input, + } + if not text_only_model: + inputs.update( + { + "images": image_input, + "vision_indices": vision_indices_input, + } + ) + + super().__init__( + inputs=inputs, + outputs={ + "encoder_sequence_output": encoder_output, + "decoder_sequence_output": decoder_output, + }, + dtype=dtype, + **kwargs, + ) + + # === Config === + self.encoder_hidden_dim = encoder_hidden_dim + self.encoder_intermediate_dim = encoder_intermediate_dim + self.encoder_num_layers = encoder_num_layers + self.encoder_num_attention_heads = encoder_num_attention_heads + self.encoder_num_key_value_heads = encoder_num_key_value_heads + self.encoder_head_dim = encoder_head_dim + self.encoder_layer_types = encoder_layer_types + self.decoder_hidden_dim = decoder_hidden_dim + self.decoder_intermediate_dim = decoder_intermediate_dim + self.decoder_num_layers = decoder_num_layers + self.decoder_num_attention_heads = decoder_num_attention_heads + self.decoder_num_key_value_heads = decoder_num_key_value_heads + self.decoder_head_dim = decoder_head_dim + self.decoder_layer_types = decoder_layer_types + self.vocabulary_size = vocabulary_size + self.dropout_rate = dropout_rate + self.rms_norm_eps = rms_norm_eps + self.tie_word_embeddings = tie_word_embeddings + self.query_pre_attn_scalar = query_pre_attn_scalar + self.attention_bias = attention_bias + self.hidden_activation = hidden_activation + self.initializer_range = initializer_range + self.attention_dropout = attention_dropout + self.sliding_window = sliding_window + self.cross_attention_hidden_size = ( + cross_attention_hidden_size or encoder_hidden_dim + ) + self.attn_logit_softcapping = attn_logit_softcapping + self.final_logit_softcapping = final_logit_softcapping + self.rope_max_wavelength = rope_max_wavelength + self.global_rope_scaling_factor = global_rope_scaling_factor + self.encoder_rope_max_wavelength = encoder_rope_max_wavelength + self.encoder_global_rope_scaling_factor = ( + encoder_global_rope_scaling_factor + ) + self.use_query_key_norm = use_query_key_norm + self.eoi_token_index = eoi_token_index + self.text_only_model = text_only_model + + # Keep `num_vision_tokens_per_image` for easy access. + if not text_only_model: + self.num_vision_tokens_per_image = ( + self.vision_encoder.num_vision_tokens_per_image + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "encoder_hidden_dim": self.encoder_hidden_dim, + "encoder_intermediate_dim": self.encoder_intermediate_dim, + "encoder_num_layers": self.encoder_num_layers, + "encoder_num_attention_heads": ( + self.encoder_num_attention_heads + ), + "encoder_num_key_value_heads": ( + self.encoder_num_key_value_heads + ), + "encoder_layer_types": self.encoder_layer_types, + "encoder_head_dim": self.encoder_head_dim, + "decoder_hidden_dim": self.decoder_hidden_dim, + "decoder_intermediate_dim": self.decoder_intermediate_dim, + "decoder_num_layers": self.decoder_num_layers, + "decoder_num_attention_heads": ( + self.decoder_num_attention_heads + ), + "decoder_num_key_value_heads": ( + self.decoder_num_key_value_heads + ), + "decoder_layer_types": self.decoder_layer_types, + "decoder_head_dim": self.decoder_head_dim, + "dropout_rate": self.dropout_rate, + "rms_norm_eps": self.rms_norm_eps, + "tie_word_embeddings": self.tie_word_embeddings, + "query_pre_attn_scalar": self.query_pre_attn_scalar, + "attention_bias": self.attention_bias, + "hidden_activation": self.hidden_activation, + "initializer_range": self.initializer_range, + "attention_dropout": self.attention_dropout, + "sliding_window": self.sliding_window, + "cross_attention_hidden_size": ( + self.cross_attention_hidden_size + ), + "attn_logit_softcapping": self.attn_logit_softcapping, + "final_logit_softcapping": (self.final_logit_softcapping), + "rope_max_wavelength": self.rope_max_wavelength, + "global_rope_scaling_factor": (self.global_rope_scaling_factor), + "encoder_rope_max_wavelength": ( + self.encoder_rope_max_wavelength + ), + "encoder_global_rope_scaling_factor": ( + self.encoder_global_rope_scaling_factor + ), + "use_query_key_norm": self.use_query_key_norm, + "eoi_token_index": self.eoi_token_index, + } + ) + if self.vision_encoder is not None: + config["vision_encoder"] = keras.saving.serialize_keras_object( + self.vision_encoder + ) + return config + + @classmethod + def from_config(cls, config): + vision_encoder = config.pop("vision_encoder", None) + if vision_encoder is not None and isinstance(vision_encoder, dict): + vision_encoder = keras.saving.deserialize_keras_object( + vision_encoder + ) + config["vision_encoder"] = vision_encoder + elif vision_encoder is not None: + config["vision_encoder"] = vision_encoder + return cls(**config) diff --git a/keras_hub/src/models/t5gemma2/t5gemma2_backbone_test.py b/keras_hub/src/models/t5gemma2/t5gemma2_backbone_test.py new file mode 100644 index 0000000000..29e94a1324 --- /dev/null +++ b/keras_hub/src/models/t5gemma2/t5gemma2_backbone_test.py @@ -0,0 +1,116 @@ +import keras +import pytest + +from keras_hub.src.models.t5gemma2.t5gemma2_backbone import T5Gemma2Backbone +from keras_hub.src.tests.test_case import TestCase + + +class T5Gemma2BackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "vocabulary_size": 100, + "encoder_hidden_dim": 32, + "encoder_intermediate_dim": 64, + "encoder_num_layers": 2, + "encoder_num_attention_heads": 4, + "encoder_num_key_value_heads": 2, + "encoder_head_dim": 8, + "encoder_layer_types": [ + "sliding_attention", + "full_attention", + ], + "decoder_hidden_dim": 32, + "decoder_intermediate_dim": 64, + "decoder_num_layers": 2, + "decoder_num_attention_heads": 4, + "decoder_num_key_value_heads": 2, + "decoder_head_dim": 8, + "decoder_layer_types": [ + "sliding_attention", + "full_attention", + ], + "dropout_rate": 0.1, + "rms_norm_eps": 1e-6, + "tie_word_embeddings": True, + "query_pre_attn_scalar": 1.0, + "attention_bias": False, + "hidden_activation": "gelu_approximate", + "sliding_window": 16, + "cross_attention_hidden_size": 32, + "attn_logit_softcapping": 50.0, + "rope_max_wavelength": 10000.0, + "initializer_range": 0.04, + "attention_dropout": 0.1, + "use_query_key_norm": True, + } + self.input_data = { + "encoder_token_ids": keras.ops.ones((2, 16), dtype="int32"), + "encoder_padding_mask": keras.ops.ones((2, 16), dtype="int32"), + "decoder_token_ids": keras.ops.ones((2, 16), dtype="int32"), + "decoder_padding_mask": keras.ops.ones((2, 16), dtype="int32"), + } + + def test_backbone_basics(self): + self.run_backbone_test( + cls=T5Gemma2Backbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape={ + "encoder_sequence_output": (2, 16, 32), + "decoder_sequence_output": (2, 16, 32), + }, + ) + + def test_asymmetrical_backbone(self): + asym_kwargs = { + "vocabulary_size": 100, + "encoder_hidden_dim": 32, + "encoder_intermediate_dim": 96, + "encoder_num_layers": 3, + "encoder_num_attention_heads": 4, + "encoder_num_key_value_heads": 2, + "encoder_head_dim": 8, + "encoder_layer_types": ["full_attention"] * 3, + "decoder_hidden_dim": 32, + "decoder_intermediate_dim": 64, + "decoder_num_layers": 2, + "decoder_num_attention_heads": 4, + "decoder_num_key_value_heads": 2, + "decoder_head_dim": 8, + "decoder_layer_types": [ + "sliding_attention", + "full_attention", + ], + "sliding_window": 16, + "dropout_rate": 0.1, + "rms_norm_eps": 1e-6, + "tie_word_embeddings": True, + "cross_attention_hidden_size": 32, + "use_query_key_norm": True, + } + self.run_backbone_test( + cls=T5Gemma2Backbone, + init_kwargs=asym_kwargs, + input_data=self.input_data, + expected_output_shape={ + "encoder_sequence_output": (2, 16, 32), + "decoder_sequence_output": (2, 16, 32), + }, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=T5Gemma2Backbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in T5Gemma2Backbone.presets: + self.run_preset_test( + cls=T5Gemma2Backbone, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/t5gemma2/t5gemma2_decoder.py b/keras_hub/src/models/t5gemma2/t5gemma2_decoder.py new file mode 100644 index 0000000000..b56d708031 --- /dev/null +++ b/keras_hub/src/models/t5gemma2/t5gemma2_decoder.py @@ -0,0 +1,349 @@ +import keras + +from keras_hub.src.models.gemma3.gemma3_layers import RMSNormalization +from keras_hub.src.models.t5gemma2.t5gemma2_attention import ( + T5Gemma2MergedAttention, +) +from keras_hub.src.models.t5gemma2.t5gemma2_layers import T5Gemma2MLP + + +class T5Gemma2DecoderLayer(keras.layers.Layer): + """Decoder layer for the T5Gemma2 model. + + This layer implements a single decoder block in the T5Gemma2 + architecture. Unlike T5Gemma1 which has separate self-attention and + cross-attention sub-layers, T5Gemma2 uses a single + `T5Gemma2MergedAttention` layer that fuses self-attention and + cross-attention by concatenating their K/V pairs. + + Args: + hidden_size: int, Dimensionality of hidden states. + rms_norm_eps: float, Epsilon for RMS normalization. + num_attention_heads: int, Number of attention heads. + num_key_value_heads: int, Number of key-value heads for GQA. + query_pre_attn_scalar: float, Scalar for query normalization. + attention_bias: bool, Whether to include bias. + intermediate_size: int, Intermediate size of the FFN. + hidden_activation: str, Activation function for the FFN. + dropout_rate: float, Dropout rate. + head_dim: int, Dimensionality of each attention head. + initializer_range: float, Range for the initializer. + attention_dropout: float, Dropout for attention weights. + layer_type: str, Either `"full_attention"` or + `"sliding_attention"`. + cross_attention_hidden_size: int, optional, Hidden size for + cross-attention. Defaults to `hidden_size`. + attn_logit_softcapping: float, optional, Softcapping value. + sliding_window: int, optional, Window size for sliding + attention. + rope_max_wavelength: float, Maximum wavelength for RoPE. + use_query_key_norm: bool, Whether to apply Q/K norm. + Defaults to `True`. + dtype: The dtype for computations. Defaults to `None`. + **kwargs: Additional keyword arguments. + """ + + def __init__( + self, + hidden_size, + rms_norm_eps, + num_attention_heads, + num_key_value_heads, + query_pre_attn_scalar, + attention_bias, + intermediate_size, + hidden_activation, + dropout_rate, + head_dim, + initializer_range, + attention_dropout, + layer_type, + cross_attention_hidden_size=None, + attn_logit_softcapping=None, + sliding_window=None, + rope_max_wavelength=10000.0, + rope_scaling_factor=1.0, + use_query_key_norm=True, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.head_dim = head_dim + self.hidden_size = hidden_size + self.rms_norm_eps = rms_norm_eps + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.query_pre_attn_scalar = query_pre_attn_scalar + self.attention_bias = attention_bias + self.intermediate_size = intermediate_size + self.hidden_activation = hidden_activation + self.dropout_rate = dropout_rate + self.initializer_range = initializer_range + self.attention_dropout = attention_dropout + self.attention_type = layer_type + self.sliding_window = sliding_window + self.rope_max_wavelength = rope_max_wavelength + self.rope_scaling_factor = rope_scaling_factor + self.cross_attention_hidden_size = cross_attention_hidden_size + self.attn_logit_softcapping = attn_logit_softcapping + self.use_query_key_norm = use_query_key_norm + + if ( + self.attention_type == "sliding_attention" + and self.sliding_window is None + ): + raise ValueError( + "`sliding_window` must be set for `sliding_attention` " + "layer type." + ) + + # Merged self+cross attention. + self.merged_attn = T5Gemma2MergedAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + query_pre_attn_scalar=query_pre_attn_scalar, + attention_bias=attention_bias, + head_dim=self.head_dim, + cross_attention_hidden_size=( + cross_attention_hidden_size or hidden_size + ), + initializer_range=initializer_range, + attention_dropout=attention_dropout, + attn_logit_softcapping=attn_logit_softcapping, + rope_max_wavelength=self.rope_max_wavelength, + rope_scaling_factor=self.rope_scaling_factor, + use_query_key_norm=use_query_key_norm, + rms_norm_eps=rms_norm_eps, + dtype=self.dtype_policy, + name="merged_attention", + ) + self.pre_self_attn_layernorm = RMSNormalization( + epsilon=rms_norm_eps, + dtype=self.dtype_policy, + name="decoder_pre_self_attention_layernorm", + ) + self.post_self_attn_layernorm = RMSNormalization( + epsilon=rms_norm_eps, + dtype=self.dtype_policy, + name="decoder_post_self_attention_layernorm", + ) + + # MLP. + self.mlp = T5Gemma2MLP( + hidden_size, + intermediate_size, + hidden_activation, + dropout_rate, + initializer_range=initializer_range, + dtype=self.dtype_policy, + name="mlp", + ) + self.pre_feedforward_layernorm = RMSNormalization( + epsilon=rms_norm_eps, + dtype=self.dtype_policy, + name="decoder_pre_feedforward_layernorm", + ) + self.post_feedforward_layernorm = RMSNormalization( + epsilon=rms_norm_eps, + dtype=self.dtype_policy, + name="decoder_post_feedforward_layernorm", + ) + + self.dropout = keras.layers.Dropout( + dropout_rate, + dtype=self.dtype_policy, + name="decoder_residual_dropout", + ) + + def build(self, input_shape): + hidden_states_shape, encoder_hidden_states_shape = input_shape + self.pre_self_attn_layernorm.build(hidden_states_shape) + self.merged_attn.build( + [hidden_states_shape, encoder_hidden_states_shape] + ) + attn_output_shape, _ = self.merged_attn.compute_output_shape( + [hidden_states_shape, encoder_hidden_states_shape] + ) + self.post_self_attn_layernorm.build(attn_output_shape) + self.dropout.build(attn_output_shape) + self.pre_feedforward_layernorm.build(attn_output_shape) + self.mlp.build(attn_output_shape) + mlp_output_shape = self.mlp.compute_output_shape(attn_output_shape) + self.post_feedforward_layernorm.build(mlp_output_shape) + self.built = True + + def _make_causal_mask( + self, + hidden_states, + padding_mask, + cache=None, + cache_update_index=None, + ): + """Creates a causal attention mask for self-attention.""" + if cache is not None: + q_len = keras.ops.shape(hidden_states)[1] + kv_len = keras.ops.shape(cache)[2] + q_indices = ( + keras.ops.arange(0, q_len, dtype="int32") + cache_update_index + ) + kv_indices = keras.ops.arange(0, kv_len, dtype="int32") + else: + q_len = kv_len = keras.ops.shape(hidden_states)[1] + q_indices = keras.ops.arange(0, q_len, dtype="int32") + kv_indices = keras.ops.arange(0, kv_len, dtype="int32") + causal_mask = kv_indices[None, :] <= q_indices[:, None] + if self.attention_type == "sliding_attention": + sliding_mask = ( + q_indices[:, None] - self.sliding_window + ) <= kv_indices[None, :] + causal_mask = keras.ops.logical_and(causal_mask, sliding_mask) + final_mask = causal_mask[None, None, :, :] + if padding_mask is not None: + padding_mask_slice = padding_mask[:, :kv_len] + padding_mask_4d = padding_mask_slice[:, None, None, :] + final_mask = keras.ops.logical_and(final_mask, padding_mask_4d) + return (1.0 - keras.ops.cast(final_mask, hidden_states.dtype)) * -1e9 + + def _make_cross_attention_mask(self, hidden_states, padding_mask): + """Creates a bidirectional mask for cross-attention.""" + if padding_mask is None: + return None + q_len = keras.ops.shape(hidden_states)[1] + bidirectional_mask = padding_mask[:, None, None, :] + # Broadcast to (batch, 1, q_len, enc_len). + bidirectional_mask = keras.ops.broadcast_to( + bidirectional_mask, + ( + keras.ops.shape(hidden_states)[0], + 1, + q_len, + keras.ops.shape(padding_mask)[1], + ), + ) + additive_mask = ( + 1.0 - keras.ops.cast(bidirectional_mask, hidden_states.dtype) + ) * -1e9 + return additive_mask + + def call( + self, + inputs, + self_attention_padding_mask=None, + cross_attention_padding_mask=None, + cache=None, + cache_update_index=None, + training=None, + ): + """Forward pass of the decoder layer. + + Args: + inputs: Tuple of (hidden_states, encoder_hidden_states). + self_attention_padding_mask: Padding mask for decoder + tokens. + cross_attention_padding_mask: Padding mask for encoder + tokens. + cache: Tuple of (self_attn_cache, cross_attn_cache). + cache_update_index: int, current position for caching. + training: bool, training mode. + + Returns: + Tuple of (hidden_states, updated_cache). + """ + hidden_states, encoder_hidden_states = inputs + self_attention_cache, cross_attention_cache = ( + cache if cache is not None else (None, None) + ) + + # Build the merged attention mask. + self_attention_mask = self._make_causal_mask( + hidden_states, + self_attention_padding_mask, + cache=self_attention_cache, + cache_update_index=cache_update_index, + ) + cross_attention_mask = self._make_cross_attention_mask( + hidden_states, cross_attention_padding_mask + ) + + # Concatenate self and cross masks along the KV dimension. + if cross_attention_mask is not None: + merged_mask = keras.ops.concatenate( + [self_attention_mask, cross_attention_mask], axis=-1 + ) + else: + merged_mask = self_attention_mask + + # Merged attention: self + cross. + residual = hidden_states + hidden_states = self.pre_self_attn_layernorm(hidden_states) + hidden_states, updated_cache = self.merged_attn( + inputs=hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=merged_mask, + cache=(self_attention_cache, cross_attention_cache), + cache_update_index=cache_update_index, + training=training, + ) + hidden_states = self.post_self_attn_layernorm(hidden_states) + hidden_states = residual + self.dropout( + hidden_states, training=training + ) + + # MLP. + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states, training=training) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + self.dropout( + hidden_states, training=training + ) + return hidden_states, updated_cache + + def compute_output_shape(self, input_shape): + hidden_states_shape, encoder_hidden_states_shape = input_shape + batch_size, dec_seq_len, _ = hidden_states_shape + _, enc_seq_len, _ = encoder_hidden_states_shape + self_cache_shape = ( + batch_size, + 2, + dec_seq_len, + self.num_key_value_heads, + self.head_dim, + ) + cross_cache_shape = ( + batch_size, + 2, + enc_seq_len, + self.num_key_value_heads, + self.head_dim, + ) + return hidden_states_shape, (self_cache_shape, cross_cache_shape) + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "rms_norm_eps": self.rms_norm_eps, + "num_attention_heads": self.num_attention_heads, + "num_key_value_heads": self.num_key_value_heads, + "query_pre_attn_scalar": self.query_pre_attn_scalar, + "attention_bias": self.attention_bias, + "intermediate_size": self.intermediate_size, + "hidden_activation": self.hidden_activation, + "dropout_rate": self.dropout_rate, + "initializer_range": self.initializer_range, + "attention_dropout": self.attention_dropout, + "layer_type": self.attention_type, + "sliding_window": self.sliding_window, + "rope_max_wavelength": self.rope_max_wavelength, + "head_dim": self.head_dim, + "cross_attention_hidden_size": ( + self.cross_attention_hidden_size + ), + "attn_logit_softcapping": self.attn_logit_softcapping, + "use_query_key_norm": self.use_query_key_norm, + } + ) + return config diff --git a/keras_hub/src/models/t5gemma2/t5gemma2_encoder.py b/keras_hub/src/models/t5gemma2/t5gemma2_encoder.py new file mode 100644 index 0000000000..538bf25513 --- /dev/null +++ b/keras_hub/src/models/t5gemma2/t5gemma2_encoder.py @@ -0,0 +1,246 @@ +import keras + +from keras_hub.src.models.gemma3.gemma3_layers import RMSNormalization +from keras_hub.src.models.t5gemma2.t5gemma2_attention import T5Gemma2Attention +from keras_hub.src.models.t5gemma2.t5gemma2_layers import T5Gemma2MLP + + +class T5Gemma2EncoderLayer(keras.layers.Layer): + """Encoder layer for the T5Gemma2 model. + + This layer implements a single encoder block in the T5Gemma2 + architecture, comprising bidirectional self-attention and a + feed-forward network (MLP). It uses Gemma3-style Q/K normalization. + + Each encoder layer has an `attention_type` attribute that specifies + whether it uses `"full_attention"` or `"sliding_attention"`. The + backbone uses this to route the correct RoPE embeddings and + attention masks. + + Args: + hidden_size: int, Dimensionality of hidden states. + rms_norm_eps: float, Epsilon for RMS normalization. + num_attention_heads: int, Number of attention heads. + num_key_value_heads: int, Number of key-value heads for GQA. + query_pre_attn_scalar: float, Scalar for query normalization. + attention_bias: bool, Whether to include bias. + intermediate_size: int, Intermediate size of the FFN. + hidden_activation: str, Activation function for the FFN. + dropout_rate: float, Dropout rate. + initializer_range: float, Range for the initializer. + attention_dropout: float, Dropout for attention weights. + layer_type: str, Either `"full_attention"` or + `"sliding_attention"`. + head_dim: int, Dimensionality of each attention head. + attn_logit_softcapping: float, optional, Softcapping value. + sliding_window: int, optional, Window size for sliding + attention. + rope_max_wavelength: float, Maximum wavelength for RoPE. + use_query_key_norm: bool, Whether to apply Q/K norm. + Defaults to `True`. + dtype: The dtype for computations. Defaults to `None`. + **kwargs: Additional keyword arguments. + """ + + def __init__( + self, + hidden_size, + rms_norm_eps, + num_attention_heads, + num_key_value_heads, + query_pre_attn_scalar, + attention_bias, + intermediate_size, + hidden_activation, + dropout_rate, + initializer_range, + attention_dropout, + layer_type, + head_dim, + attn_logit_softcapping=None, + sliding_window=None, + rope_max_wavelength=10000.0, + rope_scaling_factor=1.0, + use_query_key_norm=True, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_size = hidden_size + self.rms_norm_eps = rms_norm_eps + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.query_pre_attn_scalar = query_pre_attn_scalar + self.attention_bias = attention_bias + self.intermediate_size = intermediate_size + self.hidden_activation = hidden_activation + self.dropout_rate = dropout_rate + self.initializer_range = initializer_range + self.attention_dropout = attention_dropout + self.attention_type = layer_type + self.sliding_window = sliding_window + self.rope_max_wavelength = rope_max_wavelength + self.rope_scaling_factor = rope_scaling_factor + self.head_dim = head_dim + self.attn_logit_softcapping = attn_logit_softcapping + self.use_query_key_norm = use_query_key_norm + + if ( + self.attention_type == "sliding_attention" + and self.sliding_window is None + ): + raise ValueError( + "`sliding_window` must be set for `sliding_attention` " + "layer type." + ) + + self.self_attn = T5Gemma2Attention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + query_pre_attn_scalar=query_pre_attn_scalar, + attention_bias=attention_bias, + head_dim=self.head_dim, + initializer_range=initializer_range, + attention_dropout=attention_dropout, + attn_logit_softcapping=attn_logit_softcapping, + rope_max_wavelength=self.rope_max_wavelength, + rope_scaling_factor=self.rope_scaling_factor, + use_query_key_norm=use_query_key_norm, + rms_norm_eps=rms_norm_eps, + dtype=self.dtype_policy, + name="self_attention", + ) + self.pre_self_attn_layernorm = RMSNormalization( + epsilon=rms_norm_eps, + dtype=self.dtype_policy, + name="pre_self_attention_layernorm", + ) + self.post_self_attn_layernorm = RMSNormalization( + epsilon=rms_norm_eps, + dtype=self.dtype_policy, + name="post_self_attention_layernorm", + ) + + self.mlp = T5Gemma2MLP( + hidden_size, + intermediate_size, + hidden_activation, + dropout_rate, + initializer_range=initializer_range, + dtype=self.dtype_policy, + name="mlp", + ) + self.pre_feedforward_layernorm = RMSNormalization( + epsilon=rms_norm_eps, + dtype=self.dtype_policy, + name="pre_feedforward_layernorm", + ) + self.post_feedforward_layernorm = RMSNormalization( + epsilon=rms_norm_eps, + dtype=self.dtype_policy, + name="post_feedforward_layernorm", + ) + self.dropout = keras.layers.Dropout( + dropout_rate, + dtype=self.dtype_policy, + name="residual_dropout", + ) + + def build(self, input_shape): + self.pre_self_attn_layernorm.build(input_shape) + self.self_attn.build(input_shape) + attn_output_shape, _ = self.self_attn.compute_output_shape(input_shape) + self.post_self_attn_layernorm.build(attn_output_shape) + self.dropout.build(attn_output_shape) + self.pre_feedforward_layernorm.build(attn_output_shape) + self.mlp.build(attn_output_shape) + mlp_output_shape = self.mlp.compute_output_shape(attn_output_shape) + self.post_feedforward_layernorm.build(mlp_output_shape) + self.built = True + + def _make_attention_mask(self, hidden_states, padding_mask): + attention_mask = padding_mask[:, None, None, :] + additive_mask = ( + 1.0 - keras.ops.cast(attention_mask, hidden_states.dtype) + ) * -1e9 + # Apply bidirectional sliding window for sliding_attention layers. + if ( + self.attention_type == "sliding_attention" + and self.sliding_window is not None + ): + seq_len = keras.ops.shape(hidden_states)[1] + # Build position indices for the sliding window mask. + q_idx = keras.ops.arange(seq_len)[:, None] # (S, 1) + kv_idx = keras.ops.arange(seq_len)[None, :] # (1, S) + dist = q_idx - kv_idx + # HF bidirectional window: + # left_window = (sliding_window + 1) // 2 + # right_window = sliding_window // 2 + 1 + left_w = (self.sliding_window + 1) // 2 + right_w = self.sliding_window // 2 + 1 + window_mask = ((dist >= 0) & (dist < left_w)) | ( + (dist < 0) & (-dist < right_w) + ) + # Expand to (1, 1, S, S) and convert to additive mask. + window_mask = keras.ops.cast( + window_mask[None, None, :, :], hidden_states.dtype + ) + window_additive = (1.0 - window_mask) * -1e9 + additive_mask = additive_mask + window_additive + return additive_mask + + def call( + self, + hidden_states, + padding_mask=None, + training=None, + ): + residual = hidden_states + attention_mask = self._make_attention_mask(hidden_states, padding_mask) + hidden_states = self.pre_self_attn_layernorm(hidden_states) + hidden_states, _ = self.self_attn( + inputs=hidden_states, + attention_mask=attention_mask, + training=training, + ) + hidden_states = self.post_self_attn_layernorm(hidden_states) + hidden_states = residual + self.dropout( + hidden_states, training=training + ) + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states, training=training) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + self.dropout( + hidden_states, training=training + ) + return hidden_states + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "rms_norm_eps": self.rms_norm_eps, + "head_dim": self.head_dim, + "num_attention_heads": self.num_attention_heads, + "num_key_value_heads": self.num_key_value_heads, + "query_pre_attn_scalar": self.query_pre_attn_scalar, + "attention_bias": self.attention_bias, + "intermediate_size": self.intermediate_size, + "hidden_activation": self.hidden_activation, + "dropout_rate": self.dropout_rate, + "initializer_range": self.initializer_range, + "attention_dropout": self.attention_dropout, + "layer_type": self.attention_type, + "sliding_window": self.sliding_window, + "rope_max_wavelength": self.rope_max_wavelength, + "attn_logit_softcapping": self.attn_logit_softcapping, + "use_query_key_norm": self.use_query_key_norm, + } + ) + return config diff --git a/keras_hub/src/models/t5gemma2/t5gemma2_image_converter.py b/keras_hub/src/models/t5gemma2/t5gemma2_image_converter.py new file mode 100644 index 0000000000..40ec3c6785 --- /dev/null +++ b/keras_hub/src/models/t5gemma2/t5gemma2_image_converter.py @@ -0,0 +1,14 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter +from keras_hub.src.models.t5gemma2.t5gemma2_backbone import T5Gemma2Backbone + + +@keras_hub_export("keras_hub.layers.T5Gemma2ImageConverter") +class T5Gemma2ImageConverter(ImageConverter): + backbone_cls = T5Gemma2Backbone + + def __init__(self, **kwargs): + # Always do image preprocessing in float32 + kwargs.pop("dtype", None) + dtype = "float32" + super().__init__(dtype=dtype, **kwargs) diff --git a/keras_hub/src/models/t5gemma2/t5gemma2_layers.py b/keras_hub/src/models/t5gemma2/t5gemma2_layers.py new file mode 100644 index 0000000000..3329af7e06 --- /dev/null +++ b/keras_hub/src/models/t5gemma2/t5gemma2_layers.py @@ -0,0 +1,118 @@ +import keras + +from keras_hub.src.utils.keras_utils import clone_initializer + + +def t5gemma2_kernel_initializer(initializer_range=0.01): + """Creates a RandomNormal initializer for T5Gemma2 kernels. + + Args: + initializer_range: float, The standard deviation of the normal + distribution. Defaults to `0.01`. + + Returns: + keras.initializers.RandomNormal: A Keras RandomNormal initializer. + """ + return keras.initializers.RandomNormal(mean=0.0, stddev=initializer_range) + + +class T5Gemma2MLP(keras.layers.Layer): + """Multilayer Perceptron (MLP) block for the T5Gemma2 model. + + This layer implements the feed-forward part of a transformer block, + consisting of a gated GELU activation and dropout, following the + Gemma3 architecture pattern. + + Args: + hidden_size: int, The dimensionality of the input and output hidden + states. + intermediate_size: int, The dimensionality of the intermediate layer. + hidden_activation: str, The activation function to use, e.g., + "gelu_approximate". + dropout_rate: float, The dropout rate applied to the intermediate + hidden states. + initializer_range: float, The range for the random normal initializer + for kernel weights. Defaults to `0.02`. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for model computations and weights. Defaults to `None`. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + def __init__( + self, + hidden_size, + intermediate_size, + hidden_activation, + dropout_rate, + initializer_range=0.02, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.hidden_activation = hidden_activation + self.dropout_rate = dropout_rate + self.initializer_range = initializer_range + self.kernel_initializer = t5gemma2_kernel_initializer(initializer_range) + + self.gate_proj = keras.layers.Dense( + self.intermediate_size, + use_bias=False, + kernel_initializer=clone_initializer(self.kernel_initializer), + dtype=self.dtype_policy, + name="gate_proj", + ) + self.up_proj = keras.layers.Dense( + self.intermediate_size, + use_bias=False, + kernel_initializer=clone_initializer(self.kernel_initializer), + dtype=self.dtype_policy, + name="up_proj", + ) + self.down_proj = keras.layers.Dense( + self.hidden_size, + use_bias=False, + kernel_initializer=clone_initializer(self.kernel_initializer), + dtype=self.dtype_policy, + name="down_proj", + ) + if self.hidden_activation == "gelu_approximate": + self.act_fn = lambda x: keras.activations.gelu(x, approximate=True) + else: + self.act_fn = keras.activations.get(self.hidden_activation) + self.dropout = keras.layers.Dropout( + self.dropout_rate, + dtype=self.dtype_policy, + name="mlp_dropout", + ) + + def build(self, input_shape): + self.gate_proj.build(input_shape) + self.up_proj.build(input_shape) + intermediate_shape = self.gate_proj.compute_output_shape(input_shape) + self.dropout.build(intermediate_shape) + self.down_proj.build(intermediate_shape) + self.built = True + + def compute_output_shape(self, input_shape): + return input_shape + + def call(self, x, training=None): + hidden_states = self.act_fn(self.gate_proj(x)) * self.up_proj(x) + hidden_states = self.dropout(hidden_states, training=training) + down_proj = self.down_proj(hidden_states) + return down_proj + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "intermediate_size": self.intermediate_size, + "hidden_activation": self.hidden_activation, + "dropout_rate": self.dropout_rate, + "initializer_range": self.initializer_range, + } + ) + return config diff --git a/keras_hub/src/models/t5gemma2/t5gemma2_presets.py b/keras_hub/src/models/t5gemma2/t5gemma2_presets.py new file mode 100644 index 0000000000..52a66ceed6 --- /dev/null +++ b/keras_hub/src/models/t5gemma2/t5gemma2_presets.py @@ -0,0 +1,4 @@ +backbone_presets = { + # Placeholder presets — will be populated when checkpoints + # are made available. +} diff --git a/keras_hub/src/models/t5gemma2/t5gemma2_seq_2_seq_lm.py b/keras_hub/src/models/t5gemma2/t5gemma2_seq_2_seq_lm.py new file mode 100644 index 0000000000..025228aa61 --- /dev/null +++ b/keras_hub/src/models/t5gemma2/t5gemma2_seq_2_seq_lm.py @@ -0,0 +1,331 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM +from keras_hub.src.models.t5gemma2.t5gemma2_backbone import T5Gemma2Backbone +from keras_hub.src.models.t5gemma2.t5gemma2_seq_2_seq_lm_preprocessor import ( + T5Gemma2Seq2SeqLMPreprocessor, +) +from keras_hub.src.utils.tensor_utils import any_equal + + +@keras_hub_export("keras_hub.models.T5Gemma2Seq2SeqLM") +class T5Gemma2Seq2SeqLM(Seq2SeqLM): + """An end-to-end T5Gemma2 model for seq2seq language modeling. + + A seq2seq language model (LM) is an encoder-decoder model which is + used for conditional text generation. The encoder is given a + "context" text (fed to the encoder), and the decoder predicts the + next token based on both the encoder inputs and the previous tokens. + + T5Gemma2 extends T5Gemma1 by using Gemma3-based components, with + merged self+cross attention in the decoder, Gemma3-style Q/K + normalization, and per-layer-type RoPE. + + This model has a `generate()` method, which generates text based on + a prompt. The generation strategy used is controlled by an additional + `sampler` argument on `compile()`. + + Args: + backbone: A `keras_hub.models.T5Gemma2Backbone` instance. + preprocessor: A `keras_hub.models.T5Gemma2Seq2SeqLMPreprocessor` + or `None`. Defaults to `None`. + + Examples: + + Use `generate()` to do text generation. + ```python + t5gemma2_lm = keras_hub.models.T5Gemma2Seq2SeqLM.from_preset( + "t5gemma2_270m_270m" + ) + t5gemma2_lm.generate( + "The quick brown fox jumped.", max_length=30 + ) + ``` + + Custom backbone and vocabulary. + ```python + tokenizer = keras_hub.models.T5Gemma2Tokenizer( + proto="proto.spm", + ) + preprocessor = keras_hub.models.T5Gemma2Seq2SeqLMPreprocessor( + tokenizer=tokenizer, + encoder_sequence_length=128, + decoder_sequence_length=128, + ) + backbone = keras_hub.models.T5Gemma2Backbone( + vocabulary_size=32000, + encoder_hidden_dim=256, + encoder_intermediate_dim=512, + encoder_num_layers=4, + encoder_num_attention_heads=4, + encoder_num_key_value_heads=2, + encoder_head_dim=64, + encoder_layer_types=["full_attention"] * 4, + decoder_hidden_dim=256, + decoder_intermediate_dim=512, + decoder_num_layers=4, + decoder_num_attention_heads=4, + decoder_num_key_value_heads=2, + decoder_head_dim=64, + decoder_layer_types=["full_attention"] * 4, + dropout_rate=0.1, + rms_norm_eps=1e-6, + query_pre_attn_scalar=1.0, + attention_bias=False, + hidden_activation="gelu_approximate", + ) + t5gemma2_lm = keras_hub.models.T5Gemma2Seq2SeqLM( + backbone=backbone, + preprocessor=preprocessor, + ) + ``` + """ + + backbone_cls = T5Gemma2Backbone + preprocessor_cls = T5Gemma2Seq2SeqLMPreprocessor + + def __init__(self, backbone, preprocessor=None, **kwargs): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + # === Functional Model === + inputs = backbone.input + sequence_output = backbone(inputs)["decoder_sequence_output"] + logits = backbone.decoder_token_embedding(sequence_output, reverse=True) + if self.backbone.final_logit_softcapping is not None: + logits = logits / self.backbone.final_logit_softcapping + logits = keras.ops.tanh(logits) + logits = logits * self.backbone.final_logit_softcapping + super().__init__( + inputs=inputs, + outputs=logits, + **kwargs, + ) + + def call_encoder(self, token_ids, padding_mask): + """Process inputs through the encoder stack.""" + encoder_embeddings = self.backbone.token_embedding(token_ids) + encoder_embeddings *= keras.ops.cast( + keras.ops.sqrt(self.backbone.encoder_hidden_dim), + encoder_embeddings.dtype, + ) + encoder_hidden_states = self.backbone.encoder_dropout( + encoder_embeddings, training=False + ) + for layer in self.backbone.encoder_layers: + encoder_hidden_states = layer( + encoder_hidden_states, + padding_mask=padding_mask, + training=False, + ) + encoder_output = self.backbone.encoder_norm(encoder_hidden_states) + encoder_output = self.backbone.encoder_dropout( + encoder_output, training=False + ) + return encoder_output, padding_mask + + def call_decoder_with_cache( + self, + decoder_token_ids, + decoder_padding_mask, + cache, + cache_update_index, + encoder_output, + encoder_padding_mask, + ): + """Forward pass of the decoder with cache. + + `call_decoder_with_cache` adds an additional forward pass for + autoregressive inference. The cache stores previous key/value + tensors in the attention layers. + + Args: + decoder_token_ids: Dense int Tensor of shape + `(batch_size, max_length)`. + decoder_padding_mask: Dense int Tensor of shape + `(batch_size, max_length)`. + cache: Dense float Tensor, the cache of key/value states. + cache_update_index: int or int Tensor. + encoder_output: Dense float Tensor, encoder output. + encoder_padding_mask: Dense int Tensor. + + Returns: + Tuple of (logits, hidden_states, updated_cache). + """ + self_attention_cache, cross_attention_cache = cache + hidden_states = self.backbone.decoder_token_embedding(decoder_token_ids) + hidden_states *= keras.ops.cast( + keras.ops.sqrt(self.backbone.decoder_hidden_dim), + hidden_states.dtype, + ) + hidden_states = self.backbone.decoder_dropout( + hidden_states, training=False + ) + updated_self_attention_caches = [] + updated_cross_attention_caches = [] + for i, layer in enumerate(self.backbone.decoder_layers): + layer_self_cache = ( + self_attention_cache[:, i, ...] + if self_attention_cache is not None + else None + ) + layer_cross_cache = ( + cross_attention_cache[:, i, ...] + if cross_attention_cache is not None + else None + ) + layer_cache = (layer_self_cache, layer_cross_cache) + hidden_states, updated_layer_cache = layer( + (hidden_states, encoder_output), + self_attention_padding_mask=decoder_padding_mask, + cross_attention_padding_mask=encoder_padding_mask, + cache=layer_cache, + cache_update_index=cache_update_index, + training=False, + ) + new_self_cache, new_cross_cache = updated_layer_cache + updated_self_attention_caches.append(new_self_cache) + updated_cross_attention_caches.append(new_cross_cache) + self_attention_cache = keras.ops.stack( + updated_self_attention_caches, axis=1 + ) + cross_attention_cache = keras.ops.stack( + updated_cross_attention_caches, axis=1 + ) + hidden_states = self.backbone.decoder_norm(hidden_states) + logits = self.backbone.decoder_token_embedding( + hidden_states, reverse=True + ) + if self.backbone.final_logit_softcapping is not None: + logits = logits / self.backbone.final_logit_softcapping + logits = keras.ops.tanh(logits) + logits = logits * self.backbone.final_logit_softcapping + return ( + logits, + hidden_states, + (self_attention_cache, cross_attention_cache), + ) + + def _build_cache( + self, + encoder_token_ids, + encoder_padding_mask, + decoder_token_ids, + decoder_padding_mask, + ): + """Build an empty cache for use with `call_with_cache()`.""" + encoder_output, encoder_padding_mask = self.call_encoder( + encoder_token_ids, encoder_padding_mask + ) + batch_size = keras.ops.shape(decoder_token_ids)[0] + num_layers = self.backbone.decoder_num_layers + num_kv_heads = self.backbone.decoder_num_key_value_heads + head_dim = self.backbone.decoder_head_dim + self_cache_shape = ( + batch_size, + num_layers, + 2, + keras.ops.shape(decoder_token_ids)[1], + num_kv_heads, + head_dim, + ) + self_attention_cache = keras.ops.zeros( + self_cache_shape, dtype=self.compute_dtype + ) + cross_attention_cache = None + _, hidden_states, cache = self.call_decoder_with_cache( + decoder_token_ids=decoder_token_ids, + decoder_padding_mask=decoder_padding_mask, + cache=(self_attention_cache, cross_attention_cache), + cache_update_index=0, + encoder_output=encoder_output, + encoder_padding_mask=encoder_padding_mask, + ) + extra_cache_info = (encoder_output, encoder_padding_mask) + return hidden_states, cache, extra_cache_info + + def generate_step(self, inputs, stop_token_ids=None): + """A compilable generation function for a single batch. + + Args: + inputs: A dictionary with keys `"encoder_token_ids"`, + `"encoder_padding_mask"`, `"decoder_token_ids"`, and + `"decoder_padding_mask"`. + stop_token_ids: Tuple of end token ids to stop on. + """ + encoder_token_ids = inputs["encoder_token_ids"] + encoder_padding_mask = inputs["encoder_padding_mask"] + decoder_token_ids = inputs["decoder_token_ids"] + decoder_padding_mask = inputs["decoder_padding_mask"] + hidden_states, cache, extra_cache_info = self._build_cache( + encoder_token_ids=encoder_token_ids, + encoder_padding_mask=encoder_padding_mask, + decoder_token_ids=decoder_token_ids, + decoder_padding_mask=decoder_padding_mask, + ) + encoder_output, encoder_padding_mask = extra_cache_info + row_lengths = keras.ops.sum( + keras.ops.cast(decoder_padding_mask, "int32"), axis=-1 + ) + index = keras.ops.min(row_lengths) + + def next(prompt, cache, index): + cache_update_index = index - 1 + batch_size = keras.ops.shape(prompt)[0] + prompt = keras.ops.slice( + prompt, [0, cache_update_index], [batch_size, 1] + ) + ( + logits, + _, + updated_cache, + ) = self.call_decoder_with_cache( + decoder_token_ids=prompt, + decoder_padding_mask=None, + cache_update_index=cache_update_index, + cache=cache, + encoder_output=encoder_output, + encoder_padding_mask=encoder_padding_mask, + ) + return ( + keras.ops.squeeze(logits, axis=1), + None, + updated_cache, + ) + + decoder_token_ids = self.sampler( + next=next, + prompt=decoder_token_ids, + cache=cache, + index=index, + mask=decoder_padding_mask, + stop_token_ids=stop_token_ids, + hidden_states=hidden_states, + model=self, + ) + + if stop_token_ids is not None: + end_locations = any_equal( + decoder_token_ids, + stop_token_ids, + keras.ops.logical_not(decoder_padding_mask), + ) + end_locations = keras.ops.cast(end_locations, "int32") + cumsum = keras.ops.cast( + keras.ops.cumsum(end_locations, axis=-1), "int32" + ) + overflow = cumsum - end_locations + decoder_padding_mask = keras.ops.logical_not( + keras.ops.cast(overflow, "bool") + ) + else: + decoder_padding_mask = keras.ops.ones_like( + decoder_token_ids, dtype="bool" + ) + + return { + "decoder_token_ids": decoder_token_ids, + "decoder_padding_mask": decoder_padding_mask, + } diff --git a/keras_hub/src/models/t5gemma2/t5gemma2_seq_2_seq_lm_preprocessor.py b/keras_hub/src/models/t5gemma2/t5gemma2_seq_2_seq_lm_preprocessor.py new file mode 100644 index 0000000000..07757eff85 --- /dev/null +++ b/keras_hub/src/models/t5gemma2/t5gemma2_seq_2_seq_lm_preprocessor.py @@ -0,0 +1,206 @@ +import keras +import numpy as np + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor +from keras_hub.src.models.t5gemma2.t5gemma2_backbone import T5Gemma2Backbone +from keras_hub.src.models.t5gemma2.t5gemma2_tokenizer import T5Gemma2Tokenizer +from keras_hub.src.utils.tensor_utils import preprocessing_function + +try: + import tensorflow as tf +except ImportError: + tf = None + + +@keras_hub_export("keras_hub.models.T5Gemma2Seq2SeqLMPreprocessor") +class T5Gemma2Seq2SeqLMPreprocessor(Seq2SeqLMPreprocessor): + """T5Gemma2 Seq2Seq LM preprocessor. + + This preprocessing layer is meant for use with + `keras_hub.models.T5Gemma2Seq2SeqLM`. By default, it will take in + batches of strings, and return outputs in a + `(x, y, sample_weight)` format, where the `y` label is the next + token id in the `x` sequence. + + For use with generation, the layer also exposes two methods + `generate_preprocess()` and `generate_postprocess()`. When this + preprocessor is attached to a `keras_hub.models.T5Gemma2Seq2SeqLM` + instance, these methods will be called implicitly in `generate()`. + + When an `image_converter` is provided, the preprocessor also + supports multimodal inputs with images. Images are inserted into + the encoder sequence as placeholder tokens that the backbone's + vision encoder will replace with image embeddings. + + Args: + tokenizer: A `keras_hub.models.T5Gemma2Tokenizer` instance. + encoder_sequence_length: The length of the packed encoder inputs. + decoder_sequence_length: The length of the packed decoder inputs. + image_converter: A `keras_hub.layers.ImageConverter` instance, + or `None` for text-only. Defaults to `None`. + add_start_token: If `True`, prepend the start token. Defaults + to `False`. + add_end_token: If `True`, append the end token. Defaults to + `True`. + """ + + backbone_cls = T5Gemma2Backbone + tokenizer_cls = T5Gemma2Tokenizer + + def __init__( + self, + tokenizer, + encoder_sequence_length=512, + decoder_sequence_length=512, + image_converter=None, + image_size=None, + num_vision_tokens_per_image=None, + add_start_token=False, + add_end_token=True, + **kwargs, + ): + super().__init__( + tokenizer=tokenizer, + encoder_sequence_length=encoder_sequence_length, + decoder_sequence_length=decoder_sequence_length, + **kwargs, + ) + self.add_start_token = add_start_token + self.add_end_token = add_end_token + self.image_converter = image_converter + self._vision_image_size = image_size + self.num_vision_tokens_per_image = num_vision_tokens_per_image + + def _add_vision_inputs(self, x, batch_size): + """Add dummy image/vision_indices for multimodal text-only input. + + When a multimodal backbone (with vision encoder) is used for + text-only inference, the functional model still requires + `images` and `vision_indices` inputs. This method provides + dummy values that act as a no-op: InterleaveEmbeddings + restores position 0 after scattering zero-indexed updates. + """ + if self._vision_image_size is not None and "images" not in x: + x["images"] = np.zeros( + ( + batch_size, + 1, + self._vision_image_size, + self._vision_image_size, + 3, + ), + dtype="float32", + ) + x["vision_indices"] = np.zeros( + (batch_size, self.num_vision_tokens_per_image), + dtype="int32", + ) + return x + + @preprocessing_function + def call( + self, + x, + y=None, + sample_weight=None, + *, + encoder_sequence_length=None, + decoder_sequence_length=None, + sequence_length=None, + ): + if encoder_sequence_length is None: + encoder_sequence_length = self.encoder_sequence_length + decoder_sequence_length = decoder_sequence_length or sequence_length + if decoder_sequence_length is None: + decoder_sequence_length = self.decoder_sequence_length + + encoder_inputs = self.tokenizer(x["encoder_text"]) + encoder_token_ids, encoder_padding_mask = self.encoder_packer( + encoder_inputs, + sequence_length=encoder_sequence_length, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + decoder_inputs = self.tokenizer(x["decoder_text"]) + decoder_token_ids, decoder_padding_mask = self.decoder_packer( + decoder_inputs, + sequence_length=decoder_sequence_length + 1, + add_start_value=True, + add_end_value=self.add_end_token, + ) + batch_size = tf.shape(encoder_token_ids)[0] + x = { + "encoder_token_ids": encoder_token_ids, + "encoder_padding_mask": encoder_padding_mask, + "decoder_token_ids": decoder_token_ids[..., :-1], + "decoder_padding_mask": decoder_padding_mask[..., :-1], + } + x = self._add_vision_inputs(x, batch_size) + y = decoder_token_ids[..., 1:] + sample_weight = decoder_padding_mask[..., 1:] + return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) + + @preprocessing_function + def generate_preprocess( + self, + x, + *, + encoder_sequence_length=None, + decoder_sequence_length=None, + sequence_length=None, + ): + if not self.built: + self.build(None) + + if isinstance(x, dict): + encoder_text = x["encoder_text"] + decoder_text = x["decoder_text"] + else: + encoder_text = x + decoder_text = tf.fill((tf.shape(encoder_text)[0],), "") + + if encoder_sequence_length is None: + encoder_sequence_length = self.encoder_sequence_length + decoder_sequence_length = decoder_sequence_length or sequence_length + if decoder_sequence_length is None: + decoder_sequence_length = self.decoder_sequence_length + + encoder_token_ids = self.tokenizer(encoder_text) + encoder_token_ids, encoder_padding_mask = self.encoder_packer( + encoder_token_ids, + sequence_length=None, + add_start_value=self.add_start_token, + add_end_value=False, + ) + + decoder_token_ids = self.tokenizer(decoder_text) + decoder_token_ids, decoder_padding_mask = self.decoder_packer( + decoder_token_ids, + sequence_length=decoder_sequence_length, + add_start_value=True, + add_end_value=False, + ) + + batch_size = tf.shape(encoder_token_ids)[0] + out = { + "encoder_token_ids": encoder_token_ids, + "encoder_padding_mask": encoder_padding_mask, + "decoder_token_ids": decoder_token_ids, + "decoder_padding_mask": decoder_padding_mask, + } + return self._add_vision_inputs(out, batch_size) + + def get_config(self): + config = super().get_config() + config.update( + { + "add_start_token": self.add_start_token, + "add_end_token": self.add_end_token, + "image_size": self._vision_image_size, + "num_vision_tokens_per_image": ( + self.num_vision_tokens_per_image + ), + } + ) + return config diff --git a/keras_hub/src/models/t5gemma2/t5gemma2_seq_2_seq_lm_preprocessor_test.py b/keras_hub/src/models/t5gemma2/t5gemma2_seq_2_seq_lm_preprocessor_test.py new file mode 100644 index 0000000000..1a64d7ef47 --- /dev/null +++ b/keras_hub/src/models/t5gemma2/t5gemma2_seq_2_seq_lm_preprocessor_test.py @@ -0,0 +1,132 @@ +import os + +import numpy as np +import pytest + +from keras_hub.src.models.t5gemma2.t5gemma2_seq_2_seq_lm_preprocessor import ( + T5Gemma2Seq2SeqLMPreprocessor, +) +from keras_hub.src.models.t5gemma2.t5gemma2_tokenizer import T5Gemma2Tokenizer +from keras_hub.src.tests.test_case import TestCase + + +class T5Gemma2Seq2SeqLMPreprocessorTest(TestCase): + def setUp(self): + self.tokenizer = T5Gemma2Tokenizer( + proto=os.path.join( + self.get_test_data_dir(), "gemma3_test_vocab.spm" + ) + ) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "encoder_sequence_length": 8, + "decoder_sequence_length": 8, + } + self.input_data = ( + { + "encoder_text": ["the quick brown fox"], + "decoder_text": ["the earth is round"], + }, + ) + + def test_preprocessor_basics(self): + preprocessor = T5Gemma2Seq2SeqLMPreprocessor(**self.init_kwargs) + output = preprocessor(*self.input_data) + x, y, sample_weight = output + + # Verify output keys. + self.assertIn("encoder_token_ids", x) + self.assertIn("encoder_padding_mask", x) + self.assertIn("decoder_token_ids", x) + self.assertIn("decoder_padding_mask", x) + + # Verify shapes. + self.assertEqual(x["encoder_token_ids"].shape[-1], 8) + self.assertEqual(x["decoder_token_ids"].shape[-1], 8) + + def test_generate_preprocess(self): + preprocessor = T5Gemma2Seq2SeqLMPreprocessor(**self.init_kwargs) + input_data = { + "encoder_text": ["the quick brown fox"], + "decoder_text": ["the earth is round"], + } + output = preprocessor.generate_preprocess(input_data) + self.assertIn("encoder_token_ids", output) + self.assertIn("encoder_padding_mask", output) + self.assertIn("decoder_token_ids", output) + self.assertIn("decoder_padding_mask", output) + + def test_generate_postprocess(self): + preprocessor = T5Gemma2Seq2SeqLMPreprocessor(**self.init_kwargs) + input_data = { + "decoder_token_ids": [2, 9, 14, 10, 1], + "decoder_padding_mask": [1, 1, 1, 1, 1], + } + output = preprocessor.generate_postprocess(input_data) + self.assertIsInstance(output, str) + + def test_add_vision_inputs_multimodal(self): + """Multimodal preprocessor should add dummy vision inputs + when text-only is used for inference.""" + preprocessor = T5Gemma2Seq2SeqLMPreprocessor( + **self.init_kwargs, + image_size=64, + num_vision_tokens_per_image=16, + ) + x = { + "encoder_token_ids": np.ones((2, 8), dtype="int32"), + "encoder_padding_mask": np.ones((2, 8), dtype="int32"), + "decoder_token_ids": np.ones((2, 8), dtype="int32"), + "decoder_padding_mask": np.ones((2, 8), dtype="int32"), + } + result = preprocessor._add_vision_inputs(x, batch_size=2) + + # Should add dummy images and vision indices. + self.assertIn("images", result) + self.assertIn("vision_indices", result) + self.assertEqual(result["images"].shape, (2, 1, 64, 64, 3)) + self.assertEqual(result["vision_indices"].shape, (2, 16)) + # Dummy values should be all zeros. + np.testing.assert_array_equal( + result["images"], np.zeros_like(result["images"]) + ) + np.testing.assert_array_equal( + result["vision_indices"], + np.zeros_like(result["vision_indices"]), + ) + + def test_add_vision_inputs_skips_when_images_present(self): + """Should not overwrite existing images.""" + preprocessor = T5Gemma2Seq2SeqLMPreprocessor( + **self.init_kwargs, + image_size=64, + num_vision_tokens_per_image=16, + ) + existing_images = np.ones((1, 1, 64, 64, 3), dtype="float32") + x = { + "encoder_token_ids": np.ones((1, 8), dtype="int32"), + "images": existing_images, + } + result = preprocessor._add_vision_inputs(x, batch_size=1) + np.testing.assert_array_equal(result["images"], existing_images) + + def test_serialization(self): + preprocessor = T5Gemma2Seq2SeqLMPreprocessor( + **self.init_kwargs, + image_size=128, + num_vision_tokens_per_image=32, + ) + config = preprocessor.get_config() + self.assertEqual(config["image_size"], 128) + self.assertEqual(config["num_vision_tokens_per_image"], 32) + self.assertEqual(config["add_start_token"], False) + self.assertEqual(config["add_end_token"], True) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in T5Gemma2Seq2SeqLMPreprocessor.presets: + self.run_preset_test( + cls=T5Gemma2Seq2SeqLMPreprocessor, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/t5gemma2/t5gemma2_seq_2_seq_lm_test.py b/keras_hub/src/models/t5gemma2/t5gemma2_seq_2_seq_lm_test.py new file mode 100644 index 0000000000..7bf5ab6e28 --- /dev/null +++ b/keras_hub/src/models/t5gemma2/t5gemma2_seq_2_seq_lm_test.py @@ -0,0 +1,98 @@ +import keras +import pytest + +from keras_hub.src.models.t5gemma2.t5gemma2_backbone import T5Gemma2Backbone +from keras_hub.src.models.t5gemma2.t5gemma2_seq_2_seq_lm import ( + T5Gemma2Seq2SeqLM, +) +from keras_hub.src.tests.test_case import TestCase + + +class T5Gemma2Seq2SeqLMTest(TestCase): + def setUp(self): + self.backbone_kwargs = { + "vocabulary_size": 100, + "encoder_hidden_dim": 32, + "encoder_intermediate_dim": 64, + "encoder_num_layers": 2, + "encoder_num_attention_heads": 4, + "encoder_num_key_value_heads": 2, + "encoder_head_dim": 8, + "encoder_layer_types": [ + "full_attention", + "full_attention", + ], + "decoder_hidden_dim": 32, + "decoder_intermediate_dim": 64, + "decoder_num_layers": 2, + "decoder_num_attention_heads": 4, + "decoder_num_key_value_heads": 2, + "decoder_head_dim": 8, + "decoder_layer_types": [ + "full_attention", + "full_attention", + ], + "dropout_rate": 0.0, + "rms_norm_eps": 1e-6, + "tie_word_embeddings": True, + "query_pre_attn_scalar": 1.0, + "attention_bias": False, + "hidden_activation": "gelu_approximate", + "cross_attention_hidden_size": 32, + "rope_max_wavelength": 10000.0, + "initializer_range": 0.04, + "use_query_key_norm": True, + } + self.input_data = { + "encoder_token_ids": keras.ops.ones((2, 16), dtype="int32"), + "encoder_padding_mask": keras.ops.ones((2, 16), dtype="int32"), + "decoder_token_ids": keras.ops.ones((2, 16), dtype="int32"), + "decoder_padding_mask": keras.ops.ones((2, 16), dtype="int32"), + } + + def test_seq2seq_lm_basics(self): + backbone = T5Gemma2Backbone(**self.backbone_kwargs) + lm = T5Gemma2Seq2SeqLM(backbone=backbone) + output = lm(self.input_data) + self.assertEqual(keras.ops.shape(output), (2, 16, 100)) + + def test_call_encoder(self): + backbone = T5Gemma2Backbone(**self.backbone_kwargs) + lm = T5Gemma2Seq2SeqLM(backbone=backbone) + encoder_output, padding_mask = lm.call_encoder( + self.input_data["encoder_token_ids"], + self.input_data["encoder_padding_mask"], + ) + self.assertEqual(keras.ops.shape(encoder_output), (2, 16, 32)) + + def test_build_cache(self): + backbone = T5Gemma2Backbone(**self.backbone_kwargs) + lm = T5Gemma2Seq2SeqLM(backbone=backbone) + hidden_states, cache, extra = lm._build_cache( + encoder_token_ids=self.input_data["encoder_token_ids"], + encoder_padding_mask=self.input_data["encoder_padding_mask"], + decoder_token_ids=self.input_data["decoder_token_ids"], + decoder_padding_mask=self.input_data["decoder_padding_mask"], + ) + self.assertEqual(keras.ops.shape(hidden_states), (2, 16, 32)) + self_attention_cache, cross_attention_cache = cache + # self_attention_cache: (batch, num_layers, 2, + # seq, kv_heads, head_dim) + self.assertEqual( + keras.ops.shape(self_attention_cache), (2, 2, 2, 16, 2, 8) + ) + # cross_attention_cache: (batch, num_layers, 2, + # enc_seq, kv_heads, head_dim) + self.assertEqual( + keras.ops.shape(cross_attention_cache), + (2, 2, 2, 16, 2, 8), + ) + + @pytest.mark.large + def test_saved_model(self): + backbone = T5Gemma2Backbone(**self.backbone_kwargs) + self.run_model_saving_test( + cls=T5Gemma2Seq2SeqLM, + init_kwargs={"backbone": backbone}, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/t5gemma2/t5gemma2_tokenizer.py b/keras_hub/src/models/t5gemma2/t5gemma2_tokenizer.py new file mode 100644 index 0000000000..824a7c6bcc --- /dev/null +++ b/keras_hub/src/models/t5gemma2/t5gemma2_tokenizer.py @@ -0,0 +1,55 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.t5gemma2.t5gemma2_backbone import T5Gemma2Backbone +from keras_hub.src.tokenizers.sentence_piece_tokenizer import ( + SentencePieceTokenizer, +) + + +@keras_hub_export( + [ + "keras_hub.tokenizers.T5Gemma2Tokenizer", + "keras_hub.models.T5Gemma2Tokenizer", + ] +) +class T5Gemma2Tokenizer(SentencePieceTokenizer): + """T5Gemma2 tokenizer layer based on SentencePiece. + + This tokenizer class will tokenize raw strings into integer sequences + and is based on `keras_hub.tokenizers.SentencePieceTokenizer`. Unlike + the underlying tokenizer, it will check for all special tokens needed + by T5Gemma2 models and provides a `from_preset()` method to + automatically download a matching vocabulary for a T5Gemma2 preset. + + If input is a batch of strings (rank > 0), the layer will output a + `tf.RaggedTensor` where the last dimension of the output is ragged. + + If input is a scalar string (rank == 0), the layer will output a + dense `tf.Tensor` with static shape `[None]`. + + Args: + proto: Either a `string` path to a SentencePiece proto file, + or a `bytes` object with a serialized SentencePiece proto. + + Examples: + + ```python + tokenizer = keras_hub.models.T5Gemma2Tokenizer.from_preset( + "t5gemma2_270m_270m" + ) + tokenizer("The quick brown fox jumped.") + + # Batched input. + tokenizer(["The quick brown fox jumped.", "The fox slept."]) + + # Detokenization. + tokenizer.detokenize(tokenizer("The quick brown fox jumped.")) + ``` + """ + + backbone_cls = T5Gemma2Backbone + + def __init__(self, proto, **kwargs): + self._add_special_token("", "start_token") + self._add_special_token("", "end_token") + self._add_special_token("", "pad_token") + super().__init__(proto=proto, **kwargs) diff --git a/keras_hub/src/utils/transformers/convert_t5gemma2.py b/keras_hub/src/utils/transformers/convert_t5gemma2.py new file mode 100644 index 0000000000..626340566f --- /dev/null +++ b/keras_hub/src/utils/transformers/convert_t5gemma2.py @@ -0,0 +1,414 @@ +import numpy as np + +from keras_hub.src.models.t5gemma2.t5gemma2_backbone import T5Gemma2Backbone +from keras_hub.src.utils.preset_utils import get_file +from keras_hub.src.utils.preset_utils import load_json + +backbone_cls = T5Gemma2Backbone + + +def load_image_converter_config(preset, transformers_config): + """Load image converter config from HF preprocessor_config.json.""" + encoder_config = transformers_config.get("encoder", {}) + if "vision_config" not in encoder_config: + return None + preprocessor_config = load_json(preset, "preprocessor_config.json") + mean = preprocessor_config["image_mean"] + std = preprocessor_config["image_std"] + rescale_factor = preprocessor_config["rescale_factor"] + offset = [(-m / s) for m, s in zip(mean, std)] + scale = [(s * rescale_factor) for s in std] + image_size = encoder_config["vision_config"]["image_size"] + return { + "image_size": (image_size, image_size), + "scale": scale, + "offset": offset, + } + + +def convert_backbone_config(transformers_config): + """Convert a HuggingFace T5Gemma2 config to KerasHub backbone config.""" + # T5Gemma2EncoderConfig is Gemma3Config with text params at + # encoder["text_config"]; decoder is Gemma3TextConfig (flat). + encoder_config = transformers_config["encoder"] + enc_text = encoder_config["text_config"] + decoder_config = transformers_config["decoder"] + + hidden_activation = decoder_config["hidden_activation"] + if hidden_activation == "gelu_pytorch_tanh": + hidden_activation = "gelu_approximate" + + # Vision encoder (optional). + vision_encoder = None + if "vision_config" in encoder_config: + from keras_hub.src.models.gemma3.gemma3_vision_encoder import ( + Gemma3VisionEncoder, + ) + + vision_config = encoder_config["vision_config"] + vision_encoder = Gemma3VisionEncoder( + image_size=vision_config["image_size"], + patch_size=vision_config["patch_size"], + num_heads=vision_config["num_attention_heads"], + hidden_dim=vision_config["hidden_size"], + num_layers=vision_config["num_hidden_layers"], + intermediate_dim=vision_config["intermediate_size"], + output_dim=enc_text["hidden_size"], + pool_size=int( + vision_config["image_size"] + // vision_config["patch_size"] + // int(encoder_config["mm_tokens_per_image"] ** 0.5) + ), + layer_norm_epsilon=vision_config["layer_norm_eps"], + ) + + backbone_config = { + "vocabulary_size": decoder_config["vocab_size"], + "encoder_hidden_dim": enc_text["hidden_size"], + "encoder_intermediate_dim": enc_text["intermediate_size"], + "encoder_num_layers": enc_text["num_hidden_layers"], + "encoder_num_attention_heads": enc_text["num_attention_heads"], + "encoder_num_key_value_heads": enc_text["num_key_value_heads"], + "encoder_head_dim": enc_text["head_dim"], + "encoder_layer_types": enc_text["layer_types"], + "decoder_hidden_dim": decoder_config["hidden_size"], + "decoder_intermediate_dim": decoder_config["intermediate_size"], + "decoder_num_layers": decoder_config["num_hidden_layers"], + "decoder_num_attention_heads": decoder_config["num_attention_heads"], + "decoder_num_key_value_heads": decoder_config["num_key_value_heads"], + "decoder_head_dim": decoder_config["head_dim"], + "decoder_layer_types": decoder_config["layer_types"], + "dropout_rate": decoder_config["dropout_rate"], + "rms_norm_eps": decoder_config["rms_norm_eps"], + "query_pre_attn_scalar": decoder_config["query_pre_attn_scalar"], + "tie_word_embeddings": transformers_config["tie_word_embeddings"], + "attention_bias": decoder_config["attention_bias"], + "hidden_activation": hidden_activation, + "initializer_range": decoder_config["initializer_range"], + "attention_dropout": decoder_config["attention_dropout"], + "sliding_window": decoder_config["sliding_window"], + "cross_attention_hidden_size": enc_text["hidden_size"], + "attn_logit_softcapping": decoder_config["attn_logit_softcapping"], + "final_logit_softcapping": decoder_config["final_logit_softcapping"], + "rope_max_wavelength": ( + decoder_config["rope_parameters"]["sliding_attention"]["rope_theta"] + ), + "global_rope_scaling_factor": ( + decoder_config["rope_parameters"]["full_attention"]["factor"] + ), + "encoder_rope_max_wavelength": ( + enc_text["rope_parameters"]["sliding_attention"]["rope_theta"] + ), + "encoder_global_rope_scaling_factor": ( + enc_text["rope_parameters"]["full_attention"]["factor"] + ), + # use_qk_norm may not be in config JSON; default True. + "use_query_key_norm": enc_text.get("use_qk_norm", True), + "vision_encoder": vision_encoder, + "eoi_token_index": transformers_config["eoi_token_index"], + } + return backbone_config + + +def convert_weights(backbone, loader, transformers_config): + """Convert T5Gemma2 weights from HuggingFace to KerasHub.""" + + def transpose(x, shape): + return np.transpose(x) + + # === Vision encoder weights === + vision_encoder = backbone.vision_encoder + if vision_encoder is not None: + image_encoder = vision_encoder.get_layer("image_encoder") + + loader.port_weight( + keras_variable=image_encoder.vision_embeddings.patch_embedding.kernel, + hf_weight_key="encoder.vision_tower.vision_model.embeddings.patch_embedding.weight", + hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)), + ) + loader.port_weight( + keras_variable=image_encoder.vision_embeddings.patch_embedding.bias, + hf_weight_key="encoder.vision_tower.vision_model.embeddings.patch_embedding.bias", + ) + loader.port_weight( + keras_variable=image_encoder.vision_embeddings.position_embedding.embeddings, + hf_weight_key="encoder.vision_tower.vision_model.embeddings.position_embedding.weight", + ) + + for i in range(image_encoder.num_layers): + hf_vit = f"encoder.vision_tower.vision_model.encoder.layers.{i}" + loader.port_weight( + keras_variable=image_encoder.resblocks[i].layer_norm_1.gamma, + hf_weight_key=f"{hf_vit}.layer_norm1.weight", + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[i].layer_norm_1.beta, + hf_weight_key=f"{hf_vit}.layer_norm1.bias", + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[ + i + ].attn.query_proj.kernel, + hf_weight_key=f"{hf_vit}.self_attn.q_proj.weight", + hook_fn=transpose, + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[i].attn.query_proj.bias, + hf_weight_key=f"{hf_vit}.self_attn.q_proj.bias", + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[i].attn.key_proj.kernel, + hf_weight_key=f"{hf_vit}.self_attn.k_proj.weight", + hook_fn=transpose, + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[i].attn.key_proj.bias, + hf_weight_key=f"{hf_vit}.self_attn.k_proj.bias", + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[ + i + ].attn.value_proj.kernel, + hf_weight_key=f"{hf_vit}.self_attn.v_proj.weight", + hook_fn=transpose, + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[i].attn.value_proj.bias, + hf_weight_key=f"{hf_vit}.self_attn.v_proj.bias", + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[i].attn.out_proj.kernel, + hf_weight_key=f"{hf_vit}.self_attn.out_proj.weight", + hook_fn=transpose, + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[i].attn.out_proj.bias, + hf_weight_key=f"{hf_vit}.self_attn.out_proj.bias", + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[i].layer_norm_2.gamma, + hf_weight_key=f"{hf_vit}.layer_norm2.weight", + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[i].layer_norm_2.beta, + hf_weight_key=f"{hf_vit}.layer_norm2.bias", + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[i].mlp_dense_1.kernel, + hf_weight_key=f"{hf_vit}.mlp.fc1.weight", + hook_fn=transpose, + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[i].mlp_dense_1.bias, + hf_weight_key=f"{hf_vit}.mlp.fc1.bias", + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[i].mlp_dense_2.kernel, + hf_weight_key=f"{hf_vit}.mlp.fc2.weight", + hook_fn=transpose, + ) + loader.port_weight( + keras_variable=image_encoder.resblocks[i].mlp_dense_2.bias, + hf_weight_key=f"{hf_vit}.mlp.fc2.bias", + ) + + loader.port_weight( + keras_variable=image_encoder.encoder_layer_norm.gamma, + hf_weight_key="encoder.vision_tower.vision_model.post_layernorm.weight", + ) + loader.port_weight( + keras_variable=image_encoder.encoder_layer_norm.beta, + hf_weight_key="encoder.vision_tower.vision_model.post_layernorm.bias", + ) + + # Multi-modal projector. + loader.port_weight( + keras_variable=vision_encoder.get_layer( + "vision_output_encoder" + ).vision_soft_embedding_norm.scale, + hf_weight_key="encoder.multi_modal_projector.mm_soft_emb_norm.weight", + ) + loader.port_weight( + keras_variable=vision_encoder.get_layer( + "vision_output_encoder" + ).vision_input_projection.kernel, + hf_weight_key="encoder.multi_modal_projector.mm_input_projection_weight", + ) + + # EOI embeddings. + loader.port_weight( + keras_variable=backbone.encoder_eoi_embedding, + hf_weight_key="encoder.text_model.embed_tokens.eoi_embedding", + ) + loader.port_weight( + keras_variable=backbone.decoder_eoi_embedding, + hf_weight_key="decoder.embed_tokens.eoi_embedding", + ) + + # === Text encoder weights === + # Token embeddings. + loader.port_weight( + keras_variable=backbone.token_embedding.embeddings, + hf_weight_key="encoder.text_model.embed_tokens.weight", + ) + loader.port_weight( + keras_variable=backbone.decoder_token_embedding.embeddings, + hf_weight_key="decoder.embed_tokens.weight", + ) + + # Encoder (weights under encoder.text_model.*). + loader.port_weight( + keras_variable=backbone.encoder_norm.scale, + hf_weight_key="encoder.text_model.norm.weight", + ) + for i in range(backbone.encoder_num_layers): + layer = backbone.get_layer(f"encoder_layer_{i}") + hf_prefix = f"encoder.text_model.layers.{i}" + + # Self-attention Q/K/V/O projections. + loader.port_weight( + keras_variable=layer.self_attn.query_dense.kernel, + hf_weight_key=f"{hf_prefix}.self_attn.q_proj.weight", + hook_fn=lambda w, s: w.T.reshape(s), + ) + loader.port_weight( + keras_variable=layer.self_attn.key_dense.kernel, + hf_weight_key=f"{hf_prefix}.self_attn.k_proj.weight", + hook_fn=lambda w, s: w.T.reshape(s), + ) + loader.port_weight( + keras_variable=layer.self_attn.value_dense.kernel, + hf_weight_key=f"{hf_prefix}.self_attn.v_proj.weight", + hook_fn=lambda w, s: w.T.reshape(s), + ) + loader.port_weight( + keras_variable=layer.self_attn.output_dense.kernel, + hf_weight_key=f"{hf_prefix}.self_attn.o_proj.weight", + hook_fn=lambda w, s: w.T.reshape(s), + ) + + # Q/K normalization (Gemma3-style). + loader.port_weight( + keras_variable=layer.self_attn.query_norm.scale, + hf_weight_key=f"{hf_prefix}.self_attn.q_norm.weight", + ) + loader.port_weight( + keras_variable=layer.self_attn.key_norm.scale, + hf_weight_key=f"{hf_prefix}.self_attn.k_norm.weight", + ) + + # MLP. + loader.port_weight( + keras_variable=layer.mlp.gate_proj.kernel, + hf_weight_key=f"{hf_prefix}.mlp.gate_proj.weight", + hook_fn=lambda w, s: w.T, + ) + loader.port_weight( + keras_variable=layer.mlp.up_proj.kernel, + hf_weight_key=f"{hf_prefix}.mlp.up_proj.weight", + hook_fn=lambda w, s: w.T, + ) + loader.port_weight( + keras_variable=layer.mlp.down_proj.kernel, + hf_weight_key=f"{hf_prefix}.mlp.down_proj.weight", + hook_fn=lambda w, s: w.T, + ) + + # Layer norms. + loader.port_weight( + keras_variable=layer.pre_self_attn_layernorm.scale, + hf_weight_key=(f"{hf_prefix}.pre_self_attn_layernorm.weight"), + ) + loader.port_weight( + keras_variable=layer.post_self_attn_layernorm.scale, + hf_weight_key=(f"{hf_prefix}.post_self_attn_layernorm.weight"), + ) + loader.port_weight( + keras_variable=layer.pre_feedforward_layernorm.scale, + hf_weight_key=(f"{hf_prefix}.pre_feedforward_layernorm.weight"), + ) + loader.port_weight( + keras_variable=layer.post_feedforward_layernorm.scale, + hf_weight_key=(f"{hf_prefix}.post_feedforward_layernorm.weight"), + ) + + # Decoder (weights directly under decoder.*). + loader.port_weight( + keras_variable=backbone.decoder_norm.scale, + hf_weight_key="decoder.norm.weight", + ) + for i in range(backbone.decoder_num_layers): + layer = backbone.get_layer(f"decoder_layer_{i}") + hf_prefix = f"decoder.layers.{i}" + + # Merged attention (self+cross uses a single self_attn layer). + loader.port_weight( + keras_variable=layer.merged_attn.query_dense.kernel, + hf_weight_key=f"{hf_prefix}.self_attn.q_proj.weight", + hook_fn=lambda w, s: w.T.reshape(s), + ) + loader.port_weight( + keras_variable=layer.merged_attn.key_dense.kernel, + hf_weight_key=f"{hf_prefix}.self_attn.k_proj.weight", + hook_fn=lambda w, s: w.T.reshape(s), + ) + loader.port_weight( + keras_variable=layer.merged_attn.value_dense.kernel, + hf_weight_key=f"{hf_prefix}.self_attn.v_proj.weight", + hook_fn=lambda w, s: w.T.reshape(s), + ) + loader.port_weight( + keras_variable=layer.merged_attn.output_dense.kernel, + hf_weight_key=f"{hf_prefix}.self_attn.o_proj.weight", + hook_fn=lambda w, s: w.T.reshape(s), + ) + + # Q/K normalization. + loader.port_weight( + keras_variable=layer.merged_attn.query_norm.scale, + hf_weight_key=f"{hf_prefix}.self_attn.q_norm.weight", + ) + loader.port_weight( + keras_variable=layer.merged_attn.key_norm.scale, + hf_weight_key=f"{hf_prefix}.self_attn.k_norm.weight", + ) + + # MLP. + loader.port_weight( + keras_variable=layer.mlp.gate_proj.kernel, + hf_weight_key=f"{hf_prefix}.mlp.gate_proj.weight", + hook_fn=lambda w, s: w.T, + ) + loader.port_weight( + keras_variable=layer.mlp.up_proj.kernel, + hf_weight_key=f"{hf_prefix}.mlp.up_proj.weight", + hook_fn=lambda w, s: w.T, + ) + loader.port_weight( + keras_variable=layer.mlp.down_proj.kernel, + hf_weight_key=f"{hf_prefix}.mlp.down_proj.weight", + hook_fn=lambda w, s: w.T, + ) + + # Layer norms (no cross-attn norms — merged into self_attn). + loader.port_weight( + keras_variable=layer.pre_self_attn_layernorm.scale, + hf_weight_key=(f"{hf_prefix}.pre_self_attn_layernorm.weight"), + ) + loader.port_weight( + keras_variable=layer.post_self_attn_layernorm.scale, + hf_weight_key=(f"{hf_prefix}.post_self_attn_layernorm.weight"), + ) + loader.port_weight( + keras_variable=layer.pre_feedforward_layernorm.scale, + hf_weight_key=(f"{hf_prefix}.pre_feedforward_layernorm.weight"), + ) + loader.port_weight( + keras_variable=layer.post_feedforward_layernorm.scale, + hf_weight_key=(f"{hf_prefix}.post_feedforward_layernorm.weight"), + ) + + +def convert_tokenizer(cls, preset, **kwargs): + """Convert a T5Gemma2 tokenizer.""" + return cls(get_file(preset, "tokenizer.model"), **kwargs) diff --git a/keras_hub/src/utils/transformers/preset_loader.py b/keras_hub/src/utils/transformers/preset_loader.py index 75bd8ab6d4..e7245110cd 100644 --- a/keras_hub/src/utils/transformers/preset_loader.py +++ b/keras_hub/src/utils/transformers/preset_loader.py @@ -29,6 +29,7 @@ from keras_hub.src.utils.transformers import convert_sam3 from keras_hub.src.utils.transformers import convert_smollm3 from keras_hub.src.utils.transformers import convert_t5gemma +from keras_hub.src.utils.transformers import convert_t5gemma2 from keras_hub.src.utils.transformers import convert_vit from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader @@ -88,6 +89,8 @@ def __init__(self, preset, config): self.converter = convert_smollm3 elif model_type == "t5gemma": self.converter = convert_t5gemma + elif model_type == "t5gemma2": + self.converter = convert_t5gemma2 else: raise ValueError( "KerasHub has no converter for huggingface/transformers models " diff --git a/tools/checkpoint_conversion/convert_t5gemma2_checkpoints.py b/tools/checkpoint_conversion/convert_t5gemma2_checkpoints.py new file mode 100644 index 0000000000..771968c07a --- /dev/null +++ b/tools/checkpoint_conversion/convert_t5gemma2_checkpoints.py @@ -0,0 +1,807 @@ +import gc +import os +import random +import shutil + +import huggingface_hub +import keras +import numpy as np +import requests +import torch +import transformers +from absl import app +from absl import flags +from checkpoint_conversion_utils import get_md5_checksum +from PIL import Image + +import keras_hub +from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM +from keras_hub.src.models.t5gemma2.t5gemma2_seq_2_seq_lm import ( + T5Gemma2Seq2SeqLM, +) +from keras_hub.src.models.t5gemma2.t5gemma2_seq_2_seq_lm_preprocessor import ( + T5Gemma2Seq2SeqLMPreprocessor, +) + +random.seed(123) +torch.manual_seed(123) +device = torch.device("cpu") +torch.set_default_device(device) +torch.set_default_dtype(torch.float32) + +PRESET_MAP = { + "t5gemma2_270m_270m": "google/t5gemma-2-270m-270m", + "t5gemma2_1b_1b": "google/t5gemma-2-1b-1b", + "t5gemma2_4b_4b": "google/t5gemma-2-4b-4b", +} + +FLAGS = flags.FLAGS +flags.DEFINE_string( + "preset", + None, + f"Must be one of {','.join(PRESET_MAP.keys())}", +) + + +def convert_checkpoints(hf_model): + """Convert HuggingFace T5Gemma2 weights to KerasHub format.""" + print("\n-> Convert original weights to KerasHub format.") + print("\n-> Load KerasHub model.") + + encoder_config = hf_model.config.encoder + enc_text_config = encoder_config.text_config + decoder_config = hf_model.config.decoder + if decoder_config.hidden_activation == "gelu_pytorch_tanh": + decoder_config.hidden_activation = "gelu_approximate" + if enc_text_config.hidden_activation == "gelu_pytorch_tanh": + enc_text_config.hidden_activation = "gelu_approximate" + + # Vision encoder (optional — only present in multimodal models). + vision_encoder = None + has_vision = hasattr(encoder_config, "vision_config") and ( + encoder_config.vision_config is not None + ) + if has_vision: + from keras_hub.src.models.gemma3.gemma3_vision_encoder import ( + Gemma3VisionEncoder, + ) + + vc = encoder_config.vision_config + mm_tokens = getattr(encoder_config, "mm_tokens_per_image", 256) + pool_size = int(vc.image_size // vc.patch_size // int(mm_tokens**0.5)) + vision_encoder = Gemma3VisionEncoder( + image_size=vc.image_size, + patch_size=vc.patch_size, + num_heads=vc.num_attention_heads, + hidden_dim=vc.hidden_size, + num_layers=vc.num_hidden_layers, + intermediate_dim=vc.intermediate_size, + output_dim=enc_text_config.hidden_size, + pool_size=pool_size, + layer_norm_epsilon=getattr(vc, "layer_norm_eps", 1e-6), + ) + print( + f" Vision encoder created: {vc.image_size}x{vc.image_size}, " + f"pool_size={pool_size}, mm_tokens={mm_tokens}" + ) + + keras.config.set_floatx("float32") + keras_hub_model = keras_hub.models.T5Gemma2Backbone( + vocabulary_size=decoder_config.vocab_size, + encoder_hidden_dim=enc_text_config.hidden_size, + encoder_intermediate_dim=enc_text_config.intermediate_size, + encoder_num_layers=enc_text_config.num_hidden_layers, + encoder_num_attention_heads=(enc_text_config.num_attention_heads), + encoder_num_key_value_heads=(enc_text_config.num_key_value_heads), + encoder_head_dim=enc_text_config.head_dim, + encoder_layer_types=enc_text_config.layer_types, + decoder_hidden_dim=decoder_config.hidden_size, + decoder_intermediate_dim=decoder_config.intermediate_size, + decoder_num_layers=decoder_config.num_hidden_layers, + decoder_num_attention_heads=decoder_config.num_attention_heads, + decoder_num_key_value_heads=(decoder_config.num_key_value_heads), + decoder_head_dim=decoder_config.head_dim, + decoder_layer_types=decoder_config.layer_types, + dropout_rate=decoder_config.dropout_rate, + rms_norm_eps=decoder_config.rms_norm_eps, + query_pre_attn_scalar=decoder_config.query_pre_attn_scalar, + tie_word_embeddings=getattr( + hf_model.config, "tie_word_embeddings", True + ), + attention_bias=decoder_config.attention_bias, + hidden_activation=decoder_config.hidden_activation, + initializer_range=decoder_config.initializer_range, + attention_dropout=decoder_config.attention_dropout, + sliding_window=decoder_config.sliding_window, + cross_attention_hidden_size=enc_text_config.hidden_size, + attn_logit_softcapping=(decoder_config.attn_logit_softcapping), + final_logit_softcapping=(decoder_config.final_logit_softcapping), + rope_max_wavelength=( + decoder_config.rope_parameters["sliding_attention"]["rope_theta"] + ), + global_rope_scaling_factor=( + decoder_config.rope_parameters["full_attention"]["factor"] + ), + encoder_rope_max_wavelength=( + enc_text_config.rope_parameters["sliding_attention"]["rope_theta"] + ), + encoder_global_rope_scaling_factor=( + enc_text_config.rope_parameters["full_attention"]["factor"] + ), + use_query_key_norm=any( + "q_norm" in k for k in hf_model.state_dict().keys() + ), + vision_encoder=vision_encoder, + eoi_token_index=hf_model.config.eoi_token_index, + dtype="float32", + ) + + hf_wts = hf_model.state_dict() + # Cast all weights to float32 (HF uses bfloat16). + hf_wts = {k: v.float() for k, v in hf_wts.items()} + + # Token embeddings. + # Encoder embeds are under encoder.text_model.embed_tokens.* + keras_hub_model.get_layer("encoder_token_embedding").embeddings.assign( + hf_wts["encoder.text_model.embed_tokens.weight"].numpy() + ) + keras_hub_model.get_layer("decoder_token_embedding").embeddings.assign( + hf_wts["decoder.embed_tokens.weight"].numpy() + ) + + # Vision encoder weights. + if has_vision: + ve = keras_hub_model.vision_encoder + ie = ve.get_layer("image_encoder") + ie.vision_embeddings.patch_embedding.kernel.assign( + hf_wts[ + "encoder.vision_tower.vision_model." + "embeddings.patch_embedding.weight" + ] + .permute(2, 3, 1, 0) + .numpy() + ) + ie.vision_embeddings.patch_embedding.bias.assign( + hf_wts[ + "encoder.vision_tower.vision_model." + "embeddings.patch_embedding.bias" + ].numpy() + ) + ie.vision_embeddings.position_embedding.embeddings.assign( + hf_wts[ + "encoder.vision_tower.vision_model." + "embeddings.position_embedding.weight" + ].numpy() + ) + for vi in range(ie.num_layers): + vp = f"encoder.vision_tower.vision_model.encoder.layers.{vi}" + rb = ie.resblocks[vi] + rb.layer_norm_1.gamma.assign( + hf_wts[f"{vp}.layer_norm1.weight"].numpy() + ) + rb.layer_norm_1.beta.assign( + hf_wts[f"{vp}.layer_norm1.bias"].numpy() + ) + rb.attn.query_proj.kernel.assign( + hf_wts[f"{vp}.self_attn.q_proj.weight"].T.numpy() + ) + rb.attn.query_proj.bias.assign( + hf_wts[f"{vp}.self_attn.q_proj.bias"].numpy() + ) + rb.attn.key_proj.kernel.assign( + hf_wts[f"{vp}.self_attn.k_proj.weight"].T.numpy() + ) + rb.attn.key_proj.bias.assign( + hf_wts[f"{vp}.self_attn.k_proj.bias"].numpy() + ) + rb.attn.value_proj.kernel.assign( + hf_wts[f"{vp}.self_attn.v_proj.weight"].T.numpy() + ) + rb.attn.value_proj.bias.assign( + hf_wts[f"{vp}.self_attn.v_proj.bias"].numpy() + ) + rb.attn.out_proj.kernel.assign( + hf_wts[f"{vp}.self_attn.out_proj.weight"].T.numpy() + ) + rb.attn.out_proj.bias.assign( + hf_wts[f"{vp}.self_attn.out_proj.bias"].numpy() + ) + rb.layer_norm_2.gamma.assign( + hf_wts[f"{vp}.layer_norm2.weight"].numpy() + ) + rb.layer_norm_2.beta.assign( + hf_wts[f"{vp}.layer_norm2.bias"].numpy() + ) + rb.mlp_dense_1.kernel.assign( + hf_wts[f"{vp}.mlp.fc1.weight"].T.numpy() + ) + rb.mlp_dense_1.bias.assign(hf_wts[f"{vp}.mlp.fc1.bias"].numpy()) + rb.mlp_dense_2.kernel.assign( + hf_wts[f"{vp}.mlp.fc2.weight"].T.numpy() + ) + rb.mlp_dense_2.bias.assign(hf_wts[f"{vp}.mlp.fc2.bias"].numpy()) + ie.encoder_layer_norm.gamma.assign( + hf_wts[ + "encoder.vision_tower.vision_model.post_layernorm.weight" + ].numpy() + ) + ie.encoder_layer_norm.beta.assign( + hf_wts[ + "encoder.vision_tower.vision_model.post_layernorm.bias" + ].numpy() + ) + # Multi-modal projector. + vo = ve.get_layer("vision_output_encoder") + vo.vision_soft_embedding_norm.scale.assign( + hf_wts[ + "encoder.multi_modal_projector.mm_soft_emb_norm.weight" + ].numpy() + ) + vo.vision_input_projection.kernel.assign( + hf_wts[ + "encoder.multi_modal_projector.mm_input_projection_weight" + ].numpy() + ) + # EOI embeddings. + keras_hub_model.encoder_eoi_embedding.assign( + hf_wts["encoder.text_model.embed_tokens.eoi_embedding"].numpy() + ) + keras_hub_model.decoder_eoi_embedding.assign( + hf_wts["decoder.embed_tokens.eoi_embedding"].numpy() + ) + print(" Vision encoder weights loaded.") + + # Encoder (weights under encoder.text_model.*). + enc_hdim = keras_hub_model.encoder_hidden_dim + enc_heads = keras_hub_model.encoder_num_attention_heads + enc_kv_heads = keras_hub_model.encoder_num_key_value_heads + enc_head_dim = keras_hub_model.encoder_head_dim + keras_hub_model.encoder_norm.scale.assign( + hf_wts["encoder.text_model.norm.weight"] + ) + + for i in range(keras_hub_model.encoder_num_layers): + layer = keras_hub_model.get_layer(f"encoder_layer_{i}") + pfx = f"encoder.text_model.layers.{i}" + + # Self-attention Q/K/V/O. + layer.self_attn.query_dense.kernel.assign( + hf_wts[f"{pfx}.self_attn.q_proj.weight"] + .T.reshape(enc_hdim, enc_heads, enc_head_dim) + .numpy() + ) + layer.self_attn.key_dense.kernel.assign( + hf_wts[f"{pfx}.self_attn.k_proj.weight"] + .T.reshape(enc_hdim, enc_kv_heads, enc_head_dim) + .numpy() + ) + layer.self_attn.value_dense.kernel.assign( + hf_wts[f"{pfx}.self_attn.v_proj.weight"] + .T.reshape(enc_hdim, enc_kv_heads, enc_head_dim) + .numpy() + ) + layer.self_attn.output_dense.kernel.assign( + hf_wts[f"{pfx}.self_attn.o_proj.weight"] + .T.reshape(enc_heads, enc_head_dim, enc_hdim) + .numpy() + ) + + # Q/K normalization. + layer.self_attn.query_norm.scale.assign( + hf_wts[f"{pfx}.self_attn.q_norm.weight"] + ) + layer.self_attn.key_norm.scale.assign( + hf_wts[f"{pfx}.self_attn.k_norm.weight"] + ) + + # MLP. + layer.mlp.gate_proj.kernel.assign( + hf_wts[f"{pfx}.mlp.gate_proj.weight"].T.numpy() + ) + layer.mlp.up_proj.kernel.assign( + hf_wts[f"{pfx}.mlp.up_proj.weight"].T.numpy() + ) + layer.mlp.down_proj.kernel.assign( + hf_wts[f"{pfx}.mlp.down_proj.weight"].T.numpy() + ) + + # Layer norms. + layer.pre_self_attn_layernorm.scale.assign( + hf_wts[f"{pfx}.pre_self_attn_layernorm.weight"] + ) + layer.post_self_attn_layernorm.scale.assign( + hf_wts[f"{pfx}.post_self_attn_layernorm.weight"] + ) + layer.pre_feedforward_layernorm.scale.assign( + hf_wts[f"{pfx}.pre_feedforward_layernorm.weight"] + ) + layer.post_feedforward_layernorm.scale.assign( + hf_wts[f"{pfx}.post_feedforward_layernorm.weight"] + ) + + # Decoder. + dec_hdim = keras_hub_model.decoder_hidden_dim + dec_heads = keras_hub_model.decoder_num_attention_heads + dec_kv_heads = keras_hub_model.decoder_num_key_value_heads + dec_head_dim = keras_hub_model.decoder_head_dim + keras_hub_model.decoder_norm.scale.assign(hf_wts["decoder.norm.weight"]) + + for i in range(keras_hub_model.decoder_num_layers): + layer = keras_hub_model.get_layer(f"decoder_layer_{i}") + pfx = f"decoder.layers.{i}" + + # Merged attention (self+cross uses single self_attn in HF). + layer.merged_attn.query_dense.kernel.assign( + hf_wts[f"{pfx}.self_attn.q_proj.weight"] + .T.reshape(dec_hdim, dec_heads, dec_head_dim) + .numpy() + ) + layer.merged_attn.key_dense.kernel.assign( + hf_wts[f"{pfx}.self_attn.k_proj.weight"] + .T.reshape(dec_hdim, dec_kv_heads, dec_head_dim) + .numpy() + ) + layer.merged_attn.value_dense.kernel.assign( + hf_wts[f"{pfx}.self_attn.v_proj.weight"] + .T.reshape(dec_hdim, dec_kv_heads, dec_head_dim) + .numpy() + ) + layer.merged_attn.output_dense.kernel.assign( + hf_wts[f"{pfx}.self_attn.o_proj.weight"] + .T.reshape(dec_heads, dec_head_dim, dec_hdim) + .numpy() + ) + + # Q/K normalization. + layer.merged_attn.query_norm.scale.assign( + hf_wts[f"{pfx}.self_attn.q_norm.weight"] + ) + layer.merged_attn.key_norm.scale.assign( + hf_wts[f"{pfx}.self_attn.k_norm.weight"] + ) + + # MLP. + layer.mlp.gate_proj.kernel.assign( + hf_wts[f"{pfx}.mlp.gate_proj.weight"].T.numpy() + ) + layer.mlp.up_proj.kernel.assign( + hf_wts[f"{pfx}.mlp.up_proj.weight"].T.numpy() + ) + layer.mlp.down_proj.kernel.assign( + hf_wts[f"{pfx}.mlp.down_proj.weight"].T.numpy() + ) + + # Layer norms (no cross-attn norms — merged into self_attn). + layer.pre_self_attn_layernorm.scale.assign( + hf_wts[f"{pfx}.pre_self_attn_layernorm.weight"] + ) + layer.post_self_attn_layernorm.scale.assign( + hf_wts[f"{pfx}.post_self_attn_layernorm.weight"] + ) + layer.pre_feedforward_layernorm.scale.assign( + hf_wts[f"{pfx}.pre_feedforward_layernorm.weight"] + ) + layer.post_feedforward_layernorm.scale.assign( + hf_wts[f"{pfx}.post_feedforward_layernorm.weight"] + ) + + return keras_hub_model + + +def extract_vocab(hf_model_dir): + """Extract vocabulary from the downloaded HF model directory.""" + vocabulary_path = os.path.join(FLAGS.preset, "tokenizer.model") + print(f"\n-> Save KerasHub vocab to `{vocabulary_path}`.") + + # T5Gemma2 HF repos only have tokenizer.json (no tokenizer.model). + # The SentencePiece proto is the same as Gemma3's vocabulary. + source_path = os.path.join(hf_model_dir, "tokenizer.model") + if os.path.exists(source_path): + shutil.copyfile(source_path, vocabulary_path) + else: + # Download tokenizer.model from Gemma3 (same vocab). + print( + " tokenizer.model not found in HF repo. " + "Downloading from google/gemma-3-1b-pt..." + ) + gemma_dir = huggingface_hub.snapshot_download( + repo_id="google/gemma-3-1b-pt", + allow_patterns=["tokenizer.model"], + ) + gemma_proto = os.path.join(gemma_dir, "tokenizer.model") + shutil.copyfile(gemma_proto, vocabulary_path) + + keras_hub_tokenizer = keras_hub.models.T5Gemma2Tokenizer( + proto=vocabulary_path + ) + + print("-> Print MD5 checksum of the vocab file.") + print( + f"`{vocabulary_path}` md5sum: ", + get_md5_checksum(vocabulary_path), + ) + + return keras_hub_tokenizer + + +def check_text_output( + keras_hub_tokenizer, + keras_hub_model, + hf_tokenizer, + hf_model, + preprocessor, +): + """Check outputs of KerasHub and HuggingFace models match.""" + # Note: KerasHub counts encoder + decoder embeddings as separate + # weight matrices. HF shares a single nn.Embedding across + # encoder/decoder/lm_head, so counts it once. + print("\n--- Model Verification starts ---") + print("\n") + print("\n-> Verify parameter counts.") + keras_hub_params = keras_hub_model.count_params() + hf_params = hf_model.num_parameters() + print(f"KerasHub params: {keras_hub_params:,}") + print(f"HF params: {hf_params:,}") + if keras_hub_params == hf_params: + print("-> Parameter counts match!") + else: + diff = keras_hub_params - hf_params + print( + f"-> Parameter count difference: {diff:,} " + f"(expected — KerasHub has separate encoder/decoder " + f"embeddings; HF shares a single nn.Embedding)" + ) + + # Output comparison. + print("\n-> ---- Text-only verification. ----\n") + enc_sample_text = [ + "cricket is awesome, easily the best sport in the world!" + ] + dec_sample_text = [ + "football is good too, but nowhere near as good as cricket." + ] + + # KerasHub — build unpadded inputs (match HF's natural lengths). + keras_hub_enc_token_ids = hf_tokenizer( + enc_sample_text, return_tensors="np" + )["input_ids"] + keras_hub_dec_token_ids = hf_tokenizer( + dec_sample_text, return_tensors="np" + )["input_ids"] + keras_hub_dec_token_ids = np.concatenate( + [ + np.array([[keras_hub_tokenizer.start_token_id]]), + keras_hub_dec_token_ids, + ], + axis=-1, + ) + keras_hub_inputs = { + "encoder_token_ids": keras_hub_enc_token_ids, + "encoder_padding_mask": np.ones_like(keras_hub_enc_token_ids), + "decoder_token_ids": keras_hub_dec_token_ids, + "decoder_padding_mask": np.ones_like(keras_hub_dec_token_ids), + } + # For multimodal backbones, use preprocessor to add dummy + # image/vision_indices (the single source of truth). + keras_hub_inputs = preprocessor._add_vision_inputs( + keras_hub_inputs, batch_size=1 + ) + keras_hub_output = keras_hub_model.predict(keras_hub_inputs) + + # HF. + hf_enc_inputs = hf_tokenizer(enc_sample_text, return_tensors="pt") + hf_dec_inputs = hf_tokenizer(dec_sample_text, return_tensors="pt") + hf_decoder_input_ids = torch.cat( + [ + torch.tensor([[hf_tokenizer.bos_token_id]]), + hf_dec_inputs["input_ids"], + ], + dim=-1, + ) + hf_decoder_attention_mask = torch.cat( + [ + torch.ones(1, 1, dtype=torch.long), + hf_dec_inputs["attention_mask"], + ], + dim=-1, + ) + + hf_output = hf_model( + **hf_enc_inputs, + decoder_input_ids=hf_decoder_input_ids, + decoder_attention_mask=hf_decoder_attention_mask, + ) + + # Encoder output comparison. + keras_enc_out = keras_hub_output["encoder_sequence_output"] + hf_enc_out = hf_output.encoder_last_hidden_state.detach().float().numpy() + enc_abs_diff = np.abs(keras_enc_out - hf_enc_out) + print() + print("Encoder Outputs:") + print("KerasHub output:", keras_enc_out[0, 0, :10]) + print("HF output:", hf_enc_out[0, 0, :10]) + print(f"Mean absolute diff: {enc_abs_diff.mean():.6f}") + try: + np.testing.assert_allclose( + keras_enc_out, hf_enc_out, rtol=1e-4, atol=1e-4 + ) + print("-> Encoder outputs match! (rtol=1e-4, atol=1e-4)") + except AssertionError: + mismatch = np.sum( + ~np.isclose(keras_enc_out, hf_enc_out, rtol=1e-4, atol=1e-4) + ) + total = keras_enc_out.size + print( + f"-> Encoder outputs differ slightly beyond rtol=1e-4 " + f"(mismatched: {mismatch}/{total}, " + f"{mismatch / total * 100:.2f}%)" + ) + + # Decoder output comparison. + keras_dec_out = keras_hub_output["decoder_sequence_output"] + hf_dec_out = hf_output.last_hidden_state.detach().float().numpy() + dec_abs_diff = np.abs(keras_dec_out - hf_dec_out) + print() + print("Decoder Outputs:") + print("KerasHub output:", keras_dec_out[0, 0, :10]) + print("HF output:", hf_dec_out[0, 0, :10]) + print(f"Mean absolute diff: {dec_abs_diff.mean():.6f}") + try: + np.testing.assert_allclose( + keras_dec_out, hf_dec_out, rtol=1e-4, atol=1e-4 + ) + print("-> Decoder outputs match! (rtol=1e-4, atol=1e-4)") + except AssertionError: + mismatch = np.sum( + ~np.isclose(keras_dec_out, hf_dec_out, rtol=1e-4, atol=1e-4) + ) + total = keras_dec_out.size + print( + f"-> Decoder outputs differ slightly beyond rtol=1e-4 " + f"(mismatched: {mismatch}/{total}, " + f"{mismatch / total * 100:.2f}%)" + ) + + +def check_multimodal_output( + keras_hub_model, + hf_model, + hf_model_dir, + hf_tokenizer, +): + """Check multimodal (text+image) outputs match between KerasHub and HF.""" + if keras_hub_model.vision_encoder is None: + print("\n-> Skipping multimodal check (text-only model).") + return + + print("\n-> ---- Multimodal (text+image) verification. ----\n") + + # Download a test image. + image_url = ( + "https://huggingface.co/datasets/huggingface/" + "documentation-images/resolve/main/bee.jpg" + ) + print(f" Downloading test image: {image_url}") + image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB") + + # HF side: use AutoProcessor for proper multimodal preprocessing. + hf_processor = transformers.AutoProcessor.from_pretrained(hf_model_dir) + enc_prompt = " Describe this image" + dec_prompt = "This image shows" + + # HF encoder inputs (with image). + hf_enc_inputs = hf_processor( + text=enc_prompt, images=image, return_tensors="pt" + ) + # HF decoder inputs. + hf_dec_inputs = hf_tokenizer(dec_prompt, return_tensors="pt") + hf_decoder_input_ids = torch.cat( + [ + torch.tensor([[hf_tokenizer.bos_token_id]]), + hf_dec_inputs["input_ids"], + ], + dim=-1, + ) + hf_decoder_attention_mask = torch.cat( + [ + torch.ones(1, 1, dtype=torch.long), + hf_dec_inputs["attention_mask"], + ], + dim=-1, + ) + + with torch.no_grad(): + hf_output = hf_model( + input_ids=hf_enc_inputs["input_ids"], + attention_mask=hf_enc_inputs["attention_mask"], + pixel_values=hf_enc_inputs["pixel_values"], + decoder_input_ids=hf_decoder_input_ids, + decoder_attention_mask=hf_decoder_attention_mask, + ) + + # Build KerasHub inputs from HF token_ids (same tokenizer). + keras_enc_token_ids = hf_enc_inputs["input_ids"].numpy() + keras_enc_padding_mask = hf_enc_inputs["attention_mask"].numpy() + keras_dec_token_ids = hf_decoder_input_ids.numpy() + keras_dec_padding_mask = hf_decoder_attention_mask.numpy() + + # Transpose HF pixel_values (B,C,H,W) to KerasHub (B,1,H,W,C). + pixel_values = hf_enc_inputs["pixel_values"].numpy() + if pixel_values.ndim == 5: + pixel_values = np.transpose(pixel_values, (0, 1, 3, 4, 2)) + elif pixel_values.ndim == 4: + pixel_values = np.transpose(pixel_values, (0, 2, 3, 1)) + pixel_values = np.expand_dims(pixel_values, axis=1) + + # Find positions of image placeholder tokens for vision_indices. + image_token_id = hf_processor.tokenizer.convert_tokens_to_ids( + "" + ) + num_vision_tokens = ( + keras_hub_model.vision_encoder.num_vision_tokens_per_image + ) + # Find indices of image placeholder tokens. + token_ids_flat = keras_enc_token_ids[0] + vision_idx_list = np.where(token_ids_flat == image_token_id)[0].tolist() + + # Pad or truncate to num_vision_tokens. + if len(vision_idx_list) < num_vision_tokens: + vision_idx_list = vision_idx_list + [0] * ( + num_vision_tokens - len(vision_idx_list) + ) + vision_indices = np.array( + [vision_idx_list[:num_vision_tokens]], dtype="int32" + ) + + keras_hub_inputs = { + "encoder_token_ids": keras_enc_token_ids, + "encoder_padding_mask": keras_enc_padding_mask, + "decoder_token_ids": keras_dec_token_ids, + "decoder_padding_mask": keras_dec_padding_mask, + "images": pixel_values.astype("float32"), + "vision_indices": vision_indices, + } + + print("\n--- Multimodal Verification ---") + keras_hub_output = keras_hub_model.predict(keras_hub_inputs) + + # Encoder output comparison. + keras_enc_out = keras_hub_output["encoder_sequence_output"] + hf_enc_out = hf_output.encoder_last_hidden_state.detach().float().numpy() + enc_abs_diff = np.abs(keras_enc_out - hf_enc_out) + print() + print("Encoder Outputs (multimodal):") + print("KerasHub output:", keras_enc_out[0, 0, :10]) + print("HF output:", hf_enc_out[0, 0, :10]) + print(f"Mean absolute diff: {enc_abs_diff.mean():.6f}") + try: + np.testing.assert_allclose( + keras_enc_out, hf_enc_out, rtol=1e-4, atol=1e-4 + ) + print("-> Encoder outputs match! (rtol=1e-4, atol=1e-4)") + except AssertionError: + mismatch = np.sum( + ~np.isclose(keras_enc_out, hf_enc_out, rtol=1e-4, atol=1e-4) + ) + total = keras_enc_out.size + print( + f"-> Encoder outputs differ slightly beyond rtol=1e-4 " + f"(mismatched: {mismatch}/{total}, " + f"{mismatch / total * 100:.2f}%)" + ) + + # Decoder output comparison. + keras_dec_out = keras_hub_output["decoder_sequence_output"] + hf_dec_out = hf_output.last_hidden_state.detach().float().numpy() + dec_abs_diff = np.abs(keras_dec_out - hf_dec_out) + print() + print("Decoder Outputs (multimodal):") + print("KerasHub output:", keras_dec_out[0, 0, :10]) + print("HF output:", hf_dec_out[0, 0, :10]) + print(f"Mean absolute diff: {dec_abs_diff.mean():.6f}") + try: + np.testing.assert_allclose( + keras_dec_out, hf_dec_out, rtol=1e-4, atol=1e-4 + ) + print("-> Decoder outputs match! (rtol=1e-4, atol=1e-4)") + except AssertionError: + mismatch = np.sum( + ~np.isclose(keras_dec_out, hf_dec_out, rtol=1e-4, atol=1e-4) + ) + total = keras_dec_out.size + print( + f"-> Decoder outputs differ slightly beyond rtol=1e-4 " + f"(mismatched: {mismatch}/{total}, " + f"{mismatch / total * 100:.2f}%)" + ) + + +def main(_): + os.makedirs(FLAGS.preset, exist_ok=True) + + hf_model_name = PRESET_MAP[FLAGS.preset] + + print("\n-> Download HF model files.") + hf_model_dir = huggingface_hub.snapshot_download( + repo_id=hf_model_name, + allow_patterns=[ + "*.json", + "*.safetensors", + "tokenizer.model", + ], + ) + + print("\n-> Load HF model and HF tokenizer.") + hf_model = transformers.AutoModel.from_pretrained(hf_model_dir) + hf_model.float() # Convert all params/buffers to float32. + # Re-create embed_scale with true f32 precision (bf16-init artifact). + enc_hdim = hf_model.config.encoder.text_config.hidden_size + dec_hdim = hf_model.config.decoder.hidden_size + hf_model.encoder.text_model.embed_tokens.embed_scale = torch.tensor( + enc_hdim**0.5, dtype=torch.float32 + ) + hf_model.decoder.embed_tokens.embed_scale = torch.tensor( + dec_hdim**0.5, dtype=torch.float32 + ) + hf_model.eval() + hf_tokenizer = transformers.AutoTokenizer.from_pretrained(hf_model_dir) + + keras_hub_model = convert_checkpoints(hf_model) + print("\n-> Load KerasHub tokenizer.") + keras_hub_tokenizer = extract_vocab(hf_model_dir) + + # Create preprocessor to check_text_output can use it. + preprocessor_kwargs = {} + if keras_hub_model.vision_encoder is not None: + preprocessor_kwargs.update( + { + "image_size": keras_hub_model.vision_encoder.image_size, + "num_vision_tokens_per_image": ( + keras_hub_model.vision_encoder.num_vision_tokens_per_image + ), + } + ) + preprocessor = T5Gemma2Seq2SeqLMPreprocessor( + tokenizer=keras_hub_tokenizer, + encoder_sequence_length=512, + decoder_sequence_length=512, + **preprocessor_kwargs, + ) + + check_text_output( + keras_hub_tokenizer, + keras_hub_model, + hf_tokenizer, + hf_model, + preprocessor, + ) + + check_multimodal_output( + keras_hub_model, + hf_model, + hf_model_dir, + hf_tokenizer, + ) + + print("\n-> Releasing HF backbone from memory.") + del hf_model + gc.collect() + + keras_lm = T5Gemma2Seq2SeqLM( + backbone=keras_hub_model, + preprocessor=preprocessor, + dtype=keras_hub_model.dtype, + ) + + print(f"\n-> Saving T5Gemma2Seq2SeqLM preset to `{FLAGS.preset}`.") + keras_lm.save_to_preset(FLAGS.preset) + print("-> Preset saved successfully.") + + print("\n-> Testing preset loading.") + keras_lm = Seq2SeqLM.from_preset(FLAGS.preset) + print("-> Preset loading verified successfully.") + + +if __name__ == "__main__": + flags.mark_flag_as_required("preset") + app.run(main)