diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index 7ed9b7060d..14dfd03199 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -123,6 +123,12 @@ from keras_hub.src.models.parseq.parseq_image_converter import ( PARSeqImageConverter as PARSeqImageConverter, ) +from keras_hub.src.models.qwen3_omni.qwen3_omni_audio_converter import ( + Qwen3OmniAudioConverter as Qwen3OmniAudioConverter, +) +from keras_hub.src.models.qwen3_omni.qwen3_omni_image_converter import ( + Qwen3OmniImageConverter as Qwen3OmniImageConverter, +) from keras_hub.src.models.resnet.resnet_image_converter import ( ResNetImageConverter as ResNetImageConverter, ) diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 5d1126a40d..12ded4ffdd 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -566,6 +566,21 @@ from keras_hub.src.models.qwen3_moe.qwen3_moe_causal_lm_preprocessor import ( Qwen3MoeCausalLMPreprocessor as Qwen3MoeCausalLMPreprocessor, ) +from keras_hub.src.models.qwen3_omni.qwen3_omni_audio_encoder import ( + Qwen3OmniAudioEncoder as Qwen3OmniAudioEncoder, +) +from keras_hub.src.models.qwen3_omni.qwen3_omni_backbone import ( + Qwen3OmniBackbone as Qwen3OmniBackbone, +) +from keras_hub.src.models.qwen3_omni.qwen3_omni_causal_lm import ( + Qwen3OmniCausalLM as Qwen3OmniCausalLM, +) +from keras_hub.src.models.qwen3_omni.qwen3_omni_causal_lm_preprocessor import ( + Qwen3OmniCausalLMPreprocessor as Qwen3OmniCausalLMPreprocessor, +) +from keras_hub.src.models.qwen3_omni.qwen3_omni_vision_encoder import ( + Qwen3OmniVisionEncoder as Qwen3OmniVisionEncoder, +) from keras_hub.src.models.qwen_moe.qwen_moe_backbone import ( QwenMoeBackbone as QwenMoeBackbone, ) diff --git a/keras_hub/api/tokenizers/__init__.py b/keras_hub/api/tokenizers/__init__.py index d3fe455d7c..6ac2409035 100644 --- a/keras_hub/api/tokenizers/__init__.py +++ b/keras_hub/api/tokenizers/__init__.py @@ -84,6 +84,9 @@ from keras_hub.src.models.qwen3_moe.qwen3_moe_tokenizer import ( Qwen3MoeTokenizer as Qwen3MoeTokenizer, ) +from keras_hub.src.models.qwen3_omni.qwen3_omni_tokenizer import ( + Qwen3OmniTokenizer as Qwen3OmniTokenizer, +) from keras_hub.src.models.qwen_moe.qwen_moe_tokenizer import ( QwenMoeTokenizer as QwenMoeTokenizer, ) diff --git a/keras_hub/src/models/qwen3_omni/__init__.py b/keras_hub/src/models/qwen3_omni/__init__.py new file mode 100644 index 0000000000..cccffcc1c6 --- /dev/null +++ b/keras_hub/src/models/qwen3_omni/__init__.py @@ -0,0 +1,7 @@ +from keras_hub.src.models.qwen3_omni.qwen3_omni_backbone import ( + Qwen3OmniBackbone, +) +from keras_hub.src.models.qwen3_omni.qwen3_omni_presets import backbone_presets +from keras_hub.src.utils.preset_utils import register_presets + +register_presets(backbone_presets, Qwen3OmniBackbone) diff --git a/keras_hub/src/models/qwen3_omni/qwen3_omni_attention.py b/keras_hub/src/models/qwen3_omni/qwen3_omni_attention.py new file mode 100644 index 0000000000..8c906414ab --- /dev/null +++ b/keras_hub/src/models/qwen3_omni/qwen3_omni_attention.py @@ -0,0 +1,423 @@ +import math + +import keras +from keras import ops + +from keras_hub.src.models.qwen3_moe.qwen3_moe_layernorm import Qwen3MoeLayerNorm +from keras_hub.src.models.qwen3_omni.qwen3_omni_rope import ( + MultimodalRotaryEmbedding, +) +from keras_hub.src.utils.keras_utils import clone_initializer +from keras_hub.src.utils.keras_utils import fused_attention_op_available + + +class Qwen3OmniAttention(keras.layers.Layer): + """Multi-head attention with Multimodal RoPE for Qwen3-Omni. + + This attention layer implements: + - Grouped Query Attention (GQA) for efficiency + - Multimodal Rotary Position Embedding (M-RoPE) for multimodal inputs + - Query-Key normalization for training stability + - Optional sliding window attention + + The M-RoPE divides the head dimension into 3 sections (24, 20, 20) for + text, temporal, and spatial position encodings respectively. + + Args: + num_query_heads: int. Number of query heads. + num_key_value_heads: int. Number of key/value heads (for GQA). + head_dim: int. The dimension of each attention head. + mrope_section: tuple of 3 ints. Dimension allocation for M-RoPE + (text, temporal, spatial). Defaults to (24, 20, 20). + rope_max_wavelength: int. Maximum wavelength for M-RoPE. + Defaults to 1000000. + rope_scaling_factor: float. Scaling factor for M-RoPE. Defaults to 1.0. + kernel_initializer: Initializer for kernel weights. + dropout: float. Dropout rate for attention weights. + layer_norm_epsilon: float. Epsilon for layer normalization. + sliding_window_size: int or None. Size of sliding window. + Defaults to None. + **kwargs: Additional keyword arguments to pass to the layer + """ + + def __init__( + self, + num_query_heads, + num_key_value_heads, + head_dim=None, + mrope_section=(24, 20, 20), + rope_max_wavelength=1000000, + rope_scaling_factor=1.0, + rope_attention_scaling=1.0, + kernel_initializer="glorot_uniform", + dropout=0.0, + layer_norm_epsilon=1e-6, + sliding_window_size=None, + **kwargs, + ): + super().__init__(**kwargs) + + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.mrope_section = tuple(mrope_section) + self.dropout = dropout + self.layer_norm_epsilon = layer_norm_epsilon + self.num_key_value_groups = num_query_heads // num_key_value_heads + self.rope_max_wavelength = rope_max_wavelength + self.rope_scaling_factor = rope_scaling_factor + self.rope_attention_scaling = rope_attention_scaling + self.kernel_initializer = keras.initializers.get( + clone_initializer(kernel_initializer) + ) + self.sliding_window_size = sliding_window_size + + def build(self, inputs_shape): + # Einsum variables: + # b = batch size + # q = query length + # k = key/value length + # m = model dim + # u = num query heads + # v = num key/value heads + # h = head dim + hidden_dim = inputs_shape[-1] + if not self.head_dim: + self.head_dim = hidden_dim // self.num_query_heads + + self._inv_norm_factor = 1.0 / math.sqrt(self.head_dim) + + # Query projection with EinsumDense + self._query_dense = keras.layers.EinsumDense( + equation="bqm,muh->bquh", + output_shape=(None, self.num_query_heads, self.head_dim), + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="query", + ) + self._query_dense.build(inputs_shape) + + # Query normalization (QK norm) + self._query_dense_layer_norm = Qwen3MoeLayerNorm( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + head_dim=self.head_dim, + name="query_dense_layernorm", + ) + self._query_dense_layer_norm.build(inputs_shape) + + # Key projection + self._key_dense = keras.layers.EinsumDense( + equation="bkm,mvh->bkvh", + output_shape=(None, self.num_key_value_heads, self.head_dim), + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="key", + ) + self._key_dense.build(inputs_shape) + + # Key normalization (QK norm) + self._key_dense_layer_norm = Qwen3MoeLayerNorm( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + head_dim=self.head_dim, + name="key_dense_layernorm", + ) + self._key_dense_layer_norm.build(inputs_shape) + + # Value projection + self._value_dense = keras.layers.EinsumDense( + equation="bkm,mvh->bkvh", + output_shape=(None, self.num_key_value_heads, self.head_dim), + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="value", + ) + self._value_dense.build(inputs_shape) + + # Softmax and dropout + self._softmax = keras.layers.Softmax( + axis=-1, + dtype="float32", + name="attention_softmax", + ) + + self._dropout_layer = keras.layers.Dropout( + rate=self.dropout, + dtype=self.dtype_policy, + ) + + # Output projection + self._output_dense = keras.layers.EinsumDense( + equation="bquh,uhm->bqm", + output_shape=(None, hidden_dim), + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="attention_output", + ) + self._output_dense.build( + (None, None, self.num_query_heads, self.head_dim) + ) + + # Multimodal RoPE + self.multimodal_rotary_embedding = MultimodalRotaryEmbedding( + mrope_section=self.mrope_section, + max_wavelength=self.rope_max_wavelength, + scaling_factor=self.rope_scaling_factor, + attention_scaling=self.rope_attention_scaling, + dtype=self.dtype_policy, + ) + + self._dot_product_equation = "bquh,bkuh->buqk" + self._combine_equation = "buqk,bkuh->bquh" + + self.built = True + + def call( + self, + hidden_states, + position_ids=None, + attention_mask=None, + cache=None, + cache_update_index=None, + training=None, + ): + """Forward pass of multi-head attention with M-RoPE. + + Args: + hidden_states: Input tensor of shape (batch, seq_len, hidden_dim). + position_ids: Position IDs of shape (3, batch, seq_len) + for M-RoPE, where position_ids[0] = text positions, + position_ids[1] = temporal positions, + position_ids[2] = spatial positions. If None, creates + default sequential positions. + attention_mask: Attention mask of shape (batch, seq_len, seq_len). + cache: Optional cached key and value tensors. + cache_update_index: Index for cache update. + training: Boolean indicating training mode. + + Returns: + attention_output: Output tensor after applying attention. + cache: Updated cache tensors (if cache is provided). + """ + batch_size = ops.shape(hidden_states)[0] + seq_len = ops.shape(hidden_states)[1] + + if position_ids is None: + text_positions = ops.arange(seq_len, dtype="int32") + text_positions = ops.expand_dims(text_positions, axis=0) + text_positions = ops.repeat(text_positions, batch_size, axis=0) + position_ids = ops.stack( + [text_positions, text_positions, text_positions], axis=0 + ) + + # Project to Q, K, V + query = self._query_dense(hidden_states) + query = self._query_dense_layer_norm(query) + + def _compute_key_value(x): + key = self._key_dense(x) + key = self._key_dense_layer_norm(key) + value = self._value_dense(x) + return key, value + + if cache is not None: + key_cache = cache[:, 0, ...] + value_cache = cache[:, 1, ...] + if cache_update_index is None: + # Use cached keys/values (already have M-RoPE applied) + key = key_cache + value = value_cache + # Still need to apply M-RoPE to query + query, _ = ( + self.multimodal_rotary_embedding.apply_multimodal_rotary_embedding( + query, + key[:, :1, :, :], + position_ids, + ) + ) + else: + key_update, value_update = _compute_key_value(hidden_states) + + query, key_update = ( + self.multimodal_rotary_embedding.apply_multimodal_rotary_embedding( + query, key_update, position_ids + ) + ) + + # Update cache with new key/value + start = [0, cache_update_index, 0, 0] + key = ops.slice_update(key_cache, start, key_update) + value = ops.slice_update(value_cache, start, value_update) + cache = ops.stack((key, value), axis=1) + else: + if cache_update_index is not None: + raise ValueError( + "`cache_update_index` should not be set if `cache` is " + f"`None`. Received: cache={cache}, " + f"cache_update_index={cache_update_index}" + ) + key, value = _compute_key_value(hidden_states) + query, key = ( + self.multimodal_rotary_embedding.apply_multimodal_rotary_embedding( + query, key, position_ids + ) + ) + + # Repeat K, V for GQA: (batch, seq_len, num_kv_heads, head_dim) + # -> (batch, seq_len, num_query_heads, head_dim) + key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2) + value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2) + + attention_output = self._compute_attention( + query, + key, + value, + attention_mask, + cache_update_index=cache_update_index, + ) + + attention_output = self._dropout_layer( + attention_output, training=training + ) + attention_output = self._output_dense(attention_output) + + if cache is not None: + return attention_output, cache + return attention_output + + def _masked_softmax(self, attention_scores, attention_mask=None): + """Applies softmax with optional masking. + Args: + attention_scores: Attention score tensor. + attention_mask: Optional mask tensor. + + Returns: + Masked softmax attention weights. + """ + if attention_mask is not None: + return self._softmax( + attention_scores, attention_mask[:, None, :, :] + ) + return self._softmax(attention_scores) + + def _compute_attention( + self, query, key, value, attention_mask=None, cache_update_index=None + ): + """Computes attention using query, key, and value tensors. + + Args: + query: Query tensor. + key: Key tensor. + value: Value tensor. + attention_mask: Optional mask tensor. + cache_update_index: Index for sliding window computation. + + Returns: + attention_output: Output tensor after applying attention. + """ + # Apply sliding window mask if configured + if self.sliding_window_size: + if attention_mask is None: + query_len = ops.shape(query)[1] + key_len = ops.shape(key)[1] + + if cache_update_index is not None: + causal_mask = ops.arange(key_len) <= ( + cache_update_index + query_len - 1 + ) + causal_mask = ops.cast(causal_mask, dtype="bool") + attention_mask = ops.reshape(causal_mask, (1, key_len)) + attention_mask = ops.broadcast_to( + attention_mask, (query_len, key_len) + ) + else: + attention_mask = ops.tril( + ops.ones((query_len, key_len), dtype="bool") + ) + attention_mask = ops.expand_dims(attention_mask, 0) + attention_mask = self._mask_sliding_window( + attention_mask, + cache_update_index=cache_update_index + if cache_update_index is not None + else 0, + ) + + if fused_attention_op_available(): + if attention_mask is not None: + attention_mask = ops.expand_dims(attention_mask, axis=1) + attention_mask = ops.cast(attention_mask, dtype="bool") + attention_output = ops.dot_product_attention( + query, + key, + value, + mask=attention_mask, + scale=self._inv_norm_factor, + ) + return attention_output + + attention_scores = ops.einsum(self._dot_product_equation, query, key) + attention_scores = ops.multiply( + attention_scores, + ops.cast(self._inv_norm_factor, self.compute_dtype), + ) + + attention_scores = self._masked_softmax( + attention_scores, attention_mask + ) + attention_scores = ops.cast(attention_scores, self.compute_dtype) + attention_output = ops.einsum( + self._combine_equation, attention_scores, value + ) + + return attention_output + + def _mask_sliding_window(self, attention_mask, cache_update_index=0): + """Creates and combines a sliding window mask with the attention mask. + Args: + attention_mask: Original attention mask. + cache_update_index: Starting index for the sliding window. + + Returns: + Combined attention mask with sliding window constraints. + """ + _, query_len, key_len = ops.shape(attention_mask) + all_ones = ops.ones((key_len, key_len), "bool") + + if keras.config.backend() == "tensorflow": + # TODO carried over from qwen3moe + import tensorflow as tf + + band_size = ops.minimum(key_len, self.sliding_window_size - 1) + band_size = ops.cast(band_size, "int32") + sliding_mask = tf.linalg.band_part(all_ones, band_size, band_size) + else: + sliding_mask = ops.triu( + all_ones, -1 * self.sliding_window_size + 1 + ) * ops.tril(all_ones, self.sliding_window_size - 1) + + start = (cache_update_index, 0) + sliding_mask = ops.slice(sliding_mask, start, (query_len, key_len)) + sliding_mask = ops.expand_dims(sliding_mask, 0) + return ops.logical_and(attention_mask, ops.cast(sliding_mask, "bool")) + + def get_config(self): + config = super().get_config() + config.update( + { + "num_query_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "head_dim": self.head_dim, + "mrope_section": self.mrope_section, + "rope_max_wavelength": self.rope_max_wavelength, + "rope_scaling_factor": self.rope_scaling_factor, + "rope_attention_scaling": self.rope_attention_scaling, + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "dropout": self.dropout, + "layer_norm_epsilon": self.layer_norm_epsilon, + "sliding_window_size": self.sliding_window_size, + } + ) + return config diff --git a/keras_hub/src/models/qwen3_omni/qwen3_omni_audio_converter.py b/keras_hub/src/models/qwen3_omni/qwen3_omni_audio_converter.py new file mode 100644 index 0000000000..4f07b7e1ce --- /dev/null +++ b/keras_hub/src/models/qwen3_omni/qwen3_omni_audio_converter.py @@ -0,0 +1,203 @@ +import numpy as np +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.audio_converter import AudioConverter +from keras_hub.src.models.qwen3_omni.qwen3_omni_backbone import ( + Qwen3OmniBackbone, +) + + +@keras_hub_export("keras_hub.layers.Qwen3OmniAudioConverter") +class Qwen3OmniAudioConverter(AudioConverter): + """Audio preprocessing for Qwen3-Omni. + + Converts raw audio to log-mel spectrogram features compatible with the + Qwen3-Omni audio encoder. This uses log-mel spectrogram via STFT + feature extraction. + + Args: + num_mels: int. The number of mel-frequency filters. Defaults to + `128`. + num_fft_bins: int. The size of the Fourier Transform in STFT. + Defaults to `400`. + stride: int. The distance between neighboring sliding window + frames while computing STFT. Defaults to `160`. + sampling_rate: int. The sample rate of the audio. Defaults to + `16000`. + max_audio_length: int. The length of each audio chunk in + seconds. The input audio tensor will be padded/trimmed to + `max_audio_length * sampling_rate`. Defaults to `300`. + + Examples: + ```python + converter = keras_hub.layers.Qwen3OmniAudioConverter.from_preset( + "qwen3_omni_instruct" + ) + audio = np.ones((8000,), dtype="float32") + mel_features = converter(audio) + ``` + """ + + backbone_cls = Qwen3OmniBackbone + + def __init__( + self, + num_mels=128, + num_fft_bins=400, + stride=160, + sampling_rate=16000, + max_audio_length=300, + **kwargs, + ): + super().__init__(**kwargs) + + self._convert_input_args = False + self._allow_non_tensor_positional_args = True + self.built = True + + self.num_mels = num_mels + self.num_fft_bins = num_fft_bins + self.stride = stride + self.sampling_rate = sampling_rate + self.max_audio_length = max_audio_length + self.num_samples = self.sampling_rate * self.max_audio_length + self.mel_filters = self._get_mel_filters() + + def audio_shape(self): + """Returns the preprocessed size of a single audio sample.""" + return (self.max_audio_length, self.num_mels) + + def _get_mel_filters(self): + """Computes the Mel filter bank weights. + + Returns: + A numpy array of shape `(num_fft_bins // 2 + 1, num_mels)` + containing the Mel filter bank weights. + """ + dtype = np.float32 + weights = np.zeros( + (self.num_mels, int(1 + self.num_fft_bins // 2)), dtype=dtype + ) + + # Center freqs of each FFT bin and mel bands. + fftfreqs = np.fft.rfftfreq( + n=self.num_fft_bins, d=1.0 / self.sampling_rate + ) + min_mel = 0.0 + max_mel = 45.245640471924965 + + mels = np.linspace(min_mel, max_mel, self.num_mels + 2) + mels = np.asanyarray(mels) + + # Linear scale. + f_min = 0.0 + f_sp = 200.0 / 3 + freqs = f_min + f_sp * mels + + # Nonlinear (log) scale. + min_log_hz = 1000.0 + min_log_mel = (min_log_hz - f_min) / f_sp + logstep = np.log(6.4) / 27.0 + + log_t = mels >= min_log_mel + freqs[log_t] = min_log_hz * np.exp( + logstep * (mels[log_t] - min_log_mel) + ) + + mel_f = freqs + fdiff = np.diff(mel_f) + ramps = np.subtract.outer(mel_f, fftfreqs) + + for i in range(self.num_mels): + lower = -ramps[i] / fdiff[i] + upper = ramps[i + 2] / fdiff[i + 1] + weights[i] = np.maximum(0, np.minimum(lower, upper)) + + # scale to approx constant energy per channel. + enorm = 2.0 / (mel_f[2 : self.num_mels + 2] - mel_f[: self.num_mels]) + weights *= enorm[:, np.newaxis] + + # Transpose to (num_fft_bins // 2 + 1, num_mels). + return np.transpose(weights) + + def _extract_audio_features(self, audio): + """Compute log-mel spectrogram from audio waveform. + + Uses `keras.ops.stft` with `center=True` which internally applies + reflection padding, matching the Whisper feature extraction pipeline. + + Args: + audio: Float tensor of shape (batch, num_samples). + + Returns: + Log-mel spectrogram of shape (batch, num_frames, num_mels). + """ + audio = ops.cast(audio, self.compute_dtype) + + real, imag = ops.stft( + audio, + sequence_length=self.num_fft_bins, + sequence_stride=self.stride, + fft_length=self.num_fft_bins, + window="hann", + center=True, + ) + + magnitudes = ops.square(real[:, :-1, :]) + ops.square(imag[:, :-1, :]) + + # Apply mel filter bank. + mel_filters = ops.cast( + ops.convert_to_tensor(self.mel_filters), self.compute_dtype + ) + mel_spec = ops.matmul(magnitudes, mel_filters) + + # Log-mel spectrogram with numerical stability. + mel_spec = ops.maximum(mel_spec, 1e-10) + log_spec = ops.log(mel_spec) / ops.log( + ops.cast(ops.convert_to_tensor(10.0), self.compute_dtype) + ) + + # Dynamic range compression. + max_val = ops.max(log_spec, axis=(1, 2)) + max_val_minus_eight = ops.expand_dims(max_val - 8.0, axis=(1, 2)) + log_spec = ops.maximum(log_spec, max_val_minus_eight) + + # Normalization. + log_spec = (log_spec + 4.0) / 4.0 + + return log_spec + + def call(self, audio): + audio = ops.convert_to_tensor(audio, dtype=self.compute_dtype) + + rank_1_input = len(ops.shape(audio)) == 1 + if rank_1_input: + audio = ops.expand_dims(audio, 0) + + current_len = ops.shape(audio)[-1] + if current_len < self.num_samples: + pad_width = [[0, 0], [0, self.num_samples - current_len]] + audio = ops.pad(audio, pad_width, mode="constant") + else: + audio = audio[:, : self.num_samples] + + log_spec = self._extract_audio_features(audio) + + if rank_1_input: + log_spec = ops.squeeze(log_spec, axis=0) + + return log_spec + + def get_config(self): + config = super().get_config() + config.update( + { + "num_mels": self.num_mels, + "num_fft_bins": self.num_fft_bins, + "stride": self.stride, + "sampling_rate": self.sampling_rate, + "max_audio_length": self.max_audio_length, + } + ) + return config diff --git a/keras_hub/src/models/qwen3_omni/qwen3_omni_audio_converter_test.py b/keras_hub/src/models/qwen3_omni/qwen3_omni_audio_converter_test.py new file mode 100644 index 0000000000..1de9990a79 --- /dev/null +++ b/keras_hub/src/models/qwen3_omni/qwen3_omni_audio_converter_test.py @@ -0,0 +1,57 @@ +import numpy as np +import pytest + +from keras_hub.src.models.qwen3_omni.qwen3_omni_audio_converter import ( + Qwen3OmniAudioConverter, +) +from keras_hub.src.tests.test_case import TestCase + + +class Qwen3OmniAudioConverterTest(TestCase): + def setUp(self): + self.init_kwargs = { + "num_mels": 128, + "num_fft_bins": 400, + "stride": 100, + "sampling_rate": 100, + "max_audio_length": 5, + } + + def test_converter_output_shape(self): + converter = Qwen3OmniAudioConverter(**self.init_kwargs) + audio = np.ones((2,), dtype="float32") + output = converter(audio) + self.assertEqual(output.shape[-1], 128) + + def test_converter_batch(self): + converter = Qwen3OmniAudioConverter(**self.init_kwargs) + audio = np.ones((2, 25), dtype="float32") + output = converter(audio) + self.assertEqual(output.shape[0], 2) + self.assertEqual(output.shape[-1], 128) + + def test_config_serialization(self): + converter = Qwen3OmniAudioConverter(**self.init_kwargs) + config = converter.get_config() + self.assertEqual(config["num_mels"], 128) + self.assertEqual(config["num_fft_bins"], 400) + self.assertEqual(config["stride"], 100) + self.assertEqual(config["sampling_rate"], 100) + self.assertEqual(config["max_audio_length"], 5) + + def test_default_values(self): + converter = Qwen3OmniAudioConverter() + self.assertEqual(converter.num_mels, 128) + self.assertEqual(converter.num_fft_bins, 400) + self.assertEqual(converter.stride, 160) + self.assertEqual(converter.sampling_rate, 16000) + self.assertEqual(converter.max_audio_length, 300) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in Qwen3OmniAudioConverter.presets: + self.run_preset_test( + cls=Qwen3OmniAudioConverter, + preset=preset, + input_data=np.ones((800,), dtype="float32"), + ) diff --git a/keras_hub/src/models/qwen3_omni/qwen3_omni_audio_encoder.py b/keras_hub/src/models/qwen3_omni/qwen3_omni_audio_encoder.py new file mode 100644 index 0000000000..af16d69fdb --- /dev/null +++ b/keras_hub/src/models/qwen3_omni/qwen3_omni_audio_encoder.py @@ -0,0 +1,435 @@ +import math + +import keras +import numpy as np +from keras import layers +from keras import ops + +from keras_hub.src.api_export import keras_hub_export + + +def _create_sinusoidal_positions(length, channels, max_timescale=10000): + """Create fixed sinusoidal positional embeddings.""" + + if channels % 2 != 0: + raise ValueError( + "Sinusoidal position embeddings require even channels. " + f"Received channels={channels}" + ) + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = np.exp( + -log_timescale_increment * np.arange(channels // 2, dtype=np.float32) + ) + scaled_time = ( + np.arange(length, dtype=np.float32)[:, np.newaxis] + * inv_timescales[np.newaxis, :] + ) + + positional_embedding = np.concatenate( + [np.sin(scaled_time), np.cos(scaled_time)], axis=1 + ) + return positional_embedding + + +@keras_hub_export("keras_hub.models.Qwen3OmniAudioEncoder") +class Qwen3OmniAudioEncoder(keras.layers.Layer): + """Audio encoder for Qwen3-Omni + + This encoder processes mel-spectrogram audio features. It includes: + + - Convolutional downsampling (3 Conv2D layers, 8x total reduction) + - Fixed sinusoidal positional embeddings + - Transformer encoder layers + - Output projection to match text model dimension + + Args: + num_mel_bins: int. The number of mel frequency bins. Defaults to `128`. + d_model: int. The model dimension (hidden size). Defaults to `1280`. + encoder_layers: int. The number of transformer encoder layers. + Defaults to `32`. + encoder_attention_heads: int. The number of attention heads. + Defaults to `20`. + encoder_ffn_dim: int. The feed-forward network dimension. + Defaults to `5120`. + output_dim: int. The output projection dimension (should match text + model hidden dimension). Defaults to `2048`. + downsample_hidden_size: int. The hidden size for convolutional + downsampling layers. Defaults to `480`. + max_source_positions: int. The maximum sequence length after + downsampling. Defaults to `1500`. + scale_embedding: bool. Whether to scale embeddings by sqrt(d_model). + Defaults to `False`. + activation_function: string. The activation function name. + Defaults to `"gelu"`. + dropout: float. The dropout rate. Defaults to `0.0`. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for the model's computations and weights. Note that some + computations, such as softmax and layer normalization will always + be done at float32 precision regardless of dtype. + + Example: + ```python + import numpy as np + import keras_hub + + # Mel-spectrogram input (batch_size, time_steps, mel_bins) + input_features = np.random.uniform(size=(1, 3000, 128)) + + # Audio encoder + audio_encoder = keras_hub.models.Qwen3OmniAudioEncoder( + num_mel_bins=128, + d_model=1280, + encoder_layers=32, + encoder_attention_heads=20, + encoder_ffn_dim=5120, + output_dim=2048, + ) + output = audio_encoder({"input_features": input_features}) + ``` + """ + + def __init__( + self, + num_mel_bins=128, + d_model=1280, + encoder_layers=32, + encoder_attention_heads=20, + encoder_ffn_dim=5120, + output_dim=2048, + downsample_hidden_size=480, + max_source_positions=1500, + scale_embedding=False, + activation_function="gelu", + dropout=0.0, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + + self.num_mel_bins = num_mel_bins + self.d_model = d_model + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.encoder_ffn_dim = encoder_ffn_dim + self.output_dim = output_dim + self.downsample_hidden_size = downsample_hidden_size + self.max_source_positions = max_source_positions + self.scale_embedding = scale_embedding + self.activation_function = activation_function + self.dropout = dropout + + self.embed_scale = math.sqrt(d_model) if scale_embedding else 1.0 + + # === Convolutional downsampling layers === + self.conv2d1 = layers.Conv2D( + downsample_hidden_size, + kernel_size=3, + strides=2, + padding="same", + dtype=dtype, + name="conv2d1", + ) + self.conv2d2 = layers.Conv2D( + downsample_hidden_size, + kernel_size=3, + strides=2, + padding="same", + dtype=dtype, + name="conv2d2", + ) + self.conv2d3 = layers.Conv2D( + downsample_hidden_size, + kernel_size=3, + strides=2, + padding="same", + dtype=dtype, + name="conv2d3", + ) + + self.conv_out = layers.Dense( + d_model, + use_bias=False, + dtype=dtype, + name="conv_out", + ) + + self._positional_embedding_np = _create_sinusoidal_positions( + max_source_positions, d_model + ) + + # === Transformer encoder layers === + self.encoder_transformer_layers = [ + Qwen3OmniAudioEncoderLayer( + embed_dim=d_model, + num_heads=encoder_attention_heads, + ffn_dim=encoder_ffn_dim, + activation=activation_function, + dropout=dropout, + dtype=dtype, + name=f"encoder_layer_{i}", + ) + for i in range(encoder_layers) + ] + + # === Post-encoder normalization === + self.ln_post = layers.LayerNormalization( + dtype=dtype, + name="layer_norm", + ) + + # === Output Projection === + self.proj1 = layers.Dense( + d_model, + use_bias=True, + dtype=dtype, + name="proj1", + ) + self.proj_activation = layers.Activation( + activation_function, dtype=dtype + ) + self.proj2 = layers.Dense( + output_dim, + use_bias=True, + dtype=dtype, + name="proj2", + ) + self.dropout_layer = layers.Dropout( + dropout, dtype=dtype, name="dropout" + ) + + def _call_with_inputs(self, input_features, training=False): + """Encode mel-spectrogram features into output embeddings. + + Args: + input_features: Tensor with shape + `(batch_size, time_steps, num_mel_bins)`. + training: bool. Whether the model is in training mode. + + Returns: + Tensor with shape `(batch_size, seq_len, output_dim)`. + """ + # Apply convolutional downsampling + # Input: (batch, time, mel_bins) -> (batch, time, mel_bins, 1) + hidden_states = ops.expand_dims(input_features, axis=-1) + + hidden_states = self.conv2d1(hidden_states, training=training) + hidden_states = ops.gelu(hidden_states) + + hidden_states = self.conv2d2(hidden_states, training=training) + hidden_states = ops.gelu(hidden_states) + + hidden_states = self.conv2d3(hidden_states, training=training) + hidden_states = ops.gelu(hidden_states) + + # Flatten spatial dimensions and project + batch_size = ops.shape(hidden_states)[0] + seq_len = ops.shape(hidden_states)[1] + hidden_states = ops.transpose(hidden_states, [0, 1, 3, 2]) + hidden_states = ops.reshape(hidden_states, [batch_size, seq_len, -1]) + hidden_states = self.conv_out(hidden_states) + + # Scale embeddings + hidden_states = hidden_states * self.embed_scale + + # Add position embeddings + pos_embed_tensor = ops.convert_to_tensor( + self._positional_embedding_np, dtype=self.compute_dtype + ) + positions = pos_embed_tensor[:seq_len, :] + hidden_states = hidden_states + positions + + # Apply transformer encoder layers + for encoder_layer in self.encoder_transformer_layers: + hidden_states = encoder_layer( + hidden_states, + training=training, + ) + + # Post-encoder normalization + hidden_states = self.ln_post(hidden_states) + + # Output projection + hidden_states = self.proj1(hidden_states) + hidden_states = self.proj_activation(hidden_states) + hidden_states = self.proj2(hidden_states) + hidden_states = self.dropout_layer(hidden_states, training=training) + + return hidden_states + + def call(self, inputs, training=False): + """Process a dict of inputs through the audio encoder. + + Args: + inputs: dict. A dictionary with `"input_features"` key containing + mel-spectrogram tensor with shape + `(batch_size, time_steps, num_mel_bins)`. + training: bool. Whether the model is in training mode. + + Returns: + Tensor with shape `(batch_size, seq_len, output_dim)`. + """ + return self._call_with_inputs( + inputs["input_features"], training=training + ) + + def compute_output_spec(self, input_spec, **kwargs): + """Compute output shape for symbolic tracing.""" + input_features_spec = input_spec["input_features"] + batch_size = input_features_spec.shape[0] + seq_len = ( + input_features_spec.shape[1] // 8 + if input_features_spec.shape[1] + else None + ) + return keras.KerasTensor( + shape=(batch_size, seq_len, self.output_dim), + dtype=input_features_spec.dtype, + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "num_mel_bins": self.num_mel_bins, + "d_model": self.d_model, + "encoder_layers": self.encoder_layers, + "encoder_attention_heads": self.encoder_attention_heads, + "encoder_ffn_dim": self.encoder_ffn_dim, + "output_dim": self.output_dim, + "downsample_hidden_size": self.downsample_hidden_size, + "max_source_positions": self.max_source_positions, + "scale_embedding": self.scale_embedding, + "activation_function": self.activation_function, + "dropout": self.dropout, + } + ) + return config + + +class Qwen3OmniAudioEncoderLayer(layers.Layer): + """Audio encoder transformer layer for Qwen3-Omni. + + A pre-norm transformer encoder layer with multi-head self-attention + and a feed-forward network. + + Args: + embed_dim: int. The embedding dimension (d_model). + num_heads: int. The number of attention heads. + ffn_dim: int. The dimension of the feed-forward network. + activation: string. The activation function name. Defaults to `"gelu"`. + dropout: float. The dropout rate. Defaults to `0.0`. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for the layer's computations and weights. + """ + + def __init__( + self, + embed_dim, + num_heads, + ffn_dim, + activation="gelu", + dropout=0.0, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.embed_dim = embed_dim + self.num_heads = num_heads + self.ffn_dim = ffn_dim + + self.self_attn = layers.MultiHeadAttention( + num_heads=num_heads, + key_dim=embed_dim // num_heads, + dropout=dropout, + dtype=dtype, + name="self_attn", + ) + self.self_attn_layer_norm = layers.LayerNormalization( + epsilon=1e-5, + dtype=dtype, + name="self_attn_layer_norm", + ) + + self.fc1 = layers.Dense( + ffn_dim, + dtype=dtype, + name="fc1", + ) + self.activation_fn = layers.Activation(activation, dtype=dtype) + self.fc2 = layers.Dense( + embed_dim, + dtype=dtype, + name="fc2", + ) + self.final_layer_norm = layers.LayerNormalization( + epsilon=1e-5, + dtype=dtype, + name="final_layer_norm", + ) + + self.dropout_layer = layers.Dropout(dropout, dtype=dtype) + + def build(self, input_shape): + self.self_attn.build(input_shape, input_shape) + self.self_attn_layer_norm.build(input_shape) + self.fc1.build(input_shape) + self.fc2.build((input_shape[0], input_shape[1], self.ffn_dim)) + self.final_layer_norm.build(input_shape) + self.built = True + + def call( + self, + hidden_states, + attention_mask=None, + training=False, + ): + """Forward pass of the audio encoder layer. + + Args: + hidden_states: Tensor. The input hidden states with shape + `(batch_size, sequence_length, embed_dim)`. + attention_mask: Tensor or None. The attention mask. + training: bool. Whether the layer is in training mode. + + Returns: + Tensor with shape `(batch_size, sequence_length, embed_dim)`. + """ + residual = hidden_states + + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states = self.self_attn( + query=hidden_states, + value=hidden_states, + key=hidden_states, + attention_mask=attention_mask, + training=training, + ) + hidden_states = self.dropout_layer(hidden_states, training=training) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.dropout_layer(hidden_states, training=training) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout_layer(hidden_states, training=training) + hidden_states = residual + hidden_states + + return hidden_states + + def get_config(self): + config = super().get_config() + config.update( + { + "embed_dim": self.embed_dim, + "num_heads": self.num_heads, + "ffn_dim": self.ffn_dim, + "activation": self.activation_fn.get_config()["activation"], + "dropout": self.dropout_layer.rate, + } + ) + return config + + def compute_output_shape(self, input_shape): + return input_shape diff --git a/keras_hub/src/models/qwen3_omni/qwen3_omni_audio_encoder_test.py b/keras_hub/src/models/qwen3_omni/qwen3_omni_audio_encoder_test.py new file mode 100644 index 0000000000..7912f38849 --- /dev/null +++ b/keras_hub/src/models/qwen3_omni/qwen3_omni_audio_encoder_test.py @@ -0,0 +1,56 @@ +import numpy as np + +from keras_hub.src.models.qwen3_omni.qwen3_omni_audio_encoder import ( + Qwen3OmniAudioEncoder, +) +from keras_hub.src.models.qwen3_omni.qwen3_omni_audio_encoder import ( + Qwen3OmniAudioEncoderLayer, +) +from keras_hub.src.tests.test_case import TestCase + + +class Qwen3OmniAudioEncoderTest(TestCase): + def test_encoder_layer_output_shape(self): + layer = Qwen3OmniAudioEncoderLayer( + embed_dim=32, + num_heads=4, + ffn_dim=64, + dtype="float32", + ) + hidden_states = np.random.rand(1, 10, 32).astype("float32") + output = layer(hidden_states) + self.assertEqual(output.shape, (1, 10, 32)) + + def test_encoder_output_shape(self): + encoder = Qwen3OmniAudioEncoder( + num_mel_bins=80, + d_model=32, + encoder_layers=2, + encoder_attention_heads=4, + encoder_ffn_dim=64, + output_dim=16, + max_source_positions=100, + scale_embedding=False, + dtype="float32", + ) + input_features = np.random.rand(1, 160, 80).astype("float32") + output = encoder({"input_features": input_features}) + self.assertEqual(output.shape[-1], 16) + + def test_encoder_config_roundtrip(self): + encoder = Qwen3OmniAudioEncoder( + num_mel_bins=80, + d_model=32, + encoder_layers=2, + encoder_attention_heads=4, + encoder_ffn_dim=64, + output_dim=16, + max_source_positions=100, + scale_embedding=False, + dtype="float32", + ) + config = encoder.get_config() + restored = Qwen3OmniAudioEncoder.from_config(config) + self.assertEqual(restored.d_model, 32) + self.assertEqual(restored.output_dim, 16) + self.assertEqual(restored.encoder_layers, 2) diff --git a/keras_hub/src/models/qwen3_omni/qwen3_omni_backbone.py b/keras_hub/src/models/qwen3_omni/qwen3_omni_backbone.py new file mode 100644 index 0000000000..031523a93d --- /dev/null +++ b/keras_hub/src/models/qwen3_omni/qwen3_omni_backbone.py @@ -0,0 +1,365 @@ +import keras +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.modeling.reversible_embedding import ( + ReversibleEmbedding, +) +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.qwen3_moe.qwen3_moe_layernorm import Qwen3MoeLayerNorm +from keras_hub.src.models.qwen3_omni.qwen3_omni_decoder import ( + Qwen3OmniTransformerDecoder, +) + + +def _qwen3_omni_kernel_initializer(stddev=0.02): + return keras.initializers.RandomNormal(stddev=stddev) + + +@keras_hub_export("keras_hub.models.Qwen3OmniBackbone") +class Qwen3OmniBackbone(Backbone): + """Qwen3-Omni multimodal Transformer backbone. + + This backbone implements the base Transformer network for the Qwen3-Omni + model. It includes embedding lookups and transformer layers with a Mixture + of Experts (MoE) architecture, using Multimodal Rotary Position Embedding + (M-RoPE) for multimodal fusion. Audio/vision encoders can be optionally + provided for multimodal operation. + + The default constructor gives a fully customizable, randomly initialized + Qwen3-Omni model with any number of layers, heads, and embedding + dimensions. To load preset architectures and weights, use the `from_preset` + constructor. + + Args: + vocabulary_size: int. The size of the token vocabulary. + num_layers: int. The number of transformer layers. + num_query_heads: int. The number of heads for the query projections in + the attention layer. + num_key_value_heads: int. The number of heads for the key and value + projections in the attention layer. + hidden_dim: int. The size of the transformer hidden state at the end of + each transformer layer. + intermediate_dim: int. The output dimension of the first Dense layer in + the feedforward network for each transformer. + moe_intermediate_dim: int. The intermediate dimension for each expert + in the MoE feedforward network. + head_dim: int. The size of each attention head. + num_experts: int. The number of experts in each MoE layer. + num_experts_per_tok: int. The number of top experts to select for each + token in the MoE layer. + mrope_section: tuple. M-RoPE section dimensions + (text, temporal, spatial). Must sum to head_dim // 2. + rope_max_wavelength: int. Max wavelength for RoPE. + rope_scaling_factor: float. Scaling factor for RoPE. + rope_attention_scaling: float. Attention scaling for RoPE. + layer_norm_epsilon: float. The epsilon value used for every layer norm + in the transformer model. + dropout: float. Dropout probability for the transformer encoder. + tie_word_embeddings: bool. Whether to tie input/output embeddings. + sliding_window_size: int or None. Size of sliding attention window. + norm_topk_prob: bool. Whether to normalize top-k probabilities. + decoder_sparse_step: int. Sparse step for MoE layers. + router_aux_loss_coefficient: float. Auxiliary loss coefficient. + mlp_only_layers: list of int or None. Layers to use dense FFN instead + of MoE. + audio_encoder: Qwen3OmniAudioEncoder or None. Pre-instantiated audio + encoder. + vision_encoder: Qwen3OmniVisionEncoder or None. Pre-instantiated vision + encoder. + image_token_id: int. Token ID for image placeholders. + video_token_id: int. Token ID for video placeholders. + audio_token_id: int. Token ID for audio placeholders. + dtype: str or `keras.mixed_precision.DTypePolicy`. The dtype to use for + the model's computations and weights. + + Example: + ```python + input_data = { + "token_ids": np.ones(shape=(1, 12), dtype="int32"), + "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), + } + + # Randomly initialized Qwen3-Omni decoder with custom config. + model = keras_hub.models.Qwen3OmniBackbone( + vocabulary_size=152064, + num_layers=48, + num_query_heads=32, + num_key_value_heads=4, + hidden_dim=2048, + intermediate_dim=768, + moe_intermediate_dim=768, + head_dim=128, + num_experts=128, + num_experts_per_tok=8, + mrope_section=(24, 20, 20), + ) + model(input_data) + ``` + """ + + def __init__( + self, + vocabulary_size, + num_layers, + num_query_heads, + num_key_value_heads, + head_dim, + hidden_dim, + intermediate_dim, + moe_intermediate_dim, + num_experts, + num_experts_per_tok, + mrope_section, + rope_max_wavelength=1000000, + rope_scaling_factor=1.0, + rope_attention_scaling=1.0, + layer_norm_epsilon=1e-6, + dropout=0.0, + tie_word_embeddings=False, + norm_topk_prob=True, + decoder_sparse_step=1, + sliding_window_size=None, + router_aux_loss_coefficient=0.001, + mlp_only_layers=None, + audio_encoder=None, + vision_encoder=None, + image_token_id=None, + video_token_id=None, + audio_token_id=None, + dtype=None, + **kwargs, + ): + # === Layers === + self.token_embedding = ReversibleEmbedding( + input_dim=vocabulary_size, + output_dim=hidden_dim, + tie_weights=tie_word_embeddings, + embeddings_initializer=_qwen3_omni_kernel_initializer(stddev=0.01), + dtype=dtype, + name="token_embedding", + ) + + if not mlp_only_layers: + mlp_only_layers = [] + + self.transformer_layers = [] + for i in range(num_layers): + is_sparse_mlp = ( + (i not in mlp_only_layers) + and num_experts > 0 + and (i + 1) % decoder_sparse_step == 0 + ) + layer = Qwen3OmniTransformerDecoder( + intermediate_dim=intermediate_dim, + num_query_heads=num_query_heads, + num_key_value_heads=num_key_value_heads, + moe_intermediate_dim=moe_intermediate_dim, + head_dim=head_dim, + num_experts=num_experts, + top_k=num_experts_per_tok, + norm_top_k_prob=norm_topk_prob, + mrope_section=mrope_section, + rope_max_wavelength=rope_max_wavelength, + rope_scaling_factor=rope_scaling_factor, + rope_attention_scaling=rope_attention_scaling, + layer_norm_epsilon=layer_norm_epsilon, + activation=ops.silu, + kernel_initializer=_qwen3_omni_kernel_initializer(stddev=0.02), + dropout=dropout, + dtype=dtype, + sliding_window_size=sliding_window_size, + router_aux_loss_coefficient=router_aux_loss_coefficient, + is_sparse_mlp=is_sparse_mlp, + name=f"transformer_layer_{i}", + ) + self.transformer_layers.append(layer) + self.layer_norm = Qwen3MoeLayerNorm( + epsilon=layer_norm_epsilon, + dtype=dtype, + name="sequence_output_layernorm", + ) + + # === Functional Model === + token_id_input = keras.Input( + shape=(None,), dtype="int32", name="token_ids" + ) + padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) + x = self.token_embedding(token_id_input) + for transformer_layer in self.transformer_layers: + x = transformer_layer( + x, + position_ids=None, + decoder_padding_mask=padding_mask_input, + ) + sequence_output = self.layer_norm(x) + super().__init__( + inputs={ + "token_ids": token_id_input, + "padding_mask": padding_mask_input, + }, + outputs=sequence_output, + dtype=dtype, + **kwargs, + ) + + self.audio_encoder = audio_encoder + self.vision_encoder = vision_encoder + + # === Config === + self.vocabulary_size = vocabulary_size + self.head_dim = head_dim + self.intermediate_dim = intermediate_dim + self.moe_intermediate_dim = moe_intermediate_dim + self.num_experts = num_experts + self.num_experts_per_tok = num_experts_per_tok + self.num_layers = num_layers + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_dim = hidden_dim + self.norm_topk_prob = norm_topk_prob + self.decoder_sparse_step = decoder_sparse_step + self.router_aux_loss_coefficient = router_aux_loss_coefficient + self.mlp_only_layers = mlp_only_layers or [] + self.mrope_section = mrope_section + self.rope_max_wavelength = rope_max_wavelength + self.rope_scaling_factor = rope_scaling_factor + self.rope_attention_scaling = rope_attention_scaling + self.layer_norm_epsilon = layer_norm_epsilon + self.dropout = dropout + self.tie_word_embeddings = tie_word_embeddings + self.sliding_window_size = sliding_window_size + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.audio_token_id = audio_token_id + + def call(self, inputs, training=False): + token_ids = inputs["token_ids"] + padding_mask = inputs["padding_mask"] + audio_features = inputs.get("audio_features", None) + pixel_values = inputs.get("pixel_values", None) + grid_thw = inputs.get("grid_thw", None) + + x = self._compute_embeddings( + token_ids, audio_features, pixel_values, grid_thw + ) + for transformer_layer in self.transformer_layers: + x = transformer_layer( + x, + position_ids=None, + decoder_padding_mask=padding_mask, + training=training, + ) + return self.layer_norm(x) + + def get_config(self): + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "num_query_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "hidden_dim": self.hidden_dim, + "head_dim": self.head_dim, + "intermediate_dim": self.intermediate_dim, + "moe_intermediate_dim": self.moe_intermediate_dim, + "num_experts": self.num_experts, + "num_experts_per_tok": self.num_experts_per_tok, + "mrope_section": self.mrope_section, + "rope_max_wavelength": self.rope_max_wavelength, + "rope_scaling_factor": self.rope_scaling_factor, + "rope_attention_scaling": self.rope_attention_scaling, + "layer_norm_epsilon": self.layer_norm_epsilon, + "dropout": self.dropout, + "tie_word_embeddings": self.tie_word_embeddings, + "norm_topk_prob": self.norm_topk_prob, + "decoder_sparse_step": self.decoder_sparse_step, + "sliding_window_size": self.sliding_window_size, + "router_aux_loss_coefficient": self.router_aux_loss_coefficient, + "mlp_only_layers": self.mlp_only_layers, + "image_token_id": self.image_token_id, + "video_token_id": self.video_token_id, + "audio_token_id": self.audio_token_id, + "audio_encoder": ( + keras.saving.serialize_keras_object(self.audio_encoder) + if self.audio_encoder is not None + else None + ), + "vision_encoder": ( + keras.saving.serialize_keras_object(self.vision_encoder) + if self.vision_encoder is not None + else None + ), + } + ) + return config + + @classmethod + def from_config(cls, config): + if config.get("audio_encoder") is not None and isinstance( + config["audio_encoder"], dict + ): + config["audio_encoder"] = keras.layers.deserialize( + config["audio_encoder"] + ) + if config.get("vision_encoder") is not None and isinstance( + config["vision_encoder"], dict + ): + config["vision_encoder"] = keras.layers.deserialize( + config["vision_encoder"] + ) + return super().from_config(config) + + def _compute_embeddings( + self, + token_ids, + audio_features=None, + pixel_values=None, + grid_thw=None, + ): + inputs_embeds = self.token_embedding(token_ids) + + if audio_features is not None and self.audio_encoder is not None: + audio_embeds = self.audio_encoder( + {"input_features": audio_features} + ) + audio_mask = ops.equal( + ops.cast(token_ids, "int32"), self.audio_token_id + ) + inputs_embeds = self._masked_scatter( + inputs_embeds, audio_mask, audio_embeds + ) + + if pixel_values is not None and self.vision_encoder is not None: + vision_outputs = self.vision_encoder( + {"pixel_values": pixel_values, "grid_thw": grid_thw} + ) + visual_embeds = vision_outputs["pooler_output"] + image_mask = ops.equal( + ops.cast(token_ids, "int32"), self.image_token_id + ) + video_mask = ops.equal( + ops.cast(token_ids, "int32"), self.video_token_id + ) + visual_mask = ops.logical_or(image_mask, video_mask) + inputs_embeds = self._masked_scatter( + inputs_embeds, visual_mask, visual_embeds + ) + + return inputs_embeds + + def _masked_scatter(self, target, mask, source): + """Replace embeddings at masked positions with source embeddings.""" + mask_expanded = ops.cast(ops.expand_dims(mask, -1), target.dtype) + mask_int = ops.cast(mask, "int32") + cumsum = ops.cumsum(mask_int, axis=1) + source_indices = ops.maximum(cumsum - 1, 0) + source_indices_expanded = ops.expand_dims(source_indices, -1) + scattered_values = ops.take_along_axis( + source, source_indices_expanded, axis=1 + ) + result = target * (1 - mask_expanded) + scattered_values * mask_expanded + return result diff --git a/keras_hub/src/models/qwen3_omni/qwen3_omni_backbone_test.py b/keras_hub/src/models/qwen3_omni/qwen3_omni_backbone_test.py new file mode 100644 index 0000000000..ef4c21e278 --- /dev/null +++ b/keras_hub/src/models/qwen3_omni/qwen3_omni_backbone_test.py @@ -0,0 +1,186 @@ +import numpy as np +import pytest +from absl.testing import parameterized +from keras import ops + +from keras_hub.src.models.qwen3_omni.qwen3_omni_audio_encoder import ( + Qwen3OmniAudioEncoder, +) +from keras_hub.src.models.qwen3_omni.qwen3_omni_backbone import ( + Qwen3OmniBackbone, +) +from keras_hub.src.models.qwen3_omni.qwen3_omni_vision_encoder import ( + Qwen3OmniVisionEncoder, +) +from keras_hub.src.tests.test_case import TestCase + + +class Qwen3OmniBackboneTest(TestCase, parameterized.TestCase): + def setUp(self): + self.batch_size = 2 + self.sequence_length = 8 + + # === Vision Encoder === + vision_encoder = Qwen3OmniVisionEncoder( + depth=2, + hidden_size=32, + num_heads=4, + intermediate_size=64, + patch_size=2, + temporal_patch_size=2, + spatial_merge_size=2, + out_hidden_size=16, + num_position_embeddings=49, + deepstack_visual_indexes=[0], + dtype="float32", + ) + + # === Audio Encoder === + audio_encoder = Qwen3OmniAudioEncoder( + num_mel_bins=80, + d_model=32, + encoder_layers=2, + encoder_attention_heads=4, + encoder_ffn_dim=64, + output_dim=16, + max_source_positions=100, + scale_embedding=False, + dtype="float32", + ) + + # === Text + Vision + Audio Backbone === + self.init_kwargs = { + "vocabulary_size": 20, + "num_layers": 2, + "num_query_heads": 4, + "num_key_value_heads": 2, + "hidden_dim": 16, + "intermediate_dim": 32, + "head_dim": 4, + "moe_intermediate_dim": 16, + "num_experts": 4, + "num_experts_per_tok": 2, + "norm_topk_prob": True, + "decoder_sparse_step": 1, + "mrope_section": (1, 1, 0), + "layer_norm_epsilon": 1e-6, + "rope_max_wavelength": 10000, + "rope_scaling_factor": 1.0, + "dropout": 0.0, + "tie_word_embeddings": False, + "sliding_window_size": None, + "router_aux_loss_coefficient": 0.001, + "mlp_only_layers": [], + "vision_encoder": vision_encoder, + "audio_encoder": audio_encoder, + "image_token_id": 5, + "video_token_id": 6, + "audio_token_id": 7, + "dtype": "float32", + } + + self.input_data = { + "token_ids": ops.ones( + (self.batch_size, self.sequence_length), dtype="int32" + ), + "padding_mask": ops.ones( + (self.batch_size, self.sequence_length), dtype="int32" + ), + } + + # === Text-Only Backbone === + self.text_init_kwargs = self.init_kwargs.copy() + self.text_init_kwargs["vision_encoder"] = None + self.text_init_kwargs["audio_encoder"] = None + + @parameterized.named_parameters( + ("multimodal", "multimodal"), ("text_only", "text_only") + ) + def test_backbone_basics(self, backbone_type): + if backbone_type == "multimodal": + init_kwargs = self.init_kwargs + else: + init_kwargs = self.text_init_kwargs + + self.run_backbone_test( + cls=Qwen3OmniBackbone, + init_kwargs=init_kwargs, + input_data=self.input_data, + expected_output_shape=( + self.batch_size, + self.sequence_length, + 16, + ), + run_quantization_check=(backbone_type == "text_only"), + ) + + @parameterized.named_parameters( + ("multimodal", "multimodal"), ("text_only", "text_only") + ) + @pytest.mark.large + def test_saved_model(self, backbone_type): + if backbone_type == "multimodal": + init_kwargs = self.init_kwargs + else: + init_kwargs = self.text_init_kwargs + + self.run_model_saving_test( + cls=Qwen3OmniBackbone, + init_kwargs=init_kwargs, + input_data=self.input_data, + ) + + def test_architecture_characteristics(self): + model = Qwen3OmniBackbone(**self.init_kwargs) + self.assertEqual(len(model.transformer_layers), 2) + self.assertIsNotNone(model.vision_encoder) + self.assertIsNotNone(model.audio_encoder) + + text_model = Qwen3OmniBackbone(**self.text_init_kwargs) + self.assertIsNone(text_model.vision_encoder) + self.assertIsNone(text_model.audio_encoder) + + def test_auxiliary_loss(self): + model = Qwen3OmniBackbone(**self.init_kwargs) + _ = model(self.input_data, training=True) + self.assertTrue( + len(model.losses) > 0, "Auxiliary losses should be present" + ) + for loss in model.losses: + self.assertGreater(loss, 0.0, "Auxiliary loss should be positive") + + def test_vision_fusion_forward(self): + model = Qwen3OmniBackbone(**self.init_kwargs) + + token_ids = np.ones((1, self.sequence_length), dtype="int32") + token_ids[0, 3] = 5 + padding_mask = np.ones((1, self.sequence_length), dtype="int32") + pixel_values = ( + np.random.RandomState(0).randn(1, 2, 4, 4, 3).astype("float32") + ) + grid_thw = np.array([[1, 2, 2]], dtype="int32") + + output_fused = model( + { + "token_ids": token_ids, + "padding_mask": padding_mask, + "pixel_values": pixel_values, + "grid_thw": grid_thw, + } + ) + self.assertEqual(ops.shape(output_fused), (1, self.sequence_length, 16)) + + output_text = model( + { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + ) + self.assertEqual(ops.shape(output_text), (1, self.sequence_length, 16)) + + fused_np = ops.convert_to_numpy(output_fused) + text_np = ops.convert_to_numpy(output_text) + self.assertFalse( + np.allclose(fused_np, text_np, atol=1e-5), + "Vision fusion should change the output.", + ) diff --git a/keras_hub/src/models/qwen3_omni/qwen3_omni_causal_lm.py b/keras_hub/src/models/qwen3_omni/qwen3_omni_causal_lm.py new file mode 100644 index 0000000000..31445d9683 --- /dev/null +++ b/keras_hub/src/models/qwen3_omni/qwen3_omni_causal_lm.py @@ -0,0 +1,345 @@ +import keras +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.causal_lm import CausalLM +from keras_hub.src.models.qwen3_omni.qwen3_omni_backbone import ( + Qwen3OmniBackbone, +) +from keras_hub.src.models.qwen3_omni.qwen3_omni_causal_lm_preprocessor import ( + Qwen3OmniCausalLMPreprocessor, +) +from keras_hub.src.utils.tensor_utils import any_equal + + +@keras_hub_export( + "keras_hub.models.Qwen3OmniCausalLM", +) +class Qwen3OmniCausalLM(CausalLM): + """An end-to-end Qwen3-Omni model for causal language modeling. + + A causal language model (LM) predicts the next token based on previous + tokens. This task setup can be used to train the model unsupervised on plain + text input, or to autoregressively generate plain text similar to the data + used for training. This task can be used for pre-training or fine-tuning a + Qwen3-Omni model, simply by calling `fit()`. + + This model has a `generate()` method, which generates text based on a + prompt. The generation strategy used is controlled by an additional + `sampler` argument on `compile()`. You can recompile the model with + different `keras_hub.samplers` objects to control the generation. + By default, `"greedy"` sampling will be used. + + This model can optionally be configured with a `preprocessor` layer, in + which case it will automatically apply preprocessing to string inputs during + `fit()`, `predict()`, `evaluate()`, and `generate()`. This is done by + default when creating the model with `from_preset()`. + + Args: + backbone: A `keras_hub.models.Qwen3OmniBackbone` instance. + preprocessor: A `keras_hub.models.Qwen3OmniCausalLMPreprocessor` or + `None`. If `None`, this model will not apply preprocessing, and + inputs should be preprocessed before calling the model. + + Examples: + + Use `generate()` to do text generation. + ```python + qwen3_omni_lm = keras_hub.models.Qwen3OmniCausalLM.from_preset( + "hf://Qwen/Qwen3-Omni-30B-A3B-Thinking" + ) + qwen3_omni_lm.generate("I want to say", max_length=30) + + # Generate with batched prompts. + qwen3_omni_lm.generate(["This is a", "Where are you"], max_length=30) + ``` + + Compile the `generate()` function with a custom sampler. + ```python + qwen3_omni_lm = keras_hub.models.Qwen3OmniCausalLM.from_preset( + "hf://Qwen/Qwen3-Omni-30B-A3B-Thinking" + ) + qwen3_omni_lm.compile(sampler="top_k") + qwen3_omni_lm.generate("I want to say", max_length=30) + + qwen3_omni_lm.compile(sampler=keras_hub.samplers.BeamSampler(num_beams=2)) + qwen3_omni_lm.generate("I want to say", max_length=30) + ``` + + Use `generate()` without preprocessing. + ```python + prompt = { + # Token ids for " Qwen3 is". + "token_ids": np.array([[2, 12345, 678, 0, 0, 0, 0]] * 2), + # Use `"padding_mask"` to indicate values that should not be overridden. + "padding_mask": np.array([[1, 1, 1, 0, 0, 0, 0]] * 2), + } + + qwen3_omni_lm = keras_hub.models.Qwen3OmniCausalLM.from_preset( + "hf://Qwen/Qwen3-Omni-30B-A3B-Thinking", + preprocessor=None, + ) + qwen3_omni_lm.generate(prompt) + ``` + + Call `fit()` on a single batch. + ```python + features = ["The quick brown fox jumped.", "I forgot my homework."] + qwen3_omni_lm = keras_hub.models.Qwen3OmniCausalLM.from_preset( + "hf://Qwen/Qwen3-Omni-30B-A3B-Thinking" + ) + qwen3_omni_lm.fit(x=features, batch_size=2) + ``` + + Call `fit()` without preprocessing. + ```python + x = { + # Token ids for " Qwen3 is a language model" + "token_ids": np.array([[2, 12345, 678, 543, 9876, 1, 0, 0]] * 2), + "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 0, 0]] * 2), + } + y = np.array([[12345, 678, 543, 9876, 1, 0, 0, 0]] * 2) + sw = np.array([[1, 1, 1, 1, 1, 0, 0, 0]] * 2) + + qwen3_omni_lm = keras_hub.models.Qwen3OmniCausalLM.from_preset( + "hf://Qwen/Qwen3-Omni-30B-A3B-Thinking", + preprocessor=None, + ) + qwen3_omni_lm.fit(x=x, y=y, sample_weight=sw, batch_size=2) + ``` + """ + + backbone_cls = Qwen3OmniBackbone + preprocessor_cls = Qwen3OmniCausalLMPreprocessor + + def __init__(self, backbone, preprocessor=None, **kwargs): + self.backbone = backbone + self.preprocessor = preprocessor + inputs = backbone.input + hidden_states = backbone(inputs) + outputs = backbone.token_embedding(hidden_states, reverse=True) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + def compile( + self, + optimizer="auto", + loss="auto", + *, + weighted_metrics="auto", + sampler="greedy", + **kwargs, + ): + super().compile( + optimizer=optimizer, + loss=loss, + weighted_metrics=weighted_metrics, + sampler=sampler, + **kwargs, + ) + + def call_with_cache( + self, + token_ids, + cache, + cache_update_index, + ): + """Forward pass of `Qwen3OmniCausalLM` with cache. + + `call_with_cache` adds an additional forward pass for the model for + autoregressive inference. Unlike calling the model directly, this method + allows caching previous key/value Tensors in multi-head attention layer, + and avoids recomputing the outputs of seen tokens. + + Args: + token_ids: a dense int Tensor with shape `(batch_size, max_length)`. + cache: a dense float Tensor, the cache of key and value. + cache_update_index: int, or int Tensor. The index of current inputs + in the whole sequence. + + Returns: + A (logits, hidden_states, cache) tuple. Where `logits` is the + language model logits for the input token_ids, `hidden_states` is + the final hidden representation of the input tokens, and `cache` is + the decoding cache. + """ + x = self.backbone.token_embedding(token_ids) + # Each decoder layer has a cache; we update them separately. + updated_cache = [] + for i in range(self.backbone.num_layers): + current_cache = cache[:, i, ...] + x, next_cache = self.backbone.transformer_layers[i]( + x, + cache=current_cache, + cache_update_index=cache_update_index, + ) + updated_cache.append(next_cache) + + # Stack updated caches back into single tensor + cache = ops.stack(updated_cache, axis=1) + + # Final layer norm and projection to vocabulary + hidden_states = x = self.backbone.layer_norm(x) + logits = self.backbone.token_embedding(x, reverse=True) + return logits, hidden_states, cache + + def _build_cache(self, token_ids): + """Initialize KV cache and perform initial forward pass. + + Creates a zero-initialized cache tensor and seeds it with the initial + prompt tokens. This is called once at the start of generation. + + Args: + token_ids: Initial prompt tokens, shape + `(batch_size, prompt_length)`. + + Returns: + Tuple of (hidden_states, cache) from the initial forward pass. + """ + # Determine cache dimensions from input and model config + batch_size = ops.shape(token_ids)[0] + max_length = ops.shape(token_ids)[1] + num_layers = self.backbone.num_layers + num_key_value_heads = self.backbone.num_key_value_heads + head_dim = self.backbone.head_dim + shape = [ + batch_size, + num_layers, + 2, + max_length, + num_key_value_heads, + head_dim, + ] + cache = ops.zeros(shape, dtype=self.compute_dtype) + + # Seed cache with initial forward pass + _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) + return hidden_states, cache + + def generate_step( + self, + inputs, + stop_token_ids=None, + ): + """A compilable generation function for a single batch of inputs. + + This function represents the inner, XLA-compilable, generation function + for a single batch of inputs. Inputs should have the same structure as + model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. + + Args: + inputs: A dictionary with two keys `"token_ids"` and + `"padding_mask"` and batched tensor values. + stop_token_ids: Tuple of id's of the end token to stop on. If all + sequences have produced a new stop token, generation + will stop. + """ + token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] + # Create and seed cache with a single forward pass. + hidden_states, cache = self._build_cache(token_ids) + # Compute the lengths of all user inputted tokens ids. + row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) + # Start at the first index that has no user inputted id. + index = ops.min(row_lengths) + + def next(prompt, cache, index): + # The cache index is the index of our previous token. + cache_update_index = index - 1 + batch_size = ops.shape(prompt)[0] + prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) + logits, hidden_states, cache = self.call_with_cache( + prompt, + cache, + cache_update_index, + ) + return ( + ops.squeeze(logits, axis=1), + ops.squeeze(hidden_states, axis=1), + cache, + ) + + token_ids = self.sampler( + next=next, + prompt=token_ids, + cache=cache, + index=index, + mask=padding_mask, + stop_token_ids=stop_token_ids, + hidden_states=hidden_states, + model=self, + ) + + # Compute an output padding mask with the token ids we updated. + if stop_token_ids is not None: + # Build a mask of stop token locations not in the original + # prompt (not in locations where `padding_mask` is True). + end_locations = any_equal( + token_ids, stop_token_ids, ops.logical_not(padding_mask) + ) + end_locations = ops.cast(end_locations, "int32") + # Use cumsum to get ones in all locations after end_locations. + cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") + overflow = cumsum - end_locations + # Our padding mask is the inverse of these overflow locations. + padding_mask = ops.logical_not(ops.cast(overflow, "bool")) + else: + # Without early stopping, all locations will have been updated. + padding_mask = ops.ones_like(token_ids, dtype="bool") + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + + def score( + self, + token_ids, + padding_mask=None, + scoring_mode="logits", + layer_intercept_fn=None, + target_ids=None, + ): + if scoring_mode not in ("logits", "loss"): + raise ValueError( + "Unsupported scoring_mode. Must be one of 'logits' or 'loss'." + ) + + if scoring_mode == "loss" and target_ids is None: + raise ValueError( + "Cannot compute loss without targets. Please provide target " + "token ids via the target_ids parameter." + ) + + batch_shape = ops.shape(token_ids)[:2] + assert len(batch_shape) == 2 + + if padding_mask is None: + padding_mask = ops.ones(shape=batch_shape) + + if layer_intercept_fn is None: + + def default_layer_intercept_fn(x, unused_i): + return x + + layer_intercept_fn = default_layer_intercept_fn + + token_embeddings = self.backbone.token_embedding(token_ids) + x = layer_intercept_fn(token_embeddings, -1) + + for i, transformer_layer in enumerate(self.backbone.transformer_layers): + x = transformer_layer(x, decoder_padding_mask=padding_mask) + x = layer_intercept_fn(x, i) + + x = self.backbone.layer_norm(x) + logits = self.backbone.token_embedding(x, reverse=True) + + if scoring_mode == "logits": + return logits + + per_token_loss_fn = keras.losses.SparseCategoricalCrossentropy( + from_logits=True, reduction="none" + ) + per_token_loss = per_token_loss_fn(target_ids, logits) + return per_token_loss diff --git a/keras_hub/src/models/qwen3_omni/qwen3_omni_causal_lm_preprocessor.py b/keras_hub/src/models/qwen3_omni/qwen3_omni_causal_lm_preprocessor.py new file mode 100644 index 0000000000..0a7f17c85d --- /dev/null +++ b/keras_hub/src/models/qwen3_omni/qwen3_omni_causal_lm_preprocessor.py @@ -0,0 +1,233 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.multi_segment_packer import ( + MultiSegmentPacker, +) +from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor +from keras_hub.src.models.qwen3_omni.qwen3_omni_audio_converter import ( + Qwen3OmniAudioConverter, +) +from keras_hub.src.models.qwen3_omni.qwen3_omni_backbone import ( + Qwen3OmniBackbone, +) +from keras_hub.src.models.qwen3_omni.qwen3_omni_image_converter import ( + Qwen3OmniImageConverter, +) +from keras_hub.src.models.qwen3_omni.qwen3_omni_tokenizer import ( + Qwen3OmniTokenizer, +) +from keras_hub.src.utils.tensor_utils import preprocessing_function + + +@keras_hub_export( + "keras_hub.models.Qwen3OmniCausalLMPreprocessor", +) +class Qwen3OmniCausalLMPreprocessor(CausalLMPreprocessor): + """Multimodal preprocessor for Qwen3-Omni CausalLM. + + Handles preprocessing for text, audio, and vision inputs. + + Args: + tokenizer: Qwen3OmniTokenizer instance. + audio_converter: Qwen3OmniAudioConverter instance (optional). + image_converter: Qwen3OmniImageConverter instance (optional). + sequence_length: int. Maximum sequence length. Defaults to 1024. + add_start_token: bool. Whether to add start token. Defaults to True. + add_end_token: bool. Whether to add end token. Defaults to True. + **kwargs: Additional layer arguments. + + Examples: + ```python + # Text-only preprocessing + preprocessor = keras_hub.models.Qwen3OmniCausalLMPreprocessor.from_preset( + "qwen3_omni_instruct" + ) + x = {"prompts": "Hello", "responses": "Hi there!"} + output = preprocessor(x) + + # Multimodal preprocessing + x = { + "prompts": "What is in this image?", + "responses": "A cat", + "images": image_array, + "audio": audio_array, + } + output = preprocessor(x) + ``` + """ + + backbone_cls = Qwen3OmniBackbone + tokenizer_cls = Qwen3OmniTokenizer + audio_converter_cls = Qwen3OmniAudioConverter + image_converter_cls = Qwen3OmniImageConverter + + def __init__( + self, + tokenizer, + audio_converter=None, + image_converter=None, + sequence_length=1024, + add_start_token=True, + add_end_token=True, + **kwargs, + ): + super().__init__( + tokenizer=tokenizer, + sequence_length=sequence_length, + add_start_token=add_start_token, + add_end_token=add_end_token, + **kwargs, + ) + self.audio_converter = audio_converter + self.image_converter = image_converter + + def build(self, input_shape): + # Defer packer creation to `build()` so that we can be sure tokenizer + # assets have loaded when restoring a saved model. + super().build(input_shape) + self.multi_packer = MultiSegmentPacker( + start_value=self.tokenizer.start_token_id or [], + end_value=self.tokenizer.end_token_id or [], + pad_value=self.tokenizer.pad_token_id, + sep_value=[], + sequence_length=self.sequence_length, + ) + + def _process_multimodal_inputs(self, x): + """Extract and convert audio/image inputs from a dict.""" + audio_features = None + if "audio" in x and self.audio_converter: + audio_features = self.audio_converter(x["audio"]) + pixel_values = None + if "images" in x and self.image_converter: + pixel_values = self.image_converter(x["images"]) + return audio_features, pixel_values + + def _add_multimodal_to_output(self, output, audio_features, pixel_values): + """Attach multimodal features to the output dict if present.""" + if audio_features is not None: + output["audio_features"] = audio_features + if pixel_values is not None: + output["pixel_values"] = pixel_values + + @preprocessing_function + def call( + self, + x, + y=None, + sample_weight=None, + sequence_length=None, + ): + sequence_length = sequence_length or self.sequence_length + + # Text-only input (string or tensor) + if not isinstance(x, dict): + x = self.tokenizer(x) + token_ids, padding_mask = self.packer( + x, + sequence_length=sequence_length + 1, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + x = { + "token_ids": token_ids[..., :-1], + "padding_mask": padding_mask[..., :-1], + } + y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:] + return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) + + # Multimodal dict input + prompts = self.tokenizer(x["prompts"]) + audio_features, pixel_values = self._process_multimodal_inputs(x) + responses_text = x.get("responses", None) + + if responses_text is not None: + responses = self.tokenizer(responses_text) + # Pack prompt + response with one extra token for label shift. + token_ids, segment_ids = self.multi_packer( + (prompts, responses), + sequence_length=sequence_length + 1, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + padding_mask = token_ids != self.tokenizer.pad_token_id + response_mask = segment_ids == 1 + + # Truncate last token (no next-token target for it). + x = { + "token_ids": token_ids[..., :-1], + "padding_mask": padding_mask[..., :-1], + } + self._add_multimodal_to_output(x, audio_features, pixel_values) + + y = token_ids[..., 1:] + sample_weight = response_mask[..., 1:] + else: + # No responses — single-segment next-token prediction. + token_ids, padding_mask = self.packer( + prompts, + sequence_length=sequence_length + 1, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + x = { + "token_ids": token_ids[..., :-1], + "padding_mask": padding_mask[..., :-1], + } + self._add_multimodal_to_output(x, audio_features, pixel_values) + + y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:] + + return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) + + @preprocessing_function + def generate_preprocess( + self, + x, + sequence_length=None, + ): + """Convert inputs to integer token input for generation. + + Unlike calling the layer for training, this method does not compute + labels and will never append a `tokenizer.end_token_id` to the end of + the sequence (as generation is expected to continue at the end of the + inputted prompt). + """ + if not self.built: + self.build(None) + + # Text-only input (string or tensor) + if not isinstance(x, dict): + x = self.tokenizer(x) + token_ids, padding_mask = self.packer( + x, sequence_length=sequence_length, add_end_value=False + ) + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + + # Multimodal dict input + prompts = self.tokenizer(x["prompts"]) + audio_features, pixel_values = self._process_multimodal_inputs(x) + + if "responses" in x: + segments = (prompts, self.tokenizer(x["responses"])) + else: + segments = (prompts,) + + token_ids, segment_ids = self.multi_packer( + segments, + sequence_length=sequence_length, + add_start_value=self.add_start_token, + add_end_value=False, + ) + padding_mask = token_ids != self.tokenizer.pad_token_id + + result = { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + self._add_multimodal_to_output(result, audio_features, pixel_values) + return result diff --git a/keras_hub/src/models/qwen3_omni/qwen3_omni_causal_lm_preprocessor_test.py b/keras_hub/src/models/qwen3_omni/qwen3_omni_causal_lm_preprocessor_test.py new file mode 100644 index 0000000000..b5449896f7 --- /dev/null +++ b/keras_hub/src/models/qwen3_omni/qwen3_omni_causal_lm_preprocessor_test.py @@ -0,0 +1,142 @@ +import numpy as np +import pytest + +from keras_hub.src.models.qwen3_omni.qwen3_omni_audio_converter import ( + Qwen3OmniAudioConverter, +) +from keras_hub.src.models.qwen3_omni.qwen3_omni_causal_lm_preprocessor import ( + Qwen3OmniCausalLMPreprocessor, +) +from keras_hub.src.models.qwen3_omni.qwen3_omni_image_converter import ( + Qwen3OmniImageConverter, +) +from keras_hub.src.models.qwen3_omni.qwen3_omni_tokenizer import ( + Qwen3OmniTokenizer, +) +from keras_hub.src.tests.test_case import TestCase + + +class Qwen3OmniCausalLMPreprocessorTest(TestCase): + def setUp(self): + self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"] + self.vocab += ["<|im_end|>", "<|endoftext|>"] + self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) + self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"] + self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"] + self.merges += ["Ġai r", "Ġa i", "pla ne"] + self.tokenizer = Qwen3OmniTokenizer( + vocabulary=self.vocab, + merges=self.merges, + ) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 8, + } + self.input_data = ["airplane at airport"] + + def test_preprocessor_basics(self): + self.run_preprocessor_test( + cls=Qwen3OmniCausalLMPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=( + { + "token_ids": [[1, 3, 4, 2, 5, 6, 7, 7]], + "padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]], + }, + [[3, 4, 2, 5, 6, 7, 7, 7]], + [[1, 1, 1, 1, 1, 0, 0, 0]], + ), + ) + + def test_with_start_end_token(self): + input_data = ["airplane at airport"] * 4 + preprocessor = Qwen3OmniCausalLMPreprocessor( + **self.init_kwargs, + add_start_token=True, + add_end_token=True, + ) + x, y, sw = preprocessor(input_data) + self.assertAllEqual(x["token_ids"], [[1, 3, 4, 2, 5, 6, 7, 7]] * 4) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 1, 0, 0]] * 4) + self.assertAllEqual(y, [[3, 4, 2, 5, 6, 7, 7, 7]] * 4) + self.assertAllEqual(sw, [[1, 1, 1, 1, 1, 0, 0, 0]] * 4) + + def test_generate_preprocess(self): + input_data = "airplane at airport" + preprocessor = Qwen3OmniCausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_preprocess(input_data) + self.assertAllEqual(x["token_ids"], [1, 3, 4, 2, 5, 7, 7, 7]) + self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 0, 0, 0]) + + def test_generate_postprocess(self): + input_data = { + "token_ids": [1, 3, 4, 2, 5, 7, 7, 7], + "padding_mask": [1, 1, 1, 1, 1, 0, 0, 0], + } + preprocessor = Qwen3OmniCausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_postprocess(input_data) + self.assertAllEqual(x, "airplane at airport") + + def test_with_audio_converter(self): + audio_converter = Qwen3OmniAudioConverter(max_audio_length=2) + preprocessor = Qwen3OmniCausalLMPreprocessor( + tokenizer=self.tokenizer, + audio_converter=audio_converter, + sequence_length=8, + ) + input_data = { + "prompts": "airplane at airport", + "responses": "airplane", + "audio": np.ones((16000,), dtype=np.float32), + } + x, y, sw = preprocessor(input_data) + self.assertIn("audio_features", x) + self.assertIn("token_ids", x) + self.assertIn("padding_mask", x) + + def test_with_image_converter(self): + image_converter = Qwen3OmniImageConverter() + preprocessor = Qwen3OmniCausalLMPreprocessor( + tokenizer=self.tokenizer, + image_converter=image_converter, + sequence_length=8, + ) + input_data = { + "prompts": "airplane at airport", + "responses": "airplane", + "images": np.ones((224, 224, 3), dtype=np.uint8) * 128, + } + x, y, sw = preprocessor(input_data) + self.assertIn("pixel_values", x) + self.assertIn("token_ids", x) + self.assertIn("padding_mask", x) + + def test_multimodal_generate_preprocess(self): + audio_converter = Qwen3OmniAudioConverter(max_audio_length=2) + image_converter = Qwen3OmniImageConverter() + preprocessor = Qwen3OmniCausalLMPreprocessor( + tokenizer=self.tokenizer, + audio_converter=audio_converter, + image_converter=image_converter, + sequence_length=8, + ) + input_data = { + "prompts": "airplane", + "audio": np.ones((16000,), dtype=np.float32), + "images": np.ones((224, 224, 3), dtype=np.uint8) * 128, + } + x = preprocessor.generate_preprocess(input_data) + self.assertIn("token_ids", x) + self.assertIn("padding_mask", x) + self.assertIn("audio_features", x) + self.assertIn("pixel_values", x) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in Qwen3OmniCausalLMPreprocessor.presets: + self.run_preset_test( + cls=Qwen3OmniCausalLMPreprocessor, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/qwen3_omni/qwen3_omni_causal_lm_test.py b/keras_hub/src/models/qwen3_omni/qwen3_omni_causal_lm_test.py new file mode 100644 index 0000000000..cd2af52d60 --- /dev/null +++ b/keras_hub/src/models/qwen3_omni/qwen3_omni_causal_lm_test.py @@ -0,0 +1,134 @@ +from unittest.mock import patch + +import pytest +from keras import ops + +from keras_hub.src.models.qwen3_omni.qwen3_omni_backbone import ( + Qwen3OmniBackbone, +) +from keras_hub.src.models.qwen3_omni.qwen3_omni_causal_lm import ( + Qwen3OmniCausalLM, +) +from keras_hub.src.models.qwen3_omni.qwen3_omni_causal_lm_preprocessor import ( + Qwen3OmniCausalLMPreprocessor, +) +from keras_hub.src.models.qwen3_omni.qwen3_omni_tokenizer import ( + Qwen3OmniTokenizer, +) +from keras_hub.src.tests.test_case import TestCase + + +class Qwen3OmniCausalLMTest(TestCase): + def setUp(self): + self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"] + self.vocab += ["<|endoftext|>"] + self.vocab += ["<|im_end|>"] + self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) + self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"] + self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"] + self.merges += ["Ġai r", "Ġa i", "pla ne"] + self.preprocessor = Qwen3OmniCausalLMPreprocessor( + Qwen3OmniTokenizer(vocabulary=self.vocab, merges=self.merges), + sequence_length=7, + ) + self.backbone = Qwen3OmniBackbone( + vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(), + num_layers=2, + num_query_heads=4, + num_key_value_heads=2, + hidden_dim=8, + intermediate_dim=16, + moe_intermediate_dim=4, + head_dim=2, + num_experts=4, + num_experts_per_tok=2, + mrope_section=(1, 0, 0), + ) + self.init_kwargs = { + "preprocessor": self.preprocessor, + "backbone": self.backbone, + } + self.train_data = ([" airplane at airport", " airplane at airport"],) + self.input_data = self.preprocessor(*self.train_data)[0] + + def test_causal_lm_basics(self): + self.run_task_test( + cls=Qwen3OmniCausalLM, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 7, 8), + ) + + def test_generate(self): + causal_lm = Qwen3OmniCausalLM(**self.init_kwargs) + prompt = " airplane at airport" + output = causal_lm.generate(" airplane at airport") + self.assertTrue(prompt in output) + prompt_ids = self.preprocessor.generate_preprocess([prompt]) + causal_lm.preprocessor = None + outputs = causal_lm.generate(prompt_ids, stop_token_ids=None) + self.assertAllEqual( + outputs["token_ids"][:, :5], + prompt_ids["token_ids"][:, :5], + ) + self.assertAllEqual( + outputs["padding_mask"][:, :5], + prompt_ids["padding_mask"][:, :5], + ) + + def test_generate_strip_prompt(self): + causal_lm = Qwen3OmniCausalLM(**self.init_kwargs) + prompt = " airplane at airport" + output = causal_lm.generate(prompt, strip_prompt=True) + self.assertFalse(output.startswith(prompt)) + + def test_early_stopping(self): + causal_lm = Qwen3OmniCausalLM(**self.init_kwargs) + call_with_cache = causal_lm.call_with_cache + + def wrapper(*args, **kwargs): + logits, hidden_states, cache = call_with_cache(*args, **kwargs) + index = self.preprocessor.tokenizer.end_token_id + update = ops.ones_like(logits)[:, :, index] * 1.0e9 + update = ops.expand_dims(update, axis=-1) + logits = ops.slice_update(logits, (0, 0, index), update) + return logits, hidden_states, cache + + with patch.object(causal_lm, "call_with_cache", wraps=wrapper): + prompt = [" airplane at airport", " airplane"] + output = causal_lm.generate(prompt) + self.assertEqual(prompt, output) + + def test_generate_compilation(self): + causal_lm = Qwen3OmniCausalLM(**self.init_kwargs) + causal_lm.generate(" airplane at airport") + first_fn = causal_lm.generate_function + causal_lm.generate(" airplane at airport") + second_fn = causal_lm.generate_function + self.assertEqual(first_fn, second_fn) + causal_lm.compile(sampler="greedy") + self.assertIsNone(causal_lm.generate_function) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=Qwen3OmniCausalLM, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + def test_litert_export(self): + self.run_litert_export_test( + cls=Qwen3OmniCausalLM, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in Qwen3OmniCausalLM.presets: + self.run_preset_test( + cls=Qwen3OmniCausalLM, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/qwen3_omni/qwen3_omni_decoder.py b/keras_hub/src/models/qwen3_omni/qwen3_omni_decoder.py new file mode 100644 index 0000000000..f4fc7cfa20 --- /dev/null +++ b/keras_hub/src/models/qwen3_omni/qwen3_omni_decoder.py @@ -0,0 +1,326 @@ +import keras +from keras import ops + +from keras_hub.src.layers.modeling.transformer_layer_utils import ( + compute_causal_mask, +) +from keras_hub.src.layers.modeling.transformer_layer_utils import ( + merge_padding_and_attention_mask, +) +from keras_hub.src.models.qwen3_moe.qwen3_moe_decoder import Qwen3MoeMLP +from keras_hub.src.models.qwen3_moe.qwen3_moe_decoder import Qwen3SparseMoeBlock +from keras_hub.src.models.qwen3_moe.qwen3_moe_decoder import ( + compute_load_balancing_loss, +) +from keras_hub.src.models.qwen3_moe.qwen3_moe_layernorm import Qwen3MoeLayerNorm +from keras_hub.src.models.qwen3_omni.qwen3_omni_attention import ( + Qwen3OmniAttention, +) +from keras_hub.src.utils.keras_utils import clone_initializer + + +class Qwen3OmniTransformerDecoder(keras.layers.Layer): + """Qwen3-Omni transformer decoder block with MoE. + + This decoder block combines: + - Multi-head attention with M-RoPE (Qwen3OmniAttention) + - Mixture-of-Experts feedforward network (from Qwen3MoE) + - Pre-normalization architecture + - Residual connections + + Args: + intermediate_dim: int. Dimension of dense FFN + (used when MoE is disabled). + num_query_heads: int. Number of query attention heads. + num_key_value_heads: int. Number of key/value attention heads (for GQA). + moe_intermediate_dim: int. Intermediate dimension for each MoE expert. + head_dim: int. Dimension of each attention head. + num_experts: int. Total number of experts in the MoE layer. + top_k: int. Number of experts to activate per token. + norm_top_k_prob: bool. Whether to normalize top-k probabilities. + mrope_section: tuple. M-RoPE section dimensions + [text, temporal, spatial]. + rope_max_wavelength: int. Maximum wavelength for M-RoPE. + rope_scaling_factor: float. Scaling factor for M-RoPE. + rope_attention_scaling: float. Attention scaling for M-RoPE + (default 1.0). + layer_norm_epsilon: float. Epsilon for layer normalization. + activation: callable. Activation function (typically SiLU). + kernel_initializer: initializer. Kernel initializer. + dropout: float. Dropout rate. + sliding_window_size: int or None. Size of sliding attention window. + router_aux_loss_coefficient: float. Auxiliary loss coefficient for MoE. + is_sparse_mlp: bool. Whether to use sparse MoE or dense FFN. + dtype: DType policy for the layer. + **kwargs: Additional layer arguments. + """ + + def __init__( + self, + intermediate_dim, + num_query_heads, + num_key_value_heads, + moe_intermediate_dim, + head_dim, + num_experts, + top_k, + norm_top_k_prob=True, + mrope_section=(24, 20, 20), + rope_max_wavelength=1000000, + rope_scaling_factor=1.0, + rope_attention_scaling=1.0, + layer_norm_epsilon=1e-6, + activation=None, + kernel_initializer="glorot_uniform", + dropout=0.0, + sliding_window_size=None, + router_aux_loss_coefficient=0.001, + is_sparse_mlp=True, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + + self.intermediate_dim = intermediate_dim + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.moe_intermediate_dim = moe_intermediate_dim + self.head_dim = head_dim + self.num_experts = num_experts + self.top_k = top_k + self.norm_top_k_prob = norm_top_k_prob + self.mrope_section = mrope_section + self.rope_max_wavelength = rope_max_wavelength + self.rope_scaling_factor = rope_scaling_factor + self.rope_attention_scaling = rope_attention_scaling + self.layer_norm_epsilon = layer_norm_epsilon + self.activation = keras.activations.get(activation or "silu") + self.kernel_initializer = keras.initializers.get(kernel_initializer) + self.dropout_rate = dropout + self.sliding_window_size = sliding_window_size + self.router_aux_loss_coefficient = router_aux_loss_coefficient + self.is_sparse_mlp = is_sparse_mlp + self.supports_masking = True + + def build(self, input_shape): + hidden_dim = input_shape[-1] + + # Pre-attention layer norm + self.pre_attention_norm = Qwen3MoeLayerNorm( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="pre_attention_norm", + ) + self.pre_attention_norm.build(input_shape) + + # Multi-head attention with M-RoPE + self.attention = Qwen3OmniAttention( + num_query_heads=self.num_query_heads, + num_key_value_heads=self.num_key_value_heads, + head_dim=self.head_dim, + mrope_section=self.mrope_section, + rope_max_wavelength=self.rope_max_wavelength, + rope_scaling_factor=self.rope_scaling_factor, + rope_attention_scaling=self.rope_attention_scaling, + kernel_initializer=clone_initializer(self.kernel_initializer), + dropout=self.dropout_rate, + layer_norm_epsilon=self.layer_norm_epsilon, + sliding_window_size=self.sliding_window_size, + dtype=self.dtype_policy, + name="attention", + ) + self.attention.build(input_shape) + + # Post-attention layer norm + self.post_attention_layernorm = Qwen3MoeLayerNorm( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="post_attention_layernorm", + ) + self.post_attention_layernorm.build(input_shape) + + # MoE or dense FFN + if self.is_sparse_mlp: + # Sparse MoE feedforward reused from Qwen3Moe + self.sparse_moe = Qwen3SparseMoeBlock( + hidden_dim=hidden_dim, + moe_intermediate_dim=self.moe_intermediate_dim, + num_experts=self.num_experts, + top_k=self.top_k, + norm_top_k_prob=self.norm_top_k_prob, + router_aux_loss_coefficient=self.router_aux_loss_coefficient, + kernel_initializer=clone_initializer(self.kernel_initializer), + dtype=self.dtype_policy, + name="sparse_moe", + ) + self.sparse_moe.build(input_shape) + else: + # Dense FFN for non-MoE layers + self.dense_mlp = Qwen3MoeMLP( + intermediate_dim=self.intermediate_dim, + hidden_dim=hidden_dim, + activation_fn="silu", + layer_norm_epsilon=self.layer_norm_epsilon, + kernel_initializer=clone_initializer(self.kernel_initializer), + dtype=self.dtype_policy, + name="dense_mlp", + ) + self.dense_mlp.build(input_shape) + + # Dropout + if self.dropout_rate > 0: + self.dropout_layer = keras.layers.Dropout( + rate=self.dropout_rate, + dtype=self.dtype_policy, + ) + + self.built = True + + def call( + self, + inputs, + position_ids=None, + decoder_padding_mask=None, + cache=None, + cache_update_index=None, + training=None, + ): + """Forward pass of the decoder block. + + Args: + inputs: Input tensor of shape (batch, seq_len, hidden_dim). + position_ids: Position IDs for M-RoPE, shape (3, batch, seq_len). + decoder_padding_mask: Padding mask for attention. + cache: KV cache for generation (optional). + cache_update_index: Index for cache update. + training: Whether in training mode. + + Returns: + Output tensor of shape (batch, seq_len, hidden_dim). + """ + self_attention_mask = self._compute_self_attention_mask( + inputs=inputs, + decoder_padding_mask=decoder_padding_mask, + cache=cache, + cache_update_index=cache_update_index, + ) + residual = inputs + + x = self.pre_attention_norm(inputs) + + # Self attention block. + x = self.attention( + x, + position_ids=position_ids, + attention_mask=self_attention_mask, + cache=cache, + cache_update_index=cache_update_index, + training=training, + ) + + if cache is not None: + x, cache = x + + if self.dropout_rate > 0: + x = self.dropout_layer(x, training=training) + + x = x + residual + residual = x + + x = self.post_attention_layernorm(x) + if self.is_sparse_mlp: + x, router_logits = self.sparse_moe(x) + + # Compute auxiliary loss for load balancing + if training: + aux_loss = compute_load_balancing_loss( + router_logits, + self.num_experts, + self.top_k, + self_attention_mask, + ) + self.add_loss(self.router_aux_loss_coefficient * aux_loss) + else: + x = self.dense_mlp(x) + + x = ops.cast(x, ops.dtype(residual)) + x = x + residual + + if cache is not None: + return x, cache + return x + + def _compute_self_attention_mask( + self, + inputs, + decoder_padding_mask, + cache, + cache_update_index, + ): + """Computes the self-attention mask combining causal and padding masks. + + Args: + inputs: Input tensor. + decoder_padding_mask: Mask tensor for padding tokens. + cache: Optional cached key and value tensors. + cache_update_index: Index at which to update the cache. + + Returns: + Combined attention mask tensor. + """ + decoder_mask = merge_padding_and_attention_mask( + inputs, decoder_padding_mask, None + ) + batch_size = ops.shape(inputs)[0] + input_length = output_length = ops.shape(inputs)[1] + # We need to handle a rectangular causal mask when doing cached + # decoding. For generative inference, `inputs` will + # generally be length 1, and `cache` will be the full generation length. + if cache is not None: + input_length = ops.shape(cache)[2] + + cache_update_index = ( + 0 if cache_update_index is None else cache_update_index + ) + + causal_mask = compute_causal_mask( + batch_size, input_length, output_length, cache_update_index + ) + + return ( + ops.minimum(decoder_mask, causal_mask) + if decoder_mask is not None + else causal_mask + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "intermediate_dim": self.intermediate_dim, + "head_dim": self.head_dim, + "num_query_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "moe_intermediate_dim": self.moe_intermediate_dim, + "num_experts": self.num_experts, + "top_k": self.top_k, + "norm_top_k_prob": self.norm_top_k_prob, + "mrope_section": self.mrope_section, + "router_aux_loss_coefficient": self.router_aux_loss_coefficient, + "is_sparse_mlp": self.is_sparse_mlp, + "rope_max_wavelength": self.rope_max_wavelength, + "rope_scaling_factor": self.rope_scaling_factor, + "rope_attention_scaling": self.rope_attention_scaling, + "layer_norm_epsilon": self.layer_norm_epsilon, + "activation": keras.activations.serialize(self.activation), + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "dropout": self.dropout_rate, + "sliding_window_size": self.sliding_window_size, + } + ) + return config + + def compute_output_shape(self, input_shape): + return input_shape diff --git a/keras_hub/src/models/qwen3_omni/qwen3_omni_image_converter.py b/keras_hub/src/models/qwen3_omni/qwen3_omni_image_converter.py new file mode 100644 index 0000000000..b59f96839a --- /dev/null +++ b/keras_hub/src/models/qwen3_omni/qwen3_omni_image_converter.py @@ -0,0 +1,10 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter +from keras_hub.src.models.qwen3_omni.qwen3_omni_backbone import ( + Qwen3OmniBackbone, +) + + +@keras_hub_export("keras_hub.layers.Qwen3OmniImageConverter") +class Qwen3OmniImageConverter(ImageConverter): + backbone_cls = Qwen3OmniBackbone diff --git a/keras_hub/src/models/qwen3_omni/qwen3_omni_image_converter_test.py b/keras_hub/src/models/qwen3_omni/qwen3_omni_image_converter_test.py new file mode 100644 index 0000000000..c789e38ae7 --- /dev/null +++ b/keras_hub/src/models/qwen3_omni/qwen3_omni_image_converter_test.py @@ -0,0 +1,57 @@ +import numpy as np +import pytest + +from keras_hub.src.models.qwen3_omni.qwen3_omni_image_converter import ( + Qwen3OmniImageConverter, +) +from keras_hub.src.tests.test_case import TestCase + + +class Qwen3OmniImageConverterTest(TestCase): + def setUp(self): + self.init_kwargs = { + "height": 224, + "width": 224, + } + + def test_converter_basics(self): + converter = Qwen3OmniImageConverter(**self.init_kwargs) + # Create dummy image + image = np.ones((512, 512, 3), dtype=np.uint8) * 128 + output = converter(image) + # Single image returns unbatched (height, width, channels) + self.assertEqual(len(output.shape), 3) + self.assertEqual(output.shape[0], 224) + self.assertEqual(output.shape[1], 224) + self.assertEqual(output.shape[2], 3) + + def test_batch_processing(self): + converter = Qwen3OmniImageConverter(**self.init_kwargs) + # Create batch of dummy images with uniform shape + batch_size = 2 + images = np.ones((batch_size, 512, 512, 3), dtype=np.uint8) * 128 + output = converter(images) + # Batch returns (batch, height, width, channels) + self.assertEqual(output.shape[0], batch_size) + self.assertEqual(output.shape[1], 224) + self.assertEqual(output.shape[2], 224) + self.assertEqual(output.shape[3], 3) + + def test_single_image(self): + converter = Qwen3OmniImageConverter(**self.init_kwargs) + # Create single image + image = np.random.randint(0, 255, (384, 384, 3), dtype=np.uint8) + output = converter(image) + # Single image returns unbatched (height, width, channels) + self.assertEqual(output.shape[0], 224) + self.assertEqual(output.shape[1], 224) + self.assertEqual(output.shape[2], 3) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in Qwen3OmniImageConverter.presets: + self.run_preset_test( + cls=Qwen3OmniImageConverter, + preset=preset, + input_data=np.ones((224, 224, 3), dtype=np.uint8), + ) diff --git a/keras_hub/src/models/qwen3_omni/qwen3_omni_presets.py b/keras_hub/src/models/qwen3_omni/qwen3_omni_presets.py new file mode 100644 index 0000000000..2981b592fc --- /dev/null +++ b/keras_hub/src/models/qwen3_omni/qwen3_omni_presets.py @@ -0,0 +1,26 @@ +"""Qwen3-Omni model preset configurations.""" + +# TODO: Upload Qwen3-Omni presets to Kaggle +# Qwen3-Omni models are available on HuggingFace +# (Qwen/Qwen3-Omni-30B-A3B-Instruct) +# but not yet converted and uploaded to Kaggle Models. +# +# Once weights are converted and uploaded, please add entries like: +# "qwen3_omni_30b_a3b_en": { +# "metadata": { +# "description": ( +# "Qwen3-Omni Thinker (comprehension) model with " +# "30.5 billion total parameters and 3.3 billion activated. " +# "This is a Mixture-of-Experts " +# "(MoE) based multimodal model supporting text, audio, image, and " +# "video inputs. Built on 48 layers with 32 query and 4 key/value " +# "attention heads, utilizing 128 experts (8 active per token). " +# "Features M-RoPE for multimodal position encoding." +# ), +# "params": 30532122624, +# "path": "qwen3_omni", +# }, +# "kaggle_handle": "kaggle://keras/qwen-3-omni/keras/qwen3_omni_30b_a3b_en/1", +# } + +backbone_presets = {} diff --git a/keras_hub/src/models/qwen3_omni/qwen3_omni_rope.py b/keras_hub/src/models/qwen3_omni/qwen3_omni_rope.py new file mode 100644 index 0000000000..77e662f964 --- /dev/null +++ b/keras_hub/src/models/qwen3_omni/qwen3_omni_rope.py @@ -0,0 +1,232 @@ +from keras import ops + +from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding + + +class MultimodalRotaryEmbedding(RotaryEmbedding): + """Multimodal Rotary Position Embedding (M-RoPE) for vision-language models. + + M-RoPE extends standard RoPE to handle multimodal inputs by dividing the + head dimension into three sections: + - Text section: Standard 1D position encoding for text tokens + - Temporal section: Position encoding for time/frame dimension (video/audio) + - Spatial section: Position encoding for spatial dimensions (image patches) + + For text-only tokens, all three sections use the same position ID, making it + equivalent to standard RoPE. For vision/audio tokens, each section gets + independent position IDs for temporal and spatial modeling. + + Args: + mrope_section: tuple of 3 ints. Dimension allocation for + (text, temporal, spatial). + For example, [24, 20, 20] means: + - 24 dims for text positions + - 20 dims for temporal positions + - 20 dims for spatial positions + Total must equal head_dim // 2 (since RoPE uses pairs). + max_wavelength: int. The maximum angular wavelength. Defaults to 10000. + scaling_factor: float. Scaling factor for positions. Defaults to 1.0. + sequence_axis: int. Sequence axis in input tensor. Defaults to 1. + feature_axis: int. Feature axis in input tensor. Defaults to -1. + **kwargs: Additional arguments passed to parent RotaryEmbedding. + + Examples: + ```python + import numpy as np + import keras + + # Initialize M-RoPE with section [24, 20, 20] for 128-dim heads + mrope = MultimodalRotaryEmbedding( + mrope_section=[24, 20, 20], + max_wavelength=1000000, + ) + + q = keras.random.normal((2, 10, 32, 128)) + k = keras.random.normal((2, 10, 32, 128)) + + text_pos = np.arange(10) + position_ids = np.stack([ + text_pos, + text_pos, + text_pos, + ], axis=0) + position_ids = np.expand_dims(position_ids, 1).repeat(2, axis=1) + + q_embed, k_embed = mrope.apply_multimodal_rotary_embedding( + q, k, position_ids + ) + + # Vision tokens: different positions per section + # For image patches: text_pos=constant, temporal_pos=0, spatial_pos varies + ``` + """ + + def __init__( + self, + mrope_section, + max_wavelength=10000, + scaling_factor=1.0, + attention_scaling=1.0, + sequence_axis=1, + feature_axis=-1, + **kwargs, + ): + super().__init__( + max_wavelength=max_wavelength, + scaling_factor=scaling_factor, + sequence_axis=sequence_axis, + feature_axis=feature_axis, + **kwargs, + ) + + if len(mrope_section) != 3: + raise ValueError( + f"mrope_section must have 3 elements " + f"(text, temporal, spatial), got {len(mrope_section)}" + ) + + self.mrope_section = tuple(mrope_section) + self.attention_scaling = attention_scaling + + self.total_rope_dim = sum(mrope_section) + + def apply_multimodal_rotary_embedding(self, query, key, position_ids): + """Apply M-RoPE to query and key tensors. + + Args: + query: Query tensor of shape (batch, seq_len, num_heads, head_dim) + key: Key tensor of shape (batch, seq_len, num_heads, head_dim) + position_ids: Position IDs of shape (3, batch, seq_len) where: + position_ids[0] = text positions + position_ids[1] = temporal positions + position_ids[2] = spatial positions + + Returns: + Tuple of (query_embed, key_embed) with M-RoPE applied. + """ + # Compute inverse frequencies for the full head dimension + head_dim_half = sum(self.mrope_section) + idx = ops.arange(0, head_dim_half * 2, 2, dtype="float32") + denom = ops.cast(head_dim_half * 2, "float32") + freq_range = idx / denom + inverse_freq = ops.power( + ops.cast(self.max_wavelength, "float32"), -freq_range + ) + inverse_freq = inverse_freq / ops.cast(self.scaling_factor, "float32") + + position_ids_float = ops.cast(position_ids, "float32") + + # Expand for broadcasting: + # position_ids (3, batch, seq_len) -> (3, batch, seq_len, 1) + # inverse_freq (head_dim_half,) -> (1, 1, 1, head_dim_half) + position_ids_expanded = ops.expand_dims(position_ids_float, axis=-1) + inverse_freq_expanded = ops.reshape( + inverse_freq, (1, 1, 1, head_dim_half) + ) + + # Compute frequencies: (3, batch, seq_len, head_dim_half) + freqs_stacked = position_ids_expanded * inverse_freq_expanded + + # Apply interleaved M-RoPE to reorganize frequency layout + freqs_interleaved = self._apply_interleaved_mrope( + freqs_stacked, self.mrope_section + ) + + embedding = ops.concatenate( + [freqs_interleaved, freqs_interleaved], axis=-1 + ) + + # Apply attention scaling to cos/sin embeddings + cos_full = ops.cos(embedding) * self.attention_scaling + sin_full = ops.sin(embedding) * self.attention_scaling + cos_full = ops.cast(cos_full, self.compute_dtype) + sin_full = ops.cast(sin_full, self.compute_dtype) + + # Expand for broadcasting with (batch, seq_len, num_heads, head_dim) + cos_full = ops.expand_dims( + cos_full, axis=2 + ) # (batch, seq_len, 1, head_dim) + sin_full = ops.expand_dims(sin_full, axis=2) + + # Apply rotary embedding + query_embed = self._apply_rotary_pos_emb_single( + query, cos_full, sin_full + ) + key_embed = self._apply_rotary_pos_emb_single(key, cos_full, sin_full) + + return query_embed, key_embed + + def _apply_interleaved_mrope(self, freqs, mrope_section): + """Apply interleaved M-RoPE to reorganize frequency layout. + + Reorganizes frequency layout from chunked [TTT...HHH...WWW] to + interleaved [THWTHWTHW...TT], preserving frequency continuity. + + Args: + freqs: Frequency matrices of shape + (3, batch, seq_len, head_dim_half) where dim 0 + corresponds to [text, temporal, spatial] + All 3 matrices have the same dimension but were computed with + different position IDs. + mrope_section: Tuple of (text_dim, temporal_dim, spatial_dim) + + Returns: + Interleaved frequencies of shape + (batch, seq_len, sum(mrope_section)) + """ + freqs_t = freqs[0] + + head_dim_half = sum(mrope_section) + indices_list = [] + interleaved_length = min(mrope_section[1], mrope_section[2]) * 3 + + for pos in range(interleaved_length): + if pos % 3 == 0: + # Text dimension + indices_list.append(freqs[0][..., pos : pos + 1]) + elif pos % 3 == 1: + # Temporal dimension + indices_list.append(freqs[1][..., pos : pos + 1]) + else: + # Spatial dimension + indices_list.append(freqs[2][..., pos : pos + 1]) + + # Remaining positions will be all from text dimension + if interleaved_length < head_dim_half: + indices_list.append(freqs_t[..., interleaved_length:]) + + # Concatenate all selected frequencies + result = ops.concatenate(indices_list, axis=-1) + + return result + + def _apply_rotary_pos_emb_single(self, tensor, cos_emb, sin_emb): + """Apply rotary position embedding to a single tensor. + + Args: + tensor: Input tensor of shape (batch, seq_len, num_heads, head_dim) + cos_emb: Cosine embedding of shape (batch, seq_len, 1, head_dim) + sin_emb: Sine embedding of shape (batch, seq_len, 1, head_dim) + + Returns: + Tensor with rotary embedding applied. + """ + x1, x2 = ops.split(tensor, 2, axis=-1) + + half_rot_tensor = ops.stack([-x2, x1], axis=-2) + half_rot_tensor = ops.reshape(half_rot_tensor, ops.shape(tensor)) + + return (tensor * cos_emb) + (half_rot_tensor * sin_emb) + + def get_config(self): + config = super().get_config() + config.update( + { + "mrope_section": self.mrope_section, + "attention_scaling": self.attention_scaling, + } + ) + return config + + def compute_output_shape(self, input_shape): + return input_shape diff --git a/keras_hub/src/models/qwen3_omni/qwen3_omni_tokenizer.py b/keras_hub/src/models/qwen3_omni/qwen3_omni_tokenizer.py new file mode 100644 index 0000000000..8af9cfb0db --- /dev/null +++ b/keras_hub/src/models/qwen3_omni/qwen3_omni_tokenizer.py @@ -0,0 +1,58 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.qwen3_omni.qwen3_omni_backbone import ( + Qwen3OmniBackbone, +) +from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer + + +@keras_hub_export( + "keras_hub.tokenizers.Qwen3OmniTokenizer", +) +class Qwen3OmniTokenizer(BytePairTokenizer): + """Tokenizer for Qwen3-Omni model. + + This tokenizer implements byte-pair encoding (BPE) for Qwen3-Omni models, + handling special tokens like EOS (end of sequence) and PAD (padding). + + Args: + vocabulary: Dictionary mapping tokens to token IDs, or path to + vocabulary file. + merges: List of BPE merges, or path to merges file. + **kwargs: Additional keyword arguments passed to the parent + `BytePairTokenizer` class. + + Examples: + TODO: Update once presets registered + ```python + # Load a preset tokenizer + tokenizer = keras_hub.tokenizers.Qwen3OmniTokenizer.from_preset( + "qwen3_omni_30b_a3b_thinking_en" + ) + + # Tokenize text + tokenizer("The quick brown fox jumps.") + ``` + """ + + backbone_cls = Qwen3OmniBackbone + + def __init__( + self, + vocabulary=None, + merges=None, + **kwargs, + ): + eos_token = "<|im_end|>" + self._add_special_token(eos_token, "end_token") + + pad_token = "<|endoftext|>" + self._add_special_token(pad_token, "pad_token") + + self.start_token_id = None + self.start_token = None + + super().__init__( + vocabulary=vocabulary, + merges=merges, + **kwargs, + ) diff --git a/keras_hub/src/models/qwen3_omni/qwen3_omni_vision_encoder.py b/keras_hub/src/models/qwen3_omni/qwen3_omni_vision_encoder.py new file mode 100644 index 0000000000..a2671ded3b --- /dev/null +++ b/keras_hub/src/models/qwen3_omni/qwen3_omni_vision_encoder.py @@ -0,0 +1,1014 @@ +import math + +import keras +import numpy as np +from keras import layers +from keras import ops + +from keras_hub.src.api_export import keras_hub_export + + +class Qwen3OmniVisionPatchEmbed(layers.Layer): + """3D patch embedding layer for Qwen3-Omni vision encoder. + + Converts video or image input into patches using 3D convolution. + For images, the temporal dimension is 1. For videos, the temporal + dimension represents frames. + + Args: + patch_size: int. The spatial patch size (height and width). + temporal_patch_size: int. The temporal patch size (frames). + in_channels: int. The number of input channels (e.g., 3 for RGB). + embed_dim: int. The output embedding dimension. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for the layer's computations and weights. + """ + + def __init__( + self, + patch_size, + temporal_patch_size, + in_channels, + embed_dim, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + self.proj = layers.Conv3D( + filters=embed_dim, + kernel_size=(temporal_patch_size, patch_size, patch_size), + strides=(temporal_patch_size, patch_size, patch_size), + use_bias=True, + data_format="channels_last", + dtype=dtype, + name="proj", + ) + + def build(self, input_shape): + self.proj.build(input_shape) + self.built = True + + def call(self, pixel_values): + """Forward pass. + + Args: + pixel_values: Tensor with shape + `(batch_size, temporal, height, width, channels)`. + + Returns: + Tensor with shape `(batch_size, num_patches, embed_dim)`. + """ + hidden_states = self.proj(pixel_values) + batch_size = ops.shape(hidden_states)[0] + embed_dim = ops.shape(hidden_states)[-1] + hidden_states = ops.reshape(hidden_states, [batch_size, -1, embed_dim]) + return hidden_states + + def get_config(self): + config = super().get_config() + config.update( + { + "patch_size": self.patch_size, + "temporal_patch_size": self.temporal_patch_size, + "in_channels": self.in_channels, + "embed_dim": self.embed_dim, + } + ) + return config + + def compute_output_shape(self, input_shape): + batch_size = input_shape[0] + temporal = input_shape[1] + height = input_shape[2] + width = input_shape[3] + num_patches = ( + (temporal // self.temporal_patch_size) + * (height // self.patch_size) + * (width // self.patch_size) + ) + return (batch_size, num_patches, self.embed_dim) + + +class Qwen3OmniVisionRotaryEmbedding(layers.Layer): + """Rotary position embedding for Qwen3-Omni vision encoder. + + Computes 2D rotary position embeddings from spatial position indices. + Unlike the text M-RoPE, this operates on (row, col) position pairs. + + Args: + dim: int. The embedding dimension (head_dim // 2). + theta: float. The base frequency. Defaults to `10000.0`. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for the layer's computations and weights. + """ + + def __init__(self, dim, theta=10000.0, dtype=None, **kwargs): + super().__init__(dtype=dtype, **kwargs) + self.dim = dim + self.theta = theta + inv_freq = 1.0 / ( + theta ** (np.arange(0, dim, 2, dtype="float32") / dim) + ) + self._inv_freq = inv_freq + + def call(self, seqlen): + """Compute rotary frequency table. + + Args: + seqlen: int. The maximum spatial extent. + + Returns: + Tensor with shape `(seqlen, dim // 2)`. + """ + seq = ops.arange(seqlen, dtype="float32") + inv_freq = ops.convert_to_tensor(self._inv_freq, dtype="float32") + freqs = ops.einsum("i,j->ij", seq, inv_freq) + return freqs + + def get_config(self): + config = super().get_config() + config.update({"dim": self.dim, "theta": self.theta}) + return config + + +def _rotate_half(x): + """Rotate half the hidden dims of the input.""" + x1 = x[..., : ops.shape(x)[-1] // 2] + x2 = x[..., ops.shape(x)[-1] // 2 :] + return ops.concatenate([-x2, x1], axis=-1) + + +def _apply_rotary_pos_emb_vision(q, k, cos, sin): + """Apply rotary position embeddings to query and key for vision. + + Args: + q: Query tensor of shape `(batch, num_heads, seq_len, head_dim)`. + k: Key tensor of shape `(batch, num_heads, seq_len, head_dim)`. + cos: Cosine embedding of shape `(seq_len, head_dim)`. + sin: Sine embedding of shape `(seq_len, head_dim)`. + + Returns: + Tuple of (q_embed, k_embed) with rotary embeddings applied. + """ + # Reshape for broadcasting: (1, 1, seq_len, head_dim) + cos = ops.expand_dims(ops.expand_dims(cos, axis=0), axis=0) + sin = ops.expand_dims(ops.expand_dims(sin, axis=0), axis=0) + q_embed = (q * cos) + (_rotate_half(q) * sin) + k_embed = (k * cos) + (_rotate_half(k) * sin) + return q_embed, k_embed + + +class Qwen3OmniVisionAttention(layers.Layer): + """Multi-head attention for Qwen3-Omni vision encoder. + + Uses fused QKV projection and applies 2D rotary position embeddings. + Attention is non-causal (bidirectional). + + Args: + hidden_size: int. The hidden dimension. + num_heads: int. The number of attention heads. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for the layer's computations and weights. + """ + + def __init__(self, hidden_size, num_heads, dtype=None, **kwargs): + super().__init__(dtype=dtype, **kwargs) + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.scaling = self.head_dim**-0.5 + + self.qkv = layers.Dense( + hidden_size * 3, + use_bias=True, + dtype=dtype, + name="qkv", + ) + self.proj = layers.Dense( + hidden_size, + use_bias=True, + dtype=dtype, + name="proj", + ) + + def build(self, input_shape): + self.qkv.build(input_shape) + proj_shape = list(input_shape) + proj_shape[-1] = self.hidden_size + self.proj.build(proj_shape) + self.built = True + + def call(self, hidden_states, position_embeddings=None, training=False): + """Forward pass. + + Args: + hidden_states: Tensor of shape + `(batch_size, seq_len, hidden_size)`. + position_embeddings: Tuple of (cos, sin) tensors for RoPE, + each of shape `(seq_len, head_dim)`. + training: bool. Whether in training mode. + + Returns: + Tensor of shape `(batch_size, seq_len, hidden_size)`. + """ + batch_size = ops.shape(hidden_states)[0] + seq_len = ops.shape(hidden_states)[1] + + # Fused QKV projection + qkv = self.qkv(hidden_states) + qkv = ops.reshape( + qkv, [batch_size, seq_len, 3, self.num_heads, self.head_dim] + ) + qkv = ops.transpose(qkv, [2, 0, 3, 1, 4]) + query, key, value = ops.split(qkv, 3, axis=0) + query = ops.squeeze(query, axis=0) + key = ops.squeeze(key, axis=0) + value = ops.squeeze(value, axis=0) + + # Apply rotary position embeddings + if position_embeddings is not None: + cos, sin = position_embeddings + query, key = _apply_rotary_pos_emb_vision(query, key, cos, sin) + + # Scaled dot-product attention (non-causal) + attn_weights = ( + ops.matmul(query, ops.transpose(key, [0, 1, 3, 2])) * self.scaling + ) + attn_weights = ops.softmax(ops.cast(attn_weights, "float32"), axis=-1) + attn_weights = ops.cast(attn_weights, self.compute_dtype) + + attn_output = ops.matmul(attn_weights, value) + attn_output = ops.transpose(attn_output, [0, 2, 1, 3]) + attn_output = ops.reshape( + attn_output, [batch_size, seq_len, self.hidden_size] + ) + attn_output = self.proj(attn_output) + return attn_output + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "num_heads": self.num_heads, + } + ) + return config + + def compute_output_shape(self, input_shape): + return input_shape + + +class Qwen3OmniVisionMLP(layers.Layer): + """Feed-forward MLP for Qwen3-Omni vision encoder. + + Args: + hidden_size: int. The hidden dimension. + intermediate_size: int. The MLP intermediate dimension. + hidden_act: string. The activation function name. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for the layer's computations and weights. + """ + + def __init__( + self, hidden_size, intermediate_size, hidden_act, dtype=None, **kwargs + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + + self.fc1 = layers.Dense( + intermediate_size, + use_bias=True, + dtype=dtype, + name="fc1", + ) + if hidden_act in ("gelu_pytorch_tanh", "gelu_approximate"): + self.act_fn = lambda x: keras.activations.gelu(x, approximate=True) + else: + self.act_fn = layers.Activation( + hidden_act, dtype=dtype, name="act_fn" + ) + self.fc2 = layers.Dense( + hidden_size, + use_bias=True, + dtype=dtype, + name="fc2", + ) + + def build(self, input_shape): + self.fc1.build(input_shape) + mid_shape = list(input_shape) + mid_shape[-1] = self.intermediate_size + if hasattr(self.act_fn, "build"): + self.act_fn.build(mid_shape) + self.fc2.build(mid_shape) + self.built = True + + def call(self, hidden_states): + return self.fc2(self.act_fn(self.fc1(hidden_states))) + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "intermediate_size": self.intermediate_size, + "hidden_act": self.hidden_act, + } + ) + return config + + def compute_output_shape(self, input_shape): + return input_shape + + +class Qwen3OmniVisionPatchMerger(layers.Layer): + """Spatial patch merger for Qwen3-Omni vision encoder. + + Merges spatially adjacent patches and projects to the output dimension. + When `use_postshuffle_norm` is True, LayerNorm is applied after merging + (used for deepstack mergers). When False, LayerNorm is applied before + merging (used for the main merger). + + Args: + hidden_size: int. The hidden dimension of the vision encoder. + spatial_merge_size: int. The spatial merge factor. + out_hidden_size: int. The output projection dimension. + use_postshuffle_norm: bool. Whether to apply LayerNorm after + spatial merging. Defaults to `False`. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for the layer's computations and weights. + """ + + def __init__( + self, + hidden_size, + spatial_merge_size, + out_hidden_size, + use_postshuffle_norm=False, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_size = hidden_size + self.spatial_merge_size = spatial_merge_size + self.out_hidden_size = out_hidden_size + self.use_postshuffle_norm = use_postshuffle_norm + + merge_dim = hidden_size * (spatial_merge_size**2) + norm_dim = merge_dim if use_postshuffle_norm else hidden_size + self.ln_q = layers.LayerNormalization( + epsilon=1e-6, + dtype=dtype, + name="ln_q", + ) + self._norm_dim = norm_dim + self._merge_dim = merge_dim + + self.mlp_fc1 = layers.Dense( + merge_dim, + use_bias=True, + dtype=dtype, + name="mlp_fc1", + ) + self.mlp_act = layers.Activation("gelu", dtype=dtype, name="mlp_act") + self.mlp_fc2 = layers.Dense( + out_hidden_size, + use_bias=True, + dtype=dtype, + name="mlp_fc2", + ) + + def build(self, input_shape): + self.ln_q.build([None, self._norm_dim]) + self.mlp_fc1.build([None, self._merge_dim]) + self.mlp_act.build([None, self._merge_dim]) + self.mlp_fc2.build([None, self._merge_dim]) + self.built = True + + def call(self, hidden_states): + """Forward pass. + + Args: + hidden_states: Tensor of shape `(num_tokens, hidden_size)`. + + Returns: + Tensor of shape `(num_merged_tokens, out_hidden_size)`. + """ + if self.use_postshuffle_norm: + hidden_states = ops.reshape(hidden_states, [-1, self._merge_dim]) + hidden_states = self.ln_q(hidden_states) + else: + hidden_states = self.ln_q(hidden_states) + hidden_states = ops.reshape(hidden_states, [-1, self._merge_dim]) + + hidden_states = self.mlp_fc1(hidden_states) + hidden_states = self.mlp_act(hidden_states) + hidden_states = self.mlp_fc2(hidden_states) + return hidden_states + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "spatial_merge_size": self.spatial_merge_size, + "out_hidden_size": self.out_hidden_size, + "use_postshuffle_norm": self.use_postshuffle_norm, + } + ) + return config + + +class Qwen3OmniVisionBlock(layers.Layer): + """Vision transformer block for Qwen3-Omni. + + Implements a Vision Transformer (ViT) block with pre-normalization, + multi-head attention with rotary position embeddings, and a + feed-forward MLP. + + Args: + hidden_size: int. The hidden dimension. + num_heads: int. The number of attention heads. + intermediate_size: int. The MLP intermediate dimension. + hidden_act: string. The activation function name. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for the layer's computations and weights. + """ + + def __init__( + self, + hidden_size, + num_heads, + intermediate_size, + hidden_act="gelu", + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_size = hidden_size + self.num_heads = num_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + + self.norm1 = layers.LayerNormalization( + epsilon=1e-6, + dtype=dtype, + name="norm1", + ) + self.attn = Qwen3OmniVisionAttention( + hidden_size=hidden_size, + num_heads=num_heads, + dtype=dtype, + name="attn", + ) + self.norm2 = layers.LayerNormalization( + epsilon=1e-6, + dtype=dtype, + name="norm2", + ) + self.mlp = Qwen3OmniVisionMLP( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + hidden_act=hidden_act, + dtype=dtype, + name="mlp", + ) + + def build(self, input_shape): + self.norm1.build(input_shape) + self.attn.build(input_shape) + self.norm2.build(input_shape) + self.mlp.build(input_shape) + self.built = True + + def call( + self, + hidden_states, + position_embeddings=None, + training=False, + ): + """Forward pass. + + Args: + hidden_states: Tensor with shape + `(batch_size, sequence_length, hidden_size)`. + position_embeddings: Tuple of (cos, sin) for RoPE, or None. + training: bool. Whether the layer is in training mode. + + Returns: + Tensor with shape `(batch_size, sequence_length, hidden_size)`. + """ + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + position_embeddings=position_embeddings, + training=training, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "num_heads": self.num_heads, + "intermediate_size": self.intermediate_size, + "hidden_act": self.hidden_act, + } + ) + return config + + def compute_output_shape(self, input_shape): + return input_shape + + +@keras_hub_export("keras_hub.models.Qwen3OmniVisionEncoder") +class Qwen3OmniVisionEncoder(keras.layers.Layer): + """Vision encoder for Qwen3-Omni. + + This encoder processes image and video inputs using a Vision Transformer + (ViT) architecture with: + - 3D patch embedding for spatiotemporal features + - Learnable position embeddings with bilinear interpolation + - 2D rotary position embeddings (RoPE) in attention + - Vision transformer blocks + - Spatial patch merging with output projection + - Deepstack intermediate feature collection + + Args: + depth: int. The number of transformer layers. Defaults to `27`. + hidden_size: int. The hidden dimension. Defaults to `1152`. + hidden_act: string. The activation function name. + Defaults to `"gelu_pytorch_tanh"`. + intermediate_size: int. The MLP intermediate dimension. + Defaults to `4304`. + num_heads: int. The number of attention heads. Defaults to `16`. + in_channels: int. The number of input channels. Defaults to `3`. + patch_size: int. The spatial patch size. Defaults to `16`. + spatial_merge_size: int. The spatial merge factor for downsampling. + Defaults to `2`. + temporal_patch_size: int. The temporal patch size for videos. + Defaults to `2`. + out_hidden_size: int. The output projection dimension. + Defaults to `3584`. + num_position_embeddings: int. Number of position embeddings in the + learnable embedding table. Defaults to `2304`. + deepstack_visual_indexes: list of int. Layer indices at which to + collect deepstack intermediate features. Defaults to + `[8, 16, 24]`. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for the model's computations and weights. + + Example: + ```python + import numpy as np + import keras_hub + + # Create encoder + vision_encoder = keras_hub.models.Qwen3OmniVisionEncoder( + hidden_size=1152, + depth=27, + num_heads=16, + intermediate_size=4304, + out_hidden_size=3584, + ) + + pixel_values = np.random.uniform(size=(1, 2, 14, 14, 3)) + grid_thw = np.array([[1, 14, 14]]) + output = vision_encoder({ + "pixel_values": pixel_values, + "grid_thw": grid_thw, + }) + ``` + """ + + def __init__( + self, + depth=27, + hidden_size=1152, + hidden_act="gelu_pytorch_tanh", + intermediate_size=4304, + num_heads=16, + in_channels=3, + patch_size=16, + spatial_merge_size=2, + temporal_patch_size=2, + out_hidden_size=3584, + num_position_embeddings=2304, + deepstack_visual_indexes=None, + dtype=None, + **kwargs, + ): + # Call parent init + super().__init__(dtype=dtype, **kwargs) + + self.depth = depth + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.out_hidden_size = out_hidden_size + self.num_position_embeddings = num_position_embeddings + self.deepstack_visual_indexes = deepstack_visual_indexes or [ + 8, + 16, + 24, + ] + + self.num_grid_per_side = int(math.sqrt(num_position_embeddings)) + self.spatial_merge_unit = spatial_merge_size * spatial_merge_size + + # === Patch embedding === + self.patch_embed = Qwen3OmniVisionPatchEmbed( + patch_size=patch_size, + temporal_patch_size=temporal_patch_size, + in_channels=in_channels, + embed_dim=hidden_size, + dtype=dtype, + name="patch_embed", + ) + + # === Learnable position embeddings === + self.pos_embed = layers.Embedding( + input_dim=num_position_embeddings, + output_dim=hidden_size, + dtype=dtype, + name="pos_embed", + ) + + # === Vision rotary position embeddings === + head_dim = hidden_size // num_heads + self.rotary_pos_emb = Qwen3OmniVisionRotaryEmbedding( + head_dim // 2, + dtype=dtype, + name="rotary_pos_emb", + ) + + # === Vision transformer blocks === + self.blocks = [] + for i in range(depth): + block = Qwen3OmniVisionBlock( + hidden_size=hidden_size, + num_heads=num_heads, + intermediate_size=intermediate_size, + hidden_act=hidden_act, + dtype=dtype, + name=f"block_{i}", + ) + self.blocks.append(block) + + # === Main patch merger (pre-shuffle norm) === + self.merger = Qwen3OmniVisionPatchMerger( + hidden_size=hidden_size, + spatial_merge_size=spatial_merge_size, + out_hidden_size=out_hidden_size, + use_postshuffle_norm=False, + dtype=dtype, + name="merger", + ) + + # === Deepstack mergers (post-shuffle norm) === + self.merger_list = [] + for i in range(len(self.deepstack_visual_indexes)): + merger = Qwen3OmniVisionPatchMerger( + hidden_size=hidden_size, + spatial_merge_size=spatial_merge_size, + out_hidden_size=out_hidden_size, + use_postshuffle_norm=True, + dtype=dtype, + name=f"merger_list_{i}", + ) + self.merger_list.append(merger) + + def build(self, input_shape=None): + # Build patch embedding with a representative shape + patch_shape = ( + None, + self.temporal_patch_size, + self.patch_size, + self.patch_size, + self.in_channels, + ) + self.patch_embed.build(patch_shape) + + # Build position embedding + self.pos_embed.build([None]) + + # Build transformer blocks + block_shape = (None, None, self.hidden_size) + for block in self.blocks: + block.build(block_shape) + + # Build mergers + merger_shape = (None, self.hidden_size) + self.merger.build(merger_shape) + for merger in self.merger_list: + merger.build(merger_shape) + + self.built = True + + def _compute_rot_pos_emb(self, grid_thw): + """Compute 2D rotary position embeddings from grid_thw. + + For each image/video in the batch, computes (row, col) position + indices accounting for the spatial merge pattern, then looks up + rotary frequencies. + + Args: + grid_thw: Integer tensor of shape `(num_images, 3)` where each + row is `(temporal, height, width)` in patch units. + + Returns: + Tuple of (cos, sin) tensors, each of shape + `(total_tokens, head_dim)`. + """ + merge_size = self.spatial_merge_size + + # Find max spatial extent for frequency table + grid_h = grid_thw[:, 1] + grid_w = grid_thw[:, 2] + max_hw = ops.cast(ops.max(ops.maximum(grid_h, grid_w)), "int32") + freq_table = self.rotary_pos_emb(max_hw) + + # Build (row, col) position indices for all tokens + pos_ids_list = [] + for idx in range(ops.shape(grid_thw)[0]): + t = ops.cast(grid_thw[idx, 0], "int32") + h = ops.cast(grid_thw[idx, 1], "int32") + w = ops.cast(grid_thw[idx, 2], "int32") + merged_h = h // merge_size + merged_w = w // merge_size + + block_rows = ops.arange(merged_h, dtype="int32") + block_cols = ops.arange(merged_w, dtype="int32") + intra_row = ops.arange(merge_size, dtype="int32") + intra_col = ops.arange(merge_size, dtype="int32") + + # Full-resolution row positions: + # block_rows[:, None, None, None] * merge_size + # + intra_row[None, None, :, None] + row_idx = ops.reshape( + block_rows, [-1, 1, 1, 1] + ) * merge_size + ops.reshape(intra_row, [1, 1, -1, 1]) + col_idx = ops.reshape( + block_cols, [1, -1, 1, 1] + ) * merge_size + ops.reshape(intra_col, [1, 1, 1, -1]) + + # Broadcast to (merged_h, merged_w, merge_size, merge_size) + row_idx = ops.broadcast_to( + row_idx, + [merged_h, merged_w, merge_size, merge_size], + ) + col_idx = ops.broadcast_to( + col_idx, + [merged_h, merged_w, merge_size, merge_size], + ) + + row_idx = ops.reshape(row_idx, [-1]) + col_idx = ops.reshape(col_idx, [-1]) + + # Stack to (num_spatial_tokens, 2) + coords = ops.stack([row_idx, col_idx], axis=-1) + + # Repeat for temporal frames + if t > 1: + coords = ops.tile(coords, [t, 1]) + + pos_ids_list.append(coords) + + pos_ids = ops.concatenate(pos_ids_list, axis=0) + + # Look up rotary embeddings: (total_tokens, 2, dim//2) + # -> (total_tokens, dim) + embeddings = ops.take(freq_table, pos_ids, axis=0) + embeddings = ops.reshape(embeddings, [ops.shape(embeddings)[0], -1]) + + # Double the frequencies and compute cos/sin + emb = ops.concatenate([embeddings, embeddings], axis=-1) + cos = ops.cos(emb) + sin = ops.sin(emb) + return cos, sin + + def _fast_pos_embed_interpolate(self, grid_thw): + """Bilinear interpolation of learnable position embeddings. + + Given variable-resolution grids, performs bilinear interpolation + over the 2D embedding table, then reorders tokens to match the + spatial merge pattern. + + Args: + grid_thw: Integer tensor of shape `(num_images, 3)`. + + Returns: + Tensor of shape `(total_tokens, hidden_size)`. + """ + grid_ts = grid_thw[:, 0] + grid_hs = grid_thw[:, 1] + grid_ws = grid_thw[:, 2] + merge_size = self.spatial_merge_size + n = self.num_grid_per_side + + pos_embed_weight = self.pos_embed.embeddings + + patch_pos_embeds_list = [] + for i in range(ops.shape(grid_thw)[0]): + t = ops.cast(grid_ts[i], "int32") + h = ops.cast(grid_hs[i], "int32") + w = ops.cast(grid_ws[i], "int32") + + h_idxs = ops.cast(ops.linspace(0.0, float(n - 1), h), "float32") + w_idxs = ops.cast(ops.linspace(0.0, float(n - 1), w), "float32") + + h_floor = ops.cast(ops.floor(h_idxs), "int32") + w_floor = ops.cast(ops.floor(w_idxs), "int32") + h_ceil = ops.minimum(h_floor + 1, n - 1) + w_ceil = ops.minimum(w_floor + 1, n - 1) + + dh = h_idxs - ops.cast(h_floor, "float32") + dw = w_idxs - ops.cast(w_floor, "float32") + + # 4-corner indices into the (n*n,) embedding table + # Shape: each is (h, w) -> flattened to (h*w,) + base_h_floor = h_floor * n + base_h_ceil = h_ceil * n + + # (h, 1) + (1, w) -> (h, w) via broadcasting + idx_tl = ops.reshape(base_h_floor, [-1, 1]) + ops.reshape( + w_floor, [1, -1] + ) + idx_tr = ops.reshape(base_h_floor, [-1, 1]) + ops.reshape( + w_ceil, [1, -1] + ) + idx_bl = ops.reshape(base_h_ceil, [-1, 1]) + ops.reshape( + w_floor, [1, -1] + ) + idx_br = ops.reshape(base_h_ceil, [-1, 1]) + ops.reshape( + w_ceil, [1, -1] + ) + + # Bilinear weights: (h, 1) * (1, w) -> (h, w) + w_tl = ops.reshape(1.0 - dh, [-1, 1]) * ops.reshape( + 1.0 - dw, [1, -1] + ) + w_tr = ops.reshape(1.0 - dh, [-1, 1]) * ops.reshape(dw, [1, -1]) + w_bl = ops.reshape(dh, [-1, 1]) * ops.reshape(1.0 - dw, [1, -1]) + w_br = ops.reshape(dh, [-1, 1]) * ops.reshape(dw, [1, -1]) + + # Flatten and gather + idx_tl = ops.reshape(idx_tl, [-1]) + idx_tr = ops.reshape(idx_tr, [-1]) + idx_bl = ops.reshape(idx_bl, [-1]) + idx_br = ops.reshape(idx_br, [-1]) + w_tl = ops.reshape(w_tl, [-1, 1]) + w_tr = ops.reshape(w_tr, [-1, 1]) + w_bl = ops.reshape(w_bl, [-1, 1]) + w_br = ops.reshape(w_br, [-1, 1]) + + pos_embed = ( + ops.take(pos_embed_weight, idx_tl, axis=0) * w_tl + + ops.take(pos_embed_weight, idx_tr, axis=0) * w_tr + + ops.take(pos_embed_weight, idx_bl, axis=0) * w_bl + + ops.take(pos_embed_weight, idx_br, axis=0) * w_br + ) + + # Repeat for temporal frames: (h*w, hidden) -> (t*h*w, hidden) + pos_embed = ops.tile(pos_embed, [t, 1]) + + # Reorder to spatial merge pattern: + # (t, h//m, w//m, m, m, hidden) -> permute -> flatten + pos_embed = ops.reshape( + pos_embed, + [ + t, + h // merge_size, + merge_size, + w // merge_size, + merge_size, + -1, + ], + ) + pos_embed = ops.transpose(pos_embed, [0, 1, 3, 2, 4, 5]) + pos_embed = ops.reshape(pos_embed, [-1, self.hidden_size]) + + patch_pos_embeds_list.append(pos_embed) + + return ops.concatenate(patch_pos_embeds_list, axis=0) + + def call(self, inputs, training=False): + """Forward pass. + + Args: + inputs: dict with keys: + - `"pixel_values"`: Tensor of shape + `(total_patches, temporal_patch_size, patch_size, + patch_size, in_channels)` — pre-chunked patches. + - `"grid_thw"`: Integer tensor of shape + `(num_images_or_videos, 3)` — the temporal, height, and + width of each image/video in patch grid units. + training: bool. Whether the model is in training mode. + + Returns: + dict with keys: + - `"last_hidden_state"`: Tensor of shape + `(1, total_tokens, hidden_size)`. + - `"pooler_output"`: Tensor of shape + `(1, total_merged_tokens, out_hidden_size)`. + - `"deepstack_features"`: List of tensors, each of shape + `(1, total_merged_tokens, out_hidden_size)`. + """ + pixel_values = inputs["pixel_values"] + grid_thw = inputs["grid_thw"] + + # Patch embedding: (total_patches, t, h, w, c) -> (1, total, hidden) + hidden_states = self.patch_embed(pixel_values) + if len(ops.shape(hidden_states)) == 2: + hidden_states = ops.expand_dims(hidden_states, axis=0) + + # Add interpolated position embeddings + pos_embeds = self._fast_pos_embed_interpolate(grid_thw) + pos_embeds = ops.expand_dims(pos_embeds, axis=0) + hidden_states = hidden_states + pos_embeds + + # Compute rotary position embeddings + cos, sin = self._compute_rot_pos_emb(grid_thw) + position_embeddings = (cos, sin) + + # Apply transformer blocks and collect deepstack features + deepstack_feature_lists = [] + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk( + hidden_states, + position_embeddings=position_embeddings, + training=training, + ) + if layer_num in self.deepstack_visual_indexes: + ds_idx = self.deepstack_visual_indexes.index(layer_num) + # Squeeze batch for merger, then re-add + hs_2d = ops.squeeze(hidden_states, axis=0) + deepstack_feat = self.merger_list[ds_idx](hs_2d) + deepstack_feat = ops.expand_dims(deepstack_feat, axis=0) + deepstack_feature_lists.append(deepstack_feat) + + hs_2d = ops.squeeze(hidden_states, axis=0) + merged_hidden_states = self.merger(hs_2d) + merged_hidden_states = ops.expand_dims(merged_hidden_states, axis=0) + + return { + "last_hidden_state": hidden_states, + "pooler_output": merged_hidden_states, + "deepstack_features": deepstack_feature_lists, + } + + def compute_output_spec(self, input_spec, **kwargs): + """Compute output shape for symbolic tracing.""" + pixel_values_spec = input_spec["pixel_values"] + num_patches = None + return { + "last_hidden_state": keras.KerasTensor( + shape=(1, num_patches, self.hidden_size), + dtype=pixel_values_spec.dtype, + ), + "pooler_output": keras.KerasTensor( + shape=(1, num_patches, self.out_hidden_size), + dtype=pixel_values_spec.dtype, + ), + "deepstack_features": [ + keras.KerasTensor( + shape=(1, num_patches, self.out_hidden_size), + dtype=pixel_values_spec.dtype, + ) + for _ in self.deepstack_visual_indexes + ], + } + + def get_config(self): + config = super().get_config() + config.update( + { + "depth": self.depth, + "hidden_size": self.hidden_size, + "hidden_act": self.hidden_act, + "intermediate_size": self.intermediate_size, + "num_heads": self.num_heads, + "in_channels": self.in_channels, + "patch_size": self.patch_size, + "spatial_merge_size": self.spatial_merge_size, + "temporal_patch_size": self.temporal_patch_size, + "out_hidden_size": self.out_hidden_size, + "num_position_embeddings": self.num_position_embeddings, + "deepstack_visual_indexes": self.deepstack_visual_indexes, + } + ) + return config diff --git a/keras_hub/src/models/qwen3_omni/qwen3_omni_vision_encoder_test.py b/keras_hub/src/models/qwen3_omni/qwen3_omni_vision_encoder_test.py new file mode 100644 index 0000000000..055ad36dd3 --- /dev/null +++ b/keras_hub/src/models/qwen3_omni/qwen3_omni_vision_encoder_test.py @@ -0,0 +1,68 @@ +import numpy as np + +from keras_hub.src.models.qwen3_omni.qwen3_omni_vision_encoder import ( + Qwen3OmniVisionEncoder, +) +from keras_hub.src.models.qwen3_omni.qwen3_omni_vision_encoder import ( + Qwen3OmniVisionPatchEmbed, +) +from keras_hub.src.tests.test_case import TestCase + + +class Qwen3OmniVisionEncoderTest(TestCase): + def test_patch_embed_output_shape(self): + patch_embed = Qwen3OmniVisionPatchEmbed( + patch_size=2, + temporal_patch_size=2, + in_channels=3, + embed_dim=32, + dtype="float32", + ) + pixel_values = np.random.rand(1, 2, 4, 4, 3).astype("float32") + output = patch_embed(pixel_values) + self.assertEqual(output.shape, (1, 4, 32)) + + def test_encoder_output_shape(self): + encoder = Qwen3OmniVisionEncoder( + depth=2, + hidden_size=32, + num_heads=4, + intermediate_size=64, + patch_size=2, + temporal_patch_size=2, + spatial_merge_size=2, + out_hidden_size=16, + num_position_embeddings=49, + deepstack_visual_indexes=[0], + dtype="float32", + ) + pixel_values = np.random.rand(1, 2, 4, 4, 3).astype("float32") + grid_thw = np.array([[1, 2, 2]], dtype="int32") + + output = encoder({"pixel_values": pixel_values, "grid_thw": grid_thw}) + + self.assertIsInstance(output, dict) + self.assertIn("last_hidden_state", output) + self.assertIn("pooler_output", output) + self.assertIn("deepstack_features", output) + self.assertEqual(output["pooler_output"].shape[-1], 16) + + def test_encoder_config_roundtrip(self): + encoder = Qwen3OmniVisionEncoder( + depth=2, + hidden_size=32, + num_heads=4, + intermediate_size=64, + patch_size=2, + temporal_patch_size=2, + spatial_merge_size=2, + out_hidden_size=16, + num_position_embeddings=49, + deepstack_visual_indexes=[0], + dtype="float32", + ) + config = encoder.get_config() + restored = Qwen3OmniVisionEncoder.from_config(config) + self.assertEqual(restored.depth, 2) + self.assertEqual(restored.hidden_size, 32) + self.assertEqual(restored.out_hidden_size, 16) diff --git a/keras_hub/src/utils/transformers/convert_qwen3_omni.py b/keras_hub/src/utils/transformers/convert_qwen3_omni.py new file mode 100644 index 0000000000..610267ab14 --- /dev/null +++ b/keras_hub/src/utils/transformers/convert_qwen3_omni.py @@ -0,0 +1,615 @@ +import numpy as np + +from keras_hub.src.models.qwen3_omni.qwen3_omni_backbone import ( + Qwen3OmniBackbone, +) +from keras_hub.src.utils.preset_utils import get_file +from keras_hub.src.utils.preset_utils import load_json + +backbone_cls = Qwen3OmniBackbone + + +def convert_backbone_config(transformers_config): + """Convert HuggingFace Qwen3-Omni config to KerasHub config.""" + thinker_config = transformers_config.get("thinker_config", {}) + text_config = thinker_config.get("text_config", transformers_config) + rope_scaling = text_config.get("rope_scaling", {}) + mrope_section = rope_scaling.get("mrope_section", [24, 20, 20]) + backbone_config = { + "vocabulary_size": text_config["vocab_size"], + "hidden_dim": text_config["hidden_size"], + "head_dim": text_config["head_dim"], + "num_layers": text_config["num_hidden_layers"], + "num_query_heads": text_config["num_attention_heads"], + "num_key_value_heads": text_config["num_key_value_heads"], + "intermediate_dim": text_config["intermediate_size"], + "moe_intermediate_dim": text_config["moe_intermediate_size"], + "num_experts": text_config["num_experts"], + "num_experts_per_tok": text_config["num_experts_per_tok"], + "norm_topk_prob": text_config["norm_topk_prob"], + "decoder_sparse_step": text_config["decoder_sparse_step"], + "layer_norm_epsilon": text_config["rms_norm_eps"], + "rope_max_wavelength": text_config["rope_theta"], + "mrope_section": tuple(mrope_section), + "sliding_window_size": text_config.get("sliding_window"), + "router_aux_loss_coefficient": text_config["router_aux_loss_coef"], + "mlp_only_layers": text_config.get("mlp_only_layers", []), + "tie_word_embeddings": text_config.get("tie_word_embeddings", False), + } + + backbone_config["image_token_id"] = thinker_config.get( + "image_token_id", 151655 + ) + backbone_config["video_token_id"] = thinker_config.get( + "video_token_id", 151656 + ) + backbone_config["audio_token_id"] = thinker_config.get( + "audio_token_id", 151675 + ) + + audio_config = thinker_config.get("audio_config") + if audio_config: + from keras_hub.src.models.qwen3_omni.qwen3_omni_audio_encoder import ( + Qwen3OmniAudioEncoder, + ) + + backbone_config["audio_encoder"] = Qwen3OmniAudioEncoder( + num_mel_bins=audio_config["num_mel_bins"], + d_model=audio_config["d_model"], + encoder_layers=audio_config["encoder_layers"], + encoder_attention_heads=audio_config["encoder_attention_heads"], + encoder_ffn_dim=audio_config["encoder_ffn_dim"], + output_dim=audio_config["output_dim"], + downsample_hidden_size=audio_config.get( + "downsample_hidden_size", 480 + ), + max_source_positions=audio_config.get("max_source_positions", 1500), + scale_embedding=audio_config.get("scale_embedding", False), + activation_function=audio_config.get("activation_function", "gelu"), + dropout=audio_config.get("dropout", 0.0), + ) + + vision_config = thinker_config.get("vision_config") + if vision_config: + from keras_hub.src.models.qwen3_omni.qwen3_omni_vision_encoder import ( + Qwen3OmniVisionEncoder, + ) + + backbone_config["vision_encoder"] = Qwen3OmniVisionEncoder( + depth=vision_config["depth"], + hidden_size=vision_config["hidden_size"], + num_heads=vision_config["num_heads"], + patch_size=vision_config["patch_size"], + out_hidden_size=vision_config.get( + "out_hidden_size", + text_config.get("hidden_size", 3584), + ), + hidden_act=vision_config.get("hidden_act", "gelu_pytorch_tanh"), + intermediate_size=vision_config.get("intermediate_size", 4304), + in_channels=vision_config.get("in_channels", 3), + spatial_merge_size=vision_config.get("spatial_merge_size", 2), + temporal_patch_size=vision_config.get("temporal_patch_size", 2), + ) + + return backbone_config + + +def convert_weights(backbone, loader, transformers_config): + """Convert HF Thinker weights to KerasHub backbone.""" + + # === Audio Encoder Weights === + if backbone.audio_encoder is not None: + audio_enc = backbone.audio_encoder + + # Conv downsampling layers + def conv2d_transpose(x, _): + # PyTorch Conv2D: (out_channels, in_channels, H, W) + # Keras Conv2D: (H, W, in_channels, out_channels) + return np.transpose(x, (2, 3, 1, 0)) + + loader.port_weight( + keras_variable=audio_enc.conv2d1.kernel, + hf_weight_key="audio_tower.conv2d1.weight", + hook_fn=conv2d_transpose, + ) + loader.port_weight( + keras_variable=audio_enc.conv2d1.bias, + hf_weight_key="audio_tower.conv2d1.bias", + ) + loader.port_weight( + keras_variable=audio_enc.conv2d2.kernel, + hf_weight_key="audio_tower.conv2d2.weight", + hook_fn=conv2d_transpose, + ) + loader.port_weight( + keras_variable=audio_enc.conv2d2.bias, + hf_weight_key="audio_tower.conv2d2.bias", + ) + loader.port_weight( + keras_variable=audio_enc.conv2d3.kernel, + hf_weight_key="audio_tower.conv2d3.weight", + hook_fn=conv2d_transpose, + ) + loader.port_weight( + keras_variable=audio_enc.conv2d3.bias, + hf_weight_key="audio_tower.conv2d3.bias", + ) + + loader.port_weight( + keras_variable=audio_enc.conv_out.kernel, + hf_weight_key="audio_tower.conv_out.weight", + hook_fn=lambda x, _: np.transpose(x, (1, 0)), + ) + + # Transformer encoder layers + def audio_transpose_and_reshape(x, shape): + return np.reshape(np.transpose(x), shape) + + for i in range(audio_enc.encoder_layers_count): + layer = audio_enc.encoder_layers[i] + + # Self-attention layer norm + loader.port_weight( + keras_variable=layer.self_attn_layer_norm.gamma, + hf_weight_key=f"audio_tower.layers.{i}.self_attn_layer_norm.weight", + ) + loader.port_weight( + keras_variable=layer.self_attn_layer_norm.beta, + hf_weight_key=f"audio_tower.layers.{i}.self_attn_layer_norm.bias", + ) + + # Attention QKV projections + loader.port_weight( + keras_variable=layer.self_attn._query_dense.kernel, + hf_weight_key=f"audio_tower.layers.{i}.self_attn.q_proj.weight", + hook_fn=audio_transpose_and_reshape, + ) + loader.port_weight( + keras_variable=layer.self_attn._query_dense.bias, + hf_weight_key=f"audio_tower.layers.{i}.self_attn.q_proj.bias", + hook_fn=lambda x, shape: np.reshape(x, shape), + ) + loader.port_weight( + keras_variable=layer.self_attn._key_dense.kernel, + hf_weight_key=f"audio_tower.layers.{i}.self_attn.k_proj.weight", + hook_fn=audio_transpose_and_reshape, + ) + loader.port_weight( + keras_variable=layer.self_attn._key_dense.bias, + hf_weight_key=f"audio_tower.layers.{i}.self_attn.k_proj.bias", + hook_fn=lambda x, shape: np.reshape(x, shape), + ) + loader.port_weight( + keras_variable=layer.self_attn._value_dense.kernel, + hf_weight_key=f"audio_tower.layers.{i}.self_attn.v_proj.weight", + hook_fn=audio_transpose_and_reshape, + ) + loader.port_weight( + keras_variable=layer.self_attn._value_dense.bias, + hf_weight_key=f"audio_tower.layers.{i}.self_attn.v_proj.bias", + hook_fn=lambda x, shape: np.reshape(x, shape), + ) + loader.port_weight( + keras_variable=layer.self_attn._output_dense.kernel, + hf_weight_key=f"audio_tower.layers.{i}.self_attn.out_proj.weight", + hook_fn=audio_transpose_and_reshape, + ) + loader.port_weight( + keras_variable=layer.self_attn._output_dense.bias, + hf_weight_key=f"audio_tower.layers.{i}.self_attn.out_proj.bias", + ) + + # Feed-forward + loader.port_weight( + keras_variable=layer.final_layer_norm.gamma, + hf_weight_key=f"audio_tower.layers.{i}.final_layer_norm.weight", + ) + loader.port_weight( + keras_variable=layer.final_layer_norm.beta, + hf_weight_key=f"audio_tower.layers.{i}.final_layer_norm.bias", + ) + loader.port_weight( + keras_variable=layer.fc1.kernel, + hf_weight_key=f"audio_tower.layers.{i}.fc1.weight", + hook_fn=lambda x, _: np.transpose(x, (1, 0)), + ) + loader.port_weight( + keras_variable=layer.fc1.bias, + hf_weight_key=f"audio_tower.layers.{i}.fc1.bias", + ) + loader.port_weight( + keras_variable=layer.fc2.kernel, + hf_weight_key=f"audio_tower.layers.{i}.fc2.weight", + hook_fn=lambda x, _: np.transpose(x, (1, 0)), + ) + loader.port_weight( + keras_variable=layer.fc2.bias, + hf_weight_key=f"audio_tower.layers.{i}.fc2.bias", + ) + + # Post-encoder layer norm + loader.port_weight( + keras_variable=audio_enc.ln_post.gamma, + hf_weight_key="audio_tower.ln_post.weight", + ) + loader.port_weight( + keras_variable=audio_enc.ln_post.beta, + hf_weight_key="audio_tower.ln_post.bias", + ) + + # Output projections + loader.port_weight( + keras_variable=audio_enc.proj1.kernel, + hf_weight_key="audio_tower.proj1.weight", + hook_fn=lambda x, _: np.transpose(x, (1, 0)), + ) + loader.port_weight( + keras_variable=audio_enc.proj1.bias, + hf_weight_key="audio_tower.proj1.bias", + ) + loader.port_weight( + keras_variable=audio_enc.proj2.kernel, + hf_weight_key="audio_tower.proj2.weight", + hook_fn=lambda x, _: np.transpose(x, (1, 0)), + ) + loader.port_weight( + keras_variable=audio_enc.proj2.bias, + hf_weight_key="audio_tower.proj2.bias", + ) + + # === Vision Encoder Weights === + if backbone.vision_encoder is not None: + vision_enc = backbone.vision_encoder + + # Patch embedding (Conv3D) + def conv3d_transpose(x, _): + # PyTorch Conv3D: (out_channels, in_channels, D, H, W) + # Keras Conv3D: (D, H, W, in_channels, out_channels) + return np.transpose(x, (2, 3, 4, 1, 0)) + + loader.port_weight( + keras_variable=vision_enc.patch_embed.proj.kernel, + hf_weight_key="visual.patch_embed.proj.weight", + hook_fn=conv3d_transpose, + ) + loader.port_weight( + keras_variable=vision_enc.patch_embed.proj.bias, + hf_weight_key="visual.patch_embed.proj.bias", + ) + + # Position embeddings + loader.port_weight( + keras_variable=vision_enc.pos_embed.embeddings, + hf_weight_key="visual.pos_embed.weight", + ) + + # Vision transformer blocks + for i in range(vision_enc.depth): + block = vision_enc.blocks[i] + + # Layer norms + loader.port_weight( + keras_variable=block.norm1.gamma, + hf_weight_key=f"visual.blocks.{i}.norm1.weight", + ) + loader.port_weight( + keras_variable=block.norm1.beta, + hf_weight_key=f"visual.blocks.{i}.norm1.bias", + ) + loader.port_weight( + keras_variable=block.norm2.gamma, + hf_weight_key=f"visual.blocks.{i}.norm2.weight", + ) + loader.port_weight( + keras_variable=block.norm2.beta, + hf_weight_key=f"visual.blocks.{i}.norm2.bias", + ) + + # Attention QKV + loader.port_weight( + keras_variable=block.attn.qkv.kernel, + hf_weight_key=f"visual.blocks.{i}.attn.qkv.weight", + hook_fn=lambda x, _: np.transpose(x, (1, 0)), + ) + loader.port_weight( + keras_variable=block.attn.qkv.bias, + hf_weight_key=f"visual.blocks.{i}.attn.qkv.bias", + ) + loader.port_weight( + keras_variable=block.attn.proj.kernel, + hf_weight_key=f"visual.blocks.{i}.attn.proj.weight", + hook_fn=lambda x, _: np.transpose(x, (1, 0)), + ) + loader.port_weight( + keras_variable=block.attn.proj.bias, + hf_weight_key=f"visual.blocks.{i}.attn.proj.bias", + ) + loader.port_weight( + keras_variable=block.mlp.fc1.kernel, + hf_weight_key=f"visual.blocks.{i}.mlp.linear_fc1.weight", + hook_fn=lambda x, _: np.transpose(x, (1, 0)), + ) + loader.port_weight( + keras_variable=block.mlp.fc1.bias, + hf_weight_key=f"visual.blocks.{i}.mlp.linear_fc1.bias", + ) + loader.port_weight( + keras_variable=block.mlp.fc2.kernel, + hf_weight_key=f"visual.blocks.{i}.mlp.linear_fc2.weight", + hook_fn=lambda x, _: np.transpose(x, (1, 0)), + ) + loader.port_weight( + keras_variable=block.mlp.fc2.bias, + hf_weight_key=f"visual.blocks.{i}.mlp.linear_fc2.bias", + ) + + loader.port_weight( + keras_variable=vision_enc.merger.ln_q.gamma, + hf_weight_key="visual.merger.ln_q.weight", + ) + loader.port_weight( + keras_variable=vision_enc.merger.ln_q.beta, + hf_weight_key="visual.merger.ln_q.bias", + ) + loader.port_weight( + keras_variable=vision_enc.merger.mlp_fc1.kernel, + hf_weight_key="visual.merger.mlp.0.weight", + hook_fn=lambda x, _: np.transpose(x, (1, 0)), + ) + loader.port_weight( + keras_variable=vision_enc.merger.mlp_fc1.bias, + hf_weight_key="visual.merger.mlp.0.bias", + ) + loader.port_weight( + keras_variable=vision_enc.merger.mlp_fc2.kernel, + hf_weight_key="visual.merger.mlp.2.weight", + hook_fn=lambda x, _: np.transpose(x, (1, 0)), + ) + loader.port_weight( + keras_variable=vision_enc.merger.mlp_fc2.bias, + hf_weight_key="visual.merger.mlp.2.bias", + ) + + # Deepstack mergers + for j in range(len(vision_enc.merger_list)): + merger = vision_enc.merger_list[j] + loader.port_weight( + keras_variable=merger.ln_q.gamma, + hf_weight_key=f"visual.merger_list.{j}.ln_q.weight", + ) + loader.port_weight( + keras_variable=merger.ln_q.beta, + hf_weight_key=f"visual.merger_list.{j}.ln_q.bias", + ) + loader.port_weight( + keras_variable=merger.mlp_fc1.kernel, + hf_weight_key=f"visual.merger_list.{j}.mlp.0.weight", + hook_fn=lambda x, _: np.transpose(x, (1, 0)), + ) + loader.port_weight( + keras_variable=merger.mlp_fc1.bias, + hf_weight_key=f"visual.merger_list.{j}.mlp.0.bias", + ) + loader.port_weight( + keras_variable=merger.mlp_fc2.kernel, + hf_weight_key=f"visual.merger_list.{j}.mlp.2.weight", + hook_fn=lambda x, _: np.transpose(x, (1, 0)), + ) + loader.port_weight( + keras_variable=merger.mlp_fc2.bias, + hf_weight_key=f"visual.merger_list.{j}.mlp.2.bias", + ) + + # === Text Transformer Weights === + loader.port_weight( + keras_variable=backbone.get_layer("token_embedding").embeddings, + hf_weight_key="model.embed_tokens.weight", + ) + if not backbone.tie_word_embeddings: + loader.port_weight( + keras_variable=backbone.get_layer( + "token_embedding" + ).reverse_embeddings, + hf_weight_key="lm_head.weight", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), + ) + + def transpose_and_reshape(x, shape): + return np.reshape(np.transpose(x), shape) + + for i in range(backbone.num_layers): + decoder_layer = backbone.get_layer(f"transformer_layer_{i}") + + # Input layernorm + loader.port_weight( + keras_variable=decoder_layer.pre_attention_norm.scale, + hf_weight_key=f"model.layers.{i}.input_layernorm.weight", + ) + + # Attention layers + + ## Query + loader.port_weight( + keras_variable=decoder_layer.attention._query_dense.kernel, + hf_weight_key=f"model.layers.{i}.self_attn.q_proj.weight", + hook_fn=transpose_and_reshape, + ) + loader.port_weight( + keras_variable=decoder_layer.attention._query_dense_layer_norm.scale, + hf_weight_key=f"model.layers.{i}.self_attn.q_norm.weight", + ) + ## Key + loader.port_weight( + keras_variable=decoder_layer.attention._key_dense.kernel, + hf_weight_key=f"model.layers.{i}.self_attn.k_proj.weight", + hook_fn=transpose_and_reshape, + ) + loader.port_weight( + keras_variable=decoder_layer.attention._key_dense_layer_norm.scale, + hf_weight_key=f"model.layers.{i}.self_attn.k_norm.weight", + ) + ## Value + loader.port_weight( + keras_variable=decoder_layer.attention._value_dense.kernel, + hf_weight_key=f"model.layers.{i}.self_attn.v_proj.weight", + hook_fn=transpose_and_reshape, + ) + ## Output + loader.port_weight( + keras_variable=decoder_layer.attention._output_dense.kernel, + hf_weight_key=f"model.layers.{i}.self_attn.o_proj.weight", + hook_fn=transpose_and_reshape, + ) + + # MLP layers + if ( + (i not in backbone.mlp_only_layers) + and backbone.num_experts > 0 + and ((i + 1) % backbone.decoder_sparse_step == 0) + ): + # MoE layers + loader.port_weight( + keras_variable=decoder_layer.sparse_moe._sparse_feedforward_gate_dense.kernel, + hf_weight_key=f"model.layers.{i}.mlp.gate.weight", + hook_fn=lambda hf_tensor, _: np.transpose( + hf_tensor, axes=(1, 0) + ), + ) + # Batched experts: gate_up_proj and down_proj + gate_up_proj_list = [] + down_proj_list = [] + for expert_idx in range(backbone.num_experts): + gate_proj = loader.get_tensor( + f"model.layers.{i}.mlp.experts.{expert_idx}.gate_proj.weight" + ) + up_proj = loader.get_tensor( + f"model.layers.{i}.mlp.experts.{expert_idx}.up_proj.weight" + ) + # Transpose to (hidden_dim, intermediate_dim) + gate_proj = np.transpose(gate_proj, axes=(1, 0)) + up_proj = np.transpose(up_proj, axes=(1, 0)) + # Concatenate gate_proj and up_proj along the last dimension + gate_up_proj = np.concatenate([gate_proj, up_proj], axis=-1) + gate_up_proj_list.append(gate_up_proj) + + # Load down_proj for each expert + down_proj = loader.get_tensor( + f"model.layers.{i}.mlp.experts.{expert_idx}.down_proj.weight" + ) + down_proj = np.transpose( + down_proj, axes=(1, 0) + ) # (intermediate_dim, hidden_dim) + down_proj_list.append(down_proj) + + # Stack the lists to create batched weights + gate_up_proj_batched = np.stack( + gate_up_proj_list, axis=0 + ) # (num_experts, hidden_dim, 2 * intermediate_dim) + down_proj_batched = np.stack( + down_proj_list, axis=0 + ) # (num_experts, intermediate_dim, hidden_dim) + + # Assign batched weights to expert_bank + decoder_layer.sparse_moe.expert_bank._expert_feedforward_gate_dense.assign( + gate_up_proj_batched + ) + decoder_layer.sparse_moe.expert_bank._expert_feedforward_output_dense.assign( + down_proj_batched + ) + else: + loader.port_weight( + keras_variable=decoder_layer.dense_mlp._feedforward_intermediate_dense.kernel, + hf_weight_key=f"model.layers.{i}.mlp.up_proj.weight", + hook_fn=lambda hf_tensor, _: np.transpose( + hf_tensor, axes=(1, 0) + ), + ) + loader.port_weight( + keras_variable=decoder_layer.dense_mlp._feedforward_output_dense.kernel, + hf_weight_key=f"model.layers.{i}.mlp.down_proj.weight", + hook_fn=lambda hf_tensor, _: np.transpose( + hf_tensor, axes=(1, 0) + ), + ) + loader.port_weight( + keras_variable=decoder_layer.dense_mlp._feedforward_gate_dense.kernel, + hf_weight_key=f"model.layers.{i}.mlp.gate_proj.weight", + hook_fn=lambda hf_tensor, _: np.transpose( + hf_tensor, axes=(1, 0) + ), + ) + + # Feedforward layernorm + loader.port_weight( + keras_variable=decoder_layer.post_attention_layernorm.scale, + hf_weight_key=f"model.layers.{i}.post_attention_layernorm.weight", + ) + + # Final normalization layer + loader.port_weight( + keras_variable=backbone.get_layer("sequence_output_layernorm").scale, + hf_weight_key="model.norm.weight", + ) + + return backbone + + +def load_image_converter_config(preset, transformers_config): + """Load image converter config from a HuggingFace preset.""" + thinker_config = transformers_config.get("thinker_config", {}) + vision_config = thinker_config.get("vision_config") + if vision_config: + preprocessor_config = load_json(preset, "preprocessor_config.json") + image_mean = preprocessor_config.get( + "image_mean", [0.48145466, 0.4578275, 0.40821073] + ) + image_std = preprocessor_config.get( + "image_std", [0.26862954, 0.26130258, 0.27577711] + ) + rescale_factor = preprocessor_config.get("rescale_factor", 1 / 255) + offset = [(-m / s) for m, s in zip(image_mean, image_std)] + scale = [(s * rescale_factor) for s in image_std] + image_size = vision_config.get("image_size", 768) + return { + "image_size": (image_size, image_size), + "scale": scale, + "offset": offset, + } + return None + + +def load_audio_converter_config(preset, transformers_config): + """Load audio converter config from a HuggingFace preset.""" + thinker_config = transformers_config.get("thinker_config", {}) + audio_config = thinker_config.get("audio_config") + if audio_config: + preprocessor_config = load_json(preset, "preprocessor_config.json") + sampling_rate = preprocessor_config.get("sampling_rate", 16000) + n_samples = preprocessor_config.get("n_samples", 4800000) + max_audio_length = n_samples / sampling_rate + return { + "num_mels": audio_config.get("num_mel_bins", 128), + "sampling_rate": sampling_rate, + "max_audio_length": max_audio_length, + } + return None + + +def convert_tokenizer(cls, preset, **kwargs): + vocab = load_json(preset, "vocab.json") + merges_file = get_file(preset, "merges.txt") + with open(merges_file, "r") as f: + merges = [line.strip() for line in f if line.strip()] + tokenizer_config = load_json(preset, "tokenizer_config.json") + special_tokens = [] + if "added_tokens_decoder" in tokenizer_config: + for token_id, token_info in tokenizer_config[ + "added_tokens_decoder" + ].items(): + content = token_info.get("content", "") + if not content.startswith("<|reserved_special_token_"): + special_tokens.append(content) + if content not in vocab: + vocab[content] = int(token_id) + kwargs.update({"unsplittable_tokens": special_tokens}) + return cls(vocabulary=vocab, merges=merges, **kwargs) diff --git a/keras_hub/src/utils/transformers/convert_qwen3_omni_test.py b/keras_hub/src/utils/transformers/convert_qwen3_omni_test.py new file mode 100644 index 0000000000..d5735337da --- /dev/null +++ b/keras_hub/src/utils/transformers/convert_qwen3_omni_test.py @@ -0,0 +1,36 @@ +import pytest + +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.causal_lm import CausalLM +from keras_hub.src.models.qwen3_omni.qwen3_omni_backbone import ( + Qwen3OmniBackbone, +) +from keras_hub.src.models.qwen3_omni.qwen3_omni_causal_lm import ( + Qwen3OmniCausalLM, +) +from keras_hub.src.tests.test_case import TestCase + + +class TestTask(TestCase): + @pytest.mark.extra_large + def test_convert_preset(self): + model = Qwen3OmniCausalLM.from_preset( + "hf://Qwen/Qwen3-Omni-30B-A3B-Thinking" + ) + prompt = "What is Keras?" + output = model.generate([prompt], max_length=15) + self.assertIsNotNone(output) + + @pytest.mark.extra_large + def test_class_detection(self): + preset_name = "hf://Qwen/Qwen3-Omni-30B-A3B-Thinking" + model = CausalLM.from_preset( + preset_name, + load_weights=False, + ) + self.assertIsInstance(model, Qwen3OmniCausalLM) + model = Backbone.from_preset( + preset_name, + load_weights=False, + ) + self.assertIsInstance(model, Qwen3OmniBackbone) diff --git a/keras_hub/src/utils/transformers/preset_loader.py b/keras_hub/src/utils/transformers/preset_loader.py index 92c6ea5ef5..210f016a91 100644 --- a/keras_hub/src/utils/transformers/preset_loader.py +++ b/keras_hub/src/utils/transformers/preset_loader.py @@ -24,6 +24,7 @@ from keras_hub.src.utils.transformers import convert_qwen from keras_hub.src.utils.transformers import convert_qwen3 from keras_hub.src.utils.transformers import convert_qwen3_moe +from keras_hub.src.utils.transformers import convert_qwen3_omni from keras_hub.src.utils.transformers import convert_qwen_moe from keras_hub.src.utils.transformers import convert_sam3 from keras_hub.src.utils.transformers import convert_smollm3 @@ -77,6 +78,8 @@ def __init__(self, preset, config): self.converter = convert_qwen_moe elif model_type == "qwen3_moe": self.converter = convert_qwen3_moe + elif model_type == "qwen3_omni_moe": + self.converter = convert_qwen3_omni elif model_type == "qwen3": self.converter = convert_qwen3 elif model_type == "sam3_video": @@ -134,6 +137,15 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs): def load_tokenizer(self, cls, config_name="tokenizer.json", **kwargs): return self.converter.convert_tokenizer(cls, self.preset, **kwargs) + def load_audio_converter(self, cls, **kwargs): + if hasattr(self.converter, "load_audio_converter_config"): + config = self.converter.load_audio_converter_config( + self.preset, self.config + ) + if config is not None: + return cls(**{**config, **kwargs}) + return None + def load_image_converter(self, cls, **kwargs): if hasattr(self.converter, "load_image_converter_config"): config = self.converter.load_image_converter_config( diff --git a/tools/checkpoint_conversion/convert_qwen3_omni_checkpoints.py b/tools/checkpoint_conversion/convert_qwen3_omni_checkpoints.py new file mode 100644 index 0000000000..1e9910154e --- /dev/null +++ b/tools/checkpoint_conversion/convert_qwen3_omni_checkpoints.py @@ -0,0 +1,169 @@ +import os +import traceback + +os.environ["KERAS_BACKEND"] = "torch" +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Hide any CUDA devices + +import numpy as np +import torch +from absl import app +from absl import flags + +device = torch.device("cpu") +# Force PyTorch to use CPU +torch.set_default_device(device) + +from keras import ops # noqa: E402 +from transformers import AutoModelForMultimodalLM # noqa: E402 +from transformers import AutoTokenizer # noqa: E402 + +import keras_hub # noqa: E402 + +PRESET_MAP = { + "qwen3_omni_30b_a3b_en": "Qwen/Qwen3-Omni-30B-A3B-Instruct", + "qwen3_omni_30b_a3b_captioner_en": "Qwen/Qwen3-Omni-30B-A3B-Captioner", + "qwen3_omni_30b_a3b_thinking_en": "Qwen/Qwen3-Omni-30B-A3B-Thinking", +} + +FLAGS = flags.FLAGS +flags.DEFINE_string( + "preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}" +) + + +def test_model( + keras_hub_model, keras_hub_tokenizer, hf_model, hf_model_tokenizer +): + # First, test that the number of parameters match + keras_hub_params = keras_hub_model.count_params() + hf_params = hf_model.num_parameters() + assert keras_hub_params == hf_params + + # Test the outputs of both the models + hf_inputs = hf_model_tokenizer(["What is Keras?"], return_tensors="pt").to( + device + ) + hf_outputs = hf_model(**hf_inputs) + hf_output_logits = hf_outputs.logits.detach().cpu().float().numpy() + + keras_hub_preprocessor = keras_hub.models.Qwen3OmniCausalLMPreprocessor( + keras_hub_tokenizer + ) + keras_hub_inputs = keras_hub_preprocessor( + ["What is Keras?"], sequence_length=5 + )[0] + keras_hub_inputs = {k: v.to(device) for k, v in keras_hub_inputs.items()} + + keras_hub_output = keras_hub_model(keras_hub_inputs) + keras_hub_logits = keras_hub_model.token_embedding( + keras_hub_output, reverse=True + ) + keras_hub_logits = ops.convert_to_numpy(keras_hub_logits) + + # High tolerance since bfloat16 is used as the default dtype for Qwen + + try: + np.testing.assert_allclose( + keras_hub_logits, hf_output_logits, atol=1e-3 + ) + print("All numerics match with tolerance limit 1e-3") + except AssertionError as err: + print("\n") + print(traceback.format_exc()) + print(err.args[0]) + print("\n") + raise + + +def test_tokenizer(keras_hub_tokenizer, hf_tokenizer): + hf_output = hf_tokenizer(["What is Keras?"], return_tensors="pt") + hf_output = hf_output["input_ids"].detach().cpu().numpy() + keras_hub_preprocessor = keras_hub.models.Qwen3OmniCausalLMPreprocessor( + keras_hub_tokenizer + ) + keras_hub_output = keras_hub_preprocessor( + ["What is Keras?"], sequence_length=5 + ) + keras_hub_output = ops.convert_to_numpy(keras_hub_output[0]["token_ids"]) + + np.testing.assert_equal(keras_hub_output, hf_output) + + +def validate_output(qwen3_omni_lm, hf_model, hf_tokenizer): + input_str = "What is Keras?" + length = 32 + + keras_output = qwen3_omni_lm.generate([input_str], max_length=length) + keras_output = keras_output[0] + print("🔶 KerasHub output:", keras_output) + + # Transformers + hf_inputs = hf_tokenizer([input_str], return_tensors="pt").to(device) + outputs = hf_model.generate( + **hf_inputs, + max_length=length, # Match KerasHub's max_length + do_sample=False, # Greedy decoding (matches KerasHub) + pad_token_id=hf_tokenizer.pad_token_id, + ) + print("HF Token outputs = ", outputs) + hf_generated_text = hf_tokenizer.batch_decode( + outputs, skip_special_tokens=True + )[0] + print("🔶 Huggingface output:", hf_generated_text) + + +def main(_): + # === Get the preset name === + if FLAGS.preset not in PRESET_MAP.keys(): + raise ValueError( + f"Invalid preset {FLAGS.preset}. Must be one " + f"of {','.join(PRESET_MAP.keys())}" + ) + + preset = FLAGS.preset + hf_preset = PRESET_MAP[preset] + + # === Load the Huggingface model === + hf_full_model = AutoModelForMultimodalLM.from_pretrained( + hf_preset, + device_map=device, + trust_remote_code=True, + ) + + # Use full Thinker model (includes audio/vision encoders) + hf_model = hf_full_model.thinker + hf_tokenizer = AutoTokenizer.from_pretrained( + hf_preset, + return_tensors="pt", + trust_remote_code=True, + ) + hf_model.eval() + + keras_hub_model = keras_hub.models.Qwen3OmniBackbone.from_preset( + f"hf://{hf_preset}" + ) + keras_hub_tokenizer = keras_hub.tokenizers.Qwen3OmniTokenizer.from_preset( + f"hf://{hf_preset}" + ) + + print("\n-> Huggingface model and tokenizer loaded") + + # === Check that the models and tokenizers outputs match === + test_tokenizer(keras_hub_tokenizer, hf_tokenizer) + test_model(keras_hub_model, keras_hub_tokenizer, hf_model, hf_tokenizer) + + preprocessor = keras_hub.models.Qwen3OmniCausalLMPreprocessor( + keras_hub_tokenizer + ) + qwen3_omni_lm = keras_hub.models.Qwen3OmniCausalLM( + backbone=keras_hub_model, preprocessor=preprocessor, sampler="greedy" + ) + # == Validate model.generate output == + validate_output(qwen3_omni_lm, hf_model, hf_tokenizer) + print("\n-> Tests passed!") + qwen3_omni_lm.save_to_preset(f"./{preset}") + + +if __name__ == "__main__": + flags.mark_flag_as_required("preset") + app.run(main)