diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index a7138bccad..064f0d862b 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -123,6 +123,9 @@ from keras_hub.src.models.parseq.parseq_image_converter import ( PARSeqImageConverter as PARSeqImageConverter, ) +from keras_hub.src.models.qwen2_vl.qwen2_vl_image_converter import ( + Qwen2VLImageConverter as Qwen2VLImageConverter, +) from keras_hub.src.models.resnet.resnet_image_converter import ( ResNetImageConverter as ResNetImageConverter, ) diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index aa6f4f2023..c1624d515d 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -545,6 +545,18 @@ from keras_hub.src.models.qwen.qwen_tokenizer import ( QwenTokenizer as QwenTokenizer, ) +from keras_hub.src.models.qwen2_vl.qwen2_vl_backbone import ( + Qwen2VLBackbone as Qwen2VLBackbone, +) +from keras_hub.src.models.qwen2_vl.qwen2_vl_causal_lm import ( + Qwen2VLCausalLM as Qwen2VLCausalLM, +) +from keras_hub.src.models.qwen2_vl.qwen2_vl_causal_lm_preprocessor import ( + Qwen2VLCausalLMPreprocessor as Qwen2VLCausalLMPreprocessor, +) +from keras_hub.src.models.qwen2_vl.qwen2_vl_tokenizer import ( + Qwen2VLTokenizer as Qwen2VLTokenizer, +) from keras_hub.src.models.qwen3.qwen3_backbone import ( Qwen3Backbone as Qwen3Backbone, ) diff --git a/keras_hub/api/tokenizers/__init__.py b/keras_hub/api/tokenizers/__init__.py index 97a68ab009..aecdc4c8d5 100644 --- a/keras_hub/api/tokenizers/__init__.py +++ b/keras_hub/api/tokenizers/__init__.py @@ -81,6 +81,9 @@ from keras_hub.src.models.qwen.qwen_tokenizer import ( QwenTokenizer as QwenTokenizer, ) +from keras_hub.src.models.qwen2_vl.qwen2_vl_tokenizer import ( + Qwen2VLTokenizer as Qwen2VLTokenizer, +) from keras_hub.src.models.qwen3_moe.qwen3_moe_tokenizer import ( Qwen3MoeTokenizer as Qwen3MoeTokenizer, ) diff --git a/keras_hub/src/models/qwen2_vl/__init__.py b/keras_hub/src/models/qwen2_vl/__init__.py new file mode 100644 index 0000000000..09b895aa95 --- /dev/null +++ b/keras_hub/src/models/qwen2_vl/__init__.py @@ -0,0 +1,5 @@ +from keras_hub.src.models.qwen2_vl.qwen2_vl_backbone import Qwen2VLBackbone +from keras_hub.src.models.qwen2_vl.qwen2_vl_presets import backbone_presets +from keras_hub.src.utils.preset_utils import register_presets + +register_presets(backbone_presets, Qwen2VLBackbone) diff --git a/keras_hub/src/models/qwen2_vl/qwen2_vl_attention.py b/keras_hub/src/models/qwen2_vl/qwen2_vl_attention.py new file mode 100644 index 0000000000..4d0ba7e273 --- /dev/null +++ b/keras_hub/src/models/qwen2_vl/qwen2_vl_attention.py @@ -0,0 +1,372 @@ +import math + +import keras +from keras import ops + +from keras_hub.src.utils.keras_utils import clone_initializer +from keras_hub.src.utils.keras_utils import fused_attention_op_available + + +def _rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return ops.concatenate((-x2, x1), axis=-1) + + +def _apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section): + """Applies M-RoPE to query and key tensors. + + Splits the head dimension into temporal, height, and width sections. + Each section receives its corresponding positional embedding component. + + Args: + q: Query tensor of shape `(batch, num_heads, seq_len, head_dim)`. + k: Key tensor of shape `(batch, num_heads, seq_len, head_dim)`. + cos: Cosine embeddings of shape `(3, batch, seq_len, head_dim)`. + sin: Sine embeddings of shape `(3, batch, seq_len, head_dim)`. + mrope_section: List of 3 ints specifying how many dims for + each of [temporal, height, width]. + + Returns: + Tuple of rotated query and key tensors. + """ + # mrope_section is [t_section, h_section, w_section] in terms of + # half-head-dim. Double it since cos/sin are full head_dim. + mrope_section_doubled = [s * 2 for s in mrope_section] + + # Split cos and sin along head_dim into sections + cos_sections = ops.split( + cos, _cumsum_sections(mrope_section_doubled), axis=-1 + ) + sin_sections = ops.split( + sin, _cumsum_sections(mrope_section_doubled), axis=-1 + ) + + # Pick the right component (temporal=0, height=1, width=2) for each + # section, cycling through the 3 components. + cos_parts = [] + sin_parts = [] + for i, (c, s) in enumerate(zip(cos_sections, sin_sections)): + component = i % 3 # 0=temporal, 1=height, 2=width + cos_parts.append(c[component]) # (batch, seq_len, section_dim) + sin_parts.append(s[component]) + + cos_combined = ops.expand_dims( + ops.concatenate(cos_parts, axis=-1), axis=1 + ) # (batch, 1, seq_len, head_dim) + sin_combined = ops.expand_dims( + ops.concatenate(sin_parts, axis=-1), axis=1 + ) # (batch, 1, seq_len, head_dim) + + q_embed = q * cos_combined + _rotate_half(q) * sin_combined + k_embed = k * cos_combined + _rotate_half(k) * sin_combined + return q_embed, k_embed + + +def _cumsum_sections(sizes): + """Convert section sizes to split indices (cumulative sum minus last). + + E.g., [8, 8, 8] -> [8, 16] for use with ops.split. + """ + result = [] + acc = 0 + for s in sizes[:-1]: + acc += s + result.append(acc) + return result + + +class Qwen2VLAttention(keras.layers.Layer): + """Multi-head attention with Multimodal RoPE for Qwen2-VL. + + Supports Grouped-Query Attention (GQA) and sliding window attention. + Uses separate Q, K, V projections (all with bias) and an output + projection (without bias). + + The key difference from standard QwenAttention is the M-RoPE: + position embeddings are provided as `(cos, sin)` of shape + `(3, batch, seq_len, head_dim)` — one component each for temporal, + height, and width — combined via `mrope_section`. + + Args: + num_query_heads: int. Number of query attention heads. + num_key_value_heads: int. Number of key/value heads (for GQA). + hidden_dim: int. Model hidden dimension. + mrope_section: list. List of 3 ints specifying how many + half-head-dim elements are allocated to + [temporal, height, width]. + rope_max_wavelength: float. Max wavelength for RoPE base. + Defaults to `10000`. + kernel_initializer: string or `keras.initializers`. Initializer + for the kernel weights. Defaults to `"glorot_uniform"`. + bias_initializer: string or `keras.initializers`. Initializer + for the bias weights. Defaults to `"zeros"`. + dropout: float. Dropout rate for attention weights. + use_sliding_window_attention: bool. Whether to use sliding window. + sliding_window_size: int. Sliding window size. + dtype: string or `keras.mixed_precision.DTypePolicy`. + """ + + def __init__( + self, + num_query_heads, + num_key_value_heads, + hidden_dim, + mrope_section, + rope_max_wavelength=10000, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + dropout=0, + use_sliding_window_attention=False, + sliding_window_size=4096, + **kwargs, + ): + super().__init__(**kwargs) + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_dim = hidden_dim + self.mrope_section = mrope_section + self.rope_max_wavelength = rope_max_wavelength + self.dropout = dropout + self.use_sliding_window_attention = use_sliding_window_attention + self.sliding_window_size = sliding_window_size + + self.num_key_value_groups = num_query_heads // num_key_value_heads + self.head_dim = hidden_dim // num_query_heads + + self.kernel_initializer = keras.initializers.get( + clone_initializer(kernel_initializer) + ) + self.bias_initializer = keras.initializers.get( + clone_initializer(bias_initializer) + ) + + self._inv_norm_factor = 1.0 / math.sqrt(self.head_dim) + + # Q, K, V with bias; O without bias + self._query_dense = keras.layers.EinsumDense( + equation="bqm,muh->bquh", + output_shape=(None, self.num_query_heads, self.head_dim), + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + bias_axes="uh", + dtype=self.dtype_policy, + name="query", + ) + self._key_dense = keras.layers.EinsumDense( + equation="bkm,mvh->bkvh", + output_shape=(None, self.num_key_value_heads, self.head_dim), + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + bias_axes="vh", + dtype=self.dtype_policy, + name="key", + ) + self._value_dense = keras.layers.EinsumDense( + equation="bkm,mvh->bkvh", + output_shape=(None, self.num_key_value_heads, self.head_dim), + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + bias_axes="vh", + dtype=self.dtype_policy, + name="value", + ) + self._output_dense = keras.layers.EinsumDense( + equation="bquh,uhm->bqm", + output_shape=(None, self.hidden_dim), + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="attention_output", + ) + + self._softmax = keras.layers.Softmax( + axis=-1, + dtype="float32", + name="attention_softmax", + ) + self._dropout_layer = keras.layers.Dropout( + rate=self.dropout, + dtype=self.dtype_policy, + ) + + self._dot_product_equation = "bquh,bkuh->buqk" + self._combine_equation = "buqk,bkuh->bquh" + + def call( + self, + hidden_states, + attention_mask=None, + position_embeddings=None, + cache=None, + cache_update_index=None, + training=None, + ): + """Forward pass with M-RoPE attention. + + Args: + hidden_states: Tensor of shape `(batch, seq_len, hidden_dim)`. + attention_mask: Optional mask of shape + `(batch, seq_len, seq_len)`. + position_embeddings: Tuple of `(cos, sin)`, each of shape + `(3, batch, seq_len, head_dim)`. + cache: Optional cached key/value states. + cache_update_index: Index for cache update. + training: Boolean training mode flag. + + Returns: + attention_output: Tensor of shape + `(batch, seq_len, hidden_dim)`. + cache: Updated cache (if cache was provided). + """ + query = self._query_dense(hidden_states) + + def _compute_key_value(x): + key, value = self._key_dense(x), self._value_dense(x) + return key, value + + if cache is not None: + key_cache = cache[:, 0, ...] + value_cache = cache[:, 1, ...] + if cache_update_index is None: + key = key_cache + value = value_cache + else: + key_update, value_update = _compute_key_value(hidden_states) + start = [0, cache_update_index, 0, 0] + key = ops.slice_update(key_cache, start, key_update) + value = ops.slice_update(value_cache, start, value_update) + cache = ops.stack((key, value), axis=1) + else: + key, value = _compute_key_value(hidden_states) + + # Apply M-RoPE + if position_embeddings is not None: + cos, sin = position_embeddings + # query: (batch, seq_len, num_heads, head_dim) + # -> (batch, num_heads, seq_len, head_dim) for RoPE + query_t = ops.transpose(query, (0, 2, 1, 3)) + key_t = ops.transpose(key, (0, 2, 1, 3)) + + query_t, key_t = _apply_multimodal_rotary_pos_emb( + query_t, key_t, cos, sin, self.mrope_section + ) + + query = ops.transpose(query_t, (0, 2, 1, 3)) + key = ops.transpose(key_t, (0, 2, 1, 3)) + + # GQA: repeat key/value heads + key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2) + value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2) + + attention_output = self._compute_attention( + query, + key, + value, + attention_mask, + cache_update_index=cache_update_index, + ) + + attention_output = self._dropout_layer( + attention_output, training=training + ) + attention_output = self._output_dense(attention_output) + + if cache is not None: + return attention_output, cache + return attention_output + + def _masked_softmax(self, attention_scores, attention_mask=None): + if attention_mask is not None: + return self._softmax( + attention_scores, attention_mask[:, None, :, :] + ) + return self._softmax(attention_scores) + + def _compute_attention( + self, + query, + key, + value, + attention_mask=None, + cache_update_index=None, + ): + if fused_attention_op_available(): + if attention_mask is not None: + attention_mask = ops.expand_dims(attention_mask, axis=1) + attention_mask = ops.cast(attention_mask, dtype="bool") + attention_output = ops.dot_product_attention( + query, + key, + value, + mask=attention_mask, + scale=self._inv_norm_factor, + ) + return attention_output + + attention_scores = ops.einsum(self._dot_product_equation, query, key) + attention_scores = ops.multiply( + attention_scores, + ops.cast(self._inv_norm_factor, self.compute_dtype), + ) + if self.use_sliding_window_attention: + attention_mask = self._mask_sliding_window( + attention_mask, + cache_update_index=( + cache_update_index if cache_update_index is not None else 0 + ), + ) + attention_scores = self._masked_softmax( + attention_scores, attention_mask + ) + attention_scores = ops.cast(attention_scores, self.compute_dtype) + attention_output = ops.einsum( + self._combine_equation, attention_scores, value + ) + return attention_output + + def _mask_sliding_window( + self, + attention_mask, + cache_update_index=0, + ): + _, query_len, key_len = ops.shape(attention_mask) + all_ones = ops.ones((key_len, key_len), "bool") + if keras.config.backend() == "tensorflow": + import tensorflow as tf + + band_size = ops.minimum(key_len, self.sliding_window_size - 1) + band_size = ops.cast(band_size, "int32") + sliding_mask = tf.linalg.band_part(all_ones, band_size, band_size) + else: + sliding_mask = ops.triu( + all_ones, -1 * self.sliding_window_size + 1 + ) * ops.tril(all_ones, self.sliding_window_size - 1) + start = (cache_update_index, 0) + sliding_mask = ops.slice(sliding_mask, start, (query_len, key_len)) + sliding_mask = ops.expand_dims(sliding_mask, 0) + return ops.logical_and(attention_mask, ops.cast(sliding_mask, "bool")) + + def get_config(self): + config = super().get_config() + config.update( + { + "num_query_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "hidden_dim": self.hidden_dim, + "mrope_section": self.mrope_section, + "rope_max_wavelength": self.rope_max_wavelength, + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "bias_initializer": keras.initializers.serialize( + self.bias_initializer + ), + "dropout": self.dropout, + "use_sliding_window_attention": ( + self.use_sliding_window_attention + ), + "sliding_window_size": self.sliding_window_size, + } + ) + return config diff --git a/keras_hub/src/models/qwen2_vl/qwen2_vl_backbone.py b/keras_hub/src/models/qwen2_vl/qwen2_vl_backbone.py new file mode 100644 index 0000000000..46c078ec17 --- /dev/null +++ b/keras_hub/src/models/qwen2_vl/qwen2_vl_backbone.py @@ -0,0 +1,411 @@ +import keras +from keras import ops +from keras.layers import ReversibleEmbedding + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.qwen.qwen_layernorm import QwenLayerNorm +from keras_hub.src.models.qwen2_vl.qwen2_vl_decoder import ( + Qwen2VLTransformerDecoder, +) + + +def _qwen2_vl_kernel_initializer(stddev=0.02): + return keras.initializers.RandomNormal(stddev=stddev) + + +@keras_hub_export("keras_hub.models.Qwen2VLBackbone") +class Qwen2VLBackbone(Backbone): + """Qwen2-VL core network with optional vision encoder. + + This network implements a Transformer-based decoder with Multimodal + Rotary Position Embedding (M-RoPE) support for handling text, image, + and video inputs. When a ``vision_encoder`` is provided, image/video + patches are encoded and interleaved with text token embeddings. + + The default constructor gives a fully customizable, randomly + initialized Qwen2-VL model with any number of layers, heads, and + embedding dimensions. To load preset architectures and weights, use + the ``from_preset`` constructor. + + Args: + vocabulary_size: int. The size of the token vocabulary. + num_layers: int. The number of transformer decoder layers. + num_query_heads: int. The number of query attention heads for + each transformer layer. + num_key_value_heads: int. The number of key and value attention + heads for each transformer layer (for GQA). + hidden_dim: int. The size of the transformer hidden + representation. + intermediate_dim: int. The output dimension of the first Dense + layer in the SwiGLU feedforward network for each + transformer layer. + mrope_section: list. List of 3 ints specifying how many + half-head-dim elements are allocated to + [temporal, height, width] for M-RoPE. + rope_max_wavelength: float. The maximum angular wavelength of + the sine/cosine curves for rotary embeddings. Defaults to + `10000`. + layer_norm_epsilon: float. Epsilon for the RMS layer + normalization layers in the transformer decoder. Defaults + to `1e-6`. + dropout: float. Dropout rate for attention and hidden layers. + Defaults to `0`. + tie_word_embeddings: bool. Whether to tie input and output + embeddings. Defaults to `True`. + vision_encoder: ``keras.layers.Layer``. An optional + ``Qwen2VLVisionEncoder`` instance for processing image/video + inputs. When ``None``, the model operates in text-only mode. + Defaults to ``None``. + dtype: string or ``keras.mixed_precision.DTypePolicy``. The + dtype to use for model computations and weights. Note that + some computations, such as softmax and layer normalization, + will always be done at float32 precision regardless of + dtype. + + Examples: + + ```python + input_data = { + "token_ids": np.ones(shape=(1, 12), dtype="int32"), + "padding_mask": np.array( + [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]] + ), + "mrope_position_ids": np.broadcast_to( + np.arange(12)[None, :, None], + (1, 12, 3), + ).astype("int32"), + } + + # Randomly initialized Qwen2-VL decoder with custom config. + model = keras_hub.models.Qwen2VLBackbone( + vocabulary_size=152000, + hidden_dim=1536, + num_layers=28, + num_query_heads=12, + num_key_value_heads=2, + intermediate_dim=8960, + mrope_section=[16, 24, 24], + layer_norm_epsilon=1e-6, + dtype="float32", + ) + model(input_data) + ``` + """ + + def __init__( + self, + vocabulary_size, + num_layers, + num_query_heads, + num_key_value_heads, + hidden_dim, + intermediate_dim, + mrope_section, + rope_max_wavelength=10000, + layer_norm_epsilon=1e-6, + dropout=0, + tie_word_embeddings=True, + vision_encoder=None, + dtype=None, + **kwargs, + ): + text_only_model = vision_encoder is None + head_dim = hidden_dim // num_query_heads + + # === Layers === + token_embedding = ReversibleEmbedding( + input_dim=vocabulary_size, + output_dim=hidden_dim, + tie_weights=tie_word_embeddings, + embeddings_initializer=_qwen2_vl_kernel_initializer(stddev=0.01), + dtype=dtype, + name="token_embedding", + ) + self.vision_encoder = vision_encoder + self._vision_interleaver = Qwen2VLInterleaveEmbeddings( + dtype=dtype, name="vision_interleaver" + ) + self._vision_input_flattener = None + if vision_encoder is not None: + self._vision_input_flattener = Qwen2VLFlattenVisionInputs( + in_channels=vision_encoder.in_channels, + temporal_patch_size=vision_encoder.temporal_patch_size, + patch_size=vision_encoder.patch_size, + dtype=dtype, + name="vision_input_flattener", + ) + + transformer_layers = [] + for i in range(num_layers): + layer = Qwen2VLTransformerDecoder( + intermediate_dim=intermediate_dim, + hidden_dim=hidden_dim, + num_query_heads=num_query_heads, + num_key_value_heads=num_key_value_heads, + mrope_section=mrope_section, + rope_max_wavelength=rope_max_wavelength, + layer_norm_epsilon=layer_norm_epsilon, + activation=ops.silu, + kernel_initializer=_qwen2_vl_kernel_initializer(stddev=0.02), + dropout=dropout, + dtype=dtype, + name=f"transformer_layer_{i}", + ) + transformer_layers.append(layer) + + layer_norm = QwenLayerNorm( + epsilon=layer_norm_epsilon, + dtype=dtype, + name="sequence_output_layernorm", + ) + + # === Functional Model === + token_id_input = keras.Input( + shape=(None,), dtype="int32", name="token_ids" + ) + padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) + mrope_position_ids_input = keras.Input( + shape=(None, 3), dtype="int32", name="mrope_position_ids" + ) + + inputs = { + "token_ids": token_id_input, + "padding_mask": padding_mask_input, + "mrope_position_ids": mrope_position_ids_input, + } + + text_embeddings = token_embedding(token_id_input) + + x = text_embeddings + if not text_only_model: + image_input = keras.Input( + shape=( + None, + vision_encoder.in_channels, + vision_encoder.temporal_patch_size, + vision_encoder.patch_size, + vision_encoder.patch_size, + ), + dtype="float32", + name="images", + ) + vision_indices_input = keras.Input( + shape=(None,), dtype="int32", name="vision_indices" + ) + grid_thw_input = keras.Input( + shape=(None, 3), dtype="int32", name="grid_thw" + ) + + inputs["images"] = image_input + inputs["vision_indices"] = vision_indices_input + inputs["grid_thw"] = grid_thw_input + + flat_images, flat_grid_thw = self._vision_input_flattener( + images=image_input, grid_thw=grid_thw_input + ) + vision_embeddings = self.vision_encoder( + flat_images, grid_thw=flat_grid_thw + ) + x = self._vision_interleaver( + vision_embeddings=vision_embeddings, + text_embeddings=text_embeddings, + vision_indices=vision_indices_input, + ) + + position_embeddings = _compute_mrope_embeddings( + mrope_position_ids_input, + head_dim, + rope_max_wavelength, + ) + + for transformer_layer in transformer_layers: + x = transformer_layer( + x, + attention_mask=padding_mask_input, + position_embeddings=position_embeddings, + ) + + sequence_output = layer_norm(x) + + super().__init__( + inputs=inputs, + outputs=sequence_output, + dtype=dtype, + **kwargs, + ) + + # === Config === + self.vocabulary_size = vocabulary_size + self.num_layers = num_layers + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.mrope_section = mrope_section + self.rope_max_wavelength = rope_max_wavelength + self.layer_norm_epsilon = layer_norm_epsilon + self.dropout = dropout + self.tie_word_embeddings = tie_word_embeddings + self.text_only_model = text_only_model + self.token_embedding = token_embedding + self.transformer_layers = transformer_layers + self.layer_norm = layer_norm + + def get_config(self): + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "num_query_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "mrope_section": self.mrope_section, + "rope_max_wavelength": self.rope_max_wavelength, + "layer_norm_epsilon": self.layer_norm_epsilon, + "dropout": self.dropout, + "tie_word_embeddings": self.tie_word_embeddings, + } + ) + if self.vision_encoder is not None: + config["vision_encoder"] = keras.layers.serialize( + self.vision_encoder + ) + return config + + @classmethod + def from_config(cls, config): + vision_encoder = config.pop("vision_encoder", None) + if vision_encoder is not None: + vision_encoder = keras.layers.deserialize(vision_encoder) + return cls(vision_encoder=vision_encoder, **config) + + +class Qwen2VLInterleaveEmbeddings(keras.layers.Layer): + """Interleaves vision embeddings into text embeddings at given indices.""" + + def call(self, vision_embeddings, text_embeddings, vision_indices): + batch_size, seq_length, hidden_dim = ops.shape(text_embeddings) + + # Use -1 to avoid NumPy-triggering tuples with GPU tensors in Torch. + flat_text_embeddings = ops.reshape(text_embeddings, (-1, hidden_dim)) + # `vision_embeddings` is flattened as + # `(batch * num_vision_tokens, hidden_dim)`. + flat_vision_embeddings = vision_embeddings + + offsets = ops.multiply( + ops.arange(batch_size, dtype="int32"), seq_length + ) + offsets = ops.expand_dims(offsets, axis=-1) + flat_indices = ops.reshape(vision_indices + offsets, (-1, 1)) + flat_indices = ops.cast(flat_indices, "int32") + + # `vision_indices` is padded with 0. Restore token 0 after scatter. + zeroth_index_text_embeddings = ops.take( + flat_text_embeddings, + indices=ops.squeeze(offsets, axis=-1), + axis=0, + ) + reconstructed_embedding = ops.scatter_update( + inputs=flat_text_embeddings, + indices=flat_indices, + updates=flat_vision_embeddings, + ) + reconstructed_embedding = ops.scatter_update( + inputs=reconstructed_embedding, + indices=offsets, + updates=zeroth_index_text_embeddings, + ) + + return ops.reshape( + reconstructed_embedding, (batch_size, seq_length, hidden_dim) + ) + + def compute_output_shape( + self, + vision_embeddings_shape, + text_embeddings_shape, + vision_indices_shape, + ): + return text_embeddings_shape + + def compute_output_spec( + self, vision_embeddings, text_embeddings, vision_indices + ): + output_shape = self.compute_output_shape( + vision_embeddings.shape, + text_embeddings.shape, + vision_indices.shape, + ) + return keras.KerasTensor(output_shape, dtype=text_embeddings.dtype) + + +class Qwen2VLFlattenVisionInputs(keras.layers.Layer): + """Flattens batched vision patches and `grid_thw` to encoder format.""" + + def __init__( + self, + in_channels, + temporal_patch_size, + patch_size, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.in_channels = in_channels + self.temporal_patch_size = temporal_patch_size + self.patch_size = patch_size + + def build(self, input_shape): + self.built = True + + def call(self, images, grid_thw): + flat_images = ops.reshape( + images, + ( + -1, + self.in_channels, + self.temporal_patch_size, + self.patch_size, + self.patch_size, + ), + ) + flat_grid_thw = ops.reshape(grid_thw, (-1, 3)) + return flat_images, flat_grid_thw + + +def _compute_mrope_embeddings( + mrope_position_ids, head_dim, rope_max_wavelength +): + """Compute M-RoPE cos/sin embeddings from position IDs.""" + dim = head_dim + inv_freq = 1.0 / ( + rope_max_wavelength + ** (ops.cast(ops.arange(0, dim, 2), "float32") / dim) + ) + + position_ids = ops.transpose( + ops.cast(mrope_position_ids, "float32"), (2, 0, 1) + ) + + inv_freq_expanded = ops.reshape(inv_freq, (1, 1, -1, 1)) + inv_freq_expanded = ops.tile(inv_freq_expanded, (3, 1, 1, 1)) + + position_ids_expanded = ops.expand_dims(position_ids, axis=2) + + freqs = ops.matmul( + ops.cast(inv_freq_expanded, "float32"), + ops.cast(position_ids_expanded, "float32"), + ) + freqs = ops.transpose(freqs, (0, 1, 3, 2)) + + emb = ops.concatenate([freqs, freqs], axis=-1) + + cos = ops.cos(emb) + sin = ops.sin(emb) + + return (cos, sin) diff --git a/keras_hub/src/models/qwen2_vl/qwen2_vl_backbone_test.py b/keras_hub/src/models/qwen2_vl/qwen2_vl_backbone_test.py new file mode 100644 index 0000000000..6cad73ac28 --- /dev/null +++ b/keras_hub/src/models/qwen2_vl/qwen2_vl_backbone_test.py @@ -0,0 +1,139 @@ +import numpy as np +import pytest + +from keras_hub.src.models.qwen2_vl.qwen2_vl_backbone import Qwen2VLBackbone +from keras_hub.src.models.qwen2_vl.qwen2_vl_vision_encoder import ( + Qwen2VLVisionEncoder, +) +from keras_hub.src.tests.test_case import TestCase + + +class Qwen2VLBackboneTextOnlyTest(TestCase): + def setUp(self): + self.batch_size = 2 + self.vocabulary_size = 256 + self.seq_length = 16 + self.hidden_dim = 32 + self.head_dim = self.hidden_dim // 4 # 8 + + self.init_kwargs = { + "vocabulary_size": self.vocabulary_size, + "num_layers": 2, + "num_query_heads": 4, + "num_key_value_heads": 2, + "hidden_dim": self.hidden_dim, + "intermediate_dim": 64, + "mrope_section": [1, 1, 2], # sums to head_dim // 2 = 4 + } + + # For M-RoPE, position_ids shape is (batch, seq_len, 3) + # For text-only, all 3 components are the same sequential IDs + pos_ids = np.broadcast_to( + np.arange(self.seq_length)[None, :, None], + (self.batch_size, self.seq_length, 3), + ).astype("int32") + + self.input_data = { + "token_ids": np.random.randint( + 0, + self.vocabulary_size, + (self.batch_size, self.seq_length), + ).astype("int32"), + "padding_mask": np.ones( + (self.batch_size, self.seq_length), + dtype="int32", + ), + "mrope_position_ids": pos_ids, + } + + def test_backbone_basics(self): + self.run_backbone_test( + cls=Qwen2VLBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=( + self.batch_size, + self.seq_length, + self.hidden_dim, + ), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=Qwen2VLBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + def test_architecture_characteristics(self): + model = Qwen2VLBackbone(**self.init_kwargs) + # Check that model has expected number of transformer layers + self.assertEqual(len(model.transformer_layers), 2) + # Check layer norm exists + self.assertIsNotNone(model.layer_norm) + + +class Qwen2VLBackboneMultimodalTest(TestCase): + def test_multimodal_forward(self): + hidden_dim = 32 + batch_size = 2 + seq_length = 24 + + vision_encoder = Qwen2VLVisionEncoder( + hidden_size=hidden_dim, + embed_dim=16, + depth=1, + num_heads=4, + patch_size=2, + temporal_patch_size=2, + in_channels=3, + mlp_ratio=2.0, + spatial_merge_size=2, + ) + + model = Qwen2VLBackbone( + vocabulary_size=256, + num_layers=2, + num_query_heads=4, + num_key_value_heads=2, + hidden_dim=hidden_dim, + intermediate_dim=64, + mrope_section=[1, 1, 2], + vision_encoder=vision_encoder, + ) + + # Per sample: 16 patches -> 4 merged vision tokens. + images = np.random.rand(batch_size, 16, 3, 2, 2, 2).astype("float32") + grid_thw = np.array( + [ + [[1, 4, 4]], + [[1, 4, 4]], + ], + dtype="int32", + ) + vision_indices = np.array( + [ + [4, 5, 6, 7], + [8, 9, 10, 11], + ], + dtype="int32", + ) + mrope_position_ids = np.broadcast_to( + np.arange(seq_length)[None, :, None], + (batch_size, seq_length, 3), + ).astype("int32") + + inputs = { + "token_ids": np.random.randint( + 0, 256, (batch_size, seq_length), dtype="int32" + ), + "padding_mask": np.ones((batch_size, seq_length), dtype="int32"), + "mrope_position_ids": mrope_position_ids, + "images": images, + "vision_indices": vision_indices, + "grid_thw": grid_thw, + } + + output = model(inputs) + self.assertEqual(output.shape, (batch_size, seq_length, hidden_dim)) diff --git a/keras_hub/src/models/qwen2_vl/qwen2_vl_causal_lm.py b/keras_hub/src/models/qwen2_vl/qwen2_vl_causal_lm.py new file mode 100644 index 0000000000..b5a55a04a9 --- /dev/null +++ b/keras_hub/src/models/qwen2_vl/qwen2_vl_causal_lm.py @@ -0,0 +1,270 @@ +import numpy as np +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.causal_lm import CausalLM +from keras_hub.src.models.qwen2_vl.qwen2_vl_backbone import Qwen2VLBackbone +from keras_hub.src.models.qwen2_vl.qwen2_vl_causal_lm_preprocessor import ( + Qwen2VLCausalLMPreprocessor, +) +from keras_hub.src.utils.tensor_utils import any_equal + +try: + import tensorflow as tf +except ImportError: + tf = None + + +@keras_hub_export("keras_hub.models.Qwen2VLCausalLM") +class Qwen2VLCausalLM(CausalLM): + """An end-to-end multimodal Qwen2-VL model for causal language modeling. + + A causal language model (LM) predicts the next token based on previous + tokens. This model supports both image+text and text-only inputs. + + This model has a ``generate()`` method, which generates text based on a + prompt. The generation strategy used is controlled by an additional + ``sampler`` argument on ``compile()``. You can recompile the model with + different ``keras_hub.samplers`` objects to control the generation. By + default, ``"greedy"`` sampling will be used. + + This model can optionally be configured with a ``preprocessor`` layer, in + which case it will automatically apply preprocessing to string inputs during + ``fit()``, ``predict()``, ``evaluate()`` and ``generate()``. + + Args: + preprocessor: A ``keras_hub.models.Qwen2VLCausalLMPreprocessor`` or + ``None``. If ``None``, this model will not apply preprocessing + and inputs should be preprocessed before calling the model. + backbone: A ``keras_hub.models.Qwen2VLBackbone`` instance. + """ + + backbone_cls = Qwen2VLBackbone + preprocessor_cls = Qwen2VLCausalLMPreprocessor + + def __init__( + self, + backbone, + preprocessor=None, + **kwargs, + ): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + # === Functional Model === + inputs = backbone.input + hidden_states = backbone(inputs) + outputs = backbone.token_embedding(hidden_states, reverse=True) + + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + def compile( + self, + optimizer="auto", + loss="auto", + *, + weighted_metrics="auto", + sampler="greedy", + **kwargs, + ): + super().compile( + optimizer=optimizer, + loss=loss, + weighted_metrics=weighted_metrics, + sampler=sampler, + **kwargs, + ) + + def _normalize_generate_inputs(self, inputs): + """Handle unbatched image inputs for generation.""" + if tf and isinstance(inputs, tf.data.Dataset): + return inputs.as_numpy_iterator(), False + + if self.preprocessor is None: + return [inputs], False + + def normalize(x): + if isinstance(x, str): + return [x], True + if tf and isinstance(x, tf.Tensor) and x.shape.rank == 0: + return x[tf.newaxis], True + return x, False + + if isinstance(inputs, dict): + inputs["prompts"], input_is_scalar = normalize(inputs["prompts"]) + # If prompt is scalar, images can be a single 3D array. + if input_is_scalar and "images" in inputs: + x = inputs["images"] + if isinstance(x, np.ndarray) and len(x.shape) == 3: + inputs["images"] = [x] + elif tf and isinstance(x, tf.Tensor) and x.shape.rank == 3: + inputs["images"] = x[tf.newaxis] + elif isinstance(x, list): + inputs["images"] = [x] + if "responses" in inputs: + inputs["responses"], _ = normalize(inputs["responses"]) + else: + inputs, input_is_scalar = normalize(inputs) + + return [inputs], input_is_scalar + + def call_with_cache( + self, + token_ids, + cache, + cache_update_index, + padding_mask=None, + mrope_position_ids=None, + ): + """Forward pass with cache for autoregressive inference. + + Args: + token_ids: Dense int tensor ``(batch, seq_len)``. + cache: Dense float tensor with cached key/value states. + cache_update_index: int or int tensor. Current index in sequence. + padding_mask: Optional mask ``(batch, seq_len)``. + mrope_position_ids: Optional tensor ``(batch, seq_len, 3)`` + for M-RoPE position IDs. + + Returns: + Tuple of ``(logits, hidden_states, cache)``. + """ + from keras_hub.src.models.qwen2_vl.qwen2_vl_backbone import ( + _compute_mrope_embeddings, + ) + + x = self.backbone.token_embedding(token_ids) + + # Compute position embeddings if position ids provided. + position_embeddings = None + if mrope_position_ids is not None: + head_dim = self.backbone.hidden_dim // self.backbone.num_query_heads + position_embeddings = _compute_mrope_embeddings( + mrope_position_ids, + head_dim, + self.backbone.rope_max_wavelength, + self.backbone.mrope_section, + ) + + # Each decoder layer has a cache; we update them separately. + caches = [] + for i, transformer_layer in enumerate(self.backbone.transformer_layers): + current_cache = cache[:, i, ...] + x, next_cache = transformer_layer( + x, + attention_mask=padding_mask, + position_embeddings=position_embeddings, + cache=current_cache, + cache_update_index=cache_update_index, + ) + caches.append(next_cache) + + cache = ops.stack(caches, axis=1) + hidden_states = x = self.backbone.layer_norm(x) + logits = self.backbone.token_embedding(x, reverse=True) + return logits, hidden_states, cache + + def _build_cache( + self, token_ids, padding_mask=None, mrope_position_ids=None + ): + """Build an empty cache for use with ``call_with_cache()``.""" + batch_size = ops.shape(token_ids)[0] + max_length = ops.shape(token_ids)[1] + num_layers = self.backbone.num_layers + num_heads = self.backbone.num_key_value_heads + head_dim = self.backbone.hidden_dim // self.backbone.num_query_heads + shape = [batch_size, num_layers, 2, max_length, num_heads, head_dim] + cache = ops.zeros(shape, dtype=self.compute_dtype) + # Seed the cache. + logits, hidden_states, cache = self.call_with_cache( + token_ids=token_ids, + cache=cache, + cache_update_index=0, + padding_mask=padding_mask, + mrope_position_ids=mrope_position_ids, + ) + return hidden_states, cache + + def generate_step(self, inputs, stop_token_ids=None): + """A compilable generation function for a single batch of inputs. + + Args: + inputs: A dictionary with keys ``"token_ids"``, + ``"padding_mask"``, and optionally ``"mrope_position_ids"``. + stop_token_ids: Tuple of end token IDs to stop on. + """ + token_ids = inputs["token_ids"] + padding_mask = inputs["padding_mask"] + mrope_position_ids = inputs.get("mrope_position_ids", None) + + # Create and seed cache. + hidden_states, cache = self._build_cache( + token_ids, + padding_mask=padding_mask, + mrope_position_ids=mrope_position_ids, + ) + + # Compute the lengths of all user inputted token ids. + row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) + # Start at the first index that has no user inputted id. + index = ops.min(row_lengths) + + def next(prompt, cache, index): + # The cache index is the index of our previous token. + cache_update_index = index - 1 + batch_size = ops.shape(prompt)[0] + prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) + # Compute M-RoPE position IDs for the single generated token. + # During text generation, all three components (temporal, height, + # width) share the same sequential position. + step_position = ops.cast( + ops.reshape(cache_update_index, (1, 1, 1)), "int32" + ) + step_mrope_ids = ops.broadcast_to( + step_position, ops.shape(prompt)[:2] + (3,) + ) + logits, hidden_states, cache = self.call_with_cache( + token_ids=prompt, + cache=cache, + cache_update_index=cache_update_index, + padding_mask=padding_mask, + mrope_position_ids=step_mrope_ids, + ) + return ( + ops.squeeze(logits, axis=1), + ops.squeeze(hidden_states, axis=1), + cache, + ) + + token_ids = self.sampler( + next=next, + prompt=token_ids, + cache=cache, + index=index, + mask=padding_mask, + stop_token_ids=stop_token_ids, + hidden_states=hidden_states, + model=self, + ) + + # Compute an output padding mask with the token ids we updated. + if stop_token_ids is not None: + end_locations = any_equal( + token_ids, stop_token_ids, ops.logical_not(padding_mask) + ) + end_locations = ops.cast(end_locations, "int32") + cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") + overflow = cumsum - end_locations + padding_mask = ops.logical_not(ops.cast(overflow, "bool")) + else: + padding_mask = ops.ones_like(token_ids, dtype="bool") + + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } diff --git a/keras_hub/src/models/qwen2_vl/qwen2_vl_causal_lm_preprocessor.py b/keras_hub/src/models/qwen2_vl/qwen2_vl_causal_lm_preprocessor.py new file mode 100644 index 0000000000..438573d8a0 --- /dev/null +++ b/keras_hub/src/models/qwen2_vl/qwen2_vl_causal_lm_preprocessor.py @@ -0,0 +1,275 @@ +import keras +import tensorflow as tf + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.multi_segment_packer import ( + MultiSegmentPacker, +) +from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor +from keras_hub.src.models.qwen2_vl.qwen2_vl_backbone import Qwen2VLBackbone +from keras_hub.src.models.qwen2_vl.qwen2_vl_image_converter import ( + Qwen2VLImageConverter, +) +from keras_hub.src.models.qwen2_vl.qwen2_vl_tokenizer import Qwen2VLTokenizer +from keras_hub.src.utils.tensor_utils import preprocessing_function + + +@keras_hub_export("keras_hub.models.Qwen2VLCausalLMPreprocessor") +class Qwen2VLCausalLMPreprocessor(CausalLMPreprocessor): + """Qwen2-VL Causal LM preprocessor. + + This preprocessing layer is meant for use with + ``keras_hub.models.Qwen2VLCausalLM``. It takes in batches of strings + (and optionally images), and returns outputs in a + ``(x, y, sample_weight)`` format, where the ``y`` label is the next + token id in the ``x`` sequence. ``sample_weight`` is 0 for "prompt" + tokens, and 1 for "response" tokens, so that the loss is computed only + on the "response" tokens. + + For use with generation, the layer also exposes two methods + ``generate_preprocess()`` and ``generate_postprocess()``. + + Args: + tokenizer: A ``keras_hub.models.Qwen2VLTokenizer`` instance. + image_converter: A ``keras_hub.layers.Qwen2VLImageConverter`` instance. + Defaults to ``None``. + sequence_length: int. The length of the packed inputs. + Defaults to 1024. + add_start_token: bool. Whether to prepend the start token. + Defaults to ``False`` (Qwen models do not use a start token). + add_end_token: bool. Whether to append the end token. + Defaults to ``True``. + """ + + backbone_cls = Qwen2VLBackbone + tokenizer_cls = Qwen2VLTokenizer + image_converter_cls = Qwen2VLImageConverter + + def __init__( + self, + tokenizer, + image_converter=None, + sequence_length=1024, + add_start_token=False, + add_end_token=True, + **kwargs, + ): + super().__init__( + tokenizer=tokenizer, + sequence_length=sequence_length, + add_start_token=add_start_token, + add_end_token=add_end_token, + **kwargs, + ) + self.image_converter = image_converter + + def build(self, input_shape): + self.packer = MultiSegmentPacker( + start_value=self.tokenizer.end_token_id, + end_value=self.tokenizer.end_token_id, + pad_value=self.tokenizer.pad_token_id, + sep_value=[], + sequence_length=self.sequence_length, + ) + self.built = True + + def _compute_mrope_position_ids(self, token_ids, padding_mask): + """Compute M-RoPE position IDs for text-only input. + + For text-only mode, all three components (temporal, height, width) + are identical and equal to sequential position indices. + + Args: + token_ids: Tensor of shape ``(batch, seq_len)``. + padding_mask: Tensor of shape ``(batch, seq_len)``. + + Returns: + Tensor of shape ``(batch, seq_len, 3)``. + """ + seq_len = tf.shape(token_ids)[1] + positions = tf.range(seq_len, dtype=tf.int32) + # Broadcast to (batch, seq_len) + batch_size = tf.shape(token_ids)[0] + positions = tf.broadcast_to(positions, (batch_size, seq_len)) + # Mask out padding positions + positions = positions * tf.cast(padding_mask, tf.int32) + # Stack 3 identical copies for text-only M-RoPE + mrope_position_ids = tf.stack( + [positions, positions, positions], axis=-1 + ) + return mrope_position_ids + + def _format_output( + self, + token_ids, + padding_mask, + mrope_position_ids, + response_mask, + return_labels=False, + batched=False, + ): + """Format output dictionary, optionally computing labels.""" + if return_labels: + y = token_ids[..., 1:] + sample_weight = response_mask[..., 1:] + token_ids = token_ids[..., :-1] + padding_mask = padding_mask[..., :-1] + mrope_position_ids = mrope_position_ids[..., :-1, :] + + x = { + "token_ids": ( + token_ids if batched else tf.squeeze(token_ids, axis=0) + ), + "padding_mask": ( + padding_mask if batched else tf.squeeze(padding_mask, axis=0) + ), + "mrope_position_ids": ( + mrope_position_ids + if batched + else tf.squeeze(mrope_position_ids, axis=0) + ), + } + + if return_labels: + if not batched: + y = tf.squeeze(y, axis=0) + sample_weight = tf.squeeze(sample_weight, axis=0) + return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) + return x + + @preprocessing_function + def call( + self, + x, + y=None, + sample_weight=None, + sequence_length=None, + ): + sequence_length = sequence_length or self.sequence_length + + # Extract text. + prompts, responses = x["prompts"], x["responses"] + + # Handle batching. + batched = True + if isinstance(prompts, str): + batched = False + prompts = [prompts] + responses = [responses] + if isinstance(prompts, tf.Tensor) and len(prompts.shape) == 0: + batched = False + prompts = tf.expand_dims(prompts, axis=0) + responses = tf.expand_dims(responses, axis=0) + + # Tokenize. + prompts = self.tokenizer(prompts) + responses = self.tokenizer(responses) + + # Pack. + token_ids, segment_ids = self.packer( + (prompts, responses), + sequence_length=sequence_length + 1, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + response_mask = segment_ids == 1 + padding_mask = token_ids != self.tokenizer.pad_token_id + + # Compute M-RoPE position IDs. + mrope_position_ids = self._compute_mrope_position_ids( + token_ids, padding_mask + ) + + return self._format_output( + token_ids=token_ids, + padding_mask=padding_mask, + mrope_position_ids=mrope_position_ids, + response_mask=response_mask, + return_labels=True, + batched=batched, + ) + + @preprocessing_function + def generate_preprocess( + self, + x, + sequence_length=None, + ): + """Convert strings to integer token input for generation. + + Similar to calling the layer for training, this method takes in + strings or tensor strings, tokenizes and packs the input, and + computes a padding mask and M-RoPE position IDs. + + Unlike calling the layer for training, this method does not + compute labels and will never append a ``tokenizer.end_token_id`` + to the end of the sequence. + """ + if not self.built: + self.build(None) + + # Extract inputs. + if isinstance(x, dict): + prompts = x["prompts"] + responses = x.get("responses", None) + else: + prompts = x + responses = None + + # Handle batching. + batched = True + if isinstance(prompts, str): + batched = False + prompts = [prompts] + if responses is not None: + responses = [responses] + if isinstance(prompts, tf.Tensor) and len(prompts.shape) == 0: + batched = False + prompts = tf.expand_dims(prompts, axis=0) + if responses is not None: + responses = tf.expand_dims(responses, axis=0) + + # Tokenize. + prompts = self.tokenizer(prompts) + if responses is not None: + responses = self.tokenizer(responses) + segments = (prompts, responses) + else: + segments = (prompts,) + + # Pack (no end token for generation). + token_ids, segment_ids = self.packer( + segments, + sequence_length=sequence_length, + add_end_value=False, + ) + padding_mask = token_ids != self.tokenizer.pad_token_id + + # Compute M-RoPE position IDs. + mrope_position_ids = self._compute_mrope_position_ids( + token_ids, padding_mask + ) + + return self._format_output( + token_ids=token_ids, + padding_mask=padding_mask, + mrope_position_ids=mrope_position_ids, + response_mask=segment_ids == 1, + return_labels=False, + batched=batched, + ) + + def generate_postprocess(self, x): + """Convert integer token output to strings for generation.""" + token_ids, padding_mask = x["token_ids"], x["padding_mask"] + # Strip padding and special tokens. + token_ids = tf.ragged.boolean_mask(token_ids, padding_mask) + return self.tokenizer.detokenize(token_ids) + + def get_config(self): + config = super().get_config() + if self.image_converter is not None: + config["image_converter"] = keras.layers.serialize( + self.image_converter + ) + return config diff --git a/keras_hub/src/models/qwen2_vl/qwen2_vl_causal_lm_test.py b/keras_hub/src/models/qwen2_vl/qwen2_vl_causal_lm_test.py new file mode 100644 index 0000000000..e5c5f4308e --- /dev/null +++ b/keras_hub/src/models/qwen2_vl/qwen2_vl_causal_lm_test.py @@ -0,0 +1,79 @@ +import pytest + +from keras_hub.src.models.qwen2_vl.qwen2_vl_backbone import Qwen2VLBackbone +from keras_hub.src.models.qwen2_vl.qwen2_vl_causal_lm import Qwen2VLCausalLM +from keras_hub.src.models.qwen2_vl.qwen2_vl_causal_lm_preprocessor import ( + Qwen2VLCausalLMPreprocessor, +) +from keras_hub.src.models.qwen2_vl.qwen2_vl_tokenizer import Qwen2VLTokenizer +from keras_hub.src.tests.test_case import TestCase + + +class Qwen2VLCausalLMTest(TestCase): + def setUp(self): + self.vocab = ["!", "air", "\u0120air", "plane", "\u0120at", "port"] + self.vocab += ["<|endoftext|>"] + self.vocab += ["<|eot_id|>"] + self.vocab += ["<|vision_start|>"] + self.vocab += ["<|vision_end|>"] + self.vocab += ["<|vision_pad|>"] + self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) + self.merges = [ + "\u0120 a", + "\u0120 t", + "\u0120 i", + "\u0120 b", + "a i", + "p l", + "n e", + ] + self.merges += [ + "\u0120a t", + "p o", + "r t", + "\u0120t h", + "ai r", + "pl a", + "po rt", + ] + self.merges += ["\u0120ai r", "\u0120a i", "pla ne"] + self.preprocessor = Qwen2VLCausalLMPreprocessor( + Qwen2VLTokenizer(vocabulary=self.vocab, merges=self.merges), + sequence_length=7, + ) + self.backbone = Qwen2VLBackbone( + vocabulary_size=(self.preprocessor.tokenizer.vocabulary_size()), + num_layers=2, + num_query_heads=4, + num_key_value_heads=2, + hidden_dim=32, + intermediate_dim=64, + mrope_section=[1, 1, 2], + ) + self.init_kwargs = { + "preprocessor": self.preprocessor, + "backbone": self.backbone, + } + self.train_data = ( + { + "prompts": [" airplane at airport"] * 2, + "responses": [" airplane"] * 2, + }, + ) + self.input_data = self.preprocessor(*self.train_data)[0] + + def test_causal_lm_basics(self): + self.run_task_test( + cls=Qwen2VLCausalLM, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 7, 11), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=Qwen2VLCausalLM, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/qwen2_vl/qwen2_vl_decoder.py b/keras_hub/src/models/qwen2_vl/qwen2_vl_decoder.py new file mode 100644 index 0000000000..04103f2920 --- /dev/null +++ b/keras_hub/src/models/qwen2_vl/qwen2_vl_decoder.py @@ -0,0 +1,270 @@ +import keras +from keras import ops + +from keras_hub.src.models.qwen.qwen_layernorm import QwenLayerNorm +from keras_hub.src.models.qwen2_vl.qwen2_vl_attention import Qwen2VLAttention +from keras_hub.src.utils.keras_utils import clone_initializer + + +class Qwen2VLTransformerDecoder(keras.layers.Layer): + """A single Transformer decoder block for Qwen2-VL. + + Structure: RMSNorm → M-RoPE Attention → residual → RMSNorm → SwiGLU + MLP → residual. + + The key difference from the standard Qwen decoder is that attention + receives pre-computed M-RoPE position embeddings of shape + `(3, batch, seq_len, head_dim)`. + + Args: + intermediate_dim: int. Dimension of the MLP intermediate + (up/gate) projections. + hidden_dim: int. Model hidden dimension. + num_query_heads: int. Number of query heads. + num_key_value_heads: int. Number of key/value heads. + mrope_section: list. The M-RoPE section sizes `[t, h, w]`. + rope_max_wavelength: float. Max wavelength for RoPE base. + layer_norm_epsilon: float. Epsilon for RMS normalization. + activation: callable. Activation for the gated MLP. + kernel_initializer: Initializer for kernels. + dropout: float. Dropout rate. + use_sliding_window_attention: bool. Whether to use sliding window. + sliding_window_size: int. Size of the sliding window. + dtype: string or `keras.mixed_precision.DTypePolicy`. + """ + + def __init__( + self, + intermediate_dim, + hidden_dim, + num_query_heads, + num_key_value_heads, + mrope_section, + rope_max_wavelength=10000, + layer_norm_epsilon=1e-6, + activation=None, + kernel_initializer="glorot_uniform", + dropout=0, + use_sliding_window_attention=False, + sliding_window_size=4096, + **kwargs, + ): + super().__init__(**kwargs) + self.intermediate_dim = intermediate_dim + self.hidden_dim = hidden_dim + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.mrope_section = mrope_section + self.rope_max_wavelength = rope_max_wavelength + self.layer_norm_epsilon = layer_norm_epsilon + self.dropout = dropout + self.use_sliding_window_attention = use_sliding_window_attention + self.sliding_window_size = sliding_window_size + + if activation is None: + activation = ops.silu + self.activation = activation + + self.kernel_initializer = keras.initializers.get( + clone_initializer(kernel_initializer) + ) + + self._self_attention_layer = Qwen2VLAttention( + num_query_heads=num_query_heads, + num_key_value_heads=num_key_value_heads, + hidden_dim=hidden_dim, + mrope_section=mrope_section, + rope_max_wavelength=rope_max_wavelength, + kernel_initializer=clone_initializer(kernel_initializer), + dropout=dropout, + use_sliding_window_attention=use_sliding_window_attention, + sliding_window_size=sliding_window_size, + dtype=self.dtype_policy, + name="self_attention", + ) + + self._self_attention_layernorm = QwenLayerNorm( + epsilon=layer_norm_epsilon, + dtype=self.dtype_policy, + name="self_attention_layernorm", + ) + + self._feedforward_layernorm = QwenLayerNorm( + epsilon=layer_norm_epsilon, + dtype=self.dtype_policy, + name="feedforward_layernorm", + ) + + # SwiGLU MLP: gate_proj and up_proj -> activation -> down_proj + self._feedforward_gate_dense = keras.layers.Dense( + intermediate_dim, + use_bias=False, + kernel_initializer=clone_initializer(kernel_initializer), + dtype=self.dtype_policy, + name="feedforward_gate_dense", + ) + self._feedforward_intermediate_dense = keras.layers.Dense( + intermediate_dim, + use_bias=False, + kernel_initializer=clone_initializer(kernel_initializer), + dtype=self.dtype_policy, + name="feedforward_intermediate_dense", + ) + self._feedforward_output_dense = keras.layers.Dense( + hidden_dim, + use_bias=False, + kernel_initializer=clone_initializer(kernel_initializer), + dtype=self.dtype_policy, + name="feedforward_output_dense", + ) + + def build(self, decoder_sequence_shape): + self._self_attention_layernorm.build(decoder_sequence_shape) + self._self_attention_layer.build(decoder_sequence_shape) + self._feedforward_layernorm.build(decoder_sequence_shape) + self._feedforward_gate_dense.build(decoder_sequence_shape) + self._feedforward_intermediate_dense.build(decoder_sequence_shape) + self._feedforward_output_dense.build( + decoder_sequence_shape[:-1] + (self.intermediate_dim,) + ) + self.built = True + + def call( + self, + hidden_states, + attention_mask=None, + position_embeddings=None, + cache=None, + cache_update_index=None, + training=None, + ): + """Forward pass through the decoder block. + + Args: + hidden_states: Input tensor of shape + `(batch, seq_len, hidden_dim)`. + attention_mask: Optional mask of shape + `(batch, seq_len, seq_len)`. + position_embeddings: Tuple of `(cos, sin)`, each of shape + `(3, batch, seq_len, head_dim)` for M-RoPE. + cache: Optional cached key/value states. + cache_update_index: Index for cache update. + training: Boolean training flag. + + Returns: + hidden_states: Output tensor. + cache: Updated cache (if provided). + """ + # Self-attention with residual + residual = hidden_states + hidden_states = self._self_attention_layernorm(hidden_states) + + attention_output = self._self_attention_layer( + hidden_states, + attention_mask=self._compute_self_attention_mask( + hidden_states=hidden_states, + attention_mask=attention_mask, + cache=cache, + cache_update_index=cache_update_index, + ), + position_embeddings=position_embeddings, + cache=cache, + cache_update_index=cache_update_index, + training=training, + ) + + if cache is not None: + attention_output, cache = attention_output + + hidden_states = residual + attention_output + + # SwiGLU MLP with residual + residual = hidden_states + hidden_states = self._feedforward_layernorm(hidden_states) + + gate = self.activation(self._feedforward_gate_dense(hidden_states)) + hidden_states = self._feedforward_intermediate_dense(hidden_states) + hidden_states = self._feedforward_output_dense(gate * hidden_states) + + hidden_states = residual + hidden_states + + if cache is not None: + return hidden_states, cache + return hidden_states + + def _compute_self_attention_mask( + self, + hidden_states, + attention_mask=None, + cache=None, + cache_update_index=None, + ): + """Computes the causal self-attention mask. + + Combines padding mask with causal mask. During generation with + cache, only produces mask for the current token(s). + """ + batch_size = ops.shape(hidden_states)[0] + input_length = ops.shape(hidden_states)[1] + + if cache is not None: + output_length = ops.shape(cache)[2] + else: + output_length = input_length + + # Causal mask + causal_mask = ops.triu( + ops.ones((output_length, output_length), dtype="bool"), + k=1, + ) + causal_mask = ops.logical_not(causal_mask) + + if cache_update_index is not None: + # Slice for the current step + causal_mask = ops.slice( + causal_mask, + (cache_update_index, 0), + (input_length, output_length), + ) + + # Combine with padding mask + if attention_mask is not None: + attention_mask = ops.expand_dims(attention_mask, axis=1) + attention_mask = ops.cast(attention_mask, dtype="bool") + causal_mask = ops.expand_dims(causal_mask, axis=0) + causal_mask = ops.broadcast_to( + causal_mask, + (batch_size, input_length, output_length), + ) + causal_mask = ops.logical_and(causal_mask, attention_mask) + else: + causal_mask = ops.expand_dims(causal_mask, axis=0) + causal_mask = ops.broadcast_to( + causal_mask, + (batch_size, input_length, output_length), + ) + + return causal_mask + + def get_config(self): + config = super().get_config() + config.update( + { + "intermediate_dim": self.intermediate_dim, + "hidden_dim": self.hidden_dim, + "num_query_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "mrope_section": self.mrope_section, + "rope_max_wavelength": self.rope_max_wavelength, + "layer_norm_epsilon": self.layer_norm_epsilon, + "dropout": self.dropout, + "use_sliding_window_attention": ( + self.use_sliding_window_attention + ), + "sliding_window_size": self.sliding_window_size, + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + } + ) + return config diff --git a/keras_hub/src/models/qwen2_vl/qwen2_vl_image_converter.py b/keras_hub/src/models/qwen2_vl/qwen2_vl_image_converter.py new file mode 100644 index 0000000000..5474444275 --- /dev/null +++ b/keras_hub/src/models/qwen2_vl/qwen2_vl_image_converter.py @@ -0,0 +1,77 @@ +import math + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter +from keras_hub.src.models.qwen2_vl.qwen2_vl_backbone import Qwen2VLBackbone + + +def smart_resize( + height, width, factor=28, min_pixels=56 * 56, max_pixels=12845056 +): + """Resize image dimensions so both are divisible by ``factor`` and the + total pixel count stays within ``[min_pixels, max_pixels]``. + + Args: + height: int. Original image height. + width: int. Original image width. + factor: int. Both output dims must be multiples of this value. + Defaults to ``28`` (``patch_size * merge_size = 14 * 2``). + min_pixels: int. Minimum total pixel count. Defaults to + ``56 * 56 = 3136``. + max_pixels: int. Maximum total pixel count. Defaults to + ``12845056`` (matching HuggingFace). + + Returns: + Tuple ``(h_bar, w_bar)`` of resized dimensions. + + Raises: + ValueError: If the absolute aspect ratio exceeds 200. + """ + if height <= 0 or width <= 0: + raise ValueError( + f"Height and width must be positive, " + f"got height={height}, width={width}." + ) + if max(height, width) / min(height, width) > 200: + raise ValueError( + f"Absolute aspect ratio must be smaller than 200, got " + f"{max(height, width) / min(height, width):.1f}." + ) + h_bar = round(height / factor) * factor + w_bar = round(width / factor) * factor + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = max(factor, math.floor(height / beta / factor) * factor) + w_bar = max(factor, math.floor(width / beta / factor) * factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor + return h_bar, w_bar + + +@keras_hub_export("keras_hub.layers.Qwen2VLImageConverter") +class Qwen2VLImageConverter(ImageConverter): + """Image converter for Qwen2-VL models. + + This layer handles image preprocessing (resize, normalize) for the + Qwen2-VL vision encoder. Image processing is always performed in + ``float32``. + + The ``smart_resize`` utility (defined above) can be used to compute + target dimensions that are divisible by the patch/merge factor before + passing images to this converter. + + Args: + **kwargs: Keyword arguments passed to the base ``ImageConverter``, + including ``height``, ``width``, ``scale``, ``offset``, + ``crop_to_aspect_ratio``, ``interpolation``, etc. + """ + + backbone_cls = Qwen2VLBackbone + + def __init__(self, **kwargs): + # Always do image preprocessing in float32. + kwargs.pop("dtype", None) + dtype = "float32" + super().__init__(dtype=dtype, **kwargs) diff --git a/keras_hub/src/models/qwen2_vl/qwen2_vl_layernorm.py b/keras_hub/src/models/qwen2_vl/qwen2_vl_layernorm.py new file mode 100644 index 0000000000..0d16044046 --- /dev/null +++ b/keras_hub/src/models/qwen2_vl/qwen2_vl_layernorm.py @@ -0,0 +1,17 @@ +from keras_hub.src.models.qwen.qwen_layernorm import QwenLayerNorm + + +class Qwen2VLLayerNorm(QwenLayerNorm): + """Qwen2-VL RMS LayerNorm. + + Reuses the existing ``QwenLayerNorm`` implementation (RMS + normalization without centering). + + Args: + epsilon: float. A small float added to the denominator to + avoid dividing by zero. Defaults to `1e-6`. + dtype: string or `keras.mixed_precision.DTypePolicy`. The + dtype for computations and weights. + """ + + pass diff --git a/keras_hub/src/models/qwen2_vl/qwen2_vl_presets.py b/keras_hub/src/models/qwen2_vl/qwen2_vl_presets.py new file mode 100644 index 0000000000..0b39df8166 --- /dev/null +++ b/keras_hub/src/models/qwen2_vl/qwen2_vl_presets.py @@ -0,0 +1,44 @@ +"""Qwen2-VL preset configurations.""" + +backbone_presets = {} + +# Presets will be added after they are uploaded to Kaggle. +# "qwen2_vl_2b_instruct": { +# "metadata": { +# "description": ( +# "28-layer Qwen2-VL multimodal model with 2 billion " +# "parameters, instruction-tuned." +# ), +# "params": 2210000000, +# "path": "qwen2_vl", +# }, +# "kaggle_handle": ( +# "kaggle://keras/qwen2-vl/keras/qwen2_vl_2b_instruct/1" +# ), +# }, +# "qwen2_vl_7b_instruct": { +# "metadata": { +# "description": ( +# "28-layer Qwen2-VL multimodal model with 7 billion " +# "parameters, instruction-tuned." +# ), +# "params": 8290000000, +# "path": "qwen2_vl", +# }, +# "kaggle_handle": ( +# "kaggle://keras/qwen2-vl/keras/qwen2_vl_7b_instruct/1" +# ), +# }, +# "qwen2_vl_72b_instruct": { +# "metadata": { +# "description": ( +# "80-layer Qwen2-VL multimodal model with 72 billion " +# "parameters, instruction-tuned." +# ), +# "params": 73400000000, +# "path": "qwen2_vl", +# }, +# "kaggle_handle": ( +# "kaggle://keras/qwen2-vl/keras/qwen2_vl_72b_instruct/1" +# ), +# }, diff --git a/keras_hub/src/models/qwen2_vl/qwen2_vl_tokenizer.py b/keras_hub/src/models/qwen2_vl/qwen2_vl_tokenizer.py new file mode 100644 index 0000000000..e1b0c0defb --- /dev/null +++ b/keras_hub/src/models/qwen2_vl/qwen2_vl_tokenizer.py @@ -0,0 +1,53 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.qwen.qwen_tokenizer import QwenTokenizer +from keras_hub.src.models.qwen2_vl.qwen2_vl_backbone import Qwen2VLBackbone + +VISION_START_TOKEN = "<|vision_start|>" +VISION_END_TOKEN = "<|vision_end|>" +VISION_PAD_TOKEN = "<|vision_pad|>" +IMAGE_PAD_TOKEN = "<|image_pad|>" +VIDEO_PAD_TOKEN = "<|video_pad|>" + + +@keras_hub_export( + [ + "keras_hub.tokenizers.Qwen2VLTokenizer", + "keras_hub.models.Qwen2VLTokenizer", + ] +) +class Qwen2VLTokenizer(QwenTokenizer): + """Tokenizer for Qwen2-VL models. + + This tokenizer extends the base Qwen tokenizer with vision-related + special tokens for multimodal input handling. + + Args: + vocabulary: Dictionary mapping tokens to IDs, or path to file. + merges: List of BPE merges, or path to merges file. + """ + + backbone_cls = Qwen2VLBackbone + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._init_vision_token_ids() + + def set_vocabulary_and_merges(self, vocabulary, merges): + super().set_vocabulary_and_merges(vocabulary, merges) + self._init_vision_token_ids() + + def _safe_token_to_id(self, token): + if self.vocabulary is None: + return None + return self.vocabulary.get(token) + + def _init_vision_token_ids(self): + # Multimodal token IDs used by preprocessing/model plumbing. + self.image_token_id = self._safe_token_to_id(IMAGE_PAD_TOKEN) + self.video_token_id = self._safe_token_to_id(VIDEO_PAD_TOKEN) + self.vision_start_token_id = self._safe_token_to_id(VISION_START_TOKEN) + self.vision_end_token_id = self._safe_token_to_id(VISION_END_TOKEN) + self.vision_pad_token_id = self._safe_token_to_id(VISION_PAD_TOKEN) + # Common alias names. + self.image_pad_token_id = self.image_token_id + self.video_pad_token_id = self.video_token_id diff --git a/keras_hub/src/models/qwen2_vl/qwen2_vl_tokenizer_test.py b/keras_hub/src/models/qwen2_vl/qwen2_vl_tokenizer_test.py new file mode 100644 index 0000000000..3b244aaae1 --- /dev/null +++ b/keras_hub/src/models/qwen2_vl/qwen2_vl_tokenizer_test.py @@ -0,0 +1,74 @@ +import pytest + +from keras_hub.src.models.qwen2_vl.qwen2_vl_tokenizer import Qwen2VLTokenizer +from keras_hub.src.tests.test_case import TestCase + + +class Qwen2VLTokenizerTest(TestCase): + def setUp(self): + self.vocab = ["!", "air", "\u0120air", "plane", "\u0120at", "port"] + self.vocab += ["<|endoftext|>"] + self.vocab += ["<|eot_id|>"] + self.vocab += ["<|vision_start|>"] + self.vocab += ["<|vision_end|>"] + self.vocab += ["<|vision_pad|>"] + self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) + self.merges = [ + "\u0120 a", + "\u0120 t", + "\u0120 i", + "\u0120 b", + "a i", + "p l", + "n e", + ] + self.merges += [ + "\u0120a t", + "p o", + "r t", + "\u0120t h", + "ai r", + "pl a", + "po rt", + ] + self.merges += ["\u0120ai r", "\u0120a i", "pla ne"] + self.init_kwargs = { + "vocabulary": self.vocab, + "merges": self.merges, + } + self.input_data = [ + " airplane at airport<|endoftext|>", + " airplane airport", + ] + + def test_tokenizer_basics(self): + self.run_preprocessing_layer_test( + cls=Qwen2VLTokenizer, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=[[2, 3, 4, 2, 5, 6], [2, 3, 2, 5]], + ) + + def test_errors_missing_special_tokens(self): + with self.assertRaises(ValueError): + Qwen2VLTokenizer(vocabulary=["a", "b", "c"], merges=[]) + + def test_vision_special_tokens(self): + """Verify vision special tokens are registered.""" + tokenizer = Qwen2VLTokenizer(**self.init_kwargs) + self.assertIsNotNone(tokenizer.vision_start_token_id) + self.assertIsNotNone(tokenizer.vision_end_token_id) + self.assertIsNotNone(tokenizer.vision_pad_token_id) + # Check they map to the right vocabulary ids. + self.assertEqual(tokenizer.vision_start_token_id, 8) + self.assertEqual(tokenizer.vision_end_token_id, 9) + self.assertEqual(tokenizer.vision_pad_token_id, 10) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in Qwen2VLTokenizer.presets: + self.run_preset_test( + cls=Qwen2VLTokenizer, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/qwen2_vl/qwen2_vl_vision_encoder.py b/keras_hub/src/models/qwen2_vl/qwen2_vl_vision_encoder.py new file mode 100644 index 0000000000..200dfdc138 --- /dev/null +++ b/keras_hub/src/models/qwen2_vl/qwen2_vl_vision_encoder.py @@ -0,0 +1,693 @@ +import keras +from keras import ops + +from keras_hub.src.models.qwen2_vl.qwen2_vl_attention import _rotate_half + + +class Qwen2VLPatchEmbed(keras.layers.Layer): + """3D convolution-based patch embedding for Qwen2-VL. + + Processes image/video frames using a 3D convolution with kernel size + `(temporal_patch_size, patch_size, patch_size)` to produce patch embeddings. + + Args: + patch_size: int. Spatial patch size (height and width). + temporal_patch_size: int. Temporal patch size (number of frames + grouped together). + in_channels: int. Number of input channels. + embed_dim: int. Embedding dimension. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype + for computations and weights. + """ + + def __init__( + self, + patch_size, + temporal_patch_size, + in_channels, + embed_dim, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + # The model's internal pipeline always produces patches in + # channels-first format: (batch, C, T, H, W). Keras handles + # cross-backend compatibility internally for Conv3D. + self.proj = keras.layers.Conv3D( + filters=embed_dim, + kernel_size=(temporal_patch_size, patch_size, patch_size), + strides=(temporal_patch_size, patch_size, patch_size), + use_bias=False, + data_format="channels_first", + dtype=dtype, + name="proj", + ) + + def call(self, hidden_states): + """Processes input patches through the 3D convolution. + + Args: + hidden_states: Tensor of shape + `(total_patches, in_channels, temporal_patch_size, + patch_size, patch_size)`. + + Returns: + Tensor of shape `(total_patches, embed_dim)`. + """ + hidden_states = self.proj(hidden_states) + # Flatten spatial and temporal dims: + # (batch, embed_dim, 1, 1, 1) -> (batch, embed_dim) + hidden_states = ops.reshape(hidden_states, (-1, self.embed_dim)) + return hidden_states + + def get_config(self): + config = super().get_config() + config.update( + { + "patch_size": self.patch_size, + "temporal_patch_size": self.temporal_patch_size, + "in_channels": self.in_channels, + "embed_dim": self.embed_dim, + } + ) + return config + + +class Qwen2VLVisionRotaryEmbedding(keras.layers.Layer): + """2D Rotary position embedding for the Qwen2-VL vision encoder. + + Computes rotary embeddings from spatial position indices. The embedding + dimension is split in half: one half for height positions and the other + half for width positions. + + Args: + dim: int. Dimension of the rotary embedding (typically `head_dim // 2`). + theta: float. Base frequency for the rotary embedding. + Defaults to `10000.0`. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype + for computations. + """ + + def __init__(self, dim, theta=10000.0, dtype=None, **kwargs): + super().__init__(dtype=dtype, **kwargs) + self.dim = dim + self.theta = theta + # Compute inverse frequencies: 1 / (theta ^ (2i / dim)) + inv_freq = 1.0 / ( + theta ** (ops.cast(ops.arange(0, dim, 2), "float32") / dim) + ) + self.inv_freq = self.add_weight( + name="inv_freq", + shape=inv_freq.shape, + initializer="zeros", + trainable=False, + dtype="float32", + ) + self.inv_freq.assign(inv_freq) + + def call(self, seqlen): + """Computes rotary embeddings for a given sequence length. + + Args: + seqlen: int. The maximum sequence length (max grid dimension). + + Returns: + Tensor of shape `(seqlen, dim // 2)` containing the rotary + frequencies. + """ + seq = ops.cast(ops.arange(seqlen), "float32") + # Outer product: (seqlen,) x (dim//2,) -> (seqlen, dim//2) + freqs = ops.einsum("i,j->ij", seq, self.inv_freq) + return freqs + + def get_config(self): + config = super().get_config() + config.update( + { + "dim": self.dim, + "theta": self.theta, + } + ) + return config + + +class Qwen2VLVisionMLP(keras.layers.Layer): + """MLP block for the Qwen2-VL vision encoder. + + A two-layer feedforward network with GELU activation. + + Args: + dim: int. Input/output dimension. + hidden_dim: int. Hidden dimension. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype + for computations and weights. + """ + + def __init__(self, dim, hidden_dim, dtype=None, **kwargs): + super().__init__(dtype=dtype, **kwargs) + self.dim = dim + self.hidden_dim = hidden_dim + self.fc1 = keras.layers.Dense( + hidden_dim, + use_bias=True, + dtype=dtype, + name="fc1", + ) + self.fc2 = keras.layers.Dense( + dim, + use_bias=True, + dtype=dtype, + name="fc2", + ) + + def call(self, x): + x = self.fc1(x) + x = ops.gelu(x, approximate=False) + x = self.fc2(x) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "dim": self.dim, + "hidden_dim": self.hidden_dim, + } + ) + return config + + +def _apply_rotary_pos_emb_vision(q, k, cos, sin): + """Applies rotary position embedding to query and key tensors. + + Args: + q: Query tensor. + k: Key tensor. + cos: Cosine part of rotary embedding. + sin: Sine part of rotary embedding. + + Returns: + Tuple of rotated query and key tensors. + """ + orig_q_dtype = q.dtype + orig_k_dtype = k.dtype + q = ops.cast(q, "float32") + k = ops.cast(k, "float32") + cos = ops.cast(ops.expand_dims(cos, axis=-2), "float32") + sin = ops.cast(ops.expand_dims(sin, axis=-2), "float32") + + q_embed = q * cos + _rotate_half(q) * sin + k_embed = k * cos + _rotate_half(k) * sin + return ops.cast(q_embed, orig_q_dtype), ops.cast(k_embed, orig_k_dtype) + + +class Qwen2VLVisionAttention(keras.layers.Layer): + """Multi-head attention for the Qwen2-VL vision encoder. + + Uses a fused QKV projection and 2D rotary position embeddings. + + Args: + embed_dim: int. Embedding dimension. + num_heads: int. Number of attention heads. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype + for computations and weights. + """ + + def __init__(self, embed_dim, num_heads, dtype=None, **kwargs): + super().__init__(dtype=dtype, **kwargs) + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + + self.qkv = keras.layers.Dense( + embed_dim * 3, + use_bias=True, + dtype=dtype, + name="qkv", + ) + self.proj = keras.layers.Dense( + embed_dim, + use_bias=True, + dtype=dtype, + name="proj", + ) + + def call(self, hidden_states, position_embeddings=None): + """Forward pass of vision attention. + + Args: + hidden_states: Tensor of shape `(seq_len, embed_dim)`. + Note: no batch dim — all images are concatenated. + position_embeddings: Tuple of `(cos, sin)` for rotary + embeddings, each of shape `(seq_len, head_dim)`. + + Returns: + Tensor of shape `(seq_len, embed_dim)`. + """ + seq_length = ops.shape(hidden_states)[0] + + # QKV projection: (seq_len, 3 * embed_dim) + qkv = self.qkv(hidden_states) + # Reshape to (seq_len, 3, num_heads, head_dim) + qkv = ops.reshape(qkv, (seq_length, 3, self.num_heads, self.head_dim)) + # Transpose to (3, seq_len, num_heads, head_dim) + qkv = ops.transpose(qkv, (1, 0, 2, 3)) + query, key, value = qkv[0], qkv[1], qkv[2] + + # Apply rotary position embeddings + if position_embeddings is not None: + cos, sin = position_embeddings + query, key = _apply_rotary_pos_emb_vision(query, key, cos, sin) + + # Transpose for attention: (1, num_heads, seq_len, head_dim) + query = ops.transpose(ops.expand_dims(query, axis=0), (0, 2, 1, 3)) + key = ops.transpose(ops.expand_dims(key, axis=0), (0, 2, 1, 3)) + value = ops.transpose(ops.expand_dims(value, axis=0), (0, 2, 1, 3)) + + # Scaled dot-product attention + scale = self.head_dim**-0.5 + attn_weights = ops.matmul(query, ops.transpose(key, (0, 1, 3, 2))) + attn_weights = attn_weights * scale + attn_weights = ops.softmax(ops.cast(attn_weights, "float32"), axis=-1) + attn_weights = ops.cast(attn_weights, query.dtype) + + attn_output = ops.matmul(attn_weights, value) + # (1, num_heads, seq_len, head_dim) -> (seq_len, embed_dim) + attn_output = ops.transpose(attn_output, (0, 2, 1, 3)) + attn_output = ops.reshape(attn_output, (seq_length, -1)) + + attn_output = self.proj(attn_output) + return attn_output + + def get_config(self): + config = super().get_config() + config.update( + { + "embed_dim": self.embed_dim, + "num_heads": self.num_heads, + } + ) + return config + + +class Qwen2VLVisionBlock(keras.layers.Layer): + """A single transformer block for the Qwen2-VL vision encoder. + + Pre-norm architecture: LN → Attention → residual → LN → MLP → residual. + + Args: + embed_dim: int. Embedding dimension. + num_heads: int. Number of attention heads. + mlp_ratio: float. Ratio of MLP hidden dim to embed_dim. + Defaults to `4.0`. + layer_norm_epsilon: float. Epsilon for layer normalization. + Defaults to `1e-6`. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype + for computations and weights. + """ + + def __init__( + self, + embed_dim, + num_heads, + mlp_ratio=4.0, + layer_norm_epsilon=1e-6, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.embed_dim = embed_dim + self.num_heads = num_heads + self.mlp_ratio = mlp_ratio + self.layer_norm_epsilon = layer_norm_epsilon + + mlp_hidden_dim = int(embed_dim * mlp_ratio) + + self.norm1 = keras.layers.LayerNormalization( + epsilon=layer_norm_epsilon, + dtype=dtype, + name="norm1", + ) + self.norm2 = keras.layers.LayerNormalization( + epsilon=layer_norm_epsilon, + dtype=dtype, + name="norm2", + ) + self.attn = Qwen2VLVisionAttention( + embed_dim=embed_dim, + num_heads=num_heads, + dtype=dtype, + name="attn", + ) + self.mlp = Qwen2VLVisionMLP( + dim=embed_dim, + hidden_dim=mlp_hidden_dim, + dtype=dtype, + name="mlp", + ) + + def call(self, hidden_states, position_embeddings=None): + """Forward pass through the vision block. + + Args: + hidden_states: Tensor of shape `(seq_len, embed_dim)`. + position_embeddings: Tuple of `(cos, sin)` for rotary + embeddings. + + Returns: + Tensor of shape `(seq_len, embed_dim)`. + """ + # Self-attention with residual + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + position_embeddings=position_embeddings, + ) + # MLP with residual + hidden_states = hidden_states + self.mlp( + self.norm2(hidden_states), + ) + return hidden_states + + def get_config(self): + config = super().get_config() + config.update( + { + "embed_dim": self.embed_dim, + "num_heads": self.num_heads, + "mlp_ratio": self.mlp_ratio, + "layer_norm_epsilon": self.layer_norm_epsilon, + } + ) + return config + + +class Qwen2VLPatchMerger(keras.layers.Layer): + """Spatial patch merger for the Qwen2-VL vision encoder. + + Merges `spatial_merge_size × spatial_merge_size` adjacent patches into + a single token, reducing the number of vision tokens by + `spatial_merge_size²`. + + Architecture: LayerNorm → reshape (group patches) → Dense → GELU → Dense. + + Args: + hidden_size: int. Output dimension (the LLM hidden dimension). + context_dim: int. The ViT embedding dimension. + spatial_merge_size: int. Size of the spatial merge window. + Defaults to `2`. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype + for computations and weights. + """ + + def __init__( + self, + hidden_size, + context_dim, + spatial_merge_size=2, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_size = hidden_size + self.context_dim = context_dim + self.spatial_merge_size = spatial_merge_size + + merge_hidden = context_dim * (spatial_merge_size**2) + + self.ln_q = keras.layers.LayerNormalization( + epsilon=1e-6, + dtype=dtype, + name="ln_q", + ) + self.dense1 = keras.layers.Dense( + merge_hidden, + use_bias=True, + dtype=dtype, + name="dense1", + ) + self.dense2 = keras.layers.Dense( + hidden_size, + use_bias=True, + dtype=dtype, + name="dense2", + ) + + def call(self, x): + """Merges adjacent patches. + + Args: + x: Tensor of shape `(total_patches, context_dim)`. + + Returns: + Tensor of shape `(merged_patches, hidden_size)`. + """ + x = self.ln_q(x) + merge_size = self.spatial_merge_size**2 + # Reshape to group adjacent patches + x = ops.reshape(x, (-1, merge_size * self.context_dim)) + x = self.dense1(x) + x = ops.gelu(x, approximate=False) + x = self.dense2(x) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "context_dim": self.context_dim, + "spatial_merge_size": self.spatial_merge_size, + } + ) + return config + + +class Qwen2VLVisionEncoder(keras.layers.Layer): + """Qwen2-VL Vision Transformer encoder. + + Full vision encoder: PatchEmbed → 2D RoPE → VisionBlocks → PatchMerger. + + This encoder takes pixel values and grid dimensions as input, and + produces vision embeddings suitable for interleaving with text token + embeddings in the Qwen2-VL decoder. + + Args: + hidden_size: int. The LLM hidden dimension (output dim of merger). + embed_dim: int. ViT embedding dimension. Defaults to `1280`. + depth: int. Number of vision transformer blocks. Defaults to `32`. + num_heads: int. Number of attention heads. Defaults to `16`. + patch_size: int. Spatial patch size. Defaults to `14`. + temporal_patch_size: int. Temporal patch size. Defaults to `2`. + in_channels: int. Number of input channels. Defaults to `3`. + mlp_ratio: float. MLP hidden dim ratio. Defaults to `4.0`. + spatial_merge_size: int. Spatial merge window size. Defaults to `2`. + layer_norm_epsilon: float. Epsilon for layer norms. + Defaults to `1e-6`. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype + for computations and weights. + """ + + def __init__( + self, + hidden_size, + embed_dim=1280, + depth=32, + num_heads=16, + patch_size=14, + temporal_patch_size=2, + in_channels=3, + mlp_ratio=4.0, + spatial_merge_size=2, + layer_norm_epsilon=1e-6, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_size = hidden_size + self.embed_dim = embed_dim + self.depth = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.mlp_ratio = mlp_ratio + self.spatial_merge_size = spatial_merge_size + self.layer_norm_epsilon = layer_norm_epsilon + + head_dim = embed_dim // num_heads + + self.patch_embed = Qwen2VLPatchEmbed( + patch_size=patch_size, + temporal_patch_size=temporal_patch_size, + in_channels=in_channels, + embed_dim=embed_dim, + dtype=dtype, + name="patch_embed", + ) + + self.rotary_pos_emb = Qwen2VLVisionRotaryEmbedding( + dim=head_dim // 2, + dtype=dtype, + name="rotary_pos_emb", + ) + + self.blocks = [ + Qwen2VLVisionBlock( + embed_dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + layer_norm_epsilon=layer_norm_epsilon, + dtype=dtype, + name=f"block_{i}", + ) + for i in range(depth) + ] + + self.merger = Qwen2VLPatchMerger( + hidden_size=hidden_size, + context_dim=embed_dim, + spatial_merge_size=spatial_merge_size, + dtype=dtype, + name="merger", + ) + + def _compute_rotary_pos_emb(self, grid_thw): + """Computes 2D rotary position embeddings from grid dimensions. + + Args: + grid_thw: Tensor of shape `(num_images, 3)` with + `[temporal, height, width]` for each image/video. + + Returns: + Tensor of shape `(total_patches, head_dim // 2)` with rotary + frequencies. + """ + all_pos_ids = [] + + spatial_merge = self.spatial_merge_size + for idx in range(ops.shape(grid_thw)[0]): + t = grid_thw[idx, 0] + h = grid_thw[idx, 1] + w = grid_thw[idx, 2] + + # Height position IDs + hpos_ids = ops.repeat( + ops.expand_dims(ops.arange(h), axis=1), w, axis=1 + ) + # Reshape for spatial merge grouping + hpos_ids = ops.reshape( + hpos_ids, + ( + h // spatial_merge, + spatial_merge, + w // spatial_merge, + spatial_merge, + ), + ) + hpos_ids = ops.transpose(hpos_ids, (0, 2, 1, 3)) + hpos_ids = ops.reshape(hpos_ids, (-1,)) + + # Width position IDs + wpos_ids = ops.repeat( + ops.expand_dims(ops.arange(w), axis=0), h, axis=0 + ) + wpos_ids = ops.reshape( + wpos_ids, + ( + h // spatial_merge, + spatial_merge, + w // spatial_merge, + spatial_merge, + ), + ) + wpos_ids = ops.transpose(wpos_ids, (0, 2, 1, 3)) + wpos_ids = ops.reshape(wpos_ids, (-1,)) + + # Stack [h, w] and repeat for each temporal frame + pos_ids = ops.stack([hpos_ids, wpos_ids], axis=-1) + pos_ids = ops.tile(pos_ids, (t, 1)) + all_pos_ids.append(pos_ids) + + pos_ids = ops.concatenate(all_pos_ids, axis=0) + + # Get max grid size for frequency computation + max_grid_size = ops.max(grid_thw[:, 1:]) + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + + # Gather embeddings for each position + # pos_ids: (total_patches, 2) — height and width indices + # rotary_pos_emb_full: (max_grid_size, dim//2) + h_emb = ops.take(rotary_pos_emb_full, pos_ids[:, 0], axis=0) + w_emb = ops.take(rotary_pos_emb_full, pos_ids[:, 1], axis=0) + # Concatenate height and width embeddings along last dim + rotary_pos_emb = ops.concatenate([h_emb, w_emb], axis=-1) + + return rotary_pos_emb + + def call(self, hidden_states, grid_thw): + """Forward pass of the vision encoder. + + Args: + hidden_states: Pixel values tensor of shape + `(total_patches, in_channels, temporal_patch_size, + patch_size, patch_size)`. + grid_thw: Tensor of shape `(num_images, 3)` containing + `[temporal, height, width]` grid dimensions for each + image/video. + + Returns: + Tensor of shape `(merged_total_patches, hidden_size)`. + """ + hidden_states = self.patch_embed(hidden_states) + + # Compute rotary position embeddings + rotary_pos_emb = self._compute_rotary_pos_emb(grid_thw) + emb = ops.concatenate([rotary_pos_emb, rotary_pos_emb], axis=-1) + position_embeddings = ( + ops.cos(ops.cast(emb, "float32")), + ops.sin(ops.cast(emb, "float32")), + ) + + # Apply transformer blocks + for blk in self.blocks: + hidden_states = blk( + hidden_states, + position_embeddings=position_embeddings, + ) + + # Spatial merge + merged = self.merger(hidden_states) + return merged + + def compute_output_shape(self, hidden_states_shape, grid_thw_shape): + del grid_thw_shape + return (hidden_states_shape[0], self.hidden_size) + + def compute_output_spec(self, hidden_states, grid_thw): + output_shape = self.compute_output_shape( + hidden_states.shape, grid_thw.shape + ) + return keras.KerasTensor(output_shape, dtype=self.compute_dtype) + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "embed_dim": self.embed_dim, + "depth": self.depth, + "num_heads": self.num_heads, + "patch_size": self.patch_size, + "temporal_patch_size": self.temporal_patch_size, + "in_channels": self.in_channels, + "mlp_ratio": self.mlp_ratio, + "spatial_merge_size": self.spatial_merge_size, + "layer_norm_epsilon": self.layer_norm_epsilon, + } + ) + return config diff --git a/keras_hub/src/models/qwen2_vl/qwen2_vl_vision_encoder_test.py b/keras_hub/src/models/qwen2_vl/qwen2_vl_vision_encoder_test.py new file mode 100644 index 0000000000..16d9cdabac --- /dev/null +++ b/keras_hub/src/models/qwen2_vl/qwen2_vl_vision_encoder_test.py @@ -0,0 +1,118 @@ +"""Tests for Qwen2-VL Vision Encoder components.""" + +import numpy as np + +from keras_hub.src.models.qwen2_vl.qwen2_vl_vision_encoder import ( + Qwen2VLPatchEmbed, +) +from keras_hub.src.models.qwen2_vl.qwen2_vl_vision_encoder import ( + Qwen2VLPatchMerger, +) +from keras_hub.src.models.qwen2_vl.qwen2_vl_vision_encoder import ( + Qwen2VLVisionBlock, +) +from keras_hub.src.models.qwen2_vl.qwen2_vl_vision_encoder import ( + Qwen2VLVisionEncoder, +) +from keras_hub.src.models.qwen2_vl.qwen2_vl_vision_encoder import ( + Qwen2VLVisionRotaryEmbedding, +) +from keras_hub.src.tests.test_case import TestCase + + +class Qwen2VLPatchEmbedTest(TestCase): + def test_output_shape(self): + patch_embed = Qwen2VLPatchEmbed( + patch_size=4, + temporal_patch_size=2, + in_channels=3, + embed_dim=32, + ) + # Input: (batch, in_channels, temporal, patch_h, patch_w) + dummy_input = np.random.rand(8, 3, 2, 4, 4).astype("float32") + output = patch_embed(dummy_input) + self.assertEqual(output.shape, (8, 32)) + + +class Qwen2VLVisionRotaryEmbeddingTest(TestCase): + def test_output_shape(self): + rope = Qwen2VLVisionRotaryEmbedding(dim=16) + freqs = rope(seqlen=10) + self.assertEqual(freqs.shape, (10, 8)) + + +class Qwen2VLVisionBlockTest(TestCase): + def test_output_shape(self): + block = Qwen2VLVisionBlock( + embed_dim=32, + num_heads=4, + mlp_ratio=4.0, + ) + # Input: (seq_len, embed_dim) — no batch dim in vision encoder + dummy_input = np.random.rand(16, 32).astype("float32") + # Create dummy position embeddings (shape: seq_len, head_dim) + head_dim = 32 // 4 # embed_dim // num_heads = 8 + cos = np.ones((16, head_dim), dtype="float32") + sin = np.zeros((16, head_dim), dtype="float32") + output = block(dummy_input, position_embeddings=(cos, sin)) + self.assertEqual(output.shape, (16, 32)) + + +class Qwen2VLPatchMergerTest(TestCase): + def test_output_shape(self): + merger = Qwen2VLPatchMerger( + hidden_size=64, + context_dim=32, + spatial_merge_size=2, + ) + # 16 patches, merge 2x2 -> 4 merged patches + dummy_input = np.random.rand(16, 32).astype("float32") + output = merger(dummy_input) + self.assertEqual(output.shape, (4, 64)) + + +class Qwen2VLVisionEncoderTest(TestCase): + def test_encoder_output_shape(self): + encoder = Qwen2VLVisionEncoder( + hidden_size=64, + embed_dim=32, + depth=2, + num_heads=4, + patch_size=4, + temporal_patch_size=2, + in_channels=3, + mlp_ratio=4.0, + spatial_merge_size=2, + ) + # Create input: 1 image with grid_thw = (1, 4, 4) + # Total patches = 1 * 4 * 4 = 16; after 3D patch embed these + # become (num_patches, in_channels, temporal, patch_h, patch_w) + num_patches = 16 + dummy_input = np.random.rand(num_patches, 3, 2, 4, 4).astype("float32") + grid_thw = np.array([[1, 4, 4]], dtype="int32") + + output = encoder(dummy_input, grid_thw) + # After PatchMerger with spatial_merge_size=2: + # 16 patches / (2*2) = 4 merged patches + self.assertEqual(output.shape, (4, 64)) + + def test_multi_image_output_shape(self): + encoder = Qwen2VLVisionEncoder( + hidden_size=64, + embed_dim=32, + depth=2, + num_heads=4, + patch_size=4, + temporal_patch_size=2, + in_channels=3, + mlp_ratio=4.0, + spatial_merge_size=2, + ) + # 2 images, each with grid_thw (1, 4, 4) -> 16 patches each + num_patches = 32 # 16 + 16 + dummy_input = np.random.rand(num_patches, 3, 2, 4, 4).astype("float32") + grid_thw = np.array([[1, 4, 4], [1, 4, 4]], dtype="int32") + + output = encoder(dummy_input, grid_thw) + # 32 patches / 4 = 8 merged patches + self.assertEqual(output.shape, (8, 64)) diff --git a/keras_hub/src/utils/transformers/convert_qwen2_vl.py b/keras_hub/src/utils/transformers/convert_qwen2_vl.py new file mode 100644 index 0000000000..9dbed092ae --- /dev/null +++ b/keras_hub/src/utils/transformers/convert_qwen2_vl.py @@ -0,0 +1,318 @@ +import numpy as np + +from keras_hub.src.models.qwen2_vl.qwen2_vl_backbone import Qwen2VLBackbone +from keras_hub.src.utils.preset_utils import load_json + +backbone_cls = Qwen2VLBackbone + + +def convert_backbone_config(transformers_config): + from keras_hub.src.models.qwen2_vl.qwen2_vl_vision_encoder import ( + Qwen2VLVisionEncoder, + ) + + vision_config = transformers_config.get("vision_config", {}) + mrope_section = transformers_config.get("rope_scaling", {}).get( + "mrope_section", [16, 24, 24] + ) + + kwargs = { + "vocabulary_size": transformers_config["vocab_size"], + "hidden_dim": transformers_config["hidden_size"], + "num_layers": transformers_config["num_hidden_layers"], + "num_query_heads": transformers_config["num_attention_heads"], + "num_key_value_heads": transformers_config["num_key_value_heads"], + "intermediate_dim": transformers_config["intermediate_size"], + "layer_norm_epsilon": transformers_config.get("rms_norm_eps", 1e-6), + "rope_max_wavelength": transformers_config.get("rope_theta", 1000000), + "tie_word_embeddings": transformers_config.get( + "tie_word_embeddings", True + ), + "mrope_section": mrope_section, + } + + # Instantiate vision encoder if config is present. + if vision_config: + vision_encoder = Qwen2VLVisionEncoder( + hidden_size=vision_config.get( + "hidden_size", transformers_config["hidden_size"] + ), + embed_dim=vision_config.get("embed_dim", 1280), + depth=vision_config.get("depth", 32), + num_heads=vision_config.get("num_heads", 16), + patch_size=vision_config.get("spatial_patch_size", 14), + temporal_patch_size=vision_config.get("temporal_patch_size", 2), + in_channels=vision_config.get("in_chans", 3), + mlp_ratio=vision_config.get("mlp_ratio", 4.0), + spatial_merge_size=vision_config.get("spatial_merge_size", 2), + name="vision_encoder", + ) + kwargs["vision_encoder"] = vision_encoder + + return kwargs + + +def convert_weights(backbone, loader, transformers_config): + # === Token embedding === + loader.port_weight( + keras_variable=backbone.get_layer("token_embedding").embeddings, + hf_weight_key="model.embed_tokens.weight", + ) + if not backbone.tie_word_embeddings: + loader.port_weight( + keras_variable=backbone.get_layer( + "token_embedding" + ).reverse_embeddings, + hf_weight_key="lm_head.weight", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), + ) + + def transpose_and_reshape(x, shape): + return np.reshape(np.transpose(x), shape) + + # === Text decoder layers === + for i in range(backbone.num_layers): + decoder_layer = backbone.get_layer(f"transformer_layer_{i}") + + # Input layernorm + loader.port_weight( + keras_variable=decoder_layer._self_attention_layernorm.scale, + hf_weight_key=f"model.layers.{i}.input_layernorm.weight", + ) + + # Attention — Query + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer._query_dense.kernel, + hf_weight_key=f"model.layers.{i}.self_attn.q_proj.weight", + hook_fn=transpose_and_reshape, + ) + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer._query_dense.bias, + hf_weight_key=f"model.layers.{i}.self_attn.q_proj.bias", + hook_fn=transpose_and_reshape, + ) + # Attention — Key + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer._key_dense.kernel, + hf_weight_key=f"model.layers.{i}.self_attn.k_proj.weight", + hook_fn=transpose_and_reshape, + ) + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer._key_dense.bias, + hf_weight_key=f"model.layers.{i}.self_attn.k_proj.bias", + hook_fn=transpose_and_reshape, + ) + # Attention — Value + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer._value_dense.kernel, + hf_weight_key=f"model.layers.{i}.self_attn.v_proj.weight", + hook_fn=transpose_and_reshape, + ) + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer._value_dense.bias, + hf_weight_key=f"model.layers.{i}.self_attn.v_proj.bias", + hook_fn=transpose_and_reshape, + ) + # Attention — Output + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer._output_dense.kernel, + hf_weight_key=f"model.layers.{i}.self_attn.o_proj.weight", + hook_fn=transpose_and_reshape, + ) + + # MLP layers + loader.port_weight( + keras_variable=decoder_layer._feedforward_intermediate_dense.kernel, + hf_weight_key=f"model.layers.{i}.mlp.up_proj.weight", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), + ) + loader.port_weight( + keras_variable=decoder_layer._feedforward_output_dense.kernel, + hf_weight_key=f"model.layers.{i}.mlp.down_proj.weight", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), + ) + loader.port_weight( + keras_variable=decoder_layer._feedforward_gate_dense.kernel, + hf_weight_key=f"model.layers.{i}.mlp.gate_proj.weight", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), + ) + + # Feedforward layernorm + loader.port_weight( + keras_variable=decoder_layer._feedforward_layernorm.scale, + hf_weight_key=f"model.layers.{i}.post_attention_layernorm.weight", + ) + + # Final normalization layer + loader.port_weight( + keras_variable=backbone.get_layer("sequence_output_layernorm").scale, + hf_weight_key="model.norm.weight", + ) + + # === Vision encoder weights (if present) === + if backbone.vision_encoder is not None: + vision = backbone.vision_encoder + + # Build the vision encoder weights before porting variables. + h_w = vision.spatial_merge_size + num_patches = h_w * h_w + build_images = np.zeros( + ( + num_patches, + vision.in_channels, + vision.temporal_patch_size, + vision.patch_size, + vision.patch_size, + ), + dtype="float32", + ) + build_grid = np.array([[1, h_w, h_w]], dtype="int32") + vision(build_images, grid_thw=build_grid) + + # Patch embedding (Conv3D) + loader.port_weight( + keras_variable=vision.patch_embed.proj.kernel, + hf_weight_key="visual.patch_embed.proj.weight", + hook_fn=lambda hf_tensor, _: np.transpose( + hf_tensor, axes=(2, 3, 4, 1, 0) + ), + ) + + # Vision transformer blocks + for i in range(vision.depth): + block = vision.blocks[i] + prefix = f"visual.blocks.{i}" + + # Layer norms + loader.port_weight( + keras_variable=block.norm1.gamma, + hf_weight_key=f"{prefix}.norm1.weight", + ) + loader.port_weight( + keras_variable=block.norm1.beta, + hf_weight_key=f"{prefix}.norm1.bias", + ) + loader.port_weight( + keras_variable=block.norm2.gamma, + hf_weight_key=f"{prefix}.norm2.weight", + ) + loader.port_weight( + keras_variable=block.norm2.beta, + hf_weight_key=f"{prefix}.norm2.bias", + ) + + # Attention QKV (fused) + loader.port_weight( + keras_variable=block.attn.qkv.kernel, + hf_weight_key=f"{prefix}.attn.qkv.weight", + hook_fn=lambda hf_tensor, _: np.transpose( + hf_tensor, axes=(1, 0) + ), + ) + loader.port_weight( + keras_variable=block.attn.qkv.bias, + hf_weight_key=f"{prefix}.attn.qkv.bias", + ) + + # Attention output projection + loader.port_weight( + keras_variable=block.attn.proj.kernel, + hf_weight_key=f"{prefix}.attn.proj.weight", + hook_fn=lambda hf_tensor, _: np.transpose( + hf_tensor, axes=(1, 0) + ), + ) + loader.port_weight( + keras_variable=block.attn.proj.bias, + hf_weight_key=f"{prefix}.attn.proj.bias", + ) + + # MLP + loader.port_weight( + keras_variable=block.mlp.fc1.kernel, + hf_weight_key=f"{prefix}.mlp.fc1.weight", + hook_fn=lambda hf_tensor, _: np.transpose( + hf_tensor, axes=(1, 0) + ), + ) + loader.port_weight( + keras_variable=block.mlp.fc1.bias, + hf_weight_key=f"{prefix}.mlp.fc1.bias", + ) + loader.port_weight( + keras_variable=block.mlp.fc2.kernel, + hf_weight_key=f"{prefix}.mlp.fc2.weight", + hook_fn=lambda hf_tensor, _: np.transpose( + hf_tensor, axes=(1, 0) + ), + ) + loader.port_weight( + keras_variable=block.mlp.fc2.bias, + hf_weight_key=f"{prefix}.mlp.fc2.bias", + ) + + # Patch merger + loader.port_weight( + keras_variable=vision.merger.ln_q.gamma, + hf_weight_key="visual.merger.ln_q.weight", + ) + loader.port_weight( + keras_variable=vision.merger.ln_q.beta, + hf_weight_key="visual.merger.ln_q.bias", + ) + loader.port_weight( + keras_variable=vision.merger.dense1.kernel, + hf_weight_key="visual.merger.mlp.0.weight", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), + ) + loader.port_weight( + keras_variable=vision.merger.dense1.bias, + hf_weight_key="visual.merger.mlp.0.bias", + ) + loader.port_weight( + keras_variable=vision.merger.dense2.kernel, + hf_weight_key="visual.merger.mlp.2.weight", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), + ) + loader.port_weight( + keras_variable=vision.merger.dense2.bias, + hf_weight_key="visual.merger.mlp.2.bias", + ) + + return backbone + + +def convert_tokenizer(cls, preset, **kwargs): + tokenizer_config = load_json(preset, "tokenizer.json") + vocab = tokenizer_config["model"]["vocab"] + merges = tokenizer_config["model"]["merges"] + + # Load all special tokens with the exception of "reserved" ones. + special_tokens = set() + for token in tokenizer_config["added_tokens"]: + if not token["content"].startswith("<|reserved_special_token_"): + vocab[token["content"]] = token["id"] + special_tokens.add(token["content"]) + + # Also load from tokenizer_config.json — some special tokens + # (e.g. <|image_pad|> 151655, <|video_pad|> 151656) only appear + # in tokenizer_config.json's added_tokens_decoder. + try: + tok_cfg = load_json(preset, "tokenizer_config.json") + for _id_str, meta in tok_cfg.get("added_tokens_decoder", {}).items(): + content = meta["content"] + if content not in vocab and not content.startswith( + "<|reserved_special_token_" + ): + vocab[content] = int(_id_str) + special_tokens.add(content) + except (FileNotFoundError, KeyError): + pass + + kwargs.update( + { + "unsplittable_tokens": list(special_tokens), + } + ) + + return cls(vocabulary=vocab, merges=merges, **kwargs) diff --git a/keras_hub/src/utils/transformers/preset_loader.py b/keras_hub/src/utils/transformers/preset_loader.py index 92c6ea5ef5..df90103cd1 100644 --- a/keras_hub/src/utils/transformers/preset_loader.py +++ b/keras_hub/src/utils/transformers/preset_loader.py @@ -22,6 +22,7 @@ from keras_hub.src.utils.transformers import convert_mixtral from keras_hub.src.utils.transformers import convert_pali_gemma from keras_hub.src.utils.transformers import convert_qwen +from keras_hub.src.utils.transformers import convert_qwen2_vl from keras_hub.src.utils.transformers import convert_qwen3 from keras_hub.src.utils.transformers import convert_qwen3_moe from keras_hub.src.utils.transformers import convert_qwen_moe @@ -71,6 +72,8 @@ def __init__(self, preset, config): self.converter = convert_vit elif model_type == "qwen2": self.converter = convert_qwen + elif model_type == "qwen2_vl": + self.converter = convert_qwen2_vl elif model_type == "mixtral": self.converter = convert_mixtral elif model_type == "qwen2_moe": diff --git a/tools/checkpoint_conversion/convert_qwen2_vl_checkpoints.py b/tools/checkpoint_conversion/convert_qwen2_vl_checkpoints.py new file mode 100644 index 0000000000..09d056f5bd --- /dev/null +++ b/tools/checkpoint_conversion/convert_qwen2_vl_checkpoints.py @@ -0,0 +1,247 @@ +"""Convert Qwen2-VL checkpoints from HuggingFace to KerasHub format. + +Usage: + python tools/checkpoint_conversion/convert_qwen2_vl_checkpoints.py \ + --preset qwen2_vl_2b_instruct +""" + +import gc +import os +import traceback + +os.environ["KERAS_BACKEND"] = "torch" +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Hide any CUDA devices + +import numpy as np +import torch +from absl import app +from absl import flags +from keras import ops +from transformers import AutoModelForImageTextToText +from transformers import AutoTokenizer + +import keras_hub + +device = torch.device("cpu") +# Force PyTorch to use CPU +torch.set_default_device(device) + + +PRESET_MAP = { + "qwen2_vl_2b_instruct": "Qwen/Qwen2-VL-2B-Instruct", + "qwen2_vl_7b_instruct": "Qwen/Qwen2-VL-7B-Instruct", + "qwen2_vl_72b_instruct": "Qwen/Qwen2-VL-72B-Instruct", +} + +FLAGS = flags.FLAGS +flags.DEFINE_string( + "preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}" +) +flags.DEFINE_string( + "save_dtype", + "bfloat16", + "Dtype to save the model in. Defaults to bfloat16.", +) + + +# Tolerance for logit comparison (float32 only validation). +DTYPE_TOLERANCES = { + "float32": {"atol": 1e-4, "rtol": 1e-4}, +} + + +def test_model( + keras_hub_model, + keras_hub_preprocessor, + hf_model, + hf_tokenizer, + keras_dtype, +): + # First, test that the number of parameters match. + keras_hub_params = keras_hub_model.count_params() + hf_params = hf_model.num_parameters() + assert keras_hub_params == hf_params + print(f"\n✓ Parameter count match: {keras_hub_params:,} params") + + # Test the outputs of both the models using identical inputs. + keras_hub_inputs = keras_hub_preprocessor.generate_preprocess( + ["What is Keras?"], sequence_length=6 + ) + hf_inputs = { + "input_ids": torch.tensor(keras_hub_inputs["token_ids"]).to(device), + "attention_mask": torch.tensor(keras_hub_inputs["padding_mask"]).to( + device + ), + } + hf_outputs = hf_model.model(**hf_inputs) + hf_output_logits = ( + hf_outputs.last_hidden_state.detach().cpu().float().numpy() + ) + + keras_hub_output = keras_hub_model(keras_hub_inputs) + keras_hub_logits = ops.convert_to_numpy(keras_hub_output) + + # Compute difference stats for reporting. + abs_diff = np.abs(keras_hub_logits - hf_output_logits) + max_abs_diff = np.max(abs_diff) + mean_abs_diff = np.mean(abs_diff) + + # Get dtype-appropriate tolerances. + tolerances = DTYPE_TOLERANCES.get(keras_dtype, {"atol": 1e-4, "rtol": 1e-4}) + atol = tolerances["atol"] + rtol = tolerances["rtol"] + + print(f"\nHidden state comparison (dtype: {keras_dtype}):") + print(f" Max absolute difference: {max_abs_diff:.6f}") + print(f" Mean absolute difference: {mean_abs_diff:.6f}") + print(f" Tolerance - atol: {atol}, rtol: {rtol}") + + try: + np.testing.assert_allclose( + keras_hub_logits, hf_output_logits, atol=atol, rtol=rtol + ) + print("✓ All hidden states within tolerance.") + except AssertionError as err: + print( + "Some hidden states exceed tolerance.\n" + "NOTE: Generated text comparison is the authoritative check." + ) + print("Traceback:") + print(traceback.format_exc()) + print("Assertion message:") + print(err.args[0]) + + +def validate_output( + keras_hub_model, + keras_hub_preprocessor, + hf_model, + hf_tokenizer, +): + """Validate end-to-end text generation between KerasHub and HF models.""" + prompt = "What is Keras?" + + # KerasHub generation. + keras_hub_lm = keras_hub.models.Qwen2VLCausalLM( + backbone=keras_hub_model, + preprocessor=keras_hub_preprocessor, + ) + keras_hub_text = keras_hub_lm.generate([prompt], max_length=20) + print(f"\nKerasHub generated: {keras_hub_text}") + + # HuggingFace generation. + hf_inputs = hf_tokenizer([prompt], return_tensors="pt").to(device) + hf_output_ids = hf_model.generate(**hf_inputs, max_new_tokens=20) + hf_text = hf_tokenizer.batch_decode(hf_output_ids, skip_special_tokens=True) + print(f"HF generated: {hf_text}") + + print("\n✓ Output validation complete (manual comparison above).") + + +def test_tokenizer(keras_hub_tokenizer, hf_tokenizer): + hf_output = hf_tokenizer(["What is Keras?"], return_tensors="pt") + hf_output = hf_output["input_ids"].detach().cpu().numpy() + keras_hub_preprocessor = keras_hub.models.Qwen2VLCausalLMPreprocessor( + keras_hub_tokenizer + ) + keras_hub_output = keras_hub_preprocessor.generate_preprocess( + ["What is Keras?"], sequence_length=6 + ) + keras_hub_output = ops.convert_to_numpy(keras_hub_output["token_ids"]) + + np.testing.assert_equal(keras_hub_output, hf_output) + + +def main(_): + # === Get the preset name === + if FLAGS.preset not in PRESET_MAP.keys(): + raise ValueError( + f"Invalid preset {FLAGS.preset}. Must be one " + f"of {','.join(PRESET_MAP.keys())}" + ) + preset = FLAGS.preset + hf_preset = PRESET_MAP[preset] + + # === Load the Huggingface model === + # Force float32 load for validation. + target_dtype = torch.float32 + + hf_model = AutoModelForImageTextToText.from_pretrained( + hf_preset, + device_map=device, + torch_dtype=target_dtype, + ) + hf_tokenizer = AutoTokenizer.from_pretrained(hf_preset, return_tensors="pt") + hf_model.eval() + + # Verify the actual loaded dtype. + hf_dtype = next(hf_model.parameters()).dtype + keras_dtype = "float32" + print( + f"-> Actual loaded dtype: {hf_dtype} -> " + f"Using Keras dtype: {keras_dtype}" + ) + + # Load Keras backbone with matching dtype. + keras_hub_backbone = keras_hub.models.Qwen2VLBackbone.from_preset( + f"hf://{hf_preset}", dtype=keras_dtype + ) + keras_hub_tokenizer = keras_hub.models.Qwen2VLTokenizer.from_preset( + f"hf://{hf_preset}" + ) + keras_hub_preprocessor = keras_hub.models.Qwen2VLCausalLMPreprocessor( + keras_hub_tokenizer + ) + + print("\n-> Huggingface model and tokenizer loaded") + + # === Check that the models and tokenizers outputs match === + test_tokenizer(keras_hub_tokenizer, hf_tokenizer) + test_model( + keras_hub_backbone, + keras_hub_preprocessor, + hf_model, + hf_tokenizer, + keras_dtype, + ) + validate_output( + keras_hub_backbone, + keras_hub_preprocessor, + hf_model, + hf_tokenizer, + ) + print("\n-> Tests passed!") + + # === Save the model === + keras_hub_lm = keras_hub.models.Qwen2VLCausalLM( + backbone=keras_hub_backbone, + preprocessor=keras_hub_preprocessor, + ) + + save_dtype = FLAGS.save_dtype + if save_dtype == "float32": + print(f"\n-> Saving model in {save_dtype}...") + keras_hub_lm.save_to_preset(f"./{preset}") + else: + del keras_hub_lm + del keras_hub_backbone + del hf_model + gc.collect() + + # Reload in target dtype for saving. + print(f"\n-> Reloading model in {save_dtype} for saving...") + keras_hub_backbone_save = keras_hub.models.Qwen2VLBackbone.from_preset( + f"hf://{hf_preset}", dtype=save_dtype + ) + keras_hub_lm_save = keras_hub.models.Qwen2VLCausalLM( + backbone=keras_hub_backbone_save, + preprocessor=keras_hub_preprocessor, + ) + keras_hub_lm_save.save_to_preset(f"./{preset}") + + print(f"\n-> Saved converted model ({save_dtype}) to ./{preset}") + + +if __name__ == "__main__": + flags.mark_flag_as_required("preset") + app.run(main)