From 84c3cce399a71b72aa67ab7562e3b842941ea5e9 Mon Sep 17 00:00:00 2001 From: Laxma Reddy Patlolla Date: Wed, 11 Mar 2026 00:40:39 -0700 Subject: [PATCH 1/3] Initial qwen3.5 model text-only changes files --- keras_hub/api/models/__init__.py | 12 + keras_hub/src/models/qwen3_5/__init__.py | 5 + .../src/models/qwen3_5/qwen3_5_attention.py | 338 +++++++++++++ .../src/models/qwen3_5/qwen3_5_backbone.py | 208 ++++++++ .../models/qwen3_5/qwen3_5_backbone_test.py | 56 +++ .../src/models/qwen3_5/qwen3_5_causal_lm.py | 207 ++++++++ .../qwen3_5/qwen3_5_causal_lm_preprocessor.py | 10 + .../qwen3_5_causal_lm_preprocessor_test.py | 84 ++++ .../models/qwen3_5/qwen3_5_causal_lm_test.py | 147 ++++++ .../src/models/qwen3_5/qwen3_5_decoder.py | 308 ++++++++++++ .../models/qwen3_5/qwen3_5_gated_delta_net.py | 450 ++++++++++++++++++ .../src/models/qwen3_5/qwen3_5_layernorm.py | 43 ++ .../src/models/qwen3_5/qwen3_5_presets.py | 3 + .../src/models/qwen3_5/qwen3_5_tokenizer.py | 34 ++ .../src/utils/transformers/convert_qwen3_5.py | 245 ++++++++++ .../src/utils/transformers/preset_loader.py | 3 + .../convert_qwen3_5_checkpoints.py | 258 ++++++++++ 17 files changed, 2411 insertions(+) create mode 100644 keras_hub/src/models/qwen3_5/__init__.py create mode 100644 keras_hub/src/models/qwen3_5/qwen3_5_attention.py create mode 100644 keras_hub/src/models/qwen3_5/qwen3_5_backbone.py create mode 100644 keras_hub/src/models/qwen3_5/qwen3_5_backbone_test.py create mode 100644 keras_hub/src/models/qwen3_5/qwen3_5_causal_lm.py create mode 100644 keras_hub/src/models/qwen3_5/qwen3_5_causal_lm_preprocessor.py create mode 100644 keras_hub/src/models/qwen3_5/qwen3_5_causal_lm_preprocessor_test.py create mode 100644 keras_hub/src/models/qwen3_5/qwen3_5_causal_lm_test.py create mode 100644 keras_hub/src/models/qwen3_5/qwen3_5_decoder.py create mode 100644 keras_hub/src/models/qwen3_5/qwen3_5_gated_delta_net.py create mode 100644 keras_hub/src/models/qwen3_5/qwen3_5_layernorm.py create mode 100644 keras_hub/src/models/qwen3_5/qwen3_5_presets.py create mode 100644 keras_hub/src/models/qwen3_5/qwen3_5_tokenizer.py create mode 100644 keras_hub/src/utils/transformers/convert_qwen3_5.py create mode 100644 tools/checkpoint_conversion/convert_qwen3_5_checkpoints.py diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index e2b5909706..c0051c28c6 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -573,6 +573,18 @@ from keras_hub.src.models.qwen3.qwen3_tokenizer import ( Qwen3Tokenizer as Qwen3Tokenizer, ) +from keras_hub.src.models.qwen3_5.qwen3_5_backbone import ( + Qwen3_5Backbone as Qwen3_5Backbone, +) +from keras_hub.src.models.qwen3_5.qwen3_5_causal_lm import ( + Qwen3_5CausalLM as Qwen3_5CausalLM, +) +from keras_hub.src.models.qwen3_5.qwen3_5_causal_lm_preprocessor import ( + Qwen3_5CausalLMPreprocessor as Qwen3_5CausalLMPreprocessor, +) +from keras_hub.src.models.qwen3_5.qwen3_5_tokenizer import ( + Qwen3_5Tokenizer as Qwen3_5Tokenizer, +) from keras_hub.src.models.qwen3_moe.qwen3_moe_backbone import ( Qwen3MoeBackbone as Qwen3MoeBackbone, ) diff --git a/keras_hub/src/models/qwen3_5/__init__.py b/keras_hub/src/models/qwen3_5/__init__.py new file mode 100644 index 0000000000..49e3f0a076 --- /dev/null +++ b/keras_hub/src/models/qwen3_5/__init__.py @@ -0,0 +1,5 @@ +from keras_hub.src.models.qwen3_5.qwen3_5_backbone import Qwen3_5Backbone +from keras_hub.src.models.qwen3_5.qwen3_5_presets import backbone_presets +from keras_hub.src.utils.preset_utils import register_presets + +register_presets(backbone_presets, Qwen3_5Backbone) diff --git a/keras_hub/src/models/qwen3_5/qwen3_5_attention.py b/keras_hub/src/models/qwen3_5/qwen3_5_attention.py new file mode 100644 index 0000000000..3f3dcbc19b --- /dev/null +++ b/keras_hub/src/models/qwen3_5/qwen3_5_attention.py @@ -0,0 +1,338 @@ +import math + +import keras +from keras import ops + +from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding +from keras_hub.src.models.qwen3_5.qwen3_5_layernorm import Qwen3_5LayerNorm +from keras_hub.src.utils.keras_utils import clone_initializer +from keras_hub.src.utils.keras_utils import fused_attention_op_available + + +class Qwen3_5Attention(keras.layers.Layer): + """Full self-attention layer for Qwen3.5. + + This implements grouped-query attention (GQA) with: + - Q/K RMSNorm + - Partial rotary embeddings (only first `partial_rotary_factor` fraction + of head_dim gets RoPE) + - Sigmoid gating on attention output + - Optional sliding window + + Args: + num_query_heads: Number of query attention heads. + num_key_value_heads: Number of key/value attention heads (GQA). + head_dim: Dimension of each attention head. + partial_rotary_factor: Fraction of head_dim that gets RoPE. + rope_max_wavelength: Maximum wavelength for rotary embeddings. + rope_scaling_factor: Scaling factor for rotary embeddings. + kernel_initializer: Initializer for projection kernels. + dropout: Dropout rate for attention weights. + layer_norm_epsilon: Epsilon for Q/K RMSNorm. + sliding_window_size: Optional sliding window size. + """ + + def __init__( + self, + num_query_heads, + num_key_value_heads, + head_dim, + partial_rotary_factor=0.25, + rope_max_wavelength=10000, + rope_scaling_factor=1.0, + kernel_initializer="glorot_uniform", + dropout=0.0, + layer_norm_epsilon=1e-6, + sliding_window_size=None, + **kwargs, + ): + super().__init__(**kwargs) + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.partial_rotary_factor = partial_rotary_factor + self.rotary_dim = int(head_dim * partial_rotary_factor) + self.dropout = dropout + self.layer_norm_epsilon = layer_norm_epsilon + self.num_key_value_groups = num_query_heads // num_key_value_heads + self.rope_max_wavelength = rope_max_wavelength + self.rope_scaling_factor = rope_scaling_factor + self.sliding_window_size = sliding_window_size + self.kernel_initializer = keras.initializers.get( + clone_initializer(kernel_initializer) + ) + + def build(self, inputs_shape): + hidden_dim = inputs_shape[-1] + self._inv_norm_factor = 1.0 / math.sqrt(self.head_dim) + + # Q projects to (num_query_heads, head_dim * 2) to include gate. + self._query_dense = keras.layers.EinsumDense( + equation="bqm,muh->bquh", + output_shape=( + None, + self.num_query_heads, + self.head_dim * 2, + ), + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="query", + ) + self._query_dense.build(inputs_shape) + + self._query_norm = Qwen3_5LayerNorm( + head_dim=self.head_dim, + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="query_norm", + ) + self._query_norm.build( + (None, None, self.num_query_heads, self.head_dim) + ) + + self._key_dense = keras.layers.EinsumDense( + equation="bkm,mvh->bkvh", + output_shape=( + None, + self.num_key_value_heads, + self.head_dim, + ), + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="key", + ) + self._key_dense.build(inputs_shape) + + self._key_norm = Qwen3_5LayerNorm( + head_dim=self.head_dim, + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="key_norm", + ) + self._key_norm.build( + (None, None, self.num_key_value_heads, self.head_dim) + ) + + self._value_dense = keras.layers.EinsumDense( + equation="bkm,mvh->bkvh", + output_shape=( + None, + self.num_key_value_heads, + self.head_dim, + ), + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="value", + ) + self._value_dense.build(inputs_shape) + + self._softmax = keras.layers.Softmax( + axis=-1, dtype="float32", name="attention_softmax" + ) + self._dropout_layer = keras.layers.Dropout( + rate=self.dropout, dtype=self.dtype_policy + ) + self._output_dense = keras.layers.EinsumDense( + equation="bquh,uhm->bqm", + output_shape=(None, hidden_dim), + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="attention_output", + ) + self._output_dense.build( + (None, None, self.num_query_heads, self.head_dim) + ) + + self.rotary_embedding_layer = RotaryEmbedding( + max_wavelength=self.rope_max_wavelength, + scaling_factor=self.rope_scaling_factor, + dtype=self.dtype_policy, + ) + + self._dot_product_equation = "bquh,bkuh->buqk" + self._combine_equation = "buqk,bkuh->bquh" + self.built = True + + def _apply_partial_rope(self, x, start_index): + """Apply RoPE only to the first `rotary_dim` dimensions.""" + if self.rotary_dim == self.head_dim: + return self.rotary_embedding_layer(x, start_index=start_index) + + x_rope = x[..., : self.rotary_dim] + x_pass = x[..., self.rotary_dim :] + x_rope = self.rotary_embedding_layer(x_rope, start_index=start_index) + return ops.concatenate([x_rope, x_pass], axis=-1) + + def call( + self, + hidden_states, + attention_mask=None, + cache=None, + cache_update_index=None, + training=None, + ): + start_index = ( + cache_update_index if cache_update_index is not None else 0 + ) + + # Query projects to (head_dim * 2), split into query + gate. + qg = self._query_dense(hidden_states) + query = qg[..., : self.head_dim] + gate = qg[..., self.head_dim :] + + # Reshape gate for per-head gating: (B, seq, heads * head_dim) + gate_shape = ops.shape(gate) + gate = ops.reshape( + gate, + (gate_shape[0], gate_shape[1], -1), + ) + + query = self._query_norm(query) + query = self._apply_partial_rope(query, start_index) + + def _compute_key_value(x): + key = self._key_dense(x) + key = self._key_norm(key) + key = self._apply_partial_rope(key, start_index) + value = self._value_dense(x) + return key, value + + if cache is not None: + key_cache = cache[:, 0, ...] + value_cache = cache[:, 1, ...] + if cache_update_index is None: + key = key_cache + value = value_cache + else: + key_update, value_update = _compute_key_value(hidden_states) + start = [0, cache_update_index, 0, 0] + key = ops.slice_update(key_cache, start, key_update) + value = ops.slice_update(value_cache, start, value_update) + cache = ops.stack((key, value), axis=1) + else: + if cache_update_index is not None: + raise ValueError( + "`cache_update_index` should not be set if `cache` " + f"is `None`. Received: cache={cache}, " + f"cache_update_index={cache_update_index}" + ) + key, value = _compute_key_value(hidden_states) + + # GQA: repeat K/V heads. + key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2) + value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2) + + attention_output = self._compute_attention( + query, + key, + value, + attention_mask, + cache_update_index=cache_update_index, + ) + attention_output = self._dropout_layer( + attention_output, training=training + ) + + # Reshape to (B, seq, heads * head_dim) for gating. + out_shape = ops.shape(attention_output) + attention_output = ops.reshape( + attention_output, + (out_shape[0], out_shape[1], -1), + ) + + # Apply sigmoid gate. + attention_output = attention_output * ops.sigmoid(gate) + + # Reshape back to (B, seq, heads, head_dim) for output proj. + attention_output = ops.reshape( + attention_output, + (out_shape[0], out_shape[1], self.num_query_heads, self.head_dim), + ) + attention_output = self._output_dense(attention_output) + + if cache is not None: + return attention_output, cache + return attention_output + + def _masked_softmax(self, attention_scores, attention_mask=None): + if attention_mask is not None: + return self._softmax( + attention_scores, attention_mask[:, None, :, :] + ) + return self._softmax(attention_scores) + + def _compute_attention( + self, + query, + key, + value, + attention_mask=None, + cache_update_index=None, + ): + if fused_attention_op_available(): + if attention_mask is not None: + attention_mask = ops.expand_dims(attention_mask, axis=1) + attention_mask = ops.cast(attention_mask, dtype="bool") + return ops.dot_product_attention( + query, + key, + value, + mask=attention_mask, + scale=self._inv_norm_factor, + ) + + attention_scores = ops.einsum(self._dot_product_equation, query, key) + attention_scores = ops.multiply( + attention_scores, + ops.cast(self._inv_norm_factor, self.compute_dtype), + ) + if self.sliding_window_size: + attention_mask = self._mask_sliding_window( + attention_mask, + cache_update_index=( + cache_update_index if cache_update_index is not None else 0 + ), + ) + attention_scores = self._masked_softmax( + attention_scores, attention_mask + ) + attention_scores = ops.cast(attention_scores, self.compute_dtype) + return ops.einsum(self._combine_equation, attention_scores, value) + + def _mask_sliding_window(self, attention_mask, cache_update_index=0): + _, query_len, key_len = ops.shape(attention_mask) + all_ones = ops.ones((key_len, key_len), "bool") + if keras.config.backend() == "tensorflow": + import tensorflow as tf + + band_size = ops.minimum(key_len, self.sliding_window_size - 1) + band_size = ops.cast(band_size, "int32") + sliding_mask = tf.linalg.band_part(all_ones, band_size, band_size) + else: + sliding_mask = ops.triu( + all_ones, -1 * self.sliding_window_size + 1 + ) * ops.tril(all_ones, self.sliding_window_size - 1) + start = (cache_update_index, 0) + sliding_mask = ops.slice(sliding_mask, start, (query_len, key_len)) + sliding_mask = ops.expand_dims(sliding_mask, 0) + return ops.logical_and(attention_mask, ops.cast(sliding_mask, "bool")) + + def get_config(self): + config = super().get_config() + config.update( + { + "num_query_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "head_dim": self.head_dim, + "partial_rotary_factor": self.partial_rotary_factor, + "rope_max_wavelength": self.rope_max_wavelength, + "rope_scaling_factor": self.rope_scaling_factor, + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "dropout": self.dropout, + "layer_norm_epsilon": self.layer_norm_epsilon, + "sliding_window_size": self.sliding_window_size, + } + ) + return config diff --git a/keras_hub/src/models/qwen3_5/qwen3_5_backbone.py b/keras_hub/src/models/qwen3_5/qwen3_5_backbone.py new file mode 100644 index 0000000000..a3007170a5 --- /dev/null +++ b/keras_hub/src/models/qwen3_5/qwen3_5_backbone.py @@ -0,0 +1,208 @@ +import keras +from keras import ops +from keras.layers import ReversibleEmbedding + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.qwen3_5.qwen3_5_decoder import ( + Qwen3_5TransformerDecoder, +) +from keras_hub.src.models.qwen3_5.qwen3_5_layernorm import Qwen3_5LayerNorm + + +def _qwen3_5_kernel_initializer(stddev=0.02): + return keras.initializers.RandomNormal(stddev=stddev) + + +@keras_hub_export("keras_hub.models.Qwen3_5Backbone") +class Qwen3_5Backbone(Backbone): + """The Qwen3.5 Transformer core architecture with hyperparameters. + + This network implements a hybrid Transformer-based decoder with two + layer types: + - ``full_attention``: Standard grouped-query attention with partial + rotary embeddings and sigmoid output gating. + - ``linear_attention``: GatedDeltaNet recurrent linear attention with + causal conv1d and delta rule recurrence. + + Args: + vocabulary_size (int): The size of the token vocabulary. + num_layers (int): The number of transformer layers. + num_query_heads (int): The number of query attention heads. + num_key_value_heads (int): The number of key and value attention + heads. + head_dim (int): Dimension of each attention head. + hidden_dim (int): The size of the transformer hidden dimension. + intermediate_dim (int): The FFN intermediate dimension. + layer_types (list): List of layer types, one per layer. + Each element is ``"full_attention"`` or + ``"linear_attention"``. + partial_rotary_factor (float): Fraction of head_dim that gets + RoPE. Defaults to ``0.25``. + rope_max_wavelength (int): Maximum wavelength for RoPE. Defaults + to ``10000``. + rope_scaling_factor (float): Scaling factor for RoPE. Defaults + to ``1.0``. + layer_norm_epsilon (float): Epsilon for layer norms. Defaults + to ``1e-6``. + dropout (float): Dropout rate. Defaults to ``0.0``. + tie_word_embeddings (bool): Whether to tie input and output + embeddings. Defaults to ``False``. + sliding_window_size (int): Sliding window size for full attention + layers. Defaults to ``32768``. + linear_num_key_heads (int): Key heads for linear attention. + Defaults to ``16``. + linear_num_value_heads (int): Value heads for linear attention. + Defaults to ``32``. + linear_key_head_dim (int): Key head dim for linear attention. + Defaults to ``128``. + linear_value_head_dim (int): Value head dim for linear attention. + Defaults to ``128``. + linear_conv_kernel_dim (int): Conv kernel size for linear + attention. Defaults to ``4``. + dtype: string or ``keras.mixed_precision.DTypePolicy``. The + dtype to use for model computations and weights. + """ + + def __init__( + self, + vocabulary_size, + num_layers, + num_query_heads, + num_key_value_heads, + head_dim, + hidden_dim, + intermediate_dim, + layer_types=None, + partial_rotary_factor=0.25, + rope_max_wavelength=10000, + rope_scaling_factor=1.0, + layer_norm_epsilon=1e-6, + dropout=0.0, + tie_word_embeddings=False, + sliding_window_size=32768, + linear_num_key_heads=16, + linear_num_value_heads=32, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_conv_kernel_dim=4, + dtype=None, + **kwargs, + ): + # Default layer_types: every 4th layer is full_attention. + if layer_types is None: + layer_types = [ + ("linear_attention" if bool((i + 1) % 4) else "full_attention") + for i in range(num_layers) + ] + + # === Layers === + self.token_embedding = ReversibleEmbedding( + input_dim=vocabulary_size, + output_dim=hidden_dim, + tie_weights=tie_word_embeddings, + embeddings_initializer=_qwen3_5_kernel_initializer(stddev=0.01), + dtype=dtype, + name="token_embedding", + ) + self.transformer_layers = [] + for i in range(num_layers): + layer = Qwen3_5TransformerDecoder( + layer_type=layer_types[i], + intermediate_dim=intermediate_dim, + head_dim=head_dim, + num_query_heads=num_query_heads, + num_key_value_heads=num_key_value_heads, + partial_rotary_factor=partial_rotary_factor, + rope_max_wavelength=rope_max_wavelength, + rope_scaling_factor=rope_scaling_factor, + layer_norm_epsilon=layer_norm_epsilon, + activation=ops.silu, + kernel_initializer=_qwen3_5_kernel_initializer(stddev=0.02), + dropout=dropout, + sliding_window_size=sliding_window_size, + linear_num_key_heads=linear_num_key_heads, + linear_num_value_heads=linear_num_value_heads, + linear_key_head_dim=linear_key_head_dim, + linear_value_head_dim=linear_value_head_dim, + linear_conv_kernel_dim=linear_conv_kernel_dim, + dtype=dtype, + name=f"transformer_layer_{i}", + ) + self.transformer_layers.append(layer) + self.layer_norm = Qwen3_5LayerNorm( + epsilon=layer_norm_epsilon, + dtype=dtype, + name="sequence_output_layernorm", + ) + + # === Functional Model === + token_id_input = keras.Input( + shape=(None,), dtype="int32", name="token_ids" + ) + padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) + x = self.token_embedding(token_id_input) + for transformer_layer in self.transformer_layers: + x = transformer_layer(x, decoder_padding_mask=padding_mask_input) + sequence_output = self.layer_norm(x) + super().__init__( + inputs={ + "token_ids": token_id_input, + "padding_mask": padding_mask_input, + }, + outputs=sequence_output, + dtype=dtype, + **kwargs, + ) + + # === Config === + self.vocabulary_size = vocabulary_size + self.num_layers = num_layers + self.num_query_heads = num_query_heads + self.hidden_dim = hidden_dim + self.head_dim = head_dim + self.intermediate_dim = intermediate_dim + self.layer_types = layer_types + self.partial_rotary_factor = partial_rotary_factor + self.rope_max_wavelength = rope_max_wavelength + self.num_key_value_heads = num_key_value_heads + self.rope_scaling_factor = rope_scaling_factor + self.layer_norm_epsilon = layer_norm_epsilon + self.dropout = dropout + self.tie_word_embeddings = tie_word_embeddings + self.sliding_window_size = sliding_window_size + self.linear_num_key_heads = linear_num_key_heads + self.linear_num_value_heads = linear_num_value_heads + self.linear_key_head_dim = linear_key_head_dim + self.linear_value_head_dim = linear_value_head_dim + self.linear_conv_kernel_dim = linear_conv_kernel_dim + + def get_config(self): + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "num_query_heads": self.num_query_heads, + "hidden_dim": self.hidden_dim, + "head_dim": self.head_dim, + "intermediate_dim": self.intermediate_dim, + "layer_types": self.layer_types, + "partial_rotary_factor": self.partial_rotary_factor, + "rope_max_wavelength": self.rope_max_wavelength, + "rope_scaling_factor": self.rope_scaling_factor, + "num_key_value_heads": self.num_key_value_heads, + "layer_norm_epsilon": self.layer_norm_epsilon, + "dropout": self.dropout, + "tie_word_embeddings": self.tie_word_embeddings, + "sliding_window_size": self.sliding_window_size, + "linear_num_key_heads": self.linear_num_key_heads, + "linear_num_value_heads": self.linear_num_value_heads, + "linear_key_head_dim": self.linear_key_head_dim, + "linear_value_head_dim": self.linear_value_head_dim, + "linear_conv_kernel_dim": self.linear_conv_kernel_dim, + } + ) + return config diff --git a/keras_hub/src/models/qwen3_5/qwen3_5_backbone_test.py b/keras_hub/src/models/qwen3_5/qwen3_5_backbone_test.py new file mode 100644 index 0000000000..ce91fb52b7 --- /dev/null +++ b/keras_hub/src/models/qwen3_5/qwen3_5_backbone_test.py @@ -0,0 +1,56 @@ +import pytest +from keras import ops + +from keras_hub.src.models.qwen3_5.qwen3_5_backbone import Qwen3_5Backbone +from keras_hub.src.tests.test_case import TestCase + + +class Qwen3_5BackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "vocabulary_size": 10, + "num_layers": 4, + "num_query_heads": 4, + "num_key_value_heads": 2, + "head_dim": 4, + "hidden_dim": 16, + "intermediate_dim": 32, + "layer_types": [ + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + ], + "partial_rotary_factor": 0.25, + "linear_num_key_heads": 2, + "linear_num_value_heads": 4, + "linear_key_head_dim": 4, + "linear_value_head_dim": 4, + "linear_conv_kernel_dim": 4, + } + self.input_data = { + "token_ids": ops.ones((2, 5), dtype="int32"), + "padding_mask": ops.ones((2, 5), dtype="int32"), + } + + def test_backbone_basics(self): + self.run_backbone_test( + cls=Qwen3_5Backbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 5, 16), + run_quantization_check=False, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=Qwen3_5Backbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + def test_num_parameters(self): + model = Qwen3_5Backbone(**self.init_kwargs) + # Just verify the model builds and has params. + self.assertGreater(model.count_params(), 0) diff --git a/keras_hub/src/models/qwen3_5/qwen3_5_causal_lm.py b/keras_hub/src/models/qwen3_5/qwen3_5_causal_lm.py new file mode 100644 index 0000000000..a70420f7ff --- /dev/null +++ b/keras_hub/src/models/qwen3_5/qwen3_5_causal_lm.py @@ -0,0 +1,207 @@ +import keras +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.causal_lm import CausalLM +from keras_hub.src.models.qwen3_5.qwen3_5_backbone import Qwen3_5Backbone +from keras_hub.src.models.qwen3_5.qwen3_5_causal_lm_preprocessor import ( + Qwen3_5CausalLMPreprocessor, +) +from keras_hub.src.utils.tensor_utils import any_equal + + +@keras_hub_export("keras_hub.models.Qwen3_5CausalLM") +class Qwen3_5CausalLM(CausalLM): + """An end-to-end Qwen3.5 model for causal language modeling. + + This model predicts the next token based on previous tokens using the + Qwen3.5 hybrid architecture (full attention + GatedDeltaNet linear + attention layers). + + This model has a ``generate()`` method for autoregressive text + generation. + + Args: + backbone: A ``keras_hub.models.Qwen3_5Backbone`` instance. + preprocessor: A ``keras_hub.models.Qwen3_5CausalLMPreprocessor`` + or ``None``. + """ + + backbone_cls = Qwen3_5Backbone + preprocessor_cls = Qwen3_5CausalLMPreprocessor + + def __init__(self, backbone, preprocessor=None, **kwargs): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + # === Functional Model === + inputs = backbone.input + hidden_states = backbone(inputs) + outputs = backbone.token_embedding(hidden_states, reverse=True) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + def call_with_cache( + self, + token_ids, + cache, + cache_update_index, + ): + """Forward pass with cache for autoregressive decoding. + + Only full_attention layers use the KV cache. Linear attention + layers (GatedDeltaNet) do not cache and process tokens + independently per step. + + Args: + token_ids: Dense int tensor (batch_size, max_length). + cache: Dense float tensor, the KV cache. + cache_update_index: Int or int tensor, current step index. + + Returns: + (logits, hidden_states, cache) tuple. + """ + x = self.backbone.token_embedding(token_ids) + updated_cache = [] + for i in range(self.backbone.num_layers): + layer = self.backbone.transformer_layers[i] + if layer.layer_type == "full_attention": + current_cache = cache[:, i, ...] + x, next_cache = layer( + x, + self_attention_cache=current_cache, + self_attention_cache_update_index=(cache_update_index), + ) + updated_cache.append(next_cache) + else: + # Linear attention layers don't use KV cache. + x = layer(x) + # Append a zero placeholder to keep cache shape. + updated_cache.append(cache[:, i, ...]) + cache = ops.stack(updated_cache, axis=1) + hidden_states = x = self.backbone.layer_norm(x) + logits = self.backbone.token_embedding(x, reverse=True) + return logits, hidden_states, cache + + def _build_cache(self, token_ids): + """Build an empty cache for use with ``call_with_cache()``.""" + batch_size = ops.shape(token_ids)[0] + max_length = ops.shape(token_ids)[1] + num_layers = self.backbone.num_layers + num_key_value_heads = self.backbone.num_key_value_heads + head_dim = self.backbone.head_dim + shape = [ + batch_size, + num_layers, + 2, + max_length, + num_key_value_heads, + head_dim, + ] + cache = ops.zeros(shape, dtype=self.compute_dtype) + _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) + return hidden_states, cache + + def generate_step(self, inputs, stop_token_ids=None): + """A compilable generation function for a single batch.""" + token_ids = inputs["token_ids"] + padding_mask = inputs["padding_mask"] + hidden_states, cache = self._build_cache(token_ids) + row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) + index = ops.min(row_lengths) + + def next(prompt, cache, index): + cache_update_index = index - 1 + batch_size = ops.shape(prompt)[0] + prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) + logits, hidden_states, cache = self.call_with_cache( + prompt, cache, cache_update_index + ) + return ( + ops.squeeze(logits, axis=1), + ops.squeeze(hidden_states, axis=1), + cache, + ) + + token_ids = self.sampler( + next=next, + prompt=token_ids, + cache=cache, + index=index, + mask=padding_mask, + stop_token_ids=stop_token_ids, + hidden_states=hidden_states, + model=self, + ) + + if stop_token_ids is not None: + end_locations = any_equal( + token_ids, + stop_token_ids, + ops.logical_not(padding_mask), + ) + end_locations = ops.cast(end_locations, "int32") + cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") + overflow = cumsum - end_locations + padding_mask = ops.logical_not(ops.cast(overflow, "bool")) + else: + padding_mask = ops.ones_like(token_ids, dtype="bool") + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + + def score( + self, + token_ids, + padding_mask=None, + scoring_mode="logits", + layer_intercept_fn=None, + target_ids=None, + ): + """Score a generation represented by the provided token ids.""" + if scoring_mode not in ("logits", "loss"): + raise ValueError( + "Unsupported scoring_mode. Must be 'logits' or 'loss'." + ) + if scoring_mode == "loss" and target_ids is None: + raise ValueError( + "Cannot compute loss without targets. Please provide " + "target token ids via the target_ids parameter." + ) + + batch_shape = ops.shape(token_ids)[:2] + assert len(batch_shape) == 2 + + if padding_mask is None: + padding_mask = ops.ones(shape=batch_shape) + + if layer_intercept_fn is None: + + def default_layer_intercept_fn(x, unused_i): + return x + + layer_intercept_fn = default_layer_intercept_fn + + token_embeddings = self.backbone.token_embedding(token_ids) + x = layer_intercept_fn(token_embeddings, -1) + + for i, transformer_layer in enumerate(self.backbone.transformer_layers): + x = transformer_layer(x, decoder_padding_mask=padding_mask) + x = layer_intercept_fn(x, i) + + x = self.backbone.layer_norm(x) + logits = self.backbone.token_embedding(x, reverse=True) + + if scoring_mode == "logits": + return logits + + per_token_loss_fn = keras.losses.SparseCategoricalCrossentropy( + from_logits=True, reduction="none" + ) + per_token_loss = per_token_loss_fn(target_ids, logits) + return per_token_loss diff --git a/keras_hub/src/models/qwen3_5/qwen3_5_causal_lm_preprocessor.py b/keras_hub/src/models/qwen3_5/qwen3_5_causal_lm_preprocessor.py new file mode 100644 index 0000000000..0a04483a33 --- /dev/null +++ b/keras_hub/src/models/qwen3_5/qwen3_5_causal_lm_preprocessor.py @@ -0,0 +1,10 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor +from keras_hub.src.models.qwen3_5.qwen3_5_backbone import Qwen3_5Backbone +from keras_hub.src.models.qwen3_5.qwen3_5_tokenizer import Qwen3_5Tokenizer + + +@keras_hub_export("keras_hub.models.Qwen3_5CausalLMPreprocessor") +class Qwen3_5CausalLMPreprocessor(CausalLMPreprocessor): + backbone_cls = Qwen3_5Backbone + tokenizer_cls = Qwen3_5Tokenizer diff --git a/keras_hub/src/models/qwen3_5/qwen3_5_causal_lm_preprocessor_test.py b/keras_hub/src/models/qwen3_5/qwen3_5_causal_lm_preprocessor_test.py new file mode 100644 index 0000000000..9966102b63 --- /dev/null +++ b/keras_hub/src/models/qwen3_5/qwen3_5_causal_lm_preprocessor_test.py @@ -0,0 +1,84 @@ +from keras_hub.src.models.qwen3_5.qwen3_5_causal_lm_preprocessor import ( + Qwen3_5CausalLMPreprocessor, +) +from keras_hub.src.models.qwen3_5.qwen3_5_tokenizer import Qwen3_5Tokenizer +from keras_hub.src.tests.test_case import TestCase + + +class Qwen3_5CausalLMPreprocessorTest(TestCase): + def setUp(self): + self.vocab = ["!", "air", "\u0120air", "plane", "\u0120at", "port"] + self.vocab += ["<|im_end|>", "<|endoftext|>"] + self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) + self.merges = [ + "\u0120 a", + "\u0120 t", + "\u0120 i", + "\u0120 b", + "a i", + "p l", + "n e", + ] + self.merges += [ + "\u0120a t", + "p o", + "r t", + "\u0120t h", + "ai r", + "pl a", + "po rt", + ] + self.merges += ["\u0120ai r", "\u0120a i", "pla ne"] + self.tokenizer = Qwen3_5Tokenizer( + vocabulary=self.vocab, + merges=self.merges, + ) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 8, + } + self.input_data = ["airplane at airport"] + + def test_preprocessor_basics(self): + self.run_preprocessor_test( + cls=Qwen3_5CausalLMPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=( + { + "token_ids": [[1, 3, 4, 2, 5, 6, 7, 7]], + "padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]], + }, + [[3, 4, 2, 5, 6, 7, 7, 7]], + [[1, 1, 1, 1, 1, 0, 0, 0]], + ), + ) + + def test_with_start_end_token(self): + input_data = ["airplane at airport"] * 4 + preprocessor = Qwen3_5CausalLMPreprocessor( + **self.init_kwargs, + add_start_token=True, + add_end_token=True, + ) + x, y, sw = preprocessor(input_data) + self.assertAllEqual(x["token_ids"], [[1, 3, 4, 2, 5, 6, 7, 7]] * 4) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 1, 0, 0]] * 4) + self.assertAllEqual(y, [[3, 4, 2, 5, 6, 7, 7, 7]] * 4) + self.assertAllEqual(sw, [[1, 1, 1, 1, 1, 0, 0, 0]] * 4) + + def test_generate_preprocess(self): + input_data = "airplane at airport" + preprocessor = Qwen3_5CausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_preprocess(input_data) + self.assertAllEqual(x["token_ids"], [1, 3, 4, 2, 5, 7, 7, 7]) + self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 0, 0, 0]) + + def test_generate_postprocess(self): + input_data = { + "token_ids": [1, 3, 4, 2, 5, 7, 7, 7], + "padding_mask": [1, 1, 1, 1, 1, 0, 0, 0], + } + preprocessor = Qwen3_5CausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_postprocess(input_data) + self.assertAllEqual(x, "airplane at airport") diff --git a/keras_hub/src/models/qwen3_5/qwen3_5_causal_lm_test.py b/keras_hub/src/models/qwen3_5/qwen3_5_causal_lm_test.py new file mode 100644 index 0000000000..ceed309106 --- /dev/null +++ b/keras_hub/src/models/qwen3_5/qwen3_5_causal_lm_test.py @@ -0,0 +1,147 @@ +from unittest.mock import patch + +import pytest +from keras import ops + +from keras_hub.src.models.qwen3_5.qwen3_5_backbone import Qwen3_5Backbone +from keras_hub.src.models.qwen3_5.qwen3_5_causal_lm import Qwen3_5CausalLM +from keras_hub.src.models.qwen3_5.qwen3_5_causal_lm_preprocessor import ( + Qwen3_5CausalLMPreprocessor, +) +from keras_hub.src.models.qwen3_5.qwen3_5_tokenizer import Qwen3_5Tokenizer +from keras_hub.src.tests.test_case import TestCase + + +class Qwen3_5CausalLMTest(TestCase): + def setUp(self): + self.vocab = ["!", "air", "\u0120air", "plane", "\u0120at", "port"] + self.vocab += ["<|endoftext|>"] + self.vocab += ["<|im_end|>"] + self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) + self.merges = [ + "\u0120 a", + "\u0120 t", + "\u0120 i", + "\u0120 b", + "a i", + "p l", + "n e", + ] + self.merges += [ + "\u0120a t", + "p o", + "r t", + "\u0120t h", + "ai r", + "pl a", + "po rt", + ] + self.merges += ["\u0120ai r", "\u0120a i", "pla ne"] + self.preprocessor = Qwen3_5CausalLMPreprocessor( + Qwen3_5Tokenizer(vocabulary=self.vocab, merges=self.merges), + sequence_length=7, + ) + self.backbone = Qwen3_5Backbone( + vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(), + num_layers=4, + num_query_heads=4, + num_key_value_heads=2, + hidden_dim=8, + head_dim=4, + intermediate_dim=16, + layer_types=[ + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + ], + partial_rotary_factor=0.25, + linear_num_key_heads=2, + linear_num_value_heads=4, + linear_key_head_dim=4, + linear_value_head_dim=4, + linear_conv_kernel_dim=4, + ) + self.init_kwargs = { + "preprocessor": self.preprocessor, + "backbone": self.backbone, + } + self.train_data = ([" airplane at airport", " airplane at airport"],) + self.input_data = self.preprocessor(*self.train_data)[0] + + def test_causal_lm_basics(self): + self.run_task_test( + cls=Qwen3_5CausalLM, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 7, 8), + ) + + def test_generate(self): + causal_lm = Qwen3_5CausalLM(**self.init_kwargs) + # String input. + prompt = " airplane at airport" + output = causal_lm.generate(" airplane at airport") + self.assertTrue(prompt in output) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess([prompt]) + causal_lm.preprocessor = None + outputs = causal_lm.generate(prompt_ids, stop_token_ids=None) + self.assertAllEqual( + outputs["token_ids"][:, :5], + prompt_ids["token_ids"][:, :5], + ) + self.assertAllEqual( + outputs["padding_mask"][:, :5], + prompt_ids["padding_mask"][:, :5], + ) + + def test_generate_strip_prompt(self): + causal_lm = Qwen3_5CausalLM(**self.init_kwargs) + prompt = " airplane at airport" + output = causal_lm.generate(prompt, strip_prompt=True) + self.assertFalse(output.startswith(prompt)) + + def test_early_stopping(self): + causal_lm = Qwen3_5CausalLM(**self.init_kwargs) + call_with_cache = causal_lm.call_with_cache + + def wrapper(*args, **kwargs): + logits, hidden_states, cache = call_with_cache(*args, **kwargs) + index = self.preprocessor.tokenizer.end_token_id + update = ops.ones_like(logits)[:, :, index] * 1.0e9 + update = ops.expand_dims(update, axis=-1) + logits = ops.slice_update(logits, (0, 0, index), update) + return logits, hidden_states, cache + + with patch.object(causal_lm, "call_with_cache", wraps=wrapper): + prompt = [" airplane at airport", " airplane"] + output = causal_lm.generate(prompt) + self.assertEqual(prompt, output) + + def test_generate_compilation(self): + causal_lm = Qwen3_5CausalLM(**self.init_kwargs) + causal_lm.generate(" airplane at airport") + first_fn = causal_lm.generate_function + causal_lm.generate(" airplane at airport") + second_fn = causal_lm.generate_function + self.assertEqual(first_fn, second_fn) + causal_lm.compile(sampler="greedy") + self.assertIsNone(causal_lm.generate_function) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=Qwen3_5CausalLM, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in Qwen3_5CausalLM.presets: + self.run_preset_test( + cls=Qwen3_5CausalLM, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/qwen3_5/qwen3_5_decoder.py b/keras_hub/src/models/qwen3_5/qwen3_5_decoder.py new file mode 100644 index 0000000000..9d5df2a8f6 --- /dev/null +++ b/keras_hub/src/models/qwen3_5/qwen3_5_decoder.py @@ -0,0 +1,308 @@ +import keras +from keras import ops + +from keras_hub.src.layers.modeling.transformer_layer_utils import ( + compute_causal_mask, +) +from keras_hub.src.layers.modeling.transformer_layer_utils import ( + merge_padding_and_attention_mask, +) +from keras_hub.src.models.qwen3_5.qwen3_5_attention import Qwen3_5Attention +from keras_hub.src.models.qwen3_5.qwen3_5_gated_delta_net import ( + Qwen3_5GatedDeltaNet, +) +from keras_hub.src.models.qwen3_5.qwen3_5_layernorm import Qwen3_5LayerNorm +from keras_hub.src.utils.keras_utils import clone_initializer + + +class Qwen3_5TransformerDecoder(keras.layers.Layer): + """A Transformer decoder layer for Qwen3.5. + + Dispatches between full self-attention and linear attention + (GatedDeltaNet) based on ``layer_type``. + + Args: + layer_type: One of ``"full_attention"`` or ``"linear_attention"``. + intermediate_dim: FFN intermediate dimension. + num_query_heads: Number of query attention heads. + num_key_value_heads: Number of key/value attention heads (GQA). + head_dim: Dimension of each attention head. + partial_rotary_factor: Fraction of head_dim that gets RoPE. + rope_max_wavelength: Maximum wavelength for rotary embeddings. + rope_scaling_factor: Scaling factor for rotary embeddings. + activation: Activation function for the FFN. + layer_norm_epsilon: Epsilon for layer norms. + kernel_initializer: Initializer for projection kernels. + dropout: Dropout rate. + sliding_window_size: Sliding window size (full_attention only). + linear_num_key_heads: Number of key heads (linear_attention). + linear_num_value_heads: Number of value heads (linear_attention). + linear_key_head_dim: Key head dim (linear_attention). + linear_value_head_dim: Value head dim (linear_attention). + linear_conv_kernel_dim: Conv kernel size (linear_attention). + """ + + def __init__( + self, + layer_type, + intermediate_dim, + num_query_heads, + num_key_value_heads, + head_dim, + partial_rotary_factor=0.25, + rope_max_wavelength=10000, + rope_scaling_factor=1.0, + activation="silu", + layer_norm_epsilon=1e-6, + kernel_initializer="glorot_uniform", + dropout=0.0, + sliding_window_size=None, + linear_num_key_heads=16, + linear_num_value_heads=32, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_conv_kernel_dim=4, + **kwargs, + ): + super().__init__(**kwargs) + self.layer_type = layer_type + self.intermediate_dim = intermediate_dim + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.partial_rotary_factor = partial_rotary_factor + self.rope_max_wavelength = rope_max_wavelength + self.rope_scaling_factor = rope_scaling_factor + self.dropout = dropout + self.sliding_window_size = sliding_window_size + self.activation = keras.activations.get(activation) + self.layer_norm_epsilon = layer_norm_epsilon + self.kernel_initializer = keras.initializers.get(kernel_initializer) + self.linear_num_key_heads = linear_num_key_heads + self.linear_num_value_heads = linear_num_value_heads + self.linear_key_head_dim = linear_key_head_dim + self.linear_value_head_dim = linear_value_head_dim + self.linear_conv_kernel_dim = linear_conv_kernel_dim + self.supports_masking = True + + def build(self, decoder_sequence_shape): + self._decoder_sequence_shape = decoder_sequence_shape + self.hidden_dim = decoder_sequence_shape[-1] + + # Token mixer — dispatched by layer_type. + if self.layer_type == "linear_attention": + self._linear_attn = Qwen3_5GatedDeltaNet( + hidden_size=self.hidden_dim, + linear_num_key_heads=self.linear_num_key_heads, + linear_num_value_heads=self.linear_num_value_heads, + linear_key_head_dim=self.linear_key_head_dim, + linear_value_head_dim=self.linear_value_head_dim, + linear_conv_kernel_dim=self.linear_conv_kernel_dim, + layer_norm_epsilon=self.layer_norm_epsilon, + kernel_initializer=clone_initializer(self.kernel_initializer), + dtype=self.dtype_policy, + name="linear_attn", + ) + self._linear_attn.build(decoder_sequence_shape) + elif self.layer_type == "full_attention": + self._self_attention_layer = Qwen3_5Attention( + num_query_heads=self.num_query_heads, + num_key_value_heads=self.num_key_value_heads, + head_dim=self.head_dim, + partial_rotary_factor=self.partial_rotary_factor, + rope_max_wavelength=self.rope_max_wavelength, + rope_scaling_factor=self.rope_scaling_factor, + kernel_initializer=clone_initializer(self.kernel_initializer), + dropout=self.dropout, + layer_norm_epsilon=self.layer_norm_epsilon, + sliding_window_size=self.sliding_window_size, + dtype=self.dtype_policy, + name="self_attention", + ) + self._self_attention_layer.build(decoder_sequence_shape) + else: + raise ValueError( + f"Unknown layer_type '{self.layer_type}'. " + "Expected 'full_attention' or 'linear_attention'." + ) + + # Pre-norm. + self._input_layernorm = Qwen3_5LayerNorm( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="input_layernorm", + ) + self._input_layernorm.build(decoder_sequence_shape) + + self._self_attention_dropout = keras.layers.Dropout( + rate=self.dropout, + dtype=self.dtype_policy, + name="self_attention_dropout", + ) + + # Feedforward layers (SwiGLU). + self._feedforward_gate_dense = keras.layers.Dense( + self.intermediate_dim, + kernel_initializer=clone_initializer(self.kernel_initializer), + use_bias=False, + dtype=self.dtype_policy, + name="feedforward_gate_dense", + ) + self._feedforward_gate_dense.build(decoder_sequence_shape) + + self._feedforward_intermediate_dense = keras.layers.Dense( + self.intermediate_dim, + kernel_initializer=clone_initializer(self.kernel_initializer), + use_bias=False, + dtype=self.dtype_policy, + name="feedforward_intermediate_dense", + ) + self._feedforward_intermediate_dense.build(decoder_sequence_shape) + + self._feedforward_output_dense = keras.layers.Dense( + self.hidden_dim, + kernel_initializer=clone_initializer(self.kernel_initializer), + use_bias=False, + dtype=self.dtype_policy, + name="feedforward_output_dense", + ) + self._feedforward_output_dense.build( + self._feedforward_gate_dense.compute_output_shape( + decoder_sequence_shape + ) + ) + + self._post_attention_layernorm = Qwen3_5LayerNorm( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="post_attention_layernorm", + ) + self._post_attention_layernorm.build(decoder_sequence_shape) + + self.built = True + + def call( + self, + decoder_sequence, + decoder_padding_mask=None, + decoder_attention_mask=None, + self_attention_cache=None, + self_attention_cache_update_index=None, + training=None, + ): + residual = decoder_sequence + x = self._input_layernorm(decoder_sequence) + + # Token mixer. + if self.layer_type == "linear_attention": + # GatedDeltaNet uses only a 2D padding mask. + x = self._linear_attn( + x, + attention_mask=decoder_padding_mask, + training=training, + ) + elif self.layer_type == "full_attention": + self_attention_mask = self._compute_self_attention_mask( + decoder_sequence=decoder_sequence, + decoder_padding_mask=decoder_padding_mask, + decoder_attention_mask=decoder_attention_mask, + self_attention_cache=self_attention_cache, + self_attention_cache_update_index=( + self_attention_cache_update_index + ), + ) + x = self._self_attention_layer( + hidden_states=x, + attention_mask=self_attention_mask, + cache=self_attention_cache, + cache_update_index=self_attention_cache_update_index, + ) + if self_attention_cache is not None: + x, self_attention_cache = x + + x = self._self_attention_dropout(x, training=training) + x = x + residual + + # Feedforward block (SwiGLU). + residual = x + x = self._post_attention_layernorm(x) + gate_output = self._feedforward_gate_dense(x) + + gate_output = ops.cast(gate_output, "float32") + gate_output = self.activation(gate_output) + gate_output = ops.cast(gate_output, self.compute_dtype) + + x = self._feedforward_intermediate_dense(x) + x = self._feedforward_output_dense(ops.multiply(x, gate_output)) + + decoder_output = x + residual + + if self_attention_cache is not None: + return decoder_output, self_attention_cache + return decoder_output + + def _compute_self_attention_mask( + self, + decoder_sequence, + decoder_padding_mask, + decoder_attention_mask, + self_attention_cache, + self_attention_cache_update_index, + ): + decoder_mask = merge_padding_and_attention_mask( + decoder_sequence, + decoder_padding_mask, + decoder_attention_mask, + ) + batch_size = ops.shape(decoder_sequence)[0] + input_length = output_length = ops.shape(decoder_sequence)[1] + if self_attention_cache is not None: + input_length = ops.shape(self_attention_cache)[2] + + cache_update_index = ( + 0 + if self_attention_cache_update_index is None + else self_attention_cache_update_index + ) + causal_mask = compute_causal_mask( + batch_size, + input_length, + output_length, + cache_update_index, + ) + return ( + ops.minimum(decoder_mask, causal_mask) + if decoder_mask is not None + else causal_mask + ) + + def compute_output_shape(self, decoder_sequence_shape): + return decoder_sequence_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "layer_type": self.layer_type, + "intermediate_dim": self.intermediate_dim, + "num_query_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "head_dim": self.head_dim, + "partial_rotary_factor": self.partial_rotary_factor, + "rope_max_wavelength": self.rope_max_wavelength, + "rope_scaling_factor": self.rope_scaling_factor, + "activation": keras.activations.serialize(self.activation), + "layer_norm_epsilon": self.layer_norm_epsilon, + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "dropout": self.dropout, + "sliding_window_size": self.sliding_window_size, + "linear_num_key_heads": self.linear_num_key_heads, + "linear_num_value_heads": self.linear_num_value_heads, + "linear_key_head_dim": self.linear_key_head_dim, + "linear_value_head_dim": self.linear_value_head_dim, + "linear_conv_kernel_dim": self.linear_conv_kernel_dim, + } + ) + return config diff --git a/keras_hub/src/models/qwen3_5/qwen3_5_gated_delta_net.py b/keras_hub/src/models/qwen3_5/qwen3_5_gated_delta_net.py new file mode 100644 index 0000000000..05b165849d --- /dev/null +++ b/keras_hub/src/models/qwen3_5/qwen3_5_gated_delta_net.py @@ -0,0 +1,450 @@ +"""Gated Delta Net linear attention layer for Qwen3.5. + +This implements a recurrent linear attention mechanism that replaces +standard softmax attention in some layers. It uses: +- Causal Conv1d for local context mixing +- Delta rule recurrence for long-range memory +- Gating mechanisms (beta for write, g for decay, z for output) + +Reference: HF transformers Qwen3NextGatedDeltaNet / Qwen3_5GatedDeltaNet +""" + +import keras +from keras import ops + +from keras_hub.src.models.qwen3_5.qwen3_5_layernorm import Qwen3_5LayerNorm + + +def _l2norm(x, axis=-1, eps=1e-6): + """L2 normalize along the given axis.""" + inv_norm = ops.rsqrt(ops.sum(x * x, axis=axis, keepdims=True) + eps) + return x * inv_norm + + +def _causal_conv1d(x, weight, bias=None): + """Apply depthwise causal conv1d. + + Args: + x: (batch, channels, seq_len) + weight: (channels, 1, kernel_size) or (channels, kernel_size) + bias: (channels,) or None + + Returns: + (batch, channels, seq_len) + """ + if weight.ndim == 2: + weight = ops.expand_dims(weight, 1) + kernel_size = ops.shape(weight)[-1] + channels = ops.shape(x)[1] + + # Left-pad for causal convolution. + x_padded = ops.pad( + x, + [[0, 0], [0, 0], [kernel_size - 1, 0]], + ) + + # Depthwise conv1d: process each channel independently. + # Reshape for grouped conv: (batch, 1, seq, channels) + x_padded = ops.transpose(x_padded, (0, 2, 1)) + x_padded = ops.expand_dims(x_padded, 1) + + # Weight shape for conv: (kernel_size, 1, channels) + # Flip weight for cross-correlation -> convolution. + w = ops.transpose(weight, (2, 1, 0)) + w = ops.flip(w, axis=0) + + # Use depthwise conv. + out = ops.depthwise_conv(x_padded, w, strides=1, padding="valid") + + # out shape: (batch, seq_len, 1, channels) -> (batch, channels, seq_len) + out = ops.squeeze(out, axis=2) + out = ops.transpose(out, (0, 2, 1)) + + if bias is not None: + out = out + ops.reshape(bias, (1, channels, 1)) + return out + + +def _chunk_gated_delta_rule( + query, + key, + value, + g, + beta, + chunk_size=64, + initial_state=None, + output_final_state=False, +): + """Chunked gated delta rule for training (parallel over chunks). + + Args: + query: (B, seq, num_heads, head_k_dim) + key: (B, seq, num_heads, head_k_dim) + value: (B, seq, num_heads, head_v_dim) + g: (B, seq, num_heads) — decay gates (log-space) + beta: (B, seq, num_heads) — write gates (sigmoid-space) + chunk_size: Chunk size for blocked computation. + initial_state: Optional initial recurrent state. + output_final_state: Whether to return final state. + + Returns: + output: (B, seq, num_heads, head_v_dim) + final_state: recurrent state or None + """ + # L2-normalize Q and K. + query = _l2norm(query, axis=-1) + key = _l2norm(key, axis=-1) + + # Transpose to (B, heads, seq, dim). + query = ops.transpose(query, (0, 2, 1, 3)) + key = ops.transpose(key, (0, 2, 1, 3)) + value = ops.transpose(value, (0, 2, 1, 3)) + beta = ops.transpose(beta, (0, 2, 1)) + g = ops.transpose(g, (0, 2, 1)) + + # Cast to float32 for numerical stability. + input_dtype = query.dtype + query = ops.cast(query, "float32") + key = ops.cast(key, "float32") + value = ops.cast(value, "float32") + beta = ops.cast(beta, "float32") + g = ops.cast(g, "float32") + + batch_size = ops.shape(key)[0] + num_heads = ops.shape(key)[1] + seq_len = ops.shape(key)[2] + k_head_dim = ops.shape(key)[3] + v_head_dim = ops.shape(value)[3] + + scale = 1.0 / (k_head_dim**0.5) + query = query * scale + + v_beta = value * ops.expand_dims(beta, -1) + + # Initialize recurrent state. + if initial_state is None: + state = ops.zeros( + (batch_size, num_heads, k_head_dim, v_head_dim), + dtype="float32", + ) + else: + state = ops.cast(initial_state, "float32") + + # Process chunks using a simple loop. + # For simplicity, we process the entire sequence in one pass + # using the recurrent formulation (equivalent to chunked but + # without the chunked optimization for now). + + outputs = [] + for t in range(seq_len): + q_t = query[:, :, t, :] + k_t = key[:, :, t, :] + v_beta_t = v_beta[:, :, t, :] + g_t = g[:, :, t] + + # Decay the state. + decay = ops.exp(ops.expand_dims(ops.expand_dims(g_t, -1), -1)) + state = state * decay + + # Delta update: compute what the current state predicts + # for v given k, then add correction. + kv_pred = ops.sum(state * ops.expand_dims(k_t, -1), axis=-2) + delta = v_beta_t - kv_pred * ops.expand_dims(beta[:, :, t], -1) + state = state + ops.expand_dims(k_t, -1) * ops.expand_dims(delta, -2) + + # Query the state. + out_t = ops.sum(state * ops.expand_dims(q_t, -1), axis=-2) + outputs.append(out_t) + + output = ops.stack(outputs, axis=2) + + final_state = state if output_final_state else None + + # Transpose back to (B, seq, heads, v_dim). + output = ops.transpose(output, (0, 2, 1, 3)) + output = ops.cast(output, input_dtype) + + return output, final_state + + +def _recurrent_gated_delta_rule( + query, + key, + value, + g, + beta, + initial_state=None, + output_final_state=False, +): + """Step-by-step recurrent gated delta rule for inference. + + Same signature as _chunk_gated_delta_rule but processes one step + at a time (optimized for autoregressive generation). + """ + return _chunk_gated_delta_rule( + query, + key, + value, + g, + beta, + chunk_size=1, + initial_state=initial_state, + output_final_state=output_final_state, + ) + + +class Qwen3_5GatedDeltaNet(keras.layers.Layer): + """Gated Delta Net linear attention for Qwen3.5. + + Replaces standard self-attention in ``linear_attention`` layers. + Uses a delta rule recurrence with gating for efficient + sequence modeling. + + Args: + hidden_size: Model hidden dimension. + linear_num_key_heads: Number of key heads for linear attention. + linear_num_value_heads: Number of value heads for linear + attention. + linear_key_head_dim: Dimension per key head. + linear_value_head_dim: Dimension per value head. + linear_conv_kernel_dim: Kernel size for causal conv1d. + hidden_activation: Activation function name. + layer_norm_epsilon: Epsilon for RMSNorm. + kernel_initializer: Initializer for dense layers. + """ + + def __init__( + self, + hidden_size, + linear_num_key_heads, + linear_num_value_heads, + linear_key_head_dim, + linear_value_head_dim, + linear_conv_kernel_dim=4, + hidden_activation="silu", + layer_norm_epsilon=1e-6, + kernel_initializer="glorot_uniform", + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.num_k_heads = linear_num_key_heads + self.num_v_heads = linear_num_value_heads + self.head_k_dim = linear_key_head_dim + self.head_v_dim = linear_value_head_dim + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + self.conv_kernel_size = linear_conv_kernel_dim + self.hidden_activation = hidden_activation + self.layer_norm_epsilon = layer_norm_epsilon + self.kernel_initializer = keras.initializers.get(kernel_initializer) + + def build(self, input_shape): + # Qwen3.5 has separate projections for QKV, Z, B, A. + # QKV fused projection. + self.in_proj_qkv = keras.layers.Dense( + self.key_dim * 2 + self.value_dim, + use_bias=False, + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="in_proj_qkv", + ) + self.in_proj_qkv.build(input_shape) + + # Z (output gate) projection. + self.in_proj_z = keras.layers.Dense( + self.value_dim, + use_bias=False, + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="in_proj_z", + ) + self.in_proj_z.build(input_shape) + + # Beta (write gate) projection. + self.in_proj_b = keras.layers.Dense( + self.num_v_heads, + use_bias=False, + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="in_proj_b", + ) + self.in_proj_b.build(input_shape) + + # A (decay gate) projection. + self.in_proj_a = keras.layers.Dense( + self.num_v_heads, + use_bias=False, + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="in_proj_a", + ) + self.in_proj_a.build(input_shape) + + # Causal conv1d (depthwise). + conv_dim = self.key_dim * 2 + self.value_dim + self.conv1d_weight = self.add_weight( + name="conv1d/kernel", + shape=(conv_dim, self.conv_kernel_size), + initializer="glorot_uniform", + dtype=self.variable_dtype, + ) + self.conv1d_bias = self.add_weight( + name="conv1d/bias", + shape=(conv_dim,), + initializer="zeros", + dtype=self.variable_dtype, + ) + + # dt_bias and A_log (learnable parameters for decay). + self.dt_bias = self.add_weight( + name="dt_bias", + shape=(self.num_v_heads,), + initializer="ones", + dtype=self.variable_dtype, + ) + self.A_log = self.add_weight( + name="A_log", + shape=(self.num_v_heads,), + initializer="zeros", + dtype=self.variable_dtype, + ) + + # Output gated RMSNorm. + self.norm = Qwen3_5LayerNorm( + head_dim=self.head_v_dim, + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="norm", + ) + self.norm.build((None, self.head_v_dim)) + + # Output projection. + self.out_proj = keras.layers.Dense( + self.hidden_size, + use_bias=False, + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="out_proj", + ) + self.out_proj.build((None, None, self.value_dim)) + + self.built = True + + def call(self, hidden_states, attention_mask=None, training=None): + """Forward pass. + + Args: + hidden_states: (B, seq_len, hidden_size) + attention_mask: Optional padding mask. + training: Whether in training mode. + + Returns: + output: (B, seq_len, hidden_size) + """ + # Mask padding states. + if attention_mask is not None: + # attention_mask: (B, seq_len) with 1 for valid, 0 for pad. + if attention_mask.ndim == 2: + mask = ops.cast( + ops.expand_dims(attention_mask, -1), + hidden_states.dtype, + ) + hidden_states = hidden_states * mask + + batch_size = ops.shape(hidden_states)[0] + seq_len = ops.shape(hidden_states)[1] + + # Project QKV. + mixed_qkv = self.in_proj_qkv(hidden_states) + + # Project gating signals. + z = self.in_proj_z(hidden_states) + z = ops.reshape( + z, (batch_size, seq_len, self.num_v_heads, self.head_v_dim) + ) + + b = self.in_proj_b(hidden_states) + a = self.in_proj_a(hidden_states) + + # Causal conv1d on QKV. + # Transpose to (B, channels, seq_len) for conv. + mixed_qkv_t = ops.transpose(mixed_qkv, (0, 2, 1)) + mixed_qkv_t = _causal_conv1d( + mixed_qkv_t, self.conv1d_weight, self.conv1d_bias + ) + # Apply SiLU activation after conv. + mixed_qkv_t = mixed_qkv_t * ops.sigmoid(mixed_qkv_t) + mixed_qkv = ops.transpose(mixed_qkv_t, (0, 2, 1)) + + # Split QKV. + query, key, value = ops.split( + mixed_qkv, + [self.key_dim, self.key_dim * 2], + axis=-1, + ) + + query = ops.reshape( + query, + (batch_size, seq_len, self.num_k_heads, self.head_k_dim), + ) + key = ops.reshape( + key, + (batch_size, seq_len, self.num_k_heads, self.head_k_dim), + ) + value = ops.reshape( + value, + (batch_size, seq_len, self.num_v_heads, self.head_v_dim), + ) + + # Compute decay gate. + beta = ops.sigmoid(b) + g = -ops.exp( + ops.cast(self.A_log, "float32") + ) * keras.activations.softplus( + ops.cast(a, "float32") + ops.cast(self.dt_bias, "float32") + ) + + # Expand K heads to match V heads if needed. + if self.num_v_heads // self.num_k_heads > 1: + repeat_factor = self.num_v_heads // self.num_k_heads + query = ops.repeat(query, repeats=repeat_factor, axis=2) + key = ops.repeat(key, repeats=repeat_factor, axis=2) + + # Apply gated delta rule. + core_out, _ = _chunk_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + output_final_state=False, + ) + + # Output gated normalization. + # Reshape to (B * seq * heads, v_dim) for norm. + core_out_flat = ops.reshape(core_out, (-1, self.head_v_dim)) + z_flat = ops.reshape(z, (-1, self.head_v_dim)) + core_out_flat = self.norm(core_out_flat) + core_out_flat = core_out_flat * ops.sigmoid(z_flat) + + # Reshape back and project. + core_out = ops.reshape(core_out_flat, (batch_size, seq_len, -1)) + output = self.out_proj(core_out) + return output + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "linear_num_key_heads": self.num_k_heads, + "linear_num_value_heads": self.num_v_heads, + "linear_key_head_dim": self.head_k_dim, + "linear_value_head_dim": self.head_v_dim, + "linear_conv_kernel_dim": self.conv_kernel_size, + "hidden_activation": self.hidden_activation, + "layer_norm_epsilon": self.layer_norm_epsilon, + } + ) + return config diff --git a/keras_hub/src/models/qwen3_5/qwen3_5_layernorm.py b/keras_hub/src/models/qwen3_5/qwen3_5_layernorm.py new file mode 100644 index 0000000000..8cd29c2761 --- /dev/null +++ b/keras_hub/src/models/qwen3_5/qwen3_5_layernorm.py @@ -0,0 +1,43 @@ +import keras +from keras import ops + + +class Qwen3_5LayerNorm(keras.layers.Layer): + """RMS normalization layer for Qwen3.5. + + Qwen3.5 uses a (1 + weight)-centered RMSNorm. Weights are initialized + to zero so the effective scale starts at 1.0. + """ + + def __init__(self, head_dim=None, epsilon=1e-6, **kwargs): + super().__init__(**kwargs) + self.head_dim = head_dim + self.epsilon = epsilon + + def build(self, input_shape): + dim = self.head_dim if self.head_dim else input_shape[-1] + self.scale = self.add_weight( + name="scale", + trainable=True, + shape=(dim,), + initializer="zeros", + dtype=self.variable_dtype, + ) + self.built = True + + def call(self, x): + input_dtype = x.dtype + x = ops.cast(x, "float32") + var = ops.mean(ops.power(x, 2), axis=-1, keepdims=True) + x = x * ops.rsqrt(var + self.epsilon) + return ops.cast(x * (1.0 + self.scale), input_dtype) + + def get_config(self): + config = super().get_config() + config.update( + { + "head_dim": self.head_dim, + "epsilon": self.epsilon, + } + ) + return config diff --git a/keras_hub/src/models/qwen3_5/qwen3_5_presets.py b/keras_hub/src/models/qwen3_5/qwen3_5_presets.py new file mode 100644 index 0000000000..7b85f2bf25 --- /dev/null +++ b/keras_hub/src/models/qwen3_5/qwen3_5_presets.py @@ -0,0 +1,3 @@ +"""Qwen3.5 model preset configurations.""" + +backbone_presets = {} diff --git a/keras_hub/src/models/qwen3_5/qwen3_5_tokenizer.py b/keras_hub/src/models/qwen3_5/qwen3_5_tokenizer.py new file mode 100644 index 0000000000..f03efae713 --- /dev/null +++ b/keras_hub/src/models/qwen3_5/qwen3_5_tokenizer.py @@ -0,0 +1,34 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.qwen3_5.qwen3_5_backbone import Qwen3_5Backbone +from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer + + +@keras_hub_export("keras_hub.models.Qwen3_5Tokenizer") +class Qwen3_5Tokenizer(BytePairTokenizer): + """Tokenizer for Qwen3.5 models. + + This tokenizer implements byte-pair encoding (BPE) for Qwen3.5 models. + + Args: + vocabulary: Dictionary mapping tokens to token IDs, or path to + vocabulary file. + merges: List of BPE merges, or path to merges file. + """ + + backbone_cls = Qwen3_5Backbone + + def __init__(self, vocabulary=None, merges=None, **kwargs): + eos_token = "<|im_end|>" + self._add_special_token(eos_token, "end_token") + + pad_token = "<|endoftext|>" + self._add_special_token(pad_token, "pad_token") + + self.start_token_id = None + self.start_token = None + + super().__init__( + vocabulary=vocabulary, + merges=merges, + **kwargs, + ) diff --git a/keras_hub/src/utils/transformers/convert_qwen3_5.py b/keras_hub/src/utils/transformers/convert_qwen3_5.py new file mode 100644 index 0000000000..05cd37fda1 --- /dev/null +++ b/keras_hub/src/utils/transformers/convert_qwen3_5.py @@ -0,0 +1,245 @@ +"""HF -> KerasHub weight converter for Qwen3.5.""" + +import numpy as np + +from keras_hub.src.models.qwen3_5.qwen3_5_backbone import Qwen3_5Backbone +from keras_hub.src.utils.preset_utils import load_json + +backbone_cls = Qwen3_5Backbone + + +def convert_backbone_config(transformers_config): + # tie_word_embeddings is at the top-level config. + tie_word_embeddings = transformers_config["tie_word_embeddings"] + + # Qwen3.5 text config is nested under "text_config". + if "text_config" in transformers_config: + transformers_config = transformers_config["text_config"] + + # rope_theta and partial_rotary_factor are nested under + # rope_parameters in the HF config. + rope_params = transformers_config["rope_parameters"] + + # Build layer_types list. + num_layers = transformers_config["num_hidden_layers"] + layer_types = transformers_config.get("layer_types", None) + if layer_types is None: + # Default: every 4th layer is full_attention. + layer_types = [ + ("linear_attention" if bool((i + 1) % 4) else "full_attention") + for i in range(num_layers) + ] + + return { + "vocabulary_size": transformers_config["vocab_size"], + "head_dim": transformers_config["head_dim"], + "hidden_dim": transformers_config["hidden_size"], + "num_layers": num_layers, + "num_query_heads": transformers_config["num_attention_heads"], + "num_key_value_heads": transformers_config["num_key_value_heads"], + "intermediate_dim": transformers_config["intermediate_size"], + "layer_norm_epsilon": transformers_config["rms_norm_eps"], + "rope_max_wavelength": rope_params["rope_theta"], + "partial_rotary_factor": rope_params["partial_rotary_factor"], + "tie_word_embeddings": tie_word_embeddings, + "layer_types": layer_types, + "linear_num_key_heads": transformers_config["linear_num_key_heads"], + "linear_num_value_heads": transformers_config["linear_num_value_heads"], + "linear_key_head_dim": transformers_config["linear_key_head_dim"], + "linear_value_head_dim": transformers_config["linear_value_head_dim"], + "linear_conv_kernel_dim": transformers_config["linear_conv_kernel_dim"], + } + + +def convert_weights(backbone, loader, transformers_config): + # Embedding. + loader.port_weight( + keras_variable=backbone.get_layer("token_embedding").embeddings, + hf_weight_key="model.embed_tokens.weight", + ) + if not backbone.tie_word_embeddings: + loader.port_weight( + keras_variable=backbone.get_layer( + "token_embedding" + ).reverse_embeddings, + hf_weight_key="lm_head.weight", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), + ) + + def transpose_and_reshape(x, shape): + return np.reshape(np.transpose(x), shape) + + for i in range(backbone.num_layers): + decoder_layer = backbone.get_layer(f"transformer_layer_{i}") + layer_type = decoder_layer.layer_type + prefix = f"model.layers.{i}" + + # Input layernorm. + loader.port_weight( + keras_variable=decoder_layer._input_layernorm.scale, + hf_weight_key=f"{prefix}.input_layernorm.weight", + ) + + if layer_type == "full_attention": + attn = decoder_layer._self_attention_layer + + # Q projection (includes gate: head_dim * 2). + loader.port_weight( + keras_variable=attn._query_dense.kernel, + hf_weight_key=f"{prefix}.self_attn.q_proj.weight", + hook_fn=transpose_and_reshape, + ) + # Q norm. + loader.port_weight( + keras_variable=attn._query_norm.scale, + hf_weight_key=f"{prefix}.self_attn.q_norm.weight", + ) + # K projection. + loader.port_weight( + keras_variable=attn._key_dense.kernel, + hf_weight_key=f"{prefix}.self_attn.k_proj.weight", + hook_fn=transpose_and_reshape, + ) + # K norm. + loader.port_weight( + keras_variable=attn._key_norm.scale, + hf_weight_key=f"{prefix}.self_attn.k_norm.weight", + ) + # V projection. + loader.port_weight( + keras_variable=attn._value_dense.kernel, + hf_weight_key=f"{prefix}.self_attn.v_proj.weight", + hook_fn=transpose_and_reshape, + ) + # Output projection. + loader.port_weight( + keras_variable=attn._output_dense.kernel, + hf_weight_key=f"{prefix}.self_attn.o_proj.weight", + hook_fn=transpose_and_reshape, + ) + + elif layer_type == "linear_attention": + gdn = decoder_layer._linear_attn + + # QKV fused projection. + loader.port_weight( + keras_variable=gdn.in_proj_qkv.kernel, + hf_weight_key=f"{prefix}.linear_attn.in_proj_qkv.weight", + hook_fn=lambda hf_tensor, _: np.transpose( + hf_tensor, axes=(1, 0) + ), + ) + # Z (output gate) projection. + loader.port_weight( + keras_variable=gdn.in_proj_z.kernel, + hf_weight_key=f"{prefix}.linear_attn.in_proj_z.weight", + hook_fn=lambda hf_tensor, _: np.transpose( + hf_tensor, axes=(1, 0) + ), + ) + # B (write gate) projection. + loader.port_weight( + keras_variable=gdn.in_proj_b.kernel, + hf_weight_key=f"{prefix}.linear_attn.in_proj_b.weight", + hook_fn=lambda hf_tensor, _: np.transpose( + hf_tensor, axes=(1, 0) + ), + ) + # A (decay gate) projection. + loader.port_weight( + keras_variable=gdn.in_proj_a.kernel, + hf_weight_key=f"{prefix}.linear_attn.in_proj_a.weight", + hook_fn=lambda hf_tensor, _: np.transpose( + hf_tensor, axes=(1, 0) + ), + ) + # Conv1d weight: HF shape (channels, 1, kernel_size) -> + # KerasHub shape (channels, kernel_size). + loader.port_weight( + keras_variable=gdn.conv1d_weight, + hf_weight_key=f"{prefix}.linear_attn.conv1d.weight", + hook_fn=lambda hf_tensor, _: np.squeeze(hf_tensor, axis=1), + ) + # Conv1d bias. + loader.port_weight( + keras_variable=gdn.conv1d_bias, + hf_weight_key=f"{prefix}.linear_attn.conv1d.bias", + ) + # dt_bias. + loader.port_weight( + keras_variable=gdn.dt_bias, + hf_weight_key=f"{prefix}.linear_attn.dt_bias", + ) + # A_log. + loader.port_weight( + keras_variable=gdn.A_log, + hf_weight_key=f"{prefix}.linear_attn.A_log", + ) + # Output gated RMSNorm. + loader.port_weight( + keras_variable=gdn.norm.scale, + hf_weight_key=f"{prefix}.linear_attn.norm.weight", + ) + # Output projection. + loader.port_weight( + keras_variable=gdn.out_proj.kernel, + hf_weight_key=f"{prefix}.linear_attn.out_proj.weight", + hook_fn=lambda hf_tensor, _: np.transpose( + hf_tensor, axes=(1, 0) + ), + ) + + # MLP layers (same for both layer types). + loader.port_weight( + keras_variable=( + decoder_layer._feedforward_intermediate_dense.kernel + ), + hf_weight_key=f"{prefix}.mlp.up_proj.weight", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), + ) + loader.port_weight( + keras_variable=(decoder_layer._feedforward_output_dense.kernel), + hf_weight_key=f"{prefix}.mlp.down_proj.weight", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), + ) + loader.port_weight( + keras_variable=(decoder_layer._feedforward_gate_dense.kernel), + hf_weight_key=f"{prefix}.mlp.gate_proj.weight", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), + ) + + # Post-attention layernorm. + loader.port_weight( + keras_variable=(decoder_layer._post_attention_layernorm.scale), + hf_weight_key=f"{prefix}.post_attention_layernorm.weight", + ) + + # Final normalization layer. + loader.port_weight( + keras_variable=backbone.get_layer("sequence_output_layernorm").scale, + hf_weight_key="model.norm.weight", + ) + + return backbone + + +def convert_tokenizer(cls, preset, **kwargs): + tokenizer_config = load_json(preset, "tokenizer.json") + vocab = tokenizer_config["model"]["vocab"] + merges = tokenizer_config["model"]["merges"] + merges = [" ".join(item) for item in merges] + + # Load all special tokens except "reserved" ones. + special_tokens = set() + for token in tokenizer_config["added_tokens"]: + if not token["content"].startswith("<|reserved_special_token_"): + vocab[token["content"]] = token["id"] + special_tokens.add(token["content"]) + + kwargs.update( + { + "unsplittable_tokens": list(special_tokens), + } + ) + + return cls(vocabulary=vocab, merges=merges, **kwargs) diff --git a/keras_hub/src/utils/transformers/preset_loader.py b/keras_hub/src/utils/transformers/preset_loader.py index 75bd8ab6d4..c717d6867d 100644 --- a/keras_hub/src/utils/transformers/preset_loader.py +++ b/keras_hub/src/utils/transformers/preset_loader.py @@ -24,6 +24,7 @@ from keras_hub.src.utils.transformers import convert_pali_gemma from keras_hub.src.utils.transformers import convert_qwen from keras_hub.src.utils.transformers import convert_qwen3 +from keras_hub.src.utils.transformers import convert_qwen3_5 from keras_hub.src.utils.transformers import convert_qwen3_moe from keras_hub.src.utils.transformers import convert_qwen_moe from keras_hub.src.utils.transformers import convert_sam3 @@ -82,6 +83,8 @@ def __init__(self, preset, config): self.converter = convert_qwen3_moe elif model_type == "qwen3": self.converter = convert_qwen3 + elif model_type == "qwen3_5": + self.converter = convert_qwen3_5 elif model_type == "sam3_video": self.converter = convert_sam3 elif model_type == "smollm3": diff --git a/tools/checkpoint_conversion/convert_qwen3_5_checkpoints.py b/tools/checkpoint_conversion/convert_qwen3_5_checkpoints.py new file mode 100644 index 0000000000..e039cc081c --- /dev/null +++ b/tools/checkpoint_conversion/convert_qwen3_5_checkpoints.py @@ -0,0 +1,258 @@ +"""Checkpoint conversion script for Qwen3.5 (text-only CausalLM). + +Usage: + python tools/checkpoint_conversion/convert_qwen3_5_checkpoints.py \ + --preset qwen3_5_7b_en + +This script: +1. Loads the HF model and tokenizer +2. Loads the KerasHub model via from_preset("hf://...") +3. Compares tokenizer outputs +4. Compares model logits (forward pass) +5. Compares generated text (greedy decoding) +6. Saves the KerasHub preset +""" + +import os +import random +import traceback + +os.environ["KERAS_BACKEND"] = "torch" +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Force CPU + +import numpy as np +import torch +from absl import app +from absl import flags + +random.seed(123) +torch.manual_seed(123) +device = torch.device("cpu") +torch.set_default_device(device) + +from keras import ops # noqa: E402 +from transformers import AutoModelForCausalLM # noqa: E402 +from transformers import AutoTokenizer # noqa: E402 + +import keras_hub # noqa: E402 + +PRESET_MAP = { + "qwen3_5_0.8b_base": "Qwen/Qwen3.5-0.8B-Base", + "qwen3_5_0.8b": "Qwen/Qwen3.5-0.8B", + "qwen3_5_2b_base": "Qwen/Qwen3.5-2B-Base", + "qwen3_5_2b": "Qwen/Qwen3.5-2B", + "qwen3_5_4b_base": "Qwen/Qwen3.5-4B-Base", + "qwen3_5_4b": "Qwen/Qwen3.5-4B", + "qwen3_5_9b_base": "Qwen/Qwen3.5-9B-Base", + "qwen3_5_9b": "Qwen/Qwen3.5-9B", + "qwen3_5_27b": "Qwen/Qwen3.5-27B", + "qwen3_5_35b_a3b_base": "Qwen/Qwen3.5-35B-A3B_Base", + "qwen3_5_35b_a3b": "Qwen/Qwen3.5-35B-A3B", +} + +FLAGS = flags.FLAGS +flags.DEFINE_string( + "preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}" +) + + +def test_model( + keras_hub_model, + keras_hub_tokenizer, + hf_model, + hf_model_tokenizer, +): + """Compare parameter counts and forward pass outputs.""" + # 1. Parameter count. + keras_hub_params = keras_hub_model.count_params() + hf_params = hf_model.num_parameters() + print(f"\n-> KerasHub params: {keras_hub_params:,}") + print(f"-> HF params: {hf_params:,}") + if keras_hub_params != hf_params: + print( + f"WARNING: Parameter count mismatch! " + f"Diff: {abs(keras_hub_params - hf_params):,}" + ) + else: + print("-> Parameter counts match!") + + # 2. Forward pass comparison. + test_text = "What is Keras?" + hf_inputs = hf_model_tokenizer([test_text], return_tensors="pt").to(device) + hf_outputs = hf_model(**hf_inputs) + hf_output_logits = hf_outputs.logits.detach().cpu().float().numpy() + + keras_hub_preprocessor = keras_hub.models.Qwen3_5CausalLMPreprocessor( + keras_hub_tokenizer + ) + seq_len = hf_inputs["input_ids"].shape[1] + keras_hub_inputs = keras_hub_preprocessor( + [test_text], sequence_length=seq_len + )[0] + + keras_hub_output = keras_hub_model(keras_hub_inputs) + keras_hub_logits = keras_hub_model.token_embedding( + keras_hub_output, reverse=True + ) + keras_hub_logits = ops.convert_to_numpy(keras_hub_logits) + + # Compare. + mean_diff = np.mean(np.abs(keras_hub_logits - hf_output_logits)) + max_diff = np.max(np.abs(keras_hub_logits - hf_output_logits)) + print(f"\n-> Logits mean absolute diff: {mean_diff:.6f}") + print(f"-> Logits max absolute diff: {max_diff:.6f}") + + try: + np.testing.assert_allclose( + keras_hub_logits, hf_output_logits, atol=1e-2 + ) + print("-> Forward pass outputs match! (atol=1e-2)") + except AssertionError as err: + print("\n") + print(traceback.format_exc()) + print(err.args[0]) + print("\n") + + # Print first few logits for manual inspection. + print("\n-> KerasHub logits (first 5 of last token):") + print(f" {keras_hub_logits[0, -1, :5]}") + print("-> HF logits (first 5 of last token):") + print(f" {hf_output_logits[0, -1, :5]}") + + +def test_tokenizer(keras_hub_tokenizer, hf_tokenizer): + """Compare tokenizer outputs.""" + test_text = "What is Keras?" + hf_output = hf_tokenizer([test_text], return_tensors="pt") + hf_output = hf_output["input_ids"].detach().cpu().numpy() + + keras_hub_preprocessor = keras_hub.models.Qwen3_5CausalLMPreprocessor( + keras_hub_tokenizer + ) + seq_len = hf_output.shape[1] + keras_hub_output = keras_hub_preprocessor( + [test_text], sequence_length=seq_len + ) + keras_hub_output = ops.convert_to_numpy(keras_hub_output[0]["token_ids"]) + + print(f"\n-> HF token ids: {hf_output[0]}") + print(f"-> KerasHub token ids: {keras_hub_output[0]}") + + try: + np.testing.assert_equal(keras_hub_output, hf_output) + print("-> Tokenizer outputs match!") + except AssertionError as err: + print(f"WARNING: Tokenizer mismatch: {err}") + + +def validate_output(keras_model, hf_model, hf_tokenizer): + """Compare greedy generation outputs.""" + input_str = "What is Keras?" + length = 32 + + print(f"\n-> Generating with max_length={length}...") + + # KerasHub generation. + keras_output = keras_model.generate([input_str], max_length=length) + keras_output = keras_output[0] + print(f"\n KerasHub output: {keras_output}") + + # HF generation. + hf_inputs = hf_tokenizer([input_str], return_tensors="pt") + outputs = hf_model.generate( + **hf_inputs, + max_length=length, + do_sample=False, + num_beams=1, + pad_token_id=hf_tokenizer.pad_token_id, + ) + print(f" HF generated token ids: {outputs[0]}") + hf_generated_text = hf_tokenizer.batch_decode( + outputs, skip_special_tokens=True + )[0] + print(f" HF output: {hf_generated_text}") + + # Compare. + if keras_output.strip() == hf_generated_text.strip(): + print("\n-> Generated text MATCHES!") + else: + print( + "\n-> Generated text DIFFERS (may be expected for " + "long sequences due to floating point drift)" + ) + + +def main(_): + # === Validate preset === + if FLAGS.preset not in PRESET_MAP.keys(): + raise ValueError( + f"Invalid preset {FLAGS.preset}. Must be one " + f"of {','.join(PRESET_MAP.keys())}" + ) + preset = FLAGS.preset + hf_preset = PRESET_MAP[preset] + + print(f"=== Converting {preset} ({hf_preset}) ===") + + # === Load HF model === + print("\n-> Loading HF model...") + hf_model = AutoModelForCausalLM.from_pretrained( + hf_preset, + device_map=device, + torch_dtype=torch.float32, + ) + hf_tokenizer = AutoTokenizer.from_pretrained(hf_preset, return_tensors="pt") + hf_model.eval() + print(f" HF model loaded: {hf_model.num_parameters():,} params") + + # === Load KerasHub model === + print("\n-> Loading KerasHub model from HF preset...") + keras_hub_backbone = keras_hub.models.Qwen3_5Backbone.from_preset( + f"hf://{hf_preset}" + ) + keras_hub_tokenizer = keras_hub.models.Qwen3_5Tokenizer.from_preset( + f"hf://{hf_preset}" + ) + print(" KerasHub model loaded!") + + # === Run comparisons === + print("\n" + "=" * 50) + print("TOKENIZER COMPARISON") + print("=" * 50) + test_tokenizer(keras_hub_tokenizer, hf_tokenizer) + + print("\n" + "=" * 50) + print("MODEL COMPARISON") + print("=" * 50) + test_model( + keras_hub_backbone, + keras_hub_tokenizer, + hf_model, + hf_tokenizer, + ) + + print("\n" + "=" * 50) + print("GENERATION COMPARISON") + print("=" * 50) + preprocessor = keras_hub.models.Qwen3_5CausalLMPreprocessor( + keras_hub_tokenizer + ) + qwen3_5_lm = keras_hub.models.Qwen3_5CausalLM( + backbone=keras_hub_backbone, + preprocessor=preprocessor, + sampler="greedy", + ) + validate_output(qwen3_5_lm, hf_model, hf_tokenizer) + + # === Save preset === + output_dir = f"./{preset}" + print(f"\n-> Saving KerasHub preset to {output_dir}...") + qwen3_5_lm.save_to_preset(output_dir) + print(f" Preset saved to {output_dir}") + + print("\n=== Done! ===") + + +if __name__ == "__main__": + flags.mark_flag_as_required("preset") + app.run(main) From ba922603c37e48404c7da15d93b7ddc0230a7a26 Mon Sep 17 00:00:00 2001 From: Laxma Reddy Patlolla Date: Wed, 11 Mar 2026 23:09:48 -0700 Subject: [PATCH 2/3] Fix HF and KH generation output match --- .../src/models/qwen3_5/qwen3_5_causal_lm.py | 107 +++++++-- .../src/models/qwen3_5/qwen3_5_decoder.py | 10 + .../models/qwen3_5/qwen3_5_gated_delta_net.py | 220 ++++++++++++++---- .../src/utils/transformers/convert_qwen3_5.py | 10 +- 4 files changed, 273 insertions(+), 74 deletions(-) diff --git a/keras_hub/src/models/qwen3_5/qwen3_5_causal_lm.py b/keras_hub/src/models/qwen3_5/qwen3_5_causal_lm.py index a70420f7ff..ba0460c6fe 100644 --- a/keras_hub/src/models/qwen3_5/qwen3_5_causal_lm.py +++ b/keras_hub/src/models/qwen3_5/qwen3_5_causal_lm.py @@ -50,6 +50,7 @@ def call_with_cache( token_ids, cache, cache_update_index, + padding_mask=None, ): """Forward pass with cache for autoregressive decoding. @@ -66,51 +67,117 @@ def call_with_cache( (logits, hidden_states, cache) tuple. """ x = self.backbone.token_embedding(token_ids) - updated_cache = [] + + # We need three separate lists because XLA requires tensors to + # have consistent shapes. KV cache, Conv cache, and Recurrent cache + # all have completely different shapes and cannot be stacked. + kv_cache = cache[0] + conv_cache = cache[1] + recurrent_cache = cache[2] + + next_kv_cache = [] + next_conv_cache = [] + next_recurrent_cache = [] + for i in range(self.backbone.num_layers): layer = self.backbone.transformer_layers[i] if layer.layer_type == "full_attention": - current_cache = cache[:, i, ...] - x, next_cache = layer( + current_kv = kv_cache[:, i, ...] + x, next_kv = layer( x, - self_attention_cache=current_cache, - self_attention_cache_update_index=(cache_update_index), + decoder_padding_mask=padding_mask, + self_attention_cache=current_kv, + self_attention_cache_update_index=cache_update_index, ) - updated_cache.append(next_cache) + next_kv_cache.append(next_kv) + + # Append placeholders for linear attention + next_conv_cache.append(conv_cache[:, i, ...]) + next_recurrent_cache.append(recurrent_cache[:, i, ...]) else: - # Linear attention layers don't use KV cache. - x = layer(x) - # Append a zero placeholder to keep cache shape. - updated_cache.append(cache[:, i, ...]) - cache = ops.stack(updated_cache, axis=1) + # Linear attention (GatedDeltaNet) + current_conv = conv_cache[:, i, ...] + current_recurrent = recurrent_cache[:, i, ...] + + x, next_conv, next_recurrent = layer( + x, + decoder_padding_mask=padding_mask, + self_attention_cache=(current_conv, current_recurrent), + self_attention_cache_update_index=cache_update_index, + ) + next_conv_cache.append(next_conv) + next_recurrent_cache.append(next_recurrent) + + # Append placeholder for full attention + next_kv_cache.append(kv_cache[:, i, ...]) + + # Stack caches along the layer dimension + next_cache = ( + ops.stack(next_kv_cache, axis=1), + ops.stack(next_conv_cache, axis=1), + ops.stack(next_recurrent_cache, axis=1), + ) + hidden_states = x = self.backbone.layer_norm(x) logits = self.backbone.token_embedding(x, reverse=True) - return logits, hidden_states, cache + return logits, hidden_states, next_cache - def _build_cache(self, token_ids): + def _build_cache(self, token_ids, padding_mask): """Build an empty cache for use with ``call_with_cache()``.""" batch_size = ops.shape(token_ids)[0] max_length = ops.shape(token_ids)[1] num_layers = self.backbone.num_layers - num_key_value_heads = self.backbone.num_key_value_heads + num_kv_heads = self.backbone.num_key_value_heads head_dim = self.backbone.head_dim - shape = [ + + # KV Cache shape: (B, num_layers, 2, seq_len, num_kv_heads, head_dim) + kv_shape = [ batch_size, num_layers, 2, max_length, - num_key_value_heads, + num_kv_heads, head_dim, ] - cache = ops.zeros(shape, dtype=self.compute_dtype) - _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) + kv_cache = ops.zeros(kv_shape, dtype=self.compute_dtype) + + # Conv cache shape: (B, num_layers, conv_dim, conv_kernel_size - 1) + linear_key_dim = ( + self.backbone.linear_num_key_heads + * self.backbone.linear_key_head_dim + ) + linear_val_dim = ( + self.backbone.linear_num_value_heads + * self.backbone.linear_value_head_dim + ) + conv_dim = linear_key_dim * 2 + linear_val_dim + conv_shape = [ + batch_size, + num_layers, + conv_dim, + self.backbone.linear_conv_kernel_dim - 1, + ] + conv_cache = ops.zeros(conv_shape, dtype=self.compute_dtype) + recurrent_shape = [ + batch_size, + num_layers, + self.backbone.linear_num_value_heads, + self.backbone.linear_key_head_dim, + self.backbone.linear_value_head_dim, + ] + recurrent_cache = ops.zeros(recurrent_shape, dtype="float32") + + cache = (kv_cache, conv_cache, recurrent_cache) + _, hidden_states, cache = self.call_with_cache( + token_ids, cache, 0, padding_mask=padding_mask + ) return hidden_states, cache def generate_step(self, inputs, stop_token_ids=None): """A compilable generation function for a single batch.""" token_ids = inputs["token_ids"] padding_mask = inputs["padding_mask"] - hidden_states, cache = self._build_cache(token_ids) + hidden_states, cache = self._build_cache(token_ids, padding_mask) row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) index = ops.min(row_lengths) @@ -119,7 +186,7 @@ def next(prompt, cache, index): batch_size = ops.shape(prompt)[0] prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) logits, hidden_states, cache = self.call_with_cache( - prompt, cache, cache_update_index + prompt, cache, cache_update_index, padding_mask=None ) return ( ops.squeeze(logits, axis=1), diff --git a/keras_hub/src/models/qwen3_5/qwen3_5_decoder.py b/keras_hub/src/models/qwen3_5/qwen3_5_decoder.py index 9d5df2a8f6..cc7d30cd52 100644 --- a/keras_hub/src/models/qwen3_5/qwen3_5_decoder.py +++ b/keras_hub/src/models/qwen3_5/qwen3_5_decoder.py @@ -199,8 +199,12 @@ def call( x = self._linear_attn( x, attention_mask=decoder_padding_mask, + cache=self_attention_cache, + cache_update_index=self_attention_cache_update_index, training=training, ) + if self_attention_cache is not None: + x, self_attention_cache = x elif self.layer_type == "full_attention": self_attention_mask = self._compute_self_attention_mask( decoder_sequence=decoder_sequence, @@ -238,6 +242,12 @@ def call( decoder_output = x + residual if self_attention_cache is not None: + if self.layer_type == "linear_attention": + return ( + decoder_output, + self_attention_cache[0], + self_attention_cache[1], + ) return decoder_output, self_attention_cache return decoder_output diff --git a/keras_hub/src/models/qwen3_5/qwen3_5_gated_delta_net.py b/keras_hub/src/models/qwen3_5/qwen3_5_gated_delta_net.py index 05b165849d..0371586525 100644 --- a/keras_hub/src/models/qwen3_5/qwen3_5_gated_delta_net.py +++ b/keras_hub/src/models/qwen3_5/qwen3_5_gated_delta_net.py @@ -32,32 +32,21 @@ def _causal_conv1d(x, weight, bias=None): Returns: (batch, channels, seq_len) """ - if weight.ndim == 2: - weight = ops.expand_dims(weight, 1) + if weight.ndim == 3: + weight = ops.squeeze(weight, axis=1) kernel_size = ops.shape(weight)[-1] channels = ops.shape(x)[1] - - # Left-pad for causal convolution. x_padded = ops.pad( x, [[0, 0], [0, 0], [kernel_size - 1, 0]], ) + x_cl = ops.transpose(x_padded, (0, 2, 1)) + w = ops.transpose(weight, (1, 0)) + w = ops.expand_dims(w, -1) + out = ops.depthwise_conv(x_cl, w, strides=1, padding="valid") + # out: (batch, seq_len, channels) - # Depthwise conv1d: process each channel independently. - # Reshape for grouped conv: (batch, 1, seq, channels) - x_padded = ops.transpose(x_padded, (0, 2, 1)) - x_padded = ops.expand_dims(x_padded, 1) - - # Weight shape for conv: (kernel_size, 1, channels) - # Flip weight for cross-correlation -> convolution. - w = ops.transpose(weight, (2, 1, 0)) - w = ops.flip(w, axis=0) - - # Use depthwise conv. - out = ops.depthwise_conv(x_padded, w, strides=1, padding="valid") - - # out shape: (batch, seq_len, 1, channels) -> (batch, channels, seq_len) - out = ops.squeeze(out, axis=2) + # Convert back to channels-first. out = ops.transpose(out, (0, 2, 1)) if bias is not None: @@ -71,9 +60,10 @@ def _chunk_gated_delta_rule( value, g, beta, - chunk_size=64, + chunk_size=None, initial_state=None, output_final_state=False, + padding_mask=None, ): """Chunked gated delta rule for training (parallel over chunks). @@ -86,7 +76,6 @@ def _chunk_gated_delta_rule( chunk_size: Chunk size for blocked computation. initial_state: Optional initial recurrent state. output_final_state: Whether to return final state. - Returns: output: (B, seq, num_heads, head_v_dim) final_state: recurrent state or None @@ -142,15 +131,34 @@ def _chunk_gated_delta_rule( v_beta_t = v_beta[:, :, t, :] g_t = g[:, :, t] + # Valid token mask. 1.0 for valid, 0.0 for padding. + if padding_mask is not None: + mask_t = ops.cast(padding_mask[:, t], "float32") + # Reshape to (B, 1, 1, 1) for broadcasting. + mask_t = ops.reshape(mask_t, (-1, 1, 1, 1)) + else: + mask_t = 1.0 + # Decay the state. decay = ops.exp(ops.expand_dims(ops.expand_dims(g_t, -1), -1)) - state = state * decay + + # Keep old state if padded, else apply decay + state_decayed = state * decay + if padding_mask is not None: + state = state * (1.0 - mask_t) + state_decayed * mask_t + else: + state = state_decayed # Delta update: compute what the current state predicts # for v given k, then add correction. kv_pred = ops.sum(state * ops.expand_dims(k_t, -1), axis=-2) delta = v_beta_t - kv_pred * ops.expand_dims(beta[:, :, t], -1) - state = state + ops.expand_dims(k_t, -1) * ops.expand_dims(delta, -2) + + state_update = ops.expand_dims(k_t, -1) * ops.expand_dims(delta, -2) + if padding_mask is not None: + state_update = state_update * mask_t + + state = state + state_update # Query the state. out_t = ops.sum(state * ops.expand_dims(q_t, -1), axis=-2) @@ -175,6 +183,7 @@ def _recurrent_gated_delta_rule( beta, initial_state=None, output_final_state=False, + padding_mask=None, ): """Step-by-step recurrent gated delta rule for inference. @@ -190,6 +199,7 @@ def _recurrent_gated_delta_rule( chunk_size=1, initial_state=initial_state, output_final_state=output_final_state, + padding_mask=padding_mask, ) @@ -281,20 +291,14 @@ def build(self, input_shape): ) self.in_proj_a.build(input_shape) - # Causal conv1d (depthwise). + # Causal conv1d (depthwise). HF uses bias=False. conv_dim = self.key_dim * 2 + self.value_dim self.conv1d_weight = self.add_weight( - name="conv1d/kernel", + name="conv1d_kernel", shape=(conv_dim, self.conv_kernel_size), initializer="glorot_uniform", dtype=self.variable_dtype, ) - self.conv1d_bias = self.add_weight( - name="conv1d/bias", - shape=(conv_dim,), - initializer="zeros", - dtype=self.variable_dtype, - ) # dt_bias and A_log (learnable parameters for decay). self.dt_bias = self.add_weight( @@ -331,16 +335,26 @@ def build(self, input_shape): self.built = True - def call(self, hidden_states, attention_mask=None, training=None): + def call( + self, + hidden_states, + attention_mask=None, + cache=None, + cache_update_index=None, + training=None, + ): """Forward pass. Args: hidden_states: (B, seq_len, hidden_size) attention_mask: Optional padding mask. + cache: Tuple of (conv_state, recurrent_state) + cache_update_index: Current generation step index. training: Whether in training mode. Returns: output: (B, seq_len, hidden_size) + cache: (optional) Updated tuple of (conv_state, recurrent_state) """ # Mask padding states. if attention_mask is not None: @@ -370,12 +384,89 @@ def call(self, hidden_states, attention_mask=None, training=None): # Causal conv1d on QKV. # Transpose to (B, channels, seq_len) for conv. mixed_qkv_t = ops.transpose(mixed_qkv, (0, 2, 1)) - mixed_qkv_t = _causal_conv1d( - mixed_qkv_t, self.conv1d_weight, self.conv1d_bias - ) - # Apply SiLU activation after conv. - mixed_qkv_t = mixed_qkv_t * ops.sigmoid(mixed_qkv_t) - mixed_qkv = ops.transpose(mixed_qkv_t, (0, 2, 1)) + + if cache is not None: + # Autoregressive generation. + conv_state, recurrent_state = cache + if seq_len > 1: + combined_state = ops.concatenate( + [conv_state, mixed_qkv_t], axis=-1 + ) + + if attention_mask is not None: + valid_lengths = ops.sum( + ops.cast(attention_mask, "int32"), axis=-1 + ) + indices = ops.expand_dims( + valid_lengths + self.conv_kernel_size - 2, axis=-1 + ) + # We want range(indices - kernel_size + 2, indices + 1) + offsets = ops.arange( + self.conv_kernel_size - 1, dtype="int32" + ) + # offsets: (-kernel_size+2, ..., 0) + offsets = offsets - (self.conv_kernel_size - 2) + gather_indices = indices + ops.expand_dims(offsets, axis=0) + gather_indices = ops.expand_dims( + gather_indices, axis=1 + ) # (B, 1, kernel) + gather_indices = ops.repeat( + gather_indices, ops.shape(combined_state)[1], axis=1 + ) # (B, channels, kernel) + + conv_state = ops.take_along_axis( + combined_state, gather_indices, axis=2 + ) + else: + conv_state = combined_state[ + :, :, -(self.conv_kernel_size - 1) : + ] + padded_input = ops.concatenate( + [ + cache[0][:, :, -(self.conv_kernel_size - 1) :], + mixed_qkv_t, + ], + axis=-1, + ) + + # Use depthwise_conv to process the padded sequence. + padded_input_transposed = ops.transpose(padded_input, (0, 2, 1)) + conv1d_weight_transposed = ops.transpose( + self.conv1d_weight, (1, 0) + ) + conv1d_weight_expanded = ops.expand_dims( + conv1d_weight_transposed, -1 + ) + + mixed_qkv_t = ops.depthwise_conv( + padded_input_transposed, + conv1d_weight_expanded, + strides=1, + padding="valid", + ) + mixed_qkv_t = ops.transpose(mixed_qkv_t, (0, 2, 1)) + + else: + sliding_window = ops.concatenate( + [conv_state, mixed_qkv_t], axis=-1 + ) + conv_state = sliding_window[:, :, 1:] + conv1d_weight_expanded = ops.expand_dims(self.conv1d_weight, 0) + mixed_qkv_t = ops.sum( + sliding_window * conv1d_weight_expanded, + axis=-1, + keepdims=True, + ) + + # Apply SiLU activation after conv. + mixed_qkv_t = mixed_qkv_t * ops.sigmoid(mixed_qkv_t) + mixed_qkv = ops.transpose(mixed_qkv_t, (0, 2, 1)) + else: + # Full sequence processing (training or non-cached score mode). + mixed_qkv_t = _causal_conv1d(mixed_qkv_t, self.conv1d_weight) + # Apply SiLU activation after conv. + mixed_qkv_t = mixed_qkv_t * ops.sigmoid(mixed_qkv_t) + mixed_qkv = ops.transpose(mixed_qkv_t, (0, 2, 1)) # Split QKV. query, key, value = ops.split( @@ -411,26 +502,59 @@ def call(self, hidden_states, attention_mask=None, training=None): query = ops.repeat(query, repeats=repeat_factor, axis=2) key = ops.repeat(key, repeats=repeat_factor, axis=2) - # Apply gated delta rule. - core_out, _ = _chunk_gated_delta_rule( - query, - key, - value, - g=g, - beta=beta, - output_final_state=False, - ) + if cache is not None: + if seq_len > 1: + # Prompt initialization loop. + # Use chunked delta rule but return the final state! + core_out, last_recurrent_state = _chunk_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=True, + padding_mask=attention_mask, + ) + cache = (conv_state, last_recurrent_state) + else: + # Step generation using recurrent rule. + core_out, last_recurrent_state = _recurrent_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=True, + padding_mask=attention_mask, + ) + cache = (conv_state, last_recurrent_state) + else: + # Apply chunked sequence gated delta rule. + core_out, _ = _chunk_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + output_final_state=False, + padding_mask=attention_mask, + ) # Output gated normalization. # Reshape to (B * seq * heads, v_dim) for norm. core_out_flat = ops.reshape(core_out, (-1, self.head_v_dim)) z_flat = ops.reshape(z, (-1, self.head_v_dim)) core_out_flat = self.norm(core_out_flat) - core_out_flat = core_out_flat * ops.sigmoid(z_flat) + core_out_flat = core_out_flat * keras.activations.silu(z_flat) # Reshape back and project. core_out = ops.reshape(core_out_flat, (batch_size, seq_len, -1)) output = self.out_proj(core_out) + + if cache is not None: + return output, cache return output def get_config(self): diff --git a/keras_hub/src/utils/transformers/convert_qwen3_5.py b/keras_hub/src/utils/transformers/convert_qwen3_5.py index 05cd37fda1..b44ac209c8 100644 --- a/keras_hub/src/utils/transformers/convert_qwen3_5.py +++ b/keras_hub/src/utils/transformers/convert_qwen3_5.py @@ -160,11 +160,6 @@ def transpose_and_reshape(x, shape): hf_weight_key=f"{prefix}.linear_attn.conv1d.weight", hook_fn=lambda hf_tensor, _: np.squeeze(hf_tensor, axis=1), ) - # Conv1d bias. - loader.port_weight( - keras_variable=gdn.conv1d_bias, - hf_weight_key=f"{prefix}.linear_attn.conv1d.bias", - ) # dt_bias. loader.port_weight( keras_variable=gdn.dt_bias, @@ -179,6 +174,7 @@ def transpose_and_reshape(x, shape): loader.port_weight( keras_variable=gdn.norm.scale, hf_weight_key=f"{prefix}.linear_attn.norm.weight", + hook_fn=lambda hf_tensor, _: hf_tensor - 1.0, ) # Output projection. loader.port_weight( @@ -227,7 +223,9 @@ def convert_tokenizer(cls, preset, **kwargs): tokenizer_config = load_json(preset, "tokenizer.json") vocab = tokenizer_config["model"]["vocab"] merges = tokenizer_config["model"]["merges"] - merges = [" ".join(item) for item in merges] + # Merges may be lists (["Ġ", "a"]) or already strings ("Ġ a"). + if merges and isinstance(merges[0], list): + merges = [" ".join(item) for item in merges] # Load all special tokens except "reserved" ones. special_tokens = set() From 2ceddd206f1ac1ee8cdce1546bfa0c44d3182401 Mon Sep 17 00:00:00 2001 From: Laxma Reddy Patlolla Date: Wed, 11 Mar 2026 23:56:22 -0700 Subject: [PATCH 3/3] Improve SafetensorLoader key mapping robustness --- keras_hub/src/utils/transformers/safetensor_utils.py | 9 ++++++++- .../checkpoint_conversion/convert_qwen3_5_checkpoints.py | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/keras_hub/src/utils/transformers/safetensor_utils.py b/keras_hub/src/utils/transformers/safetensor_utils.py index 24c8dff338..99cacab8a0 100644 --- a/keras_hub/src/utils/transformers/safetensor_utils.py +++ b/keras_hub/src/utils/transformers/safetensor_utils.py @@ -53,8 +53,15 @@ def get_prefixed_key(self, hf_weight_key, dict_like): Returns: str: The full key including the prefix (if any). """ + # Check for exact matches first to handle mixed + # root/nested weight maps. + if hf_weight_key in dict_like.keys(): + return hf_weight_key + if self.prefix is not None: - return self.prefix + hf_weight_key + full_key = self.prefix + hf_weight_key + if full_key in dict_like.keys(): + return full_key for full_key in dict_like.keys(): if full_key.endswith(hf_weight_key) and full_key != hf_weight_key: diff --git a/tools/checkpoint_conversion/convert_qwen3_5_checkpoints.py b/tools/checkpoint_conversion/convert_qwen3_5_checkpoints.py index e039cc081c..2e52127418 100644 --- a/tools/checkpoint_conversion/convert_qwen3_5_checkpoints.py +++ b/tools/checkpoint_conversion/convert_qwen3_5_checkpoints.py @@ -46,7 +46,7 @@ "qwen3_5_9b_base": "Qwen/Qwen3.5-9B-Base", "qwen3_5_9b": "Qwen/Qwen3.5-9B", "qwen3_5_27b": "Qwen/Qwen3.5-27B", - "qwen3_5_35b_a3b_base": "Qwen/Qwen3.5-35B-A3B_Base", + "qwen3_5_35b_a3b_base": "Qwen/Qwen3.5-35B-A3B-Base", "qwen3_5_35b_a3b": "Qwen/Qwen3.5-35B-A3B", }