diff --git a/keras_hub/src/models/qwen2_vl/__init__.py b/keras_hub/src/models/qwen2_vl/__init__.py new file mode 100644 index 0000000000..99a7735a73 --- /dev/null +++ b/keras_hub/src/models/qwen2_vl/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2024 The KerasHub Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_hub.src.models.qwen2_vl.qwen2_vl_backbone import Qwen2VLBackbone +from keras_hub.src.models.qwen2_vl.qwen2_vl_causal_lm import Qwen2VLCausalLM +from keras_hub.src.models.qwen2_vl.qwen2_vl_projector import Qwen2VLProjector +from keras_hub.src.models.qwen2_vl.qwen2_vl_vision_encoder import ( + Qwen2VLVisionEncoder, +) diff --git a/keras_hub/src/models/qwen2_vl/qwen2_vl_backbone.py b/keras_hub/src/models/qwen2_vl/qwen2_vl_backbone.py new file mode 100644 index 0000000000..0697de25ea --- /dev/null +++ b/keras_hub/src/models/qwen2_vl/qwen2_vl_backbone.py @@ -0,0 +1,100 @@ +# Copyright 2024 The KerasHub Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone + + +@keras_hub_export("keras_hub.models.Qwen2VLBackbone") +class Qwen2VLBackbone(Backbone): + """Qwen2-VL Backbone model. + + This backbone combines the Vision Encoder and the Text Backbone. + It follows the KerasHub Functional API pattern. + """ + + def __init__( + self, + vision_encoder, + text_backbone, + image_converter=None, + **kwargs, + ): + # --- Inputs --- + # 1. Image Input: 5D (Batch, Time, H, W, Channels) + # We use flexible shapes (None) to support dynamic resizing + images = keras.Input(shape=(None, None, None, 3), name="images") + + # 2. Text Input: (Batch, Seq_Len) + token_ids = keras.Input(shape=(None,), dtype="int32", name="token_ids") + padding_mask = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) + + # --- Forward Pass --- + # 1. Vision Branch + # The encoder outputs (Batch, Time, H, W, Hidden) + vision_features = vision_encoder(images) + + # 2. Projection + # We assume the projector is attached to the vision encoder or separate. + # Ideally, we define the projector here if it's not part of the encoder. + # For this implementation, we assume the vision_encoder returns + # projected features OR we leave the merging logic to the CausalLM. + + # NOTE: In the Functional API style for KerasHub, the Backbone usually + # just exposes the sub-models. + + # Let's wrap the outputs. + # Since Qwen2-VL is complex (token replacement), we return the features + # separately so the CausalLM can merge them. + + outputs = { + "vision_features": vision_features, + "token_ids": token_ids, + "padding_mask": padding_mask, + } + + # --- Initialize Super --- + super().__init__( + inputs={ + "images": images, + "token_ids": token_ids, + "padding_mask": padding_mask, + }, + outputs=outputs, + **kwargs, + ) + + self.vision_encoder = vision_encoder + self.text_backbone = text_backbone + self.image_converter = image_converter + + def get_config(self): + config = super().get_config() + config.update( + { + "vision_encoder": keras.saving.serialize_keras_object( + self.vision_encoder + ), + "text_backbone": keras.saving.serialize_keras_object( + self.text_backbone + ), + "image_converter": keras.saving.serialize_keras_object( + self.image_converter + ), + } + ) + return config diff --git a/keras_hub/src/models/qwen2_vl/qwen2_vl_causal_lm.py b/keras_hub/src/models/qwen2_vl/qwen2_vl_causal_lm.py new file mode 100644 index 0000000000..d727d05c73 --- /dev/null +++ b/keras_hub/src/models/qwen2_vl/qwen2_vl_causal_lm.py @@ -0,0 +1,50 @@ +# Copyright 2024 The KerasHub Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.causal_lm import CausalLM + + +@keras_hub_export("keras_hub.models.Qwen2VLCausalLM") +class Qwen2VLCausalLM(CausalLM): + """Qwen2-VL Causal LM model.""" + + def __init__(self, backbone, preprocessor=None, **kwargs): + super().__init__(backbone=backbone, preprocessor=preprocessor, **kwargs) + self.backbone = backbone + + def call(self, inputs, training=False, mask=None): + images = inputs["images"] + token_ids = inputs["token_ids"] + + vision_encoder = self.backbone.vision_encoder + text_backbone = self.backbone.text_backbone + + image_embeds = vision_encoder(images, training=training) + text_embeds = text_backbone.token_embedding(token_ids) + + x = keras.ops.concatenate([image_embeds, text_embeds], axis=1) + + for layer in text_backbone.transformer_layers: + x = layer(x, training=training) + + if hasattr(text_backbone, "layer_norm"): + x = text_backbone.layer_norm(x) + + x = self.backbone.text_backbone.token_embedding(x, reverse=True) + return x + + def get_config(self): + return super().get_config() diff --git a/keras_hub/src/models/qwen2_vl/qwen2_vl_causal_lm_preprocessor.py b/keras_hub/src/models/qwen2_vl/qwen2_vl_causal_lm_preprocessor.py new file mode 100644 index 0000000000..fd6422d7e9 --- /dev/null +++ b/keras_hub/src/models/qwen2_vl/qwen2_vl_causal_lm_preprocessor.py @@ -0,0 +1,83 @@ +# Copyright 2024 The KerasHub Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor + + +@keras_hub_export("keras_hub.models.Qwen2VLCausalLMPreprocessor") +class Qwen2VLCausalLMPreprocessor(CausalLMPreprocessor): + """Qwen2-VL Causal LM Preprocessor. + + This class handles the preprocessing of inputs for the Qwen2-VL model. + It combines text tokenization with image preprocessing for the vision + encoder. + + Args: + tokenizer: A `keras_hub.models.Tokenizer` instance. + image_converter: A callable or layer that converts raw images + to tensors. If `None`, image inputs will pass through unchanged. + """ + + def __init__( + self, + tokenizer, + image_converter=None, + sequence_length=1024, + **kwargs, + ): + super().__init__( + tokenizer=tokenizer, + sequence_length=sequence_length, + **kwargs, + ) + self.image_converter = image_converter + + def generate_preprocess(self, x, sequence_length=None): + if isinstance(x, dict): + text = x.get("text", "") + images = x.get("images", None) + else: + text = x + images = None + + token_ids = self.tokenizer(text) + + if images is not None and self.image_converter: + images = self.image_converter(images) + + return { + "token_ids": token_ids, + "images": images, + } + + def get_config(self): + config = super().get_config() + config.update( + { + "image_converter": keras.saving.serialize_keras_object( + self.image_converter + ), + } + ) + return config + + @classmethod + def from_config(cls, config): + if "image_converter" in config: + config["image_converter"] = keras.saving.deserialize_keras_object( + config["image_converter"] + ) + return cls(**config) diff --git a/keras_hub/src/models/qwen2_vl/qwen2_vl_causal_lm_test.py b/keras_hub/src/models/qwen2_vl/qwen2_vl_causal_lm_test.py new file mode 100644 index 0000000000..98b60ff443 --- /dev/null +++ b/keras_hub/src/models/qwen2_vl/qwen2_vl_causal_lm_test.py @@ -0,0 +1,65 @@ +import numpy as np + +from keras_hub.src.models.qwen2_vl.qwen2_vl_causal_lm_preprocessor import ( + Qwen2VLCausalLMPreprocessor, +) + +# FIX: Import the Real Image Converter +from keras_hub.src.models.qwen2_vl.qwen2_vl_image_converter import ( + Qwen2VLImageConverter, +) +from keras_hub.src.tests.test_case import TestCase + + +class MockTokenizer: + def __init__(self): + self.pad_token_id = 0 + + def __call__(self, text): + return np.array([[1, 2, 3, 4, 5]], dtype="int32") + + +class Qwen2VLIntegrationTest(TestCase): + def test_smart_resizing_flow(self): + # 1. Setup Real Converter + # We set min_pixels small so we can test resizing easily + image_converter = Qwen2VLImageConverter( + min_pixels=100 * 100, max_pixels=1000 * 1000 + ) + + # 2. Setup Preprocessor + preprocessor = Qwen2VLCausalLMPreprocessor( + tokenizer=MockTokenizer(), + image_converter=image_converter, + sequence_length=16, + ) + + # 3. Create a weirdly shaped image (e.g., 50x300 - very wide) + # The smart resizer should try to keep this aspect ratio + input_h, input_w = 50, 300 + raw_image = np.random.randint(0, 255, (input_h, input_w, 3)).astype( + "float32" + ) + + input_data = {"text": "Hello world", "images": raw_image} + + # 4. Run Preprocessor + processed = preprocessor.generate_preprocess(input_data) + + # 5. Verify Structure + images = processed["images"] + print(f"\nOriginal Shape: {(input_h, input_w)}") + print(f"Resized Shape: {images.shape}") + + # Check 1: It should be 4D (Time, H, W, C) + self.assertEqual(len(images.shape), 4) + + # Check 2: Time dimension should be 1 + self.assertEqual(images.shape[0], 1) + + # Check 3: Dimensions should be multiples of 28 (The 'snap' logic) + h, w = images.shape[1], images.shape[2] + self.assertTrue(h % 28 == 0, f"Height {h} is not multiple of 28") + self.assertTrue(w % 28 == 0, f"Width {w} is not multiple of 28") + + print("āœ… Smart Resizing Logic successful!") diff --git a/keras_hub/src/models/qwen2_vl/qwen2_vl_image_converter.py b/keras_hub/src/models/qwen2_vl/qwen2_vl_image_converter.py new file mode 100644 index 0000000000..1cabefc678 --- /dev/null +++ b/keras_hub/src/models/qwen2_vl/qwen2_vl_image_converter.py @@ -0,0 +1,107 @@ +# Copyright 2024 The KerasHub Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import numpy as np +from keras import layers +from keras import ops + +from keras_hub.src.api_export import keras_hub_export + + +@keras_hub_export("keras_hub.layers.Qwen2VLImageConverter") +class Qwen2VLImageConverter(layers.Layer): + """Image converter for Qwen2-VL. + + This layer handles smart resizing and normalization. + + This layer analyzes the aspect ratio of input images and resizes them + to an optimal grid size that is a multiple of the patch size. + + Args: + min_pixels: Int. Minimum number of pixels for the resized image. + max_pixels: Int. Maximum number of pixels for the resized image. + patch_size: Int. The patch size of the vision encoder (default 14). + mean: List/Tuple. Mean values for normalization. + std: List/Tuple. Standard deviation values for normalization. + """ + + def __init__( + self, + min_pixels=224 * 224, + max_pixels=1280 * 28 * 28, + patch_size=14, + mean=[0.48145466, 0.4578275, 0.40821073], + std=[0.26862954, 0.26130258, 0.27577711], + **kwargs, + ): + super().__init__(**kwargs) + self.min_pixels = min_pixels + self.max_pixels = max_pixels + self.patch_size = patch_size + self.mean = np.array(mean, dtype="float32") + self.std = np.array(std, dtype="float32") + self.rescaling_layer = layers.Rescaling(scale=1.0 / 255.0) + + def _smart_resize(self, height, width): + """Calculates the optimal new dimensions.""" + pixel_count = height * width + scale = 1.0 + + if pixel_count < self.min_pixels: + scale = math.sqrt(self.min_pixels / pixel_count) + elif pixel_count > self.max_pixels: + scale = math.sqrt(self.max_pixels / pixel_count) + + new_h = int(height * scale) + new_w = int(width * scale) + + # Snap to multiples of 2x patch_size (28) + snap = self.patch_size * 2 + new_h = round(new_h / snap) * snap + new_w = round(new_w / snap) * snap + + return new_h, new_w + + def call(self, image): + input_shape = ops.shape(image) + h, w = input_shape[-3], input_shape[-2] + + # Smart resizing logic typically runs on CPU/NumPy side in + # preprocessing pipelines. For this implementation, we assume + # values are available or eager execution. + new_h, new_w = self._smart_resize(float(h), float(w)) + + resized_image = ops.image.resize(image, (new_h, new_w)) + + # Normalize + x = self.rescaling_layer(resized_image) + x = (x - self.mean) / self.std + + # Add Time dimension if missing (static image case) + if len(ops.shape(x)) == 3: + x = ops.expand_dims(x, axis=0) + + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "min_pixels": self.min_pixels, + "max_pixels": self.max_pixels, + "patch_size": self.patch_size, + } + ) + return config diff --git a/keras_hub/src/models/qwen2_vl/qwen2_vl_integration_test.py b/keras_hub/src/models/qwen2_vl/qwen2_vl_integration_test.py new file mode 100644 index 0000000000..a5314cfb26 --- /dev/null +++ b/keras_hub/src/models/qwen2_vl/qwen2_vl_integration_test.py @@ -0,0 +1,60 @@ +import keras +import numpy as np + +# FIX: Added '_causal_lm' to the import path to match the filename you created +from keras_hub.src.models.qwen2_vl.qwen2_vl_causal_lm_preprocessor import ( + Qwen2VLCausalLMPreprocessor, +) +from keras_hub.src.tests.test_case import TestCase + + +# 1. Mock the Image Converter (Simulates resizing images) +class MockImageConverter(keras.layers.Layer): + def call(self, x): + # Return a fake 5D tensor: (Batch, Time, Height, Width, Channels) + return np.random.random((1, 1, 14, 14, 3)).astype("float32") + + +# 2. Mock the Tokenizer (Simulates turning text into numbers) +class MockTokenizer: + def __init__(self): + self.pad_token_id = 0 + + def __call__(self, text): + # Just return random integers simulating token IDs + # Shape: (Batch, Seq_Len) -> (1, 5) + return np.array([[1, 2, 3, 4, 5]], dtype="int32") + + +class Qwen2VLIntegrationTest(TestCase): + def test_end_to_end_flow(self): + # Setup Preprocessor with Mocks + preprocessor = Qwen2VLCausalLMPreprocessor( + tokenizer=MockTokenizer(), + image_converter=MockImageConverter(), + sequence_length=16, + ) + + # Inputs + # Note: In a real scenario, this would be a real image path or array + input_data = { + "text": "Hello world", + "images": np.random.random((224, 224, 3)), + } + + # Run Preprocessor + # The preprocessor handles the dictionary unpacking + processed = preprocessor.generate_preprocess(input_data) + + # Verify Structure + self.assertTrue("token_ids" in processed) + self.assertTrue("images" in processed) + + # Check shapes + # Token IDs should come from our MockTokenizer (1, 5) + self.assertEqual(processed["token_ids"].shape, (1, 5)) + + # Images should come from our MockImageConverter (1, 1, 14, 14, 3) + self.assertEqual(processed["images"].shape, (1, 1, 14, 14, 3)) + + print("\nāœ… End-to-End Preprocessing flow successful!") diff --git a/keras_hub/src/models/qwen2_vl/qwen2_vl_projector.py b/keras_hub/src/models/qwen2_vl/qwen2_vl_projector.py new file mode 100644 index 0000000000..23a66765b9 --- /dev/null +++ b/keras_hub/src/models/qwen2_vl/qwen2_vl_projector.py @@ -0,0 +1,52 @@ +# Copyright 2024 The KerasHub Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras import layers +from keras import ops + + +class Qwen2VLProjector(layers.Layer): + """Qwen2-VL Projector. + + This layer projects the vision encoder outputs to the text embedding space. + It reshapes the 5D video/image features into a 2D sequence of tokens. + """ + + def __init__(self, hidden_dim, output_dim, **kwargs): + super().__init__(**kwargs) + self.hidden_dim = hidden_dim + self.output_dim = output_dim + self.dense = layers.Dense(output_dim, name="dense") + + def call(self, x): + # x shape: (Batch, Time, Height, Width, Channels) + shape = ops.shape(x) + B, _, _, _, C = shape[0], shape[1], shape[2], shape[3], shape[4] + + # Flatten Time, Height, Width into Sequence Length + # (Batch, T*H*W, Channels) + x = ops.reshape(x, (B, -1, C)) + + # Project to text embedding dimension + x = self.dense(x) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + "output_dim": self.output_dim, + } + ) + return config diff --git a/keras_hub/src/models/qwen2_vl/qwen2_vl_vision_encoder.py b/keras_hub/src/models/qwen2_vl/qwen2_vl_vision_encoder.py new file mode 100644 index 0000000000..93f6a6e204 --- /dev/null +++ b/keras_hub/src/models/qwen2_vl/qwen2_vl_vision_encoder.py @@ -0,0 +1,216 @@ +# Copyright 2024 The KerasHub Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras +import numpy as np +from keras import layers +from keras import ops + +from keras_hub.src.models.backbone import Backbone + + +class Qwen2VLVisionEncoder(Backbone): + """Qwen2-VL Vision Encoder (ViT). + + A 3D Vision Transformer backbone that processes video/image inputs + using 3D convolution patch embeddings and rotary position embeddings. + """ + + def __init__( + self, + patch_size=14, + temporal_patch_size=2, + hidden_size=1152, + depth=27, + num_heads=16, + mlp_ratio=4, + activation="silu", + dtype=None, + **kwargs, + ): + inputs = keras.Input( + shape=(None, None, None, 3), dtype=dtype, name="images" + ) + + # 1. Patch Embedding (3D Convolution) + self.patch_embed = layers.Conv3D( + filters=hidden_size, + kernel_size=(temporal_patch_size, patch_size, patch_size), + strides=(temporal_patch_size, patch_size, patch_size), + padding="valid", + name="patch_embed", + ) + x = self.patch_embed(inputs) + + # Flatten spatial dims: (Batch, Seq_Len, Hidden) + # Keeps Batch dim dynamic, calculates Seq_Len, keeps Hidden fixed. + x = ops.reshape(x, (ops.shape(x)[0], -1, hidden_size)) + + # 2. Rotary Embedding + self.rotary_emb = Qwen2VLRotaryEmbedding(hidden_size // num_heads) + + # 3. Transformer Blocks + self.blocks = [] + for i in range(depth): + block = Qwen2VLVisionBlock( + hidden_size=hidden_size, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + activation=activation, + name=f"blocks.{i}", + ) + self.blocks.append(block) + x = block(x) + + outputs = x + super().__init__(inputs=inputs, outputs=outputs, dtype=dtype, **kwargs) + + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.hidden_size = hidden_size + self.depth = depth + self.num_heads = num_heads + self.mlp_ratio = mlp_ratio + self.activation = activation + + def call(self, x, grid_thw=None): + x = self.patch_embed(x) + + # Capture dynamic shapes for restoration later + shape = ops.shape(x) + B, T, H, W, C = shape[0], shape[1], shape[2], shape[3], shape[4] + + # Flatten for Transformer + x = ops.reshape(x, (B, -1, self.hidden_size)) + + # Calculate RoPE if grid info is provided + rotary_pos_emb = None + if grid_thw is not None: + rotary_pos_emb = self.rotary_emb(grid_thw) + + for block in self.blocks: + x = block(x, rotary_pos_emb=rotary_pos_emb) + + # Restore 5D shape: (Batch, Time, Height, Width, Channels) + x = ops.reshape(x, (B, T, H, W, C)) + + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "patch_size": self.patch_size, + "temporal_patch_size": self.temporal_patch_size, + "hidden_size": self.hidden_size, + "depth": self.depth, + "num_heads": self.num_heads, + "mlp_ratio": self.mlp_ratio, + "activation": self.activation, + } + ) + return config + + +class Qwen2VLVisionBlock(layers.Layer): + """Single Transformer Block for Qwen2-VL Vision.""" + + def __init__(self, hidden_size, num_heads, mlp_ratio, activation, **kwargs): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.num_heads = num_heads + self.mlp_ratio = mlp_ratio + self.activation = activation + + self.norm1 = layers.LayerNormalization(epsilon=1e-6) + self.attn = layers.MultiHeadAttention( + num_heads=num_heads, key_dim=hidden_size // num_heads + ) + self.norm2 = layers.LayerNormalization(epsilon=1e-6) + + self.mlp = keras.Sequential( + [ + layers.Dense(int(hidden_size * mlp_ratio)), + layers.Activation(activation), + layers.Dense(hidden_size), + ] + ) + + def call(self, x, rotary_pos_emb=None): + residual = x + x = self.norm1(x) + # Note: Pass rotary embeddings here when Keras MHA supports it fully, + # or implement custom attention if needed. For now, standard MHA. + x = self.attn(x, x) + x = x + residual + + residual = x + x = self.norm2(x) + x = self.mlp(x) + x = x + residual + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "num_heads": self.num_heads, + "mlp_ratio": self.mlp_ratio, + "activation": self.activation, + } + ) + return config + + +class Qwen2VLRotaryEmbedding(layers.Layer): + """Calculates 3D Rotary Positional Embeddings.""" + + def __init__(self, dim, base=10000, **kwargs): + super().__init__(**kwargs) + self.dim = dim + self.base = base + self.inv_freq = self._compute_inv_freq(dim, base) + + def _compute_inv_freq(self, dim, base): + exponent = np.arange(0, dim, 2).astype("float32") + value = exponent / dim + inv_freq = 1.0 / (base**value) + return inv_freq + + def call(self, grid_thw): + # Implementation of 3D RoPE (Time, Height, Width) + max_t = ops.max(grid_thw[:, 0]) + max_h = ops.max(grid_thw[:, 1]) + max_w = ops.max(grid_thw[:, 2]) + + t_pos = ops.arange(max_t, dtype="float32") + h_pos = ops.arange(max_h, dtype="float32") + w_pos = ops.arange(max_w, dtype="float32") + + inv_freq_tensor = ops.convert_to_tensor(self.inv_freq, dtype="float32") + + t_emb = ops.outer(t_pos, inv_freq_tensor) + h_emb = ops.outer(h_pos, inv_freq_tensor) + w_emb = ops.outer(w_pos, inv_freq_tensor) + + t_emb = ops.concatenate([t_emb, t_emb], axis=-1) + h_emb = ops.concatenate([h_emb, h_emb], axis=-1) + w_emb = ops.concatenate([w_emb, w_emb], axis=-1) + + return t_emb, h_emb, w_emb + + def get_config(self): + config = super().get_config() + config.update({"dim": self.dim, "base": self.base}) + return config