diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index a7138bccad..064f0d862b 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -123,6 +123,9 @@ from keras_hub.src.models.parseq.parseq_image_converter import ( PARSeqImageConverter as PARSeqImageConverter, ) +from keras_hub.src.models.qwen2_vl.qwen2_vl_image_converter import ( + Qwen2VLImageConverter as Qwen2VLImageConverter, +) from keras_hub.src.models.resnet.resnet_image_converter import ( ResNetImageConverter as ResNetImageConverter, ) diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index aa6f4f2023..c1624d515d 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -545,6 +545,18 @@ from keras_hub.src.models.qwen.qwen_tokenizer import ( QwenTokenizer as QwenTokenizer, ) +from keras_hub.src.models.qwen2_vl.qwen2_vl_backbone import ( + Qwen2VLBackbone as Qwen2VLBackbone, +) +from keras_hub.src.models.qwen2_vl.qwen2_vl_causal_lm import ( + Qwen2VLCausalLM as Qwen2VLCausalLM, +) +from keras_hub.src.models.qwen2_vl.qwen2_vl_causal_lm_preprocessor import ( + Qwen2VLCausalLMPreprocessor as Qwen2VLCausalLMPreprocessor, +) +from keras_hub.src.models.qwen2_vl.qwen2_vl_tokenizer import ( + Qwen2VLTokenizer as Qwen2VLTokenizer, +) from keras_hub.src.models.qwen3.qwen3_backbone import ( Qwen3Backbone as Qwen3Backbone, ) diff --git a/keras_hub/api/tokenizers/__init__.py b/keras_hub/api/tokenizers/__init__.py index 97a68ab009..aecdc4c8d5 100644 --- a/keras_hub/api/tokenizers/__init__.py +++ b/keras_hub/api/tokenizers/__init__.py @@ -81,6 +81,9 @@ from keras_hub.src.models.qwen.qwen_tokenizer import ( QwenTokenizer as QwenTokenizer, ) +from keras_hub.src.models.qwen2_vl.qwen2_vl_tokenizer import ( + Qwen2VLTokenizer as Qwen2VLTokenizer, +) from keras_hub.src.models.qwen3_moe.qwen3_moe_tokenizer import ( Qwen3MoeTokenizer as Qwen3MoeTokenizer, ) diff --git a/keras_hub/src/models/qwen2_vl/__init__.py b/keras_hub/src/models/qwen2_vl/__init__.py new file mode 100644 index 0000000000..09b895aa95 --- /dev/null +++ b/keras_hub/src/models/qwen2_vl/__init__.py @@ -0,0 +1,5 @@ +from keras_hub.src.models.qwen2_vl.qwen2_vl_backbone import Qwen2VLBackbone +from keras_hub.src.models.qwen2_vl.qwen2_vl_presets import backbone_presets +from keras_hub.src.utils.preset_utils import register_presets + +register_presets(backbone_presets, Qwen2VLBackbone) diff --git a/keras_hub/src/models/qwen2_vl/qwen2_vl_backbone.py b/keras_hub/src/models/qwen2_vl/qwen2_vl_backbone.py new file mode 100644 index 0000000000..49a7bc1786 --- /dev/null +++ b/keras_hub/src/models/qwen2_vl/qwen2_vl_backbone.py @@ -0,0 +1,296 @@ +import keras +from keras import ops +from keras.layers import ReversibleEmbedding + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.qwen.qwen_decoder import QwenTransformerDecoder +from keras_hub.src.models.qwen.qwen_layernorm import QwenLayerNorm +from keras_hub.src.models.qwen2_vl.qwen2_vl_vision_encoder import ( + Qwen2VLVisionEncoder, +) + + +def _qwen2vl_kernel_initializer(stddev=0.02): + return keras.initializers.RandomNormal(stddev=stddev) + + +@keras_hub_export("keras_hub.models.Qwen2VLBackbone") +class Qwen2VLBackbone(Backbone): + """Qwen2-VL multimodal backbone. + + Combines a 3D Vision Encoder (ViT with RoPE + PatchMerger) with a + Qwen2 causal language model decoder. Vision tokens produced by the + encoder replace the ``image_token_id`` placeholder tokens in the text + sequence before being passed through the decoder layers. + + Args: + vocabulary_size: int. Vocabulary size of the text model. + num_layers: int. Number of transformer decoder layers. + num_query_heads: int. Number of query attention heads. + num_key_value_heads: int. Number of key/value attention heads (GQA). + hidden_dim: int. LLM hidden dimension. + intermediate_dim: int. Feed-forward intermediate dimension. + vision_patch_size: int. Spatial patch size for the vision encoder. + Defaults to ``14``. + vision_temporal_patch_size: int. Temporal patch size. Defaults to + ``2``. + vision_in_channels: int. Vision input channels. Defaults to ``3``. + vision_embed_dim: int. Vision encoder internal dimension. Defaults + to ``1280``. + vision_depth: int. Number of vision transformer blocks. Defaults to + ``32``. + vision_num_heads: int. Vision attention heads. Defaults to ``16``. + vision_mlp_ratio: float. Vision MLP hidden dim multiplier. Defaults + to ``4``. + spatial_merge_size: int. Spatial merge factor for PatchMerger. + Defaults to ``2``. + image_token_id: int. Token id used as image placeholder in the text + sequence. The number of ``image_token_id`` placeholders in the + input must exactly equal the number of merged vision tokens + produced by encoding ``patch_values`` with ``image_grid_thw``. + Defaults to ``151655``. + rope_max_wavelength: int. RoPE base wavelength for the text model. + Defaults to ``1000000``. + rope_scaling_factor: float. RoPE scaling factor. Defaults to ``1.0``. + layer_norm_epsilon: float. Epsilon for RMS norm layers. Defaults to + ``1e-6``. + dropout: float. Dropout rate. Defaults to ``0``. + tie_word_embeddings: bool. Whether to tie input/output embeddings. + Defaults to ``False``. + use_sliding_window_attention: bool. Whether to use sliding window + attention. Defaults to ``False``. + sliding_window_size: int. Sliding window size. Defaults to ``32768``. + dtype: string or ``keras.mixed_precision.DTypePolicy``. + """ + + def __init__( + self, + vocabulary_size, + num_layers, + num_query_heads, + num_key_value_heads, + hidden_dim, + intermediate_dim, + vision_patch_size=14, + vision_temporal_patch_size=2, + vision_in_channels=3, + vision_embed_dim=1280, + vision_depth=32, + vision_num_heads=16, + vision_mlp_ratio=4, + spatial_merge_size=2, + image_token_id=151655, + rope_max_wavelength=1000000, + rope_scaling_factor=1.0, + layer_norm_epsilon=1e-6, + dropout=0, + tie_word_embeddings=False, + use_sliding_window_attention=False, + sliding_window_size=32768, + dtype=None, + **kwargs, + ): + # === Vision encoder === + self.vision_encoder = Qwen2VLVisionEncoder( + patch_size=vision_patch_size, + temporal_patch_size=vision_temporal_patch_size, + in_channels=vision_in_channels, + embed_dim=vision_embed_dim, + hidden_size=hidden_dim, + depth=vision_depth, + num_heads=vision_num_heads, + mlp_ratio=vision_mlp_ratio, + spatial_merge_size=spatial_merge_size, + dtype=dtype, + name="vision_encoder", + ) + + # === Text decoder === + self.token_embedding = ReversibleEmbedding( + input_dim=vocabulary_size, + output_dim=hidden_dim, + tie_weights=tie_word_embeddings, + embeddings_initializer=_qwen2vl_kernel_initializer(stddev=0.01), + dtype=dtype, + name="token_embedding", + ) + self.transformer_layers = [] + for i in range(num_layers): + layer = QwenTransformerDecoder( + intermediate_dim=intermediate_dim, + num_query_heads=num_query_heads, + num_key_value_heads=num_key_value_heads, + rope_max_wavelength=rope_max_wavelength, + rope_scaling_factor=rope_scaling_factor, + layer_norm_epsilon=layer_norm_epsilon, + activation=ops.silu, + kernel_initializer=_qwen2vl_kernel_initializer(stddev=0.02), + dropout=dropout, + dtype=dtype, + use_sliding_window_attention=use_sliding_window_attention, + sliding_window_size=sliding_window_size, + name=f"transformer_layer_{i}", + ) + self.transformer_layers.append(layer) + self.layer_norm = QwenLayerNorm( + epsilon=layer_norm_epsilon, + dtype=dtype, + name="sequence_output_layernorm", + ) + + # === Functional model === + # Only text inputs in functional graph; vision inputs handled in call() + token_ids = keras.Input(shape=(None,), dtype="int32", name="token_ids") + padding_mask = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) + + # Text embedding path (vision merging happens in call()) + token_embeddings = self.token_embedding(token_ids) + x = token_embeddings + for transformer_layer in self.transformer_layers: + x = transformer_layer(x, decoder_padding_mask=padding_mask) + sequence_output = self.layer_norm(x) + + super().__init__( + inputs={ + "token_ids": token_ids, + "padding_mask": padding_mask, + }, + outputs=sequence_output, + dtype=dtype, + **kwargs, + ) + + # === Config === + self.vocabulary_size = vocabulary_size + self.num_layers = num_layers + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.vision_patch_size = vision_patch_size + self.vision_temporal_patch_size = vision_temporal_patch_size + self.vision_in_channels = vision_in_channels + self.vision_embed_dim = vision_embed_dim + self.vision_depth = vision_depth + self.vision_num_heads = vision_num_heads + self.vision_mlp_ratio = vision_mlp_ratio + self.spatial_merge_size = spatial_merge_size + self.image_token_id = image_token_id + self.rope_max_wavelength = rope_max_wavelength + self.rope_scaling_factor = rope_scaling_factor + self.layer_norm_epsilon = layer_norm_epsilon + self.dropout = dropout + self.tie_word_embeddings = tie_word_embeddings + self.use_sliding_window_attention = use_sliding_window_attention + self.sliding_window_size = sliding_window_size + + def call(self, inputs, training=None): + """Forward pass with vision token replacement. + + Embeds text tokens, encodes vision patches, replaces + ``image_token_id`` placeholder positions in the embedding sequence + with the merged vision features, then runs the decoder. + + Args: + inputs: Dict with keys ``"token_ids"``, ``"padding_mask"``, + ``"patch_values"`` (optional), ``"image_grid_thw"`` + (optional). + training: bool or None. + + Returns: + Hidden-state tensor of shape ``(batch, seq_len, hidden_dim)``. + """ + token_ids = inputs["token_ids"] + padding_mask = inputs["padding_mask"] + patch_values = inputs.get("patch_values", None) + grid_thw = inputs.get("image_grid_thw", None) + + # Embed text tokens → (batch, seq_len, hidden_dim) + x = self.token_embedding(token_ids) + + # If vision inputs are present, encode and scatter into x. + if patch_values is not None and grid_thw is not None: + # vision_features: (total_merged_tokens, hidden_dim) + vision_features = self.vision_encoder( + patch_values, grid_thw, training=training + ) + # Build a boolean mask of image placeholder positions. + # image_mask: (batch, seq_len) + image_mask = ops.equal( + token_ids, + ops.cast(self.image_token_id, token_ids.dtype), + ) + # Flatten batch+seq dims, replace masked positions with + # vision features, then restore shape. + batch_size = ops.shape(x)[0] + seq_len = ops.shape(x)[1] + x_flat = ops.reshape(x, (-1, self.hidden_dim)) + mask_flat = ops.reshape(image_mask, (-1,)) + # vision_features is already in the right order (same order as + # the image placeholder tokens appear left-to-right). + vision_indices = ops.where(mask_flat) + if isinstance(vision_indices, (list, tuple)): + vision_indices = vision_indices[0] + vision_indices = ops.reshape(vision_indices, (-1, 1)) + vision_indices = ops.cast(vision_indices, "int32") + n_placeholders = ops.shape(vision_indices)[0] + n_vision = ops.shape(vision_features)[0] + if n_placeholders != n_vision: + raise ValueError( + f"Vision token count mismatch: the number of " + f"image_token_id={self.image_token_id} placeholders " + f"in token_ids ({n_placeholders}) does not equal the " + f"number of merged vision tokens produced by the " + f"vision encoder from patch_values/image_grid_thw " + f"({n_vision}). Ensure the preprocessor inserts " + f"exactly one placeholder per merged vision token." + ) + x_flat = ops.scatter_update(x_flat, vision_indices, vision_features) + x = ops.reshape(x_flat, (batch_size, seq_len, self.hidden_dim)) + + # Decoder layers + for transformer_layer in self.transformer_layers: + x = transformer_layer( + x, + decoder_padding_mask=padding_mask, + training=training, + ) + + # Final layer norm + x = self.layer_norm(x) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "num_query_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "vision_patch_size": self.vision_patch_size, + "vision_temporal_patch_size": self.vision_temporal_patch_size, + "vision_in_channels": self.vision_in_channels, + "vision_embed_dim": self.vision_embed_dim, + "vision_depth": self.vision_depth, + "vision_num_heads": self.vision_num_heads, + "vision_mlp_ratio": self.vision_mlp_ratio, + "spatial_merge_size": self.spatial_merge_size, + "image_token_id": self.image_token_id, + "rope_max_wavelength": self.rope_max_wavelength, + "rope_scaling_factor": self.rope_scaling_factor, + "layer_norm_epsilon": self.layer_norm_epsilon, + "dropout": self.dropout, + "tie_word_embeddings": self.tie_word_embeddings, + "use_sliding_window_attention": ( + self.use_sliding_window_attention + ), + "sliding_window_size": self.sliding_window_size, + } + ) + return config diff --git a/keras_hub/src/models/qwen2_vl/qwen2_vl_backbone_test.py b/keras_hub/src/models/qwen2_vl/qwen2_vl_backbone_test.py new file mode 100644 index 0000000000..594b12202d --- /dev/null +++ b/keras_hub/src/models/qwen2_vl/qwen2_vl_backbone_test.py @@ -0,0 +1,86 @@ +import numpy as np +import pytest + +from keras_hub.src.models.qwen2_vl.qwen2_vl_backbone import Qwen2VLBackbone +from keras_hub.src.tests.test_case import TestCase + + +class Qwen2VLBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "vocabulary_size": 1000, + "num_layers": 2, + "num_query_heads": 8, + "num_key_value_heads": 4, + "hidden_dim": 128, + "intermediate_dim": 256, + "vision_patch_size": 14, + "vision_temporal_patch_size": 2, + "vision_in_channels": 3, + "vision_embed_dim": 64, + "vision_depth": 2, + "vision_num_heads": 4, + "spatial_merge_size": 2, + } + self.input_data = { + "token_ids": np.array([[1, 2, 3, 4, 5]]), + "padding_mask": np.array([[1, 1, 1, 1, 1]]), + } + + def test_backbone_basics(self): + self.run_backbone_test( + cls=Qwen2VLBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(1, 5, 128), + ) + + def test_backbone_with_vision_inputs(self): + """Test that the backbone correctly handles vision inputs.""" + # image_token_id defaults to 151655, but vocab_size=1000, + # so use a custom backbone with image_token_id=999. + init_kwargs = dict(self.init_kwargs) + init_kwargs["image_token_id"] = 999 + backbone = Qwen2VLBackbone(**init_kwargs) + + # Token sequence with 2 image placeholder tokens at positions 1,2. + token_ids = np.array([[1, 999, 999, 4, 5]]) + padding_mask = np.array([[1, 1, 1, 1, 1]]) + + # Vision: grid_thw = [1, 2, 2] → 4 raw patches. + # After merger (spatial_merge_size=2): 4/4 = 1 merged token. + # But we have 2 placeholder tokens, so we need 2 merged tokens. + # grid_thw = [1, 4, 4] → 16 raw patches → 16/4 = 4 merged. + # Use [2, 2, 2] → 8 raw patches → 8/4 = 2 merged tokens. + grid_thw = np.array([[2, 2, 2]], dtype="int32") + patch_flat_dim = 3 * 2 * 14 * 14 # in_channels * temporal * patch² + total_patches = 8 + patch_values = np.random.rand(total_patches, patch_flat_dim).astype( + "float32" + ) + + inputs = { + "token_ids": token_ids, + "padding_mask": padding_mask, + "patch_values": patch_values, + "image_grid_thw": grid_thw, + } + output = backbone(inputs) + self.assertEqual(output.shape, (1, 5, 128)) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=Qwen2VLBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in Qwen2VLBackbone.presets: + self.run_preset_test( + cls=Qwen2VLBackbone, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/qwen2_vl/qwen2_vl_causal_lm.py b/keras_hub/src/models/qwen2_vl/qwen2_vl_causal_lm.py new file mode 100644 index 0000000000..f5e3e811ac --- /dev/null +++ b/keras_hub/src/models/qwen2_vl/qwen2_vl_causal_lm.py @@ -0,0 +1,354 @@ +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.causal_lm import CausalLM +from keras_hub.src.models.qwen2_vl.qwen2_vl_backbone import Qwen2VLBackbone +from keras_hub.src.models.qwen2_vl.qwen2_vl_causal_lm_preprocessor import ( + Qwen2VLCausalLMPreprocessor, +) +from keras_hub.src.utils.tensor_utils import any_equal + + +@keras_hub_export("keras_hub.models.Qwen2VLCausalLM") +class Qwen2VLCausalLM(CausalLM): + """End-to-end Qwen2-VL model for causal vision-language modeling. + + A causal language model (LM) predicts the next token based on previous + tokens. This task setup can be used to train the model unsupervised on + plain text input, or to autoregressively generate plain text similar to + the data used for training. This task can be used for pre-training or + fine-tuning a Qwen2-VL model, simply by calling `fit()`. + + This model has a `generate()` method, which generates text based on a + prompt and optional image(s). The generation strategy used is controlled + by an additional `sampler` argument on `compile()`. You can recompile the + model with different `keras_hub.samplers` objects to control the + generation. By default, `"greedy"` sampling will be used. + + This model can optionally be configured with a `preprocessor` layer, in + which case it will automatically apply preprocessing to string inputs + during `fit()`, `predict()`, `evaluate()`, and `generate()`. This is done + by default when creating the model with `from_preset()`. + + Args: + backbone: A `keras_hub.models.Qwen2VLBackbone` instance. + preprocessor: A `keras_hub.models.Qwen2VLCausalLMPreprocessor` or + `None`. If `None`, this model will not apply preprocessing, and + inputs should be preprocessed before calling the model. + + Examples: + + Use `generate()` to do vision-language generation. + ```python + qwen2_vl_lm = keras_hub.models.Qwen2VLCausalLM.from_preset( + "qwen2_vl_2b_instruct" + ) + qwen2_vl_lm.generate( + {"prompts": "Describe this image", "images": image}, + max_length=128 + ) + ``` + + Use `generate()` with batched prompts and images. + ```python + qwen2_vl_lm.generate( + { + "prompts": ["What is in this image?", "Describe the scene"], + "images": [image1, image2] + }, + max_length=128 + ) + ``` + + Compile the `generate()` function with a custom sampler. + ```python + qwen2_vl_lm = keras_hub.models.Qwen2VLCausalLM.from_preset( + "qwen2_vl_2b_instruct" + ) + qwen2_vl_lm.compile(sampler="top_k") + qwen2_vl_lm.generate( + {"prompts": "What do you see?", "images": image}, + max_length=128 + ) + + qwen2_vl_lm.compile(sampler=keras_hub.samplers.BeamSampler(num_beams=2)) + qwen2_vl_lm.generate( + {"prompts": "Describe this", "images": image}, + max_length=128 + ) + ``` + + Use `generate()` without preprocessing. + ```python + prompt = { + "token_ids": np.array([[151644, 872, 4320]] * 2), + "padding_mask": np.array([[1, 1, 1]] * 2), + "patch_values": np.random.rand(256, 1176), + "image_grid_thw": np.array([[2, 8, 8]]), + } + + qwen2_vl_lm = keras_hub.models.Qwen2VLCausalLM.from_preset( + "qwen2_vl_2b_instruct", + preprocessor=None, + ) + qwen2_vl_lm.generate(prompt) + ``` + + Call `fit()` on a single batch. + ```python + features = { + "prompts": ["Describe this image", "What is this?"], + "images": [image1, image2], + } + qwen2_vl_lm = keras_hub.models.Qwen2VLCausalLM.from_preset( + "qwen2_vl_2b_instruct" + ) + qwen2_vl_lm.fit(x=features, batch_size=2) + ``` + + Call `fit()` without preprocessing. + ```python + x = { + "token_ids": np.array([[151644, 872, 4320, 151645]] * 2), + "padding_mask": np.array([[1, 1, 1, 1]] * 2), + "patch_values": np.random.rand(256, 1176), + "image_grid_thw": np.array([[2, 8, 8], [2, 8, 8]]), + } + y = np.array([[872, 4320, 151645, 0]] * 2) + sw = np.array([[1, 1, 1, 0]] * 2) + + qwen2_vl_lm = keras_hub.models.Qwen2VLCausalLM.from_preset( + "qwen2_vl_2b_instruct", + preprocessor=None, + ) + qwen2_vl_lm.fit(x=x, y=y, sample_weight=sw, batch_size=2) + ``` + + Custom backbone and vocabulary. + ```python + tokenizer = keras_hub.models.Qwen2VLTokenizer( + vocabulary="./vocab.json", + merges="./merges.txt", + ) + preprocessor = keras_hub.models.Qwen2VLCausalLMPreprocessor( + tokenizer=tokenizer, + sequence_length=128, + ) + backbone = keras_hub.models.Qwen2VLBackbone( + vocabulary_size=50000, + num_layers=12, + num_query_heads=12, + hidden_dim=768, + intermediate_dim=3072, + ) + qwen2_vl_lm = keras_hub.models.Qwen2VLCausalLM( + backbone=backbone, + preprocessor=preprocessor, + ) + qwen2_vl_lm.fit(x=features, batch_size=2) + ``` + """ + + backbone_cls = Qwen2VLBackbone + preprocessor_cls = Qwen2VLCausalLMPreprocessor + + def __init__(self, backbone, preprocessor=None, **kwargs): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + # === Functional model === + inputs = backbone.input + hidden_states = backbone(inputs=inputs) + outputs = backbone.token_embedding(hidden_states, reverse=True) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + def call_with_cache( + self, + token_ids, + cache, + cache_update_index, + img_embeddings=None, + ): + """Forward pass of `Qwen2VLCausalLM` with cache. + + `call_with_cache` adds an additional forward pass for the model for + autoregressive inference. Unlike calling the model directly, this + method allows caching previous key/value Tensors in multi-head + attention layer, and avoids recomputing the outputs of seen tokens. + + Args: + token_ids: a dense int Tensor with shape + `(batch_size, max_length)`. + cache: a dense float Tensor, the cache of key and value. + cache_update_index: int, or int Tensor. The index of current + inputs in the whole sequence. + img_embeddings: optional float Tensor of shape + `(total_merged_tokens, hidden_dim)`. Pre-computed vision + features from the vision encoder. When provided, they are + scattered into the text embeddings at positions matching + ``image_token_id``. + + Returns: + A (logits, hidden_states, cache) tuple. + """ + x = self.backbone.token_embedding(token_ids) + + # Scatter vision features into image placeholder positions. + if img_embeddings is not None: + image_mask = ops.equal( + token_ids, + ops.cast(self.backbone.image_token_id, token_ids.dtype), + ) + batch_size = ops.shape(x)[0] + seq_len = ops.shape(x)[1] + x_flat = ops.reshape(x, (-1, self.backbone.hidden_dim)) + mask_flat = ops.reshape(image_mask, (-1,)) + vision_indices = ops.where(mask_flat) + if isinstance(vision_indices, (list, tuple)): + vision_indices = vision_indices[0] + vision_indices = ops.reshape(vision_indices, (-1, 1)) + vision_indices = ops.cast(vision_indices, "int32") + x_flat = ops.scatter_update(x_flat, vision_indices, img_embeddings) + x = ops.reshape( + x_flat, (batch_size, seq_len, self.backbone.hidden_dim) + ) + + # Each decoder layer has a cache; we update them separately. + updated_cache = [] + for i in range(self.backbone.num_layers): + current_cache = cache[:, i, ...] + x, next_cache = self.backbone.transformer_layers[i]( + x, + self_attention_cache=current_cache, + self_attention_cache_update_index=cache_update_index, + ) + updated_cache.append(next_cache) + cache = ops.stack(updated_cache, axis=1) + hidden_states = x = self.backbone.layer_norm(x) + logits = self.backbone.token_embedding(x, reverse=True) + return logits, hidden_states, cache + + def _build_cache(self, token_ids, img_embeddings=None): + """Build an empty cache for use with `call_with_cache()`. + + Args: + token_ids: int Tensor of shape `(batch_size, max_length)`. + img_embeddings: optional float Tensor of pre-computed vision + features to scatter into image placeholder positions + during the initial seeding pass. + """ + batch_size = ops.shape(token_ids)[0] + max_length = ops.shape(token_ids)[1] + num_layers = self.backbone.num_layers + num_key_value_heads = self.backbone.num_key_value_heads + head_dim = self.backbone.hidden_dim // self.backbone.num_query_heads + shape = [ + batch_size, + num_layers, + 2, + max_length, + num_key_value_heads, + head_dim, + ] + cache = ops.zeros(shape, dtype=self.compute_dtype) + # Seed the cache. + _, hidden_states, cache = self.call_with_cache( + token_ids, + cache, + 0, + img_embeddings=img_embeddings, + ) + return hidden_states, cache + + def generate_step( + self, + inputs, + stop_token_ids=None, + ): + """A compilable generation function for a single batch of inputs. + + This function represents the inner, XLA-compilable, generation function + for a single batch of inputs. Inputs should have the same structure as + model inputs, a dictionary with keys `"token_ids"`, `"padding_mask"`, + and optionally `"patch_values"` and `"image_grid_thw"`. + + Args: + inputs: A dictionary with keys `"token_ids"`, `"padding_mask"`, + and optionally `"patch_values"` and `"image_grid_thw"`. + stop_token_ids: Tuple of id's of the end token to stop on. If all + sequences have produced a new stop token, generation + will stop. + """ + token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] + patch_values = inputs.get("patch_values", None) + image_grid_thw = inputs.get("image_grid_thw", None) + + # Run vision encoder if images are present. + img_embeddings = None + if patch_values is not None and image_grid_thw is not None: + img_embeddings = self.backbone.vision_encoder( + patch_values, image_grid_thw + ) + + # Create and seed cache with a single forward pass. + hidden_states, cache = self._build_cache( + token_ids, + img_embeddings=img_embeddings, + ) + # Compute the lengths of all user inputted tokens ids. + row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) + # Start at the first index that has no user inputted id. + index = ops.min(row_lengths) + + def next_token(prompt, cache, index): + # The cache index is the index of our previous token. + cache_update_index = index - 1 + batch_size = ops.shape(prompt)[0] + prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) + logits, hidden_states, cache = self.call_with_cache( + prompt, + cache, + cache_update_index, + ) + return ( + ops.squeeze(logits, axis=1), + ops.squeeze(hidden_states, axis=1), + cache, + ) + + token_ids = self.sampler( + next=next_token, + prompt=token_ids, + cache=cache, + index=index, + mask=padding_mask, + stop_token_ids=stop_token_ids, + hidden_states=hidden_states, + model=self, + ) + + # Compute an output padding mask with the token ids we updated. + if stop_token_ids is not None: + # Build a mask of stop token locations not in the original + # prompt (not in locations where `padding_mask` is True). + end_locations = any_equal( + token_ids, stop_token_ids, ops.logical_not(padding_mask) + ) + end_locations = ops.cast(end_locations, "int32") + # Use cumsum to get ones in all locations after end_locations. + cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") + overflow = cumsum - end_locations + # Our padding mask is the inverse of these overflow locations. + padding_mask = ops.logical_not(ops.cast(overflow, "bool")) + else: + # Without early stopping, all locations will have been updated. + padding_mask = ops.ones_like(token_ids, dtype="bool") + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } diff --git a/keras_hub/src/models/qwen2_vl/qwen2_vl_causal_lm_preprocessor.py b/keras_hub/src/models/qwen2_vl/qwen2_vl_causal_lm_preprocessor.py new file mode 100644 index 0000000000..72660606af --- /dev/null +++ b/keras_hub/src/models/qwen2_vl/qwen2_vl_causal_lm_preprocessor.py @@ -0,0 +1,175 @@ +import numpy as np + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor +from keras_hub.src.models.qwen2_vl.qwen2_vl_backbone import Qwen2VLBackbone +from keras_hub.src.models.qwen2_vl.qwen2_vl_image_converter import ( + Qwen2VLImageConverter, +) +from keras_hub.src.models.qwen2_vl.qwen2_vl_tokenizer import Qwen2VLTokenizer + + +@keras_hub_export("keras_hub.models.Qwen2VLCausalLMPreprocessor") +class Qwen2VLCausalLMPreprocessor(CausalLMPreprocessor): + """Qwen2-VL Causal LM Preprocessor. + + Handles tokenization, image preprocessing, and assembly of the full + input dict required by ``Qwen2VLBackbone``. + + When images are provided the preprocessor: + 1. Runs ``image_converter`` to get flat patches and ``grid_thw``. + 2. Computes the number of image placeholder tokens as + ``grid_t * grid_h * grid_w // spatial_merge_size²``. + 3. Inserts ``<|vision_start|>`` + N × ``<|image_pad|>`` + + ``<|vision_end|>`` tokens into the text token sequence at the + position of the first ``<|image_pad|>`` marker (or prepends + them if no marker is present). + 4. Pads / truncates to ``sequence_length`` and builds ``padding_mask``. + + Returns a dict with keys: + - ``"token_ids"``: int32 array of shape ``(seq_len,)``. + - ``"padding_mask"``: int32 array of shape ``(seq_len,)``. + - ``"patch_values"``: float32 array of shape + ``(total_patches, patch_flat_dim)`` or ``None``. + - ``"image_grid_thw"``: int32 array of shape ``(num_images, 3)`` or + ``None``. + + Args: + tokenizer: A ``keras_hub`` tokenizer instance. + image_converter: A ``Qwen2VLImageConverter`` instance or ``None``. + sequence_length: int. Maximum token sequence length. Defaults to + ``1024``. + spatial_merge_size: int. Must match the backbone's + ``spatial_merge_size``. Used to compute the number of image + placeholder tokens. Defaults to ``2``. + """ + + backbone_cls = Qwen2VLBackbone + tokenizer_cls = Qwen2VLTokenizer + image_converter_cls = Qwen2VLImageConverter + + def __init__( + self, + tokenizer, + image_converter=None, + sequence_length=1024, + spatial_merge_size=2, + **kwargs, + ): + super().__init__( + tokenizer=tokenizer, + sequence_length=sequence_length, + **kwargs, + ) + self.image_converter = image_converter + self.spatial_merge_size = spatial_merge_size + + def generate_preprocess(self, x, sequence_length=None): + """Preprocess a single example for generation. + + Args: + x: Either a plain string, or a dict with keys ``"text"`` + (str) and optionally ``"images"`` (NumPy array). + sequence_length: int or ``None``. Overrides + ``self.sequence_length`` when provided. + + Returns: + Dict with keys ``"token_ids"``, ``"padding_mask"``, + ``"patch_values"``, ``"image_grid_thw"``. + """ + seq_len = sequence_length or self.sequence_length + + if isinstance(x, dict): + text = x.get("text", "") + images = x.get("images", None) + else: + text = x + images = None + + # Tokenize text + token_ids = self.tokenizer(text) + if hasattr(token_ids, "cpu"): + token_ids = token_ids.cpu() + if hasattr(token_ids, "numpy"): + token_ids = token_ids.numpy() + token_ids = np.asarray(token_ids, dtype="int32").reshape(-1) + + patch_values = None + grid_thw = None + + if images is not None and self.image_converter is not None: + patches, grid_thw = self.image_converter.call(images) + patch_values = patches + + # Build vision token blocks for all images. + vision_blocks = [] + for i in range(grid_thw.shape[0]): + gt = int(grid_thw[i, 0]) + gh = int(grid_thw[i, 1]) + gw = int(grid_thw[i, 2]) + num_vision_tokens = (gt * gh * gw) // ( + self.spatial_merge_size**2 + ) + vision_block = np.array( + [self.tokenizer.vision_start_token_id] + + [self.tokenizer.image_token_id] * num_vision_tokens + + [self.tokenizer.vision_end_token_id], + dtype="int32", + ) + vision_blocks.append(vision_block) + + # Insert vision blocks: replace image markers if present, + # otherwise prepend all blocks. + combined_block = np.concatenate(vision_blocks) + img_marker_positions = np.where( + token_ids == self.tokenizer.image_token_id + )[0] + if len(img_marker_positions) > 0: + # Replace the first marker with all vision blocks + # concatenated (multi-image). + pos = img_marker_positions[0] + token_ids = np.concatenate( + [token_ids[:pos], combined_block, token_ids[pos + 1 :]] + ) + else: + token_ids = np.concatenate([combined_block, token_ids]) + + # Pad or truncate to seq_len + current_len = len(token_ids) + if current_len >= seq_len: + token_ids = token_ids[:seq_len] + padding_mask = np.ones(seq_len, dtype="int32") + else: + pad_len = seq_len - current_len + padding_mask = np.concatenate( + [ + np.ones(current_len, dtype="int32"), + np.zeros(pad_len, dtype="int32"), + ] + ) + token_ids = np.concatenate( + [ + token_ids, + np.full( + pad_len, + self.tokenizer.pad_token_id, + dtype="int32", + ), + ] + ) + + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + "patch_values": patch_values, + "image_grid_thw": grid_thw, + } + + def get_config(self): + config = super().get_config() + config.update( + { + "spatial_merge_size": self.spatial_merge_size, + } + ) + return config diff --git a/keras_hub/src/models/qwen2_vl/qwen2_vl_causal_lm_preprocessor_test.py b/keras_hub/src/models/qwen2_vl/qwen2_vl_causal_lm_preprocessor_test.py new file mode 100644 index 0000000000..acb1f062f7 --- /dev/null +++ b/keras_hub/src/models/qwen2_vl/qwen2_vl_causal_lm_preprocessor_test.py @@ -0,0 +1,111 @@ +import numpy as np +import pytest + +from keras_hub.src.models.qwen2_vl.qwen2_vl_causal_lm_preprocessor import ( + Qwen2VLCausalLMPreprocessor, +) +from keras_hub.src.models.qwen2_vl.qwen2_vl_image_converter import ( + Qwen2VLImageConverter, +) +from keras_hub.src.models.qwen2_vl.qwen2_vl_tokenizer import Qwen2VLTokenizer +from keras_hub.src.tests.test_case import TestCase + + +class Qwen2VLCausalLMPreprocessorTest(TestCase): + def setUp(self): + self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"] + self.vocab += ["<|endoftext|>"] + self.vocab += ["<|eot_id|>"] + self.vocab += ["<|vision_start|>"] + self.vocab += ["<|vision_end|>"] + self.vocab += ["<|image_pad|>"] + self.vocab += ["<|video_pad|>"] + self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) + self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"] + self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"] + self.merges += ["Ġai r", "Ġa i", "pla ne"] + self.tokenizer = Qwen2VLTokenizer( + vocabulary=self.vocab, + merges=self.merges, + ) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 8, + } + self.input_data = ["airplane at airport"] + + def test_preprocessor_basics(self): + self.run_preprocessor_test( + cls=Qwen2VLCausalLMPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=( + { + "token_ids": [[1, 3, 4, 2, 5, 6, 0, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]], + }, + [[3, 4, 2, 5, 6, 0, 0, 0]], + [[1, 1, 1, 1, 1, 0, 0, 0]], + ), + ) + + def test_with_end_token(self): + input_data = ["airplane at airport"] * 4 + + preprocessor = Qwen2VLCausalLMPreprocessor( + **self.init_kwargs, + add_end_token=True, + ) + x, y, sw = preprocessor(input_data) + self.assertAllEqual(x["token_ids"], [[1, 3, 4, 2, 5, 6, 0, 0]] * 4) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 1, 0, 0]] * 4) + self.assertAllEqual(y, [[3, 4, 2, 5, 6, 0, 0, 0]] * 4) + self.assertAllEqual(sw, [[1, 1, 1, 1, 1, 0, 0, 0]] * 4) + + def test_generate_preprocess(self): + input_data = "airplane at airport" + preprocessor = Qwen2VLCausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_preprocess(input_data) + self.assertAllEqual(x["token_ids"], [1, 3, 4, 2, 5, 0, 0, 0]) + self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 0, 0, 0]) + + def test_generate_postprocess(self): + input_data = { + "token_ids": [1, 3, 4, 2, 5, 6, 0, 0], + "padding_mask": [1, 1, 1, 1, 1, 1, 0, 0], + } + preprocessor = Qwen2VLCausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_postprocess(input_data) + self.assertAllEqual(x, "airplane at airport") + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in Qwen2VLCausalLMPreprocessor.presets: + self.run_preset_test( + cls=Qwen2VLCausalLMPreprocessor, + preset=preset, + input_data=self.input_data, + ) + + def test_generate_preprocess_with_image(self): + """Test that generate_preprocess handles image inputs.""" + image_converter = Qwen2VLImageConverter() + preprocessor = Qwen2VLCausalLMPreprocessor( + tokenizer=self.tokenizer, + image_converter=image_converter, + sequence_length=64, + spatial_merge_size=2, + ) + # Create a small dummy image (56x56 is the minimum after smart_resize). + dummy_image = np.random.randint(0, 255, (56, 56, 3), dtype=np.uint8) + x = {"text": "Describe this image", "images": dummy_image} + result = preprocessor.generate_preprocess(x) + + self.assertIn("token_ids", result) + self.assertIn("padding_mask", result) + self.assertIn("patch_values", result) + self.assertIn("image_grid_thw", result) + self.assertEqual(len(result["token_ids"]), 64) + self.assertEqual(len(result["padding_mask"]), 64) + self.assertIsNotNone(result["patch_values"]) + self.assertIsNotNone(result["image_grid_thw"]) diff --git a/keras_hub/src/models/qwen2_vl/qwen2_vl_image_converter.py b/keras_hub/src/models/qwen2_vl/qwen2_vl_image_converter.py new file mode 100644 index 0000000000..005b2eb5db --- /dev/null +++ b/keras_hub/src/models/qwen2_vl/qwen2_vl_image_converter.py @@ -0,0 +1,268 @@ +import math +import warnings + +import numpy as np + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter +from keras_hub.src.models.qwen2_vl.qwen2_vl_backbone import Qwen2VLBackbone + + +def smart_resize( + height, width, factor=28, min_pixels=56 * 56, max_pixels=12845056 +): + """Resize image dimensions so both are divisible by ``factor`` and the + total pixel count stays within ``[min_pixels, max_pixels]``. + + Args: + height: int. Original image height. + width: int. Original image width. + factor: int. Both output dims must be multiples of this value. + Defaults to ``28`` (``patch_size * merge_size = 14 * 2``). + min_pixels: int. Minimum total pixel count. Defaults to + ``56 * 56 = 3136``. + max_pixels: int. Maximum total pixel count. Defaults to + ``12845056`` (matching HuggingFace). + + Returns: + Tuple ``(h_bar, w_bar)`` of resized dimensions. + + Raises: + ValueError: If the absolute aspect ratio exceeds 200. + """ + if height <= 0 or width <= 0: + raise ValueError( + f"Height and width must be positive, " + f"got height={height}, width={width}." + ) + if max(height, width) / min(height, width) > 200: + raise ValueError( + f"Absolute aspect ratio must be smaller than 200, got " + f"{max(height, width) / min(height, width):.1f}." + ) + h_bar = round(height / factor) * factor + w_bar = round(width / factor) * factor + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = max(factor, math.floor(height / beta / factor) * factor) + w_bar = max(factor, math.floor(width / beta / factor) * factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor + return h_bar, w_bar + + +@keras_hub_export("keras_hub.layers.Qwen2VLImageConverter") +class Qwen2VLImageConverter(ImageConverter): + """Image preprocessor for Qwen2-VL. + + Converts a raw NumPy image (H, W, C) or a list of frames into the flat + patch tensor and ``grid_thw`` metadata required by + ``Qwen2VLVisionEncoder``. + + Processing steps: + 1. Smart-resize to dimensions divisible by ``patch_size * merge_size``. + 2. Rescale pixel values to ``[0, 1]``. + 3. Normalize with CLIP mean/std. + 4. Pad temporal dimension to a multiple of ``temporal_patch_size``. + 5. Reshape into flat patches of shape + ``(grid_t * grid_h * grid_w, + in_channels * temporal_patch_size * patch_size²)``. + + Returns a tuple ``(patches, grid_thw)`` where ``grid_thw`` is a + NumPy array of shape ``(num_images, 3)`` with ``[grid_t, grid_h, grid_w]``. + + Args: + min_pixels: int. Minimum total pixel count after resize. + Defaults to ``56 * 56``. + max_pixels: int. Maximum total pixel count after resize. + Defaults to ``12845056``. + patch_size: int. Spatial patch size. Defaults to ``14``. + temporal_patch_size: int. Temporal patch size. Defaults to ``2``. + merge_size: int. Spatial merge factor (used to compute the resize + factor ``patch_size * merge_size``). Defaults to ``2``. + image_mean: list of float. Per-channel normalization mean. + Defaults to CLIP mean. + image_std: list of float. Per-channel normalization std. + Defaults to CLIP std. + + Note: + The ``dtype`` is always forced to ``float32`` for preprocessing + regardless of any value passed by the caller. A warning is emitted + if a non-default ``dtype`` is supplied. + """ + + backbone_cls = Qwen2VLBackbone + + def __init__( + self, + min_pixels=56 * 56, + max_pixels=12845056, + patch_size=14, + temporal_patch_size=2, + merge_size=2, + image_mean=(0.48145466, 0.4578275, 0.40821073), + image_std=(0.26862954, 0.26130258, 0.27577711), + **kwargs, + ): + # Force float32 for image preprocessing. + user_dtype = kwargs.pop("dtype", None) + if user_dtype is not None: + warnings.warn( + f"Qwen2VLImageConverter forces dtype='float32' for " + f"preprocessing. The supplied dtype='{user_dtype}' " + f"will be ignored.", + stacklevel=2, + ) + super().__init__(dtype="float32", **kwargs) + self.min_pixels = min_pixels + self.max_pixels = max_pixels + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.merge_size = merge_size + self.image_mean = np.array(image_mean, dtype="float32") + self.image_std = np.array(image_std, dtype="float32") + self._factor = patch_size * merge_size + + def call(self, image): + """Preprocess a single image, a list of video frames, or + a list of separate images. + + Args: + image: NumPy array of shape ``(H, W, C)`` for a single image, + ``(T, H, W, C)`` for a video clip, or a **list** of + NumPy arrays (each ``(H, W, C)``) for multiple separate + images. Pixel values should be in ``[0, 255]``. + + Returns: + Tuple ``(patches, grid_thw)``: + - ``patches``: float32 NumPy array of shape + ``(total_patches, + C * temporal_patch_size * patch_size²)``. + - ``grid_thw``: int32 NumPy array of shape + ``(num_images, 3)`` with ``[grid_t, grid_h, grid_w]`` + per image. + """ + # Handle a list of separate images by processing each one + # independently and concatenating the results. + if isinstance(image, list): + all_patches = [] + all_grids = [] + for img in image: + p, g = self.call(img) + all_patches.append(p) + all_grids.append(g) + return ( + np.concatenate(all_patches, axis=0), + np.concatenate(all_grids, axis=0), + ) + + image = np.asarray(image, dtype="float32") + if image.ndim == 3: + image = image[np.newaxis] + + height, width = image.shape[1], image.shape[2] + resized_h, resized_w = smart_resize( + height, + width, + factor=self._factor, + min_pixels=self.min_pixels, + max_pixels=self.max_pixels, + ) + + frames = [] + for frame in image: + if resized_h != height or resized_w != width: + frame = self._resize_frame(frame, resized_h, resized_w) + frames.append(frame) + + patches = np.stack(frames, axis=0) + + patches = patches / np.float32(255.0) + patches = (patches - self.image_mean) / self.image_std + + patches = patches.transpose(0, 3, 1, 2) + + num_frames = patches.shape[0] + if num_frames % self.temporal_patch_size != 0: + repeat = self.temporal_patch_size - ( + num_frames % self.temporal_patch_size + ) + patches = np.concatenate( + [patches, np.repeat(patches[-1:], repeat, axis=0)], axis=0 + ) + + channel = patches.shape[1] + grid_t = patches.shape[0] // self.temporal_patch_size + grid_h = resized_h // self.patch_size + grid_w = resized_w // self.patch_size + + patches = patches.reshape( + grid_t, + self.temporal_patch_size, + channel, + grid_h // self.merge_size, + self.merge_size, + self.patch_size, + grid_w // self.merge_size, + self.merge_size, + self.patch_size, + ) + patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8) + patches = patches.reshape( + grid_t * grid_h * grid_w, + channel + * self.temporal_patch_size + * self.patch_size + * self.patch_size, + ) + + grid_thw = np.array([[grid_t, grid_h, grid_w]], dtype="int32") + return patches, grid_thw + + def _resize_frame(self, frame, target_h, target_w): + """Resize a single frame using PIL (preferred) or NumPy fallback.""" + try: + from PIL import Image as PILImage + + if hasattr(PILImage, "Resampling"): + resample = PILImage.Resampling.BICUBIC + else: + resample = PILImage.BICUBIC + pil = PILImage.fromarray(frame.astype("uint8")) + pil = pil.resize((target_w, target_h), resample) + return np.array(pil, dtype="float32") + except ImportError: + return _numpy_resize(frame, target_h, target_w) + + def get_config(self): + config = super().get_config() + config.update( + { + "min_pixels": self.min_pixels, + "max_pixels": self.max_pixels, + "patch_size": self.patch_size, + "temporal_patch_size": self.temporal_patch_size, + "merge_size": self.merge_size, + "image_mean": list(self.image_mean), + "image_std": list(self.image_std), + } + ) + return config + + +def _numpy_resize(frame, new_h, new_w): + """Fallback bilinear resize using NumPy (no PIL dependency).""" + old_h, old_w = frame.shape[:2] + row_idx = np.linspace(0, old_h - 1, new_h) + col_idx = np.linspace(0, old_w - 1, new_w) + r0 = np.floor(row_idx).astype(int).clip(0, old_h - 1) + r1 = np.ceil(row_idx).astype(int).clip(0, old_h - 1) + c0 = np.floor(col_idx).astype(int).clip(0, old_w - 1) + c1 = np.ceil(col_idx).astype(int).clip(0, old_w - 1) + dr = (row_idx - r0)[:, np.newaxis, np.newaxis] + dc = (col_idx - c0)[np.newaxis, :, np.newaxis] + top = frame[r0][:, c0] * (1 - dc) + frame[r0][:, c1] * dc + bot = frame[r1][:, c0] * (1 - dc) + frame[r1][:, c1] * dc + return (top * (1 - dr) + bot * dr).astype("float32") diff --git a/keras_hub/src/models/qwen2_vl/qwen2_vl_image_converter_test.py b/keras_hub/src/models/qwen2_vl/qwen2_vl_image_converter_test.py new file mode 100644 index 0000000000..a06627736c --- /dev/null +++ b/keras_hub/src/models/qwen2_vl/qwen2_vl_image_converter_test.py @@ -0,0 +1,103 @@ +import numpy as np + +from keras_hub.src.models.qwen2_vl.qwen2_vl_image_converter import ( + Qwen2VLImageConverter, +) +from keras_hub.src.models.qwen2_vl.qwen2_vl_image_converter import smart_resize +from keras_hub.src.tests.test_case import TestCase + + +class SmartResizeTest(TestCase): + def test_smart_resize_basic(self): + # Should round to nearest multiple of 28 + h, w = smart_resize(100, 100, factor=28) + self.assertEqual(h % 28, 0) + self.assertEqual(w % 28, 0) + self.assertEqual(h, 112) + self.assertEqual(w, 112) + + def test_smart_resize_max_pixels(self): + # Should scale down if exceeds max_pixels + h, w = smart_resize(5000, 5000, factor=28, max_pixels=1000000) + self.assertLessEqual(h * w, 1000000) + self.assertEqual(h % 28, 0) + self.assertEqual(w % 28, 0) + + def test_smart_resize_min_pixels(self): + # Should scale up if below min_pixels + h, w = smart_resize(10, 10, factor=28, min_pixels=56 * 56) + self.assertGreaterEqual(h * w, 56 * 56) + self.assertEqual(h % 28, 0) + self.assertEqual(w % 28, 0) + + def test_smart_resize_aspect_ratio_error(self): + # Should raise error if aspect ratio > 200 + with self.assertRaises(ValueError): + smart_resize(10000, 10, factor=28) + + +class Qwen2VLImageConverterTest(TestCase): + def setUp(self): + self.converter = Qwen2VLImageConverter( + patch_size=14, + temporal_patch_size=2, + merge_size=2, + ) + + def test_converter_output_shape(self): + # Single image: (H, W, C) + image = np.random.rand(224, 224, 3).astype("float32") + patches, grid_thw = self.converter(image) + + # patches should be flat: (total_patches, patch_flat_dim) + self.assertEqual(len(patches.shape), 2) + # grid_thw should be (num_images, 3) + self.assertEqual(grid_thw.shape, (1, 3)) + + def test_converter_multiple_images(self): + # Multiple images as list + images = [ + np.random.rand(224, 224, 3).astype("float32"), + np.random.rand(448, 224, 3).astype("float32"), + ] + patches, grid_thw = self.converter(images) + + # grid_thw should have 2 rows (2 images) + self.assertEqual(grid_thw.shape[0], 2) + self.assertEqual(grid_thw.shape[1], 3) + + def test_converter_config_roundtrip(self): + config = self.converter.get_config() + new_converter = Qwen2VLImageConverter.from_config(config) + + # Test that it works the same + image = np.random.rand(224, 224, 3).astype("float32") + patches1, grid1 = self.converter(image) + patches2, grid2 = new_converter(image) + + self.assertEqual(patches1.shape, patches2.shape) + np.testing.assert_array_equal(grid1, grid2) + + def test_image_normalization(self): + # Test that images are normalized with correct mean/std. + # A uniform 255 image becomes (1.0 - mean) / std per channel. + image = np.ones((224, 224, 3), dtype="float32") * 255 + patches, grid_thw = self.converter(image) + + mean = np.array([0.48145466, 0.4578275, 0.40821073]) + std = np.array([0.26862954, 0.26130258, 0.27577711]) + expected = (1.0 - mean) / std # per-channel expected values + + # Each patch row is flattened as + # (temporal_patch_size * patch_size² values) per channel, so + # reshape to (..., 3) to extract per-channel means. + ps = self.converter.patch_size + tps = self.converter.temporal_patch_size + block = tps * ps * ps # values per channel per patch + # patches shape: (total_patches, 3 * block) + reshaped = patches.reshape(-1, 3, block) # (N, C, block) + channel_means = reshaped.mean(axis=(0, 2)) # (3,) + + np.testing.assert_allclose( + channel_means, expected, atol=1e-5, rtol=1e-5 + ) diff --git a/keras_hub/src/models/qwen2_vl/qwen2_vl_presets.py b/keras_hub/src/models/qwen2_vl/qwen2_vl_presets.py new file mode 100644 index 0000000000..d8a1e1a3f7 --- /dev/null +++ b/keras_hub/src/models/qwen2_vl/qwen2_vl_presets.py @@ -0,0 +1,84 @@ +# TODO: Add Qwen2-VL presets once PR is accepted +# Right now it is filled with dummy values + +"""Qwen2-VL preset configurations.""" + +backbone_presets = {} + +# backbone_presets = { +# "qwen2_vl_2b_base": { +# "metadata": { +# "description": ( +# "2 billion parameter Qwen2-VL vision-language base model. " +# "Pretrained model without instruction tuning. Supports image " +# "and video understanding with multimodal " +# "rotary position embeddings." +# ), +# "params": 2014341120, +# "path": "qwen2_vl", +# }, +# "kaggle_handle": "kaggle://keras/qwen2_vl/keras/qwen2_vl_2b_base/1", +# }, +# "qwen2_vl_2b_instruct": { +# "metadata": { +# "description": ( +# "2 billion parameter Qwen2-VL vision-language model. " +# "Instruction-tuned for image and video understanding with " +# "multimodal rotary position embeddings." +# ), +# "params": 2014341120, +# "path": "qwen2_vl", +# }, +# "kaggle_handle": "kaggle://keras/qwen2_vl/keras/qwen2_vl_2b_instruct/1", +# }, +# "qwen2_vl_7b_base": { +# "metadata": { +# "description": ( +# "7 billion parameter Qwen2-VL vision-language base model. " +# "Pretrained model without instruction tuning. Supports image " +# "and video understanding with multimodal " +# "rotary position embeddings." +# ), +# "params": 7616192512, +# "path": "qwen2_vl", +# }, +# "kaggle_handle": "kaggle://keras/qwen2_vl/keras/qwen2_vl_7b_base/1", +# }, +# "qwen2_vl_7b_instruct": { +# "metadata": { +# "description": ( +# "7 billion parameter Qwen2-VL vision-language model. " +# "Instruction-tuned for image and video understanding with " +# "multimodal rotary position embeddings." +# ), +# "params": 7616192512, +# "path": "qwen2_vl", +# }, +# "kaggle_handle": "kaggle://keras/qwen2_vl/keras/qwen2_vl_7b_instruct/1", +# }, +# "qwen2_vl_72b_base": { +# "metadata": { +# "description": ( +# "72 billion parameter Qwen2-VL vision-language base model. " +# "Pretrained model without instruction tuning. Supports image " +# "and video understanding with multimodal " +# "rotary position embeddings." +# ), +# "params": 72706203648, +# "path": "qwen2_vl", +# }, +# "kaggle_handle": "kaggle://keras/qwen2_vl/keras/qwen2_vl_72b_base/1", +# }, +# "qwen2_vl_72b_instruct": { +# "metadata": { +# "description": ( +# "72 billion parameter Qwen2-VL vision-language model. " +# "Instruction-tuned for image and video understanding with " +# "multimodal rotary position embeddings." +# ), +# "params": 72706203648, +# "path": "qwen2_vl", +# }, +# "kaggle_handle": "kaggle://keras/qwen2_vl/keras/qwen2_vl_72b_instruct/1", +# }, +# } diff --git a/keras_hub/src/models/qwen2_vl/qwen2_vl_tokenizer.py b/keras_hub/src/models/qwen2_vl/qwen2_vl_tokenizer.py new file mode 100644 index 0000000000..2129753e2d --- /dev/null +++ b/keras_hub/src/models/qwen2_vl/qwen2_vl_tokenizer.py @@ -0,0 +1,91 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.qwen.qwen_tokenizer import QwenTokenizer +from keras_hub.src.models.qwen2_vl.qwen2_vl_backbone import Qwen2VLBackbone + + +@keras_hub_export( + [ + "keras_hub.models.Qwen2VLTokenizer", + "keras_hub.tokenizers.Qwen2VLTokenizer", + ] +) +class Qwen2VLTokenizer(QwenTokenizer): + """Qwen2-VL tokenizer layer. + + This tokenizer layer provides an implementation of a Qwen2-VL tokenizer + using the BytePair (BPE) method. It includes vocabulary and merges data + necessary for tokenizing Qwen2-VL model inputs. + + In addition to standard text tokenization, this tokenizer exposes + vision-related token IDs used by the preprocessor to construct + multimodal input sequences: + + - ``image_token_id``: resolved from ``<|image_pad|>`` (HF ID 151655). + One placeholder per merged vision patch is inserted by the + preprocessor. + - ``video_token_id``: resolved from ``<|video_pad|>`` (HF ID 151656). + - ``vision_start_token_id``: resolved from ``<|vision_start|>`` + (HF ID 151652). Marks the start of a vision token block. + - ``vision_end_token_id``: resolved from ``<|vision_end|>`` + (HF ID 151653). Marks the end of a vision token block. + + Note: ``<|image_pad|>`` and ``<|video_pad|>`` are defined in + HuggingFace's ``tokenizer_config.json`` (``added_tokens_decoder``) + but are absent from ``tokenizer.json``'s ``added_tokens`` list. + The converter loads them from ``tokenizer_config.json`` so they + are present in the vocabulary passed to this class. + + Args: + vocabulary: string or dict, maps token to integer ids. If it is a + string, it should be the file path to a json file. + merges: string or list, contains the merge rule. If it is a string, + it should be the file path to merge rules. The merge rule file + should have one merge rule per line. + + Examples: + + ```python + # Unbatched input. + tokenizer = keras_hub.models.Qwen2VLTokenizer.from_preset( + "qwen2_vl_2b_instruct", + ) + tokenizer("The quick brown fox jumped.") + + # Batched input. + tokenizer(["The quick brown fox jumped.", "The fox slept."]) + + # Detokenization. + tokenizer.detokenize([[151643, 791, 4320, 14198]]) + ``` + """ + + backbone_cls = Qwen2VLBackbone + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._init_vision_token_ids() + + def set_vocabulary_and_merges(self, vocabulary, merges): + """Override to re-resolve vision token IDs after vocabulary load. + + Keras preset deserialization is two-phase: ``__init__`` is called + with ``vocabulary=None``, and then ``load_preset_assets`` + loads the real vocabulary from files and calls this method. + By hooking here we ensure vision token IDs are set + correctly after both phases. + """ + super().set_vocabulary_and_merges(vocabulary, merges) + self._init_vision_token_ids() + + def _init_vision_token_ids(self): + """Resolve vision token IDs from the vocabulary.""" + if self.vocabulary is not None: + self.image_token_id = self.token_to_id("<|image_pad|>") + self.video_token_id = self.token_to_id("<|video_pad|>") + self.vision_start_token_id = self.token_to_id("<|vision_start|>") + self.vision_end_token_id = self.token_to_id("<|vision_end|>") + else: + self.image_token_id = None + self.video_token_id = None + self.vision_start_token_id = None + self.vision_end_token_id = None diff --git a/keras_hub/src/models/qwen2_vl/qwen2_vl_vision_encoder.py b/keras_hub/src/models/qwen2_vl/qwen2_vl_vision_encoder.py new file mode 100644 index 0000000000..576039d38e --- /dev/null +++ b/keras_hub/src/models/qwen2_vl/qwen2_vl_vision_encoder.py @@ -0,0 +1,505 @@ +import numpy as np +from keras import layers +from keras import ops + + +def _quick_gelu(x): + """Quick GELU: x * sigmoid(1.702 * x).""" + return x * ops.sigmoid(ops.cast(1.702, x.dtype) * x) + + +def _rotate_half(x): + """Rotate last dim by splitting and negating halves.""" + half = ops.shape(x)[-1] // 2 + x1 = x[..., :half] + x2 = x[..., half:] + return ops.concatenate([-x2, x1], axis=-1) + + +def _apply_rotary_pos_emb_vision(q, k, cos, sin): + """Apply RoPE to vision q and k. + + Args: + q: (seq_len, num_heads, head_dim) + k: (seq_len, num_heads, head_dim) + cos: (seq_len, head_dim) + sin: (seq_len, head_dim) + """ + cos = ops.expand_dims(cos, axis=-2) + sin = ops.expand_dims(sin, axis=-2) + q_embed = (q * cos) + (_rotate_half(q) * sin) + k_embed = (k * cos) + (_rotate_half(k) * sin) + return q_embed, k_embed + + +class Qwen2VLVisionRotaryEmbedding(layers.Layer): + """Vision RoPE: returns raw frequency table of shape (seqlen, dim//2). + + Args: + dim: int. Half of the attention head dimension (head_dim // 2). + theta: float. RoPE base. Defaults to 10000.0. + """ + + def __init__(self, dim, theta=10000.0, dtype=None, **kwargs): + super().__init__(dtype=dtype, **kwargs) + self.dim = dim + self.theta = theta + + inv_freq_init = 1.0 / ( + theta ** (np.arange(0, dim, 2, dtype="float32") / dim) + ) + self.inv_freq = self.add_weight( + name="inv_freq", + shape=(len(inv_freq_init),), + initializer="zeros", + trainable=False, + ) + self.inv_freq.assign(inv_freq_init) + + def call(self, seqlen): + seq = ops.cast(ops.arange(seqlen), "float32") + freqs = ops.outer(seq, ops.cast(self.inv_freq, "float32")) + return freqs + + def get_config(self): + config = super().get_config() + config.update({"dim": self.dim, "theta": self.theta}) + return config + + +class Qwen2VLVisionAttention(layers.Layer): + """Fused-QKV self-attention with RoPE for the vision encoder. + + Has single qkv Dense (bias=True), + proj Dense (bias=True), manual scaled dot-product attention. + + Args: + embed_dim: int. Vision encoder embedding dimension. + num_heads: int. Number of attention heads. + """ + + def __init__(self, embed_dim, num_heads, dtype=None, **kwargs): + super().__init__(dtype=dtype, **kwargs) + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.scale = self.head_dim**-0.5 + self.qkv = layers.Dense( + embed_dim * 3, use_bias=True, dtype=dtype, name="qkv" + ) + self.proj = layers.Dense( + embed_dim, use_bias=True, dtype=dtype, name="proj" + ) + + def build(self, input_shape): + self.qkv.build(input_shape) + self.proj.build([None, self.embed_dim]) + self.built = True + + def call(self, x, position_embeddings=None): + seq_len = ops.shape(x)[0] + qkv = self.qkv(x) + qkv = ops.reshape(qkv, (seq_len, 3, self.num_heads, self.head_dim)) + q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2] + if position_embeddings is not None: + cos, sin = position_embeddings + q, k = _apply_rotary_pos_emb_vision(q, k, cos, sin) + q = ops.transpose(q, (1, 0, 2)) + k = ops.transpose(k, (1, 0, 2)) + v = ops.transpose(v, (1, 0, 2)) + attn = ops.matmul(q, ops.transpose(k, (0, 2, 1))) * self.scale + attn = ops.softmax(ops.cast(attn, "float32"), axis=-1) + attn = ops.cast(attn, q.dtype) + out = ops.matmul(attn, v) + out = ops.transpose(out, (1, 0, 2)) + out = ops.reshape(out, (seq_len, self.embed_dim)) + return self.proj(out) + + def get_config(self): + config = super().get_config() + config.update( + {"embed_dim": self.embed_dim, "num_heads": self.num_heads} + ) + return config + + +class Qwen2VLVisionMlp(layers.Layer): + """Two-layer MLP with quick_gelu for the vision transformer block. + + Args: + embed_dim: int. Input/output dimension. + mlp_dim: int. Hidden dimension. + """ + + def __init__(self, embed_dim, mlp_dim, dtype=None, **kwargs): + super().__init__(dtype=dtype, **kwargs) + self.embed_dim = embed_dim + self.mlp_dim = mlp_dim + self.fc1 = layers.Dense(mlp_dim, use_bias=True, dtype=dtype, name="fc1") + self.fc2 = layers.Dense( + embed_dim, use_bias=True, dtype=dtype, name="fc2" + ) + + def build(self, input_shape): + self.fc1.build(input_shape) + self.fc2.build([None, self.mlp_dim]) + self.built = True + + def call(self, x): + return self.fc2(_quick_gelu(self.fc1(x))) + + def get_config(self): + config = super().get_config() + config.update({"embed_dim": self.embed_dim, "mlp_dim": self.mlp_dim}) + return config + + +class Qwen2VLVisionBlock(layers.Layer): + """Single vision transformer block. + + Pre-norm, fused-QKV attention, quick_gelu MLP. + + Args: + embed_dim: int. Vision encoder embedding dimension. + num_heads: int. Number of attention heads. + mlp_ratio: float. MLP hidden dim multiplier. Defaults to 4. + """ + + def __init__(self, embed_dim, num_heads, mlp_ratio=4, dtype=None, **kwargs): + super().__init__(dtype=dtype, **kwargs) + self.embed_dim = embed_dim + self.num_heads = num_heads + self.mlp_ratio = mlp_ratio + self.norm1 = layers.LayerNormalization( + epsilon=1e-6, dtype=dtype, name="norm1" + ) + self.attn = Qwen2VLVisionAttention( + embed_dim=embed_dim, num_heads=num_heads, dtype=dtype, name="attn" + ) + self.norm2 = layers.LayerNormalization( + epsilon=1e-6, dtype=dtype, name="norm2" + ) + self.mlp = Qwen2VLVisionMlp( + embed_dim=embed_dim, + mlp_dim=int(embed_dim * mlp_ratio), + dtype=dtype, + name="mlp", + ) + + def build(self, input_shape): + self.norm1.build(input_shape) + self.attn.build(input_shape) + self.norm2.build(input_shape) + self.mlp.build(input_shape) + self.built = True + + def call(self, x, position_embeddings=None): + x = x + self.attn( + self.norm1(x), position_embeddings=position_embeddings + ) + x = x + self.mlp(self.norm2(x)) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "embed_dim": self.embed_dim, + "num_heads": self.num_heads, + "mlp_ratio": self.mlp_ratio, + } + ) + return config + + +class Qwen2VLPatchMerger(layers.Layer): + """Merges spatial patches and projects to the LLM hidden dimension. + + Consists of: + - LayerNorm on vision features. + - Reshape: group spatial_merge_size² adjacent tokens. + - Two-layer MLP (Linear → GELU → Linear). + + Args: + hidden_size: int. Output dimension (LLM hidden size). + embed_dim: int. Vision encoder embedding dimension. + spatial_merge_size: int. Spatial merge factor. Defaults to 2. + """ + + def __init__( + self, hidden_size, embed_dim, spatial_merge_size=2, dtype=None, **kwargs + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_size = hidden_size + self.embed_dim = embed_dim + self.spatial_merge_size = spatial_merge_size + self.mlp_hidden = embed_dim * (spatial_merge_size**2) + self.ln_q = layers.LayerNormalization( + epsilon=1e-6, dtype=dtype, name="ln_q" + ) + self.mlp_fc1 = layers.Dense( + self.mlp_hidden, use_bias=True, dtype=dtype, name="mlp_fc1" + ) + self.mlp_fc2 = layers.Dense( + hidden_size, use_bias=True, dtype=dtype, name="mlp_fc2" + ) + + def build(self, input_shape): + self.ln_q.build(input_shape) + self.mlp_fc1.build([None, self.mlp_hidden]) + self.mlp_fc2.build([None, self.mlp_hidden]) + self.built = True + + def call(self, x): + x = self.ln_q(x) + total = ops.shape(x)[0] + merge_sq = self.spatial_merge_size**2 + x = ops.reshape(x, (total // merge_sq, self.mlp_hidden)) + x = ops.gelu(self.mlp_fc1(x)) + x = self.mlp_fc2(x) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "embed_dim": self.embed_dim, + "spatial_merge_size": self.spatial_merge_size, + } + ) + return config + + +class Qwen2VLVisionEncoder(layers.Layer): + """Qwen2-VL Vision Encoder (3D ViT with RoPE and PatchMerger). + + Accepts a flat patch tensor produced by ``Qwen2VLImageConverter`` + of shape ``(total_patches, C * temp_patch * patch²)`` + and a ``grid_thw`` tensor of shape ``(num_images, 3)``. + + Returns merged vision features of shape + ``(total_patches // spatial_merge_size², hidden_size)``. + + Args: + patch_size: int. Spatial patch size. Defaults to 14. + temporal_patch_size: int. Temporal patch size. Defaults to 2. + in_channels: int. Input image channels. Defaults to 3. + embed_dim: int. ViT internal embedding dimension. Defaults to 1280. + hidden_size: int. LLM hidden dimension (PatchMerger output). + Defaults to 3584. + depth: int. Number of vision transformer blocks. Defaults to 32. + num_heads: int. Number of attention heads. Defaults to 16. + mlp_ratio: float. MLP hidden dim multiplier. Defaults to 4. + spatial_merge_size: int. Spatial merge factor. Defaults to 2. + """ + + def __init__( + self, + patch_size=14, + temporal_patch_size=2, + in_channels=3, + embed_dim=1280, + hidden_size=3584, + depth=32, + num_heads=16, + mlp_ratio=4, + spatial_merge_size=2, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.depth = depth + self.num_heads = num_heads + self.mlp_ratio = mlp_ratio + self.spatial_merge_size = spatial_merge_size + + self.patch_embed = layers.Conv3D( + filters=embed_dim, + kernel_size=(temporal_patch_size, patch_size, patch_size), + strides=(temporal_patch_size, patch_size, patch_size), + padding="valid", + use_bias=False, + dtype=dtype, + name="patch_embed", + ) + head_dim = embed_dim // num_heads + self.rotary_pos_emb = Qwen2VLVisionRotaryEmbedding( + dim=head_dim // 2, theta=10000.0, dtype=dtype, name="rotary_pos_emb" + ) + self.blocks = [ + Qwen2VLVisionBlock( + embed_dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + dtype=dtype, + name=f"blocks_{i}", + ) + for i in range(depth) + ] + self.merger = Qwen2VLPatchMerger( + hidden_size=hidden_size, + embed_dim=embed_dim, + spatial_merge_size=spatial_merge_size, + dtype=dtype, + name="merger", + ) + + # Eagerly build all sublayers so their variables exist for + # weight loading. This is necessary because the vision encoder + # is NOT part of the backbone's Functional graph — Keras will + # not auto-build it during deserialization. + self.build() + + def build(self, input_shape=None): + """Build all sublayers so their variables exist for weight loading.""" + if self.built: + return + # Conv3D sees 5-D input after reshape+transpose in call(). + conv_shape = ( + None, + self.temporal_patch_size, + self.patch_size, + self.patch_size, + self.in_channels, + ) + self.patch_embed.build(conv_shape) + + # Blocks and merger operate on (seq_len, embed_dim). + block_shape = (None, self.embed_dim) + for block in self.blocks: + block.build(block_shape) + self.merger.build(block_shape) + self.built = True + + def _rot_pos_emb(self, grid_thw): + """Build per-token (cos, sin) from grid_thw. + + Replicates HF rot_pos_emb: spatial-merge interleaved h/w pos ids, + indexed into the rotary frequency table. + + Args: + grid_thw: int tensor of shape (num_images, 3) — [t, h, w]. + + Returns: + Tuple (cos, sin) each of shape (total_tokens, head_dim). + """ + pos_ids_list = [] + # grid_thw is a NumPy array at preprocessing time; use len() so this + # works both eagerly and in traced graphs where shape[0] is known. + num_images = ( + grid_thw.shape[0] + if hasattr(grid_thw, "shape") + else ops.shape(grid_thw)[0] + ) + for i in range(int(num_images)): + t = grid_thw[i, 0] + h = grid_thw[i, 1] + w = grid_thw[i, 2] + hpos = ops.reshape(ops.arange(h), (h, 1)) + hpos = ops.broadcast_to(hpos, (h, w)) + hpos = ops.reshape( + hpos, + ( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ), + ) + hpos = ops.transpose(hpos, (0, 2, 1, 3)) + hpos = ops.reshape(hpos, (-1,)) + wpos = ops.reshape(ops.arange(w), (1, w)) + wpos = ops.broadcast_to(wpos, (h, w)) + wpos = ops.reshape( + wpos, + ( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ), + ) + wpos = ops.transpose(wpos, (0, 2, 1, 3)) + wpos = ops.reshape(wpos, (-1,)) + hw_ids = ops.stack([hpos, wpos], axis=-1) + hw_ids = ops.tile(hw_ids, [int(t), 1]) + pos_ids_list.append(hw_ids) + pos_ids = ops.concatenate(pos_ids_list, axis=0) + max_grid_size = ops.max(grid_thw[:, 1:]) + rotary_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_emb = ops.take( + rotary_emb_full, ops.reshape(pos_ids, (-1,)), axis=0 + ) + rotary_emb = ops.reshape(rotary_emb, (ops.shape(pos_ids)[0], -1)) + emb = ops.concatenate([rotary_emb, rotary_emb], axis=-1) + cos = ops.cos(emb) + sin = ops.sin(emb) + return cos, sin + + def call(self, hidden_states, grid_thw=None): + """Forward pass. + + Args: + hidden_states: Flat patch tensor of shape + ``(total_patches, C * temp_patch * patch²)``. + Each row is one flattened patch group as produced by + ``Qwen2VLImageConverter``. + grid_thw: Int tensor of shape ``(num_images, 3)``. + + Returns: + Merged features of shape + ``(total_patches // spatial_merge_size², hidden_size)``. + """ + # Reshape flat patches into 5D for Conv3D: + # (N, in_channels, temporal_patch_size, patch_size, patch_size) + hidden_states = ops.reshape( + hidden_states, + ( + -1, + self.in_channels, + self.temporal_patch_size, + self.patch_size, + self.patch_size, + ), + ) + # Conv3D expects (batch, d, h, w, channels) in channels-last Keras. + # Transpose from (N, C, T, P, P) → (N, T, P, P, C). + hidden_states = ops.transpose(hidden_states, (0, 2, 3, 4, 1)) + hidden_states = self.patch_embed(hidden_states) + hidden_states = ops.reshape(hidden_states, (-1, self.embed_dim)) + + position_embeddings = None + if grid_thw is not None: + cos, sin = self._rot_pos_emb(grid_thw) + position_embeddings = (cos, sin) + + for block in self.blocks: + hidden_states = block( + hidden_states, position_embeddings=position_embeddings + ) + + hidden_states = self.merger(hidden_states) + return hidden_states + + def get_config(self): + config = super().get_config() + config.update( + { + "patch_size": self.patch_size, + "temporal_patch_size": self.temporal_patch_size, + "in_channels": self.in_channels, + "embed_dim": self.embed_dim, + "hidden_size": self.hidden_size, + "depth": self.depth, + "num_heads": self.num_heads, + "mlp_ratio": self.mlp_ratio, + "spatial_merge_size": self.spatial_merge_size, + } + ) + return config diff --git a/keras_hub/src/models/qwen2_vl/qwen2_vl_vision_encoder_test.py b/keras_hub/src/models/qwen2_vl/qwen2_vl_vision_encoder_test.py new file mode 100644 index 0000000000..27a0e7be92 --- /dev/null +++ b/keras_hub/src/models/qwen2_vl/qwen2_vl_vision_encoder_test.py @@ -0,0 +1,103 @@ +import numpy as np +import pytest + +from keras_hub.src.models.qwen2_vl.qwen2_vl_vision_encoder import ( + Qwen2VLVisionEncoder, +) +from keras_hub.src.tests.test_case import TestCase + + +class Qwen2VLVisionEncoderTest(TestCase): + def setUp(self): + self.init_kwargs = { + "patch_size": 14, + "temporal_patch_size": 2, + "in_channels": 3, + "embed_dim": 64, + "hidden_size": 128, + "depth": 2, + "num_heads": 4, + "mlp_ratio": 4, + "spatial_merge_size": 2, + } + + def test_vision_encoder_basics(self): + encoder = Qwen2VLVisionEncoder(**self.init_kwargs) + + # Derive patch dimensions from init_kwargs to avoid drift. + kw = self.init_kwargs + patch_flat_dim = ( + kw["in_channels"] + * kw["temporal_patch_size"] + * kw["patch_size"] ** 2 + ) + + # 1 image with t=2, h=2, w=2 → total_patches = 8 + grid_thw = np.array([[2, 2, 2]], dtype="int32") + total_patches = int(np.prod(grid_thw)) + hidden_states = np.random.rand(total_patches, patch_flat_dim).astype( + "float32" + ) + + output = encoder(hidden_states, grid_thw) + + # After merger, should reduce by spatial_merge_size^2 + merge_sq = kw["spatial_merge_size"] ** 2 + expected_tokens = total_patches // merge_sq + self.assertEqual(output.shape, (expected_tokens, kw["hidden_size"])) + + def test_vision_encoder_config_roundtrip(self): + encoder = Qwen2VLVisionEncoder(**self.init_kwargs) + config = encoder.get_config() + new_encoder = Qwen2VLVisionEncoder.from_config(config) + + # Verify config values match + self.assertEqual(encoder.patch_size, new_encoder.patch_size) + self.assertEqual( + encoder.temporal_patch_size, new_encoder.temporal_patch_size + ) + self.assertEqual(encoder.in_channels, new_encoder.in_channels) + self.assertEqual(encoder.embed_dim, new_encoder.embed_dim) + self.assertEqual(encoder.hidden_size, new_encoder.hidden_size) + self.assertEqual(encoder.depth, new_encoder.depth) + self.assertEqual(encoder.num_heads, new_encoder.num_heads) + self.assertEqual(encoder.mlp_ratio, new_encoder.mlp_ratio) + self.assertEqual( + encoder.spatial_merge_size, new_encoder.spatial_merge_size + ) + + @pytest.mark.large + def test_vision_encoder_with_multiple_images(self): + encoder = Qwen2VLVisionEncoder(**self.init_kwargs) + + kw = self.init_kwargs + patch_flat_dim = ( + kw["in_channels"] + * kw["temporal_patch_size"] + * kw["patch_size"] ** 2 + ) + + # 2 images with different grid sizes + grid_thw = np.array([[2, 2, 2], [2, 4, 4]], dtype="int32") + total_patches = int(np.sum(np.prod(grid_thw, axis=1))) + hidden_states = np.random.rand(total_patches, patch_flat_dim).astype( + "float32" + ) + + output = encoder(hidden_states, grid_thw) + + merge_sq = kw["spatial_merge_size"] ** 2 + expected_tokens = total_patches // merge_sq + self.assertEqual(output.shape, (expected_tokens, kw["hidden_size"])) + + def test_rotary_embeddings(self): + encoder = Qwen2VLVisionEncoder(**self.init_kwargs) + + # Test that rotary embeddings are generated correctly + grid_thw = np.array([[1, 2, 2]], dtype="int32") + cos, sin = encoder._rot_pos_emb(grid_thw) + + # Should have embeddings for all tokens + # 1 * 2 * 2 = 4 patches total + self.assertEqual(cos.shape[0], 4) + self.assertEqual(sin.shape[0], 4) diff --git a/keras_hub/src/utils/transformers/convert_qwen2_vl.py b/keras_hub/src/utils/transformers/convert_qwen2_vl.py new file mode 100644 index 0000000000..6e1a4c2978 --- /dev/null +++ b/keras_hub/src/utils/transformers/convert_qwen2_vl.py @@ -0,0 +1,296 @@ +"""Convert HuggingFace Qwen2-VL weights to KerasHub.""" + +import numpy as np + +from keras_hub.src.models.qwen2_vl.qwen2_vl_backbone import Qwen2VLBackbone +from keras_hub.src.utils.preset_utils import load_json + +backbone_cls = Qwen2VLBackbone + + +def convert_backbone_config(transformers_config): + # Newer transformers nest text params under "text_config". + tc = transformers_config.get("text_config", transformers_config) + vision_config = transformers_config.get("vision_config", {}) + rope_params = tc.get("rope_parameters", {}) + rope_theta = tc.get("rope_theta", rope_params.get("rope_theta", 1000000)) + return { + "vocabulary_size": tc["vocab_size"], + "num_layers": tc["num_hidden_layers"], + "num_query_heads": tc["num_attention_heads"], + "num_key_value_heads": tc["num_key_value_heads"], + "hidden_dim": tc["hidden_size"], + "intermediate_dim": tc["intermediate_size"], + "vision_patch_size": vision_config.get("patch_size", 14), + "vision_temporal_patch_size": vision_config.get( + "temporal_patch_size", 2 + ), + "vision_in_channels": vision_config.get("in_channels", 3), + "vision_embed_dim": vision_config.get( + "embed_dim", vision_config.get("hidden_size", 1280) + ), + "vision_depth": vision_config.get( + "depth", vision_config.get("num_hidden_layers", 32) + ), + "vision_num_heads": vision_config.get( + "num_heads", vision_config.get("num_attention_heads", 16) + ), + "vision_mlp_ratio": vision_config.get("mlp_ratio", 4), + "spatial_merge_size": vision_config.get("spatial_merge_size", 2), + "image_token_id": transformers_config.get("image_token_id", 151655), + "rope_max_wavelength": rope_theta, + "layer_norm_epsilon": tc.get("rms_norm_eps", 1e-6), + "tie_word_embeddings": transformers_config.get( + "tie_word_embeddings", False + ), + "use_sliding_window_attention": tc.get("use_sliding_window", False), + "sliding_window_size": tc.get("sliding_window", 32768), + } + + +def convert_weights(backbone, loader, transformers_config): + # ── helpers ────────────────────────────────────────────────────── + def transpose_and_reshape(x, shape): + return np.reshape(np.transpose(x), shape) + + def transpose_2d(x, _): + return np.transpose(x, axes=(1, 0)) + + # ── token embeddings ──────────────────────────────────────────── + loader.port_weight( + keras_variable=backbone.get_layer("token_embedding").embeddings, + hf_weight_key="model.embed_tokens.weight", + ) + if not backbone.tie_word_embeddings: + loader.port_weight( + keras_variable=backbone.get_layer( + "token_embedding" + ).reverse_embeddings, + hf_weight_key="lm_head.weight", + hook_fn=transpose_2d, + ) + + # ── text decoder layers ───────────────────────────────────────── + for i in range(backbone.num_layers): + decoder = backbone.get_layer(f"transformer_layer_{i}") + + # Pre-attention RMSNorm + loader.port_weight( + keras_variable=decoder._self_attention_layernorm.scale, + hf_weight_key=f"model.layers.{i}.input_layernorm.weight", + ) + + # Q/K/V projections (EinsumDense → transpose + reshape) + loader.port_weight( + keras_variable=(decoder._self_attention_layer._query_dense.kernel), + hf_weight_key=(f"model.layers.{i}.self_attn.q_proj.weight"), + hook_fn=transpose_and_reshape, + ) + loader.port_weight( + keras_variable=decoder._self_attention_layer._query_dense.bias, + hf_weight_key=f"model.layers.{i}.self_attn.q_proj.bias", + hook_fn=transpose_and_reshape, + ) + loader.port_weight( + keras_variable=(decoder._self_attention_layer._key_dense.kernel), + hf_weight_key=(f"model.layers.{i}.self_attn.k_proj.weight"), + hook_fn=transpose_and_reshape, + ) + loader.port_weight( + keras_variable=decoder._self_attention_layer._key_dense.bias, + hf_weight_key=f"model.layers.{i}.self_attn.k_proj.bias", + hook_fn=transpose_and_reshape, + ) + loader.port_weight( + keras_variable=(decoder._self_attention_layer._value_dense.kernel), + hf_weight_key=(f"model.layers.{i}.self_attn.v_proj.weight"), + hook_fn=transpose_and_reshape, + ) + loader.port_weight( + keras_variable=decoder._self_attention_layer._value_dense.bias, + hf_weight_key=f"model.layers.{i}.self_attn.v_proj.bias", + hook_fn=transpose_and_reshape, + ) + + # Output projection + loader.port_weight( + keras_variable=(decoder._self_attention_layer._output_dense.kernel), + hf_weight_key=(f"model.layers.{i}.self_attn.o_proj.weight"), + hook_fn=transpose_and_reshape, + ) + + # MLP (gate / up / down) + loader.port_weight( + keras_variable=decoder._feedforward_gate_dense.kernel, + hf_weight_key=f"model.layers.{i}.mlp.gate_proj.weight", + hook_fn=transpose_2d, + ) + loader.port_weight( + keras_variable=decoder._feedforward_intermediate_dense.kernel, + hf_weight_key=f"model.layers.{i}.mlp.up_proj.weight", + hook_fn=transpose_2d, + ) + loader.port_weight( + keras_variable=decoder._feedforward_output_dense.kernel, + hf_weight_key=f"model.layers.{i}.mlp.down_proj.weight", + hook_fn=transpose_2d, + ) + + # Post-attention RMSNorm + loader.port_weight( + keras_variable=decoder._feedforward_layernorm.scale, + hf_weight_key=(f"model.layers.{i}.post_attention_layernorm.weight"), + ) + + # Final layernorm + loader.port_weight( + keras_variable=backbone.get_layer("sequence_output_layernorm").scale, + hf_weight_key="model.norm.weight", + ) + + # ── vision encoder ────────────────────────────────────────────── + vision = backbone.get_layer("vision_encoder") + + # Patch embedding (Conv3D) + # HF: (embed_dim, C, T, H, W) → Keras: (T, H, W, C, embed_dim) + loader.port_weight( + keras_variable=vision.patch_embed.kernel, + hf_weight_key="visual.patch_embed.proj.weight", + hook_fn=lambda x, _: np.transpose(x, (2, 3, 4, 1, 0)), + ) + + # Vision blocks + for i in range(vision.depth): + block = vision.blocks[i] + prefix = f"visual.blocks.{i}" + + # LayerNorm 1 + loader.port_weight( + keras_variable=block.norm1.gamma, + hf_weight_key=f"{prefix}.norm1.weight", + ) + loader.port_weight( + keras_variable=block.norm1.beta, + hf_weight_key=f"{prefix}.norm1.bias", + ) + + # Fused QKV attention + loader.port_weight( + keras_variable=block.attn.qkv.kernel, + hf_weight_key=f"{prefix}.attn.qkv.weight", + hook_fn=transpose_2d, + ) + loader.port_weight( + keras_variable=block.attn.qkv.bias, + hf_weight_key=f"{prefix}.attn.qkv.bias", + ) + + # Output projection + loader.port_weight( + keras_variable=block.attn.proj.kernel, + hf_weight_key=f"{prefix}.attn.proj.weight", + hook_fn=transpose_2d, + ) + loader.port_weight( + keras_variable=block.attn.proj.bias, + hf_weight_key=f"{prefix}.attn.proj.bias", + ) + + # LayerNorm 2 + loader.port_weight( + keras_variable=block.norm2.gamma, + hf_weight_key=f"{prefix}.norm2.weight", + ) + loader.port_weight( + keras_variable=block.norm2.beta, + hf_weight_key=f"{prefix}.norm2.bias", + ) + + # MLP + loader.port_weight( + keras_variable=block.mlp.fc1.kernel, + hf_weight_key=f"{prefix}.mlp.fc1.weight", + hook_fn=transpose_2d, + ) + loader.port_weight( + keras_variable=block.mlp.fc1.bias, + hf_weight_key=f"{prefix}.mlp.fc1.bias", + ) + loader.port_weight( + keras_variable=block.mlp.fc2.kernel, + hf_weight_key=f"{prefix}.mlp.fc2.weight", + hook_fn=transpose_2d, + ) + loader.port_weight( + keras_variable=block.mlp.fc2.bias, + hf_weight_key=f"{prefix}.mlp.fc2.bias", + ) + + # Patch merger + merger = vision.merger + loader.port_weight( + keras_variable=merger.ln_q.gamma, + hf_weight_key="visual.merger.ln_q.weight", + ) + loader.port_weight( + keras_variable=merger.ln_q.beta, + hf_weight_key="visual.merger.ln_q.bias", + ) + # HF merger MLP is nn.Sequential(Linear, GELU, Linear) + # sub-indices: .0 = fc1, .1 = GELU (no weights), .2 = fc2 + loader.port_weight( + keras_variable=merger.mlp_fc1.kernel, + hf_weight_key="visual.merger.mlp.0.weight", + hook_fn=transpose_2d, + ) + loader.port_weight( + keras_variable=merger.mlp_fc1.bias, + hf_weight_key="visual.merger.mlp.0.bias", + ) + loader.port_weight( + keras_variable=merger.mlp_fc2.kernel, + hf_weight_key="visual.merger.mlp.2.weight", + hook_fn=transpose_2d, + ) + loader.port_weight( + keras_variable=merger.mlp_fc2.bias, + hf_weight_key="visual.merger.mlp.2.bias", + ) + + return backbone + + +def convert_tokenizer(cls, preset, **kwargs): + tokenizer_config = load_json(preset, "tokenizer.json") + vocab = tokenizer_config["model"]["vocab"] + merges = tokenizer_config["model"]["merges"] + + # Load all special tokens except reserved placeholders. + special_tokens = set() + for token in tokenizer_config["added_tokens"]: + if not token["content"].startswith("<|reserved_special_token_"): + vocab[token["content"]] = token["id"] + special_tokens.add(token["content"]) + + # HF's tokenizer.json added_tokens goes up to <|vision_pad|> (151654) + # but tokenizer_config.json also defines <|image_pad|> (151655) and + # <|video_pad|> (151656) in added_tokens_decoder. Load those too. + try: + tok_cfg = load_json(preset, "tokenizer_config.json") + for _id_str, meta in tok_cfg.get("added_tokens_decoder", {}).items(): + content = meta["content"] + if content not in vocab and not content.startswith( + "<|reserved_special_token_" + ): + vocab[content] = int(_id_str) + special_tokens.add(content) + except Exception: + pass + + kwargs.update( + { + "unsplittable_tokens": list(special_tokens), + } + ) + + return cls(vocabulary=vocab, merges=merges, **kwargs) diff --git a/keras_hub/src/utils/transformers/preset_loader.py b/keras_hub/src/utils/transformers/preset_loader.py index 92c6ea5ef5..df90103cd1 100644 --- a/keras_hub/src/utils/transformers/preset_loader.py +++ b/keras_hub/src/utils/transformers/preset_loader.py @@ -22,6 +22,7 @@ from keras_hub.src.utils.transformers import convert_mixtral from keras_hub.src.utils.transformers import convert_pali_gemma from keras_hub.src.utils.transformers import convert_qwen +from keras_hub.src.utils.transformers import convert_qwen2_vl from keras_hub.src.utils.transformers import convert_qwen3 from keras_hub.src.utils.transformers import convert_qwen3_moe from keras_hub.src.utils.transformers import convert_qwen_moe @@ -71,6 +72,8 @@ def __init__(self, preset, config): self.converter = convert_vit elif model_type == "qwen2": self.converter = convert_qwen + elif model_type == "qwen2_vl": + self.converter = convert_qwen2_vl elif model_type == "mixtral": self.converter = convert_mixtral elif model_type == "qwen2_moe": diff --git a/tools/checkpoint_conversion/convert_qwen2_vl_checkpoints.py b/tools/checkpoint_conversion/convert_qwen2_vl_checkpoints.py new file mode 100644 index 0000000000..5d300ac08a --- /dev/null +++ b/tools/checkpoint_conversion/convert_qwen2_vl_checkpoints.py @@ -0,0 +1,433 @@ +import os +import traceback + +os.environ["KERAS_BACKEND"] = "torch" + +import json # noqa: E402 + +import numpy as np # noqa: E402 +import torch # noqa: E402 +from absl import app # noqa: E402 +from absl import flags # noqa: E402 + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +torch.set_default_device(device) + +from keras import ops # noqa: E402 +from transformers import AutoProcessor # noqa: E402 +from transformers import AutoTokenizer # noqa: E402 +from transformers import Qwen2VLForConditionalGeneration # noqa: E402 + +import keras_hub # noqa: E402 +from keras_hub.src.models.qwen2_vl.qwen2_vl_backbone import ( # noqa: E402 + Qwen2VLBackbone, +) +from keras_hub.src.models.qwen2_vl.qwen2_vl_image_converter import ( # noqa: E402, E501 + Qwen2VLImageConverter, +) + +PRESET_MAP = { + "qwen2_vl_2b_instruct": "Qwen/Qwen2-VL-2B-Instruct", + "qwen2_vl_7b_instruct": "Qwen/Qwen2-VL-7B-Instruct", + "qwen2_vl_72b_instruct": "Qwen/Qwen2-VL-72B-Instruct", +} + +FLAGS = flags.FLAGS +flags.DEFINE_string( + "preset", + None, + f"Must be one of {','.join(PRESET_MAP.keys())}", +) + + +# Helpers + + +def transpose_and_reshape(x, shape): + """Transpose a 2-D HF weight and reshape to Keras EinsumDense shape.""" + return np.reshape(np.transpose(x), shape) + + +def transpose_2d(x): + """Simple 2-D transpose for Dense kernels.""" + return np.transpose(x, axes=(1, 0)) + + +# Build config from HF + + +def build_backbone_config(hf_config): + """Map HF ``config.json`` fields to ``Qwen2VLBackbone`` kwargs.""" + vc = hf_config.get("vision_config", {}) + return { + "vocabulary_size": hf_config["vocab_size"], + "num_layers": hf_config["num_hidden_layers"], + "num_query_heads": hf_config["num_attention_heads"], + "num_key_value_heads": hf_config["num_key_value_heads"], + "hidden_dim": hf_config["hidden_size"], + "intermediate_dim": hf_config["intermediate_size"], + "vision_patch_size": vc.get("patch_size", 14), + "vision_temporal_patch_size": vc.get("temporal_patch_size", 2), + "vision_in_channels": vc.get("in_channels", 3), + "vision_embed_dim": vc.get("embed_dim", vc.get("hidden_size", 1280)), + "vision_depth": vc.get("depth", vc.get("num_hidden_layers", 32)), + "vision_num_heads": vc.get( + "num_heads", vc.get("num_attention_heads", 16) + ), + "vision_mlp_ratio": vc.get("mlp_ratio", 4), + "spatial_merge_size": vc.get("spatial_merge_size", 2), + "image_token_id": hf_config.get("image_token_id", 151655), + "rope_max_wavelength": hf_config.get("rope_theta", 1000000), + "layer_norm_epsilon": hf_config.get("rms_norm_eps", 1e-6), + "tie_word_embeddings": hf_config.get("tie_word_embeddings", False), + "use_sliding_window_attention": hf_config.get( + "use_sliding_window", False + ), + "sliding_window_size": hf_config.get("sliding_window", 32768), + } + + +# Port weights + + +def port_weights(backbone, hf_state_dict): + """Assign every HF weight to the corresponding KerasHub variable.""" + + def get(key): + return hf_state_dict[key].cpu().float().numpy() + + # ── Token embeddings ──────────────────────────────────────── + backbone.get_layer("token_embedding").embeddings.assign( + get("model.embed_tokens.weight") + ) + if not backbone.tie_word_embeddings: + backbone.get_layer("token_embedding").reverse_embeddings.assign( + transpose_2d(get("lm_head.weight")) + ) + + # ── Text decoder layers ───────────────────────────────────── + for i in range(backbone.num_layers): + d = backbone.get_layer(f"transformer_layer_{i}") + pfx = f"model.layers.{i}" + + # Pre-attention RMSNorm + d._self_attention_layernorm.scale.assign( + get(f"{pfx}.input_layernorm.weight") + ) + + # Q projection + q_w = get(f"{pfx}.self_attn.q_proj.weight") + q_b = get(f"{pfx}.self_attn.q_proj.bias") + q_k_shape = list(d._self_attention_layer._query_dense.kernel.shape) + q_b_shape = list(d._self_attention_layer._query_dense.bias.shape) + d._self_attention_layer._query_dense.kernel.assign( + transpose_and_reshape(q_w, q_k_shape) + ) + d._self_attention_layer._query_dense.bias.assign( + transpose_and_reshape(q_b, q_b_shape) + ) + + # K projection + k_w = get(f"{pfx}.self_attn.k_proj.weight") + k_b = get(f"{pfx}.self_attn.k_proj.bias") + k_k_shape = list(d._self_attention_layer._key_dense.kernel.shape) + k_b_shape = list(d._self_attention_layer._key_dense.bias.shape) + d._self_attention_layer._key_dense.kernel.assign( + transpose_and_reshape(k_w, k_k_shape) + ) + d._self_attention_layer._key_dense.bias.assign( + transpose_and_reshape(k_b, k_b_shape) + ) + + # V projection + v_w = get(f"{pfx}.self_attn.v_proj.weight") + v_b = get(f"{pfx}.self_attn.v_proj.bias") + v_k_shape = list(d._self_attention_layer._value_dense.kernel.shape) + v_b_shape = list(d._self_attention_layer._value_dense.bias.shape) + d._self_attention_layer._value_dense.kernel.assign( + transpose_and_reshape(v_w, v_k_shape) + ) + d._self_attention_layer._value_dense.bias.assign( + transpose_and_reshape(v_b, v_b_shape) + ) + + # O projection + o_w = get(f"{pfx}.self_attn.o_proj.weight") + o_k_shape = list(d._self_attention_layer._output_dense.kernel.shape) + d._self_attention_layer._output_dense.kernel.assign( + transpose_and_reshape(o_w, o_k_shape) + ) + + # MLP (gate / up / down) + d._feedforward_gate_dense.kernel.assign( + transpose_2d(get(f"{pfx}.mlp.gate_proj.weight")) + ) + d._feedforward_intermediate_dense.kernel.assign( + transpose_2d(get(f"{pfx}.mlp.up_proj.weight")) + ) + d._feedforward_output_dense.kernel.assign( + transpose_2d(get(f"{pfx}.mlp.down_proj.weight")) + ) + + # Post-attention RMSNorm + d._feedforward_layernorm.scale.assign( + get(f"{pfx}.post_attention_layernorm.weight") + ) + + # Final layernorm + backbone.get_layer("sequence_output_layernorm").scale.assign( + get("model.norm.weight") + ) + + # Vision encoder + vision = backbone.get_layer("vision_encoder") + + # Conv3D patch embedding + # HF: (embed_dim, C, T, H, W) → Keras: (T, H, W, C, embed_dim) + vision.patch_embed.kernel.assign( + np.transpose(get("visual.patch_embed.proj.weight"), (2, 3, 4, 1, 0)) + ) + + # Vision blocks + for i in range(vision.depth): + blk = vision.blocks[i] + bp = f"visual.blocks.{i}" + + # LayerNorm 1 + blk.norm1.gamma.assign(get(f"{bp}.norm1.weight")) + blk.norm1.beta.assign(get(f"{bp}.norm1.bias")) + + # Fused QKV + blk.attn.qkv.kernel.assign(transpose_2d(get(f"{bp}.attn.qkv.weight"))) + blk.attn.qkv.bias.assign(get(f"{bp}.attn.qkv.bias")) + + # Output projection + blk.attn.proj.kernel.assign(transpose_2d(get(f"{bp}.attn.proj.weight"))) + blk.attn.proj.bias.assign(get(f"{bp}.attn.proj.bias")) + + # LayerNorm 2 + blk.norm2.gamma.assign(get(f"{bp}.norm2.weight")) + blk.norm2.beta.assign(get(f"{bp}.norm2.bias")) + + # MLP + blk.mlp.fc1.kernel.assign(transpose_2d(get(f"{bp}.mlp.fc1.weight"))) + blk.mlp.fc1.bias.assign(get(f"{bp}.mlp.fc1.bias")) + blk.mlp.fc2.kernel.assign(transpose_2d(get(f"{bp}.mlp.fc2.weight"))) + blk.mlp.fc2.bias.assign(get(f"{bp}.mlp.fc2.bias")) + + # Patch merger + merger = vision.merger + merger.ln_q.gamma.assign(get("visual.merger.ln_q.weight")) + merger.ln_q.beta.assign(get("visual.merger.ln_q.bias")) + # HF Sequential: .0 = fc1, .1 = GELU (no params), .2 = fc2 + merger.mlp_fc1.kernel.assign( + transpose_2d(get("visual.merger.mlp.0.weight")) + ) + merger.mlp_fc1.bias.assign(get("visual.merger.mlp.0.bias")) + merger.mlp_fc2.kernel.assign( + transpose_2d(get("visual.merger.mlp.2.weight")) + ) + merger.mlp_fc2.bias.assign(get("visual.merger.mlp.2.bias")) + + print(f" Ported {len(hf_state_dict)} HF weights → KerasHub backbone") + return backbone + + +# Verify tokenizer + + +def verify_tokenizer(keras_tokenizer, hf_tokenizer): + print("\n── Tokenizer verification ──") + test_strings = [ + "What is Keras?", + "Describe the weather today.", + "Hello, world! 🌍", + ] + for s in test_strings: + hf_ids = hf_tokenizer(s, add_special_tokens=False)["input_ids"] + keras_ids = keras_tokenizer(s) + if hasattr(keras_ids, "numpy"): + keras_ids = keras_ids.numpy() + keras_ids = np.asarray(keras_ids).flatten().tolist() + np.testing.assert_equal(keras_ids, hf_ids, err_msg=f"Mismatch: {s!r}") + print(f" ✅ '{s}' → {len(hf_ids)} tokens match") + print(" All tokenizer checks passed") + + +# Verify preprocessor + + +def verify_preprocessor(keras_tokenizer, hf_processor): + # Text-only + text = "Describe the weather" + hf_ids = hf_processor.tokenizer(text, add_special_tokens=False)["input_ids"] + keras_pp = keras_hub.models.Qwen2VLCausalLMPreprocessor( + tokenizer=keras_tokenizer, sequence_length=32 + ) + result = keras_pp.generate_preprocess(text) + padding_mask = np.asarray(result["padding_mask"]) + num_real = int(np.sum(padding_mask)) + keras_ids = np.asarray(result["token_ids"])[:num_real].tolist() + np.testing.assert_equal(keras_ids, hf_ids) + print("Text-only preprocessing matches") + + # With image + image_converter = Qwen2VLImageConverter() + keras_pp_img = keras_hub.models.Qwen2VLCausalLMPreprocessor( + tokenizer=keras_tokenizer, + image_converter=image_converter, + sequence_length=512, + spatial_merge_size=2, + ) + dummy = np.random.randint(0, 255, (56, 56, 3), dtype=np.uint8) + result = keras_pp_img.generate_preprocess( + {"text": "Describe this image", "images": dummy} + ) + assert result["patch_values"] is not None + assert result["image_grid_thw"] is not None + grid_thw = result["image_grid_thw"] + expected = int(np.prod(grid_thw[0]) // 4) + actual = int( + np.sum( + np.asarray(result["token_ids"]) + == keras_tokenizer.image_pad_token_id + ) + ) + assert actual == expected, f"vision tokens: {actual} != {expected}" + print( + f"Image preprocessing: {expected} vision tokens " + f"from grid {grid_thw[0].tolist()}" + ) + print("All preprocessor checks passed") + + +# Verify backbone outputs + + +def verify_backbone(keras_backbone, keras_tokenizer, hf_model, hf_tokenizer): + # Parameter counts + keras_params = keras_backbone.count_params() + hf_params = hf_model.num_parameters() + print(f"KerasHub params: {keras_params:,}") + print(f"HF total params: {hf_params:,}") + + # Hidden state comparison (text-only path) + test_text = "What is Keras?" + hf_inputs = hf_tokenizer( + test_text, return_tensors="pt", add_special_tokens=False + ).to(device) + seq_len = hf_inputs["input_ids"].shape[1] + + with torch.no_grad(): + hf_outputs = hf_model.model(**hf_inputs) + hf_hidden = hf_outputs.last_hidden_state.detach().cpu().float().numpy() + + keras_pp = keras_hub.models.Qwen2VLCausalLMPreprocessor( + tokenizer=keras_tokenizer, sequence_length=seq_len + ) + k_in = keras_pp([test_text], sequence_length=seq_len)[0] + k_in = {k: v.to(device) for k, v in k_in.items()} + keras_hidden = ops.convert_to_numpy(keras_backbone(k_in)) + + print(f"HF hidden shape: {hf_hidden.shape}") + print(f"Keras hidden shape: {keras_hidden.shape}") + + try: + np.testing.assert_allclose( + keras_hidden, hf_hidden, atol=1e-4, rtol=1e-4 + ) + print("Hidden states match (atol=1e-4)") + except AssertionError: + max_diff = float(np.max(np.abs(keras_hidden - hf_hidden))) + mean_diff = float(np.mean(np.abs(keras_hidden - hf_hidden))) + print(f"Max abs diff: {max_diff:.6e}") + print(f"Mean abs diff: {mean_diff:.6e}") + print(traceback.format_exc()) + + # Logits via LM head + keras_logits = ops.convert_to_numpy( + keras_backbone.token_embedding( + ops.convert_to_tensor(keras_hidden), reverse=True + ) + ) + with torch.no_grad(): + hf_logits = ( + hf_model.lm_head(hf_outputs.last_hidden_state) + .detach() + .cpu() + .float() + .numpy() + ) + + try: + np.testing.assert_allclose( + keras_logits, hf_logits, atol=1e-4, rtol=1e-4 + ) + print("Logits match (atol=1e-4)") + except AssertionError: + max_diff = float(np.max(np.abs(keras_logits - hf_logits))) + print(f"Logits max diff: {max_diff:.6e}") + print(traceback.format_exc()) + + # Mean diff as a single number (like the DistilBERT example) + mean_hidden = float(np.mean(keras_hidden - hf_hidden)) + print(f" Mean diff (keras - hf): {mean_hidden:.6e}") + + print(" Backbone verification complete") + + +def main(_): + preset = FLAGS.preset + if preset not in PRESET_MAP: + raise ValueError( + f"Invalid preset {preset}. " + f"Must be one of {','.join(PRESET_MAP.keys())}" + ) + hf_id = PRESET_MAP[preset] + + # Load HF model + print(f"Loading HF model: {hf_id}") + hf_model = Qwen2VLForConditionalGeneration.from_pretrained( + hf_id, device_map=device, torch_dtype=torch.float32 + ) + hf_model.eval() + hf_tokenizer = AutoTokenizer.from_pretrained(hf_id) + hf_processor = AutoProcessor.from_pretrained(hf_id) + hf_state_dict = hf_model.state_dict() + + # Use in-memory config + hf_config = hf_model.config.to_dict() + print(f" HF state_dict: {len(hf_state_dict)} tensors") + + # Build KerasHub backbone from config + backbone_kwargs = build_backbone_config(hf_config) + print(f" Config: {json.dumps(backbone_kwargs, indent=2)}") + keras_backbone = Qwen2VLBackbone(**backbone_kwargs) + keras_backbone.summary() + + # Port weights + port_weights(keras_backbone, hf_state_dict) + + # Load KerasHub tokenizer + keras_tokenizer = keras_hub.models.Qwen2VLTokenizer.from_preset( + f"hf://{hf_id}" + ) + print(" KerasHub tokenizer loaded") + + # Verify + verify_tokenizer(keras_tokenizer, hf_tokenizer) + verify_preprocessor(keras_tokenizer, hf_processor) + verify_backbone(keras_backbone, keras_tokenizer, hf_model, hf_tokenizer) + + # Save preset + save_dir = f"./{preset}" + keras_backbone.save_to_preset(save_dir) + keras_tokenizer.save_to_preset(save_dir) + print(f"Preset saved to {save_dir}/") + print(f"Contents: {os.listdir(save_dir)}") + print(f"All checks passed for {preset}!") + + +if __name__ == "__main__": + flags.mark_flag_as_required("preset") + app.run(main)