Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions keras_hub/src/models/qwen2_vl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright 2024 The KerasHub Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Comment on lines +1 to +13

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Copyright 2024 The KerasHub Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

is not present in any other model's init file, curious as to why you put it here?

from keras_hub.src.models.qwen2_vl.qwen2_vl_backbone import Qwen2VLBackbone
from keras_hub.src.models.qwen2_vl.qwen2_vl_causal_lm import Qwen2VLCausalLM
from keras_hub.src.models.qwen2_vl.qwen2_vl_projector import Qwen2VLProjector
from keras_hub.src.models.qwen2_vl.qwen2_vl_vision_encoder import (
Qwen2VLVisionEncoder,
)
Comment on lines +15 to +19
Copy link

@samudraneel05 samudraneel05 Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from keras_hub.src.models.qwen2_vl.qwen2_vl_causal_lm import Qwen2VLCausalLM
from keras_hub.src.models.qwen2_vl.qwen2_vl_projector import Qwen2VLProjector
from keras_hub.src.models.qwen2_vl.qwen2_vl_vision_encoder import (
Qwen2VLVisionEncoder,
)
from keras_hub.src.models.qwen2_vl.qwen2_vl_presets import backbone_presets
from keras_hub.src.utils.preset_utils import register_presets
register_presets(backbone_presets, Qwen2VLBackbone)

the other imports inside the init are unecessary and go against repo standards

100 changes: 100 additions & 0 deletions keras_hub/src/models/qwen2_vl/qwen2_vl_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright 2024 The KerasHub Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Comment on lines +1 to +13

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Copyright 2024 The KerasHub Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

why is this here when it does not exist on any other model's backbone file?

import keras

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.backbone import Backbone


@keras_hub_export("keras_hub.models.Qwen2VLBackbone")
class Qwen2VLBackbone(Backbone):
"""Qwen2-VL Backbone model.

This backbone combines the Vision Encoder and the Text Backbone.
It follows the KerasHub Functional API pattern.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weird line, better to remove and instead add parameters or args

"""

def __init__(
self,
vision_encoder,
text_backbone,
image_converter=None,
**kwargs,
):
# --- Inputs ---
# 1. Image Input: 5D (Batch, Time, H, W, Channels)
# We use flexible shapes (None) to support dynamic resizing
images = keras.Input(shape=(None, None, None, 3), name="images")

# 2. Text Input: (Batch, Seq_Len)
token_ids = keras.Input(shape=(None,), dtype="int32", name="token_ids")
padding_mask = keras.Input(
shape=(None,), dtype="int32", name="padding_mask"
)

# --- Forward Pass ---
# 1. Vision Branch
# The encoder outputs (Batch, Time, H, W, Hidden)
vision_features = vision_encoder(images)

# 2. Projection
# We assume the projector is attached to the vision encoder or separate.
# Ideally, we define the projector here if it's not part of the encoder.
# For this implementation, we assume the vision_encoder returns
# projected features OR we leave the merging logic to the CausalLM.

# NOTE: In the Functional API style for KerasHub, the Backbone usually
# just exposes the sub-models.

# Let's wrap the outputs.
# Since Qwen2-VL is complex (token replacement), we return the features
# separately so the CausalLM can merge them.

outputs = {
"vision_features": vision_features,
"token_ids": token_ids,
"padding_mask": padding_mask,
}

# --- Initialize Super ---
super().__init__(
inputs={
"images": images,
"token_ids": token_ids,
"padding_mask": padding_mask,
},
outputs=outputs,
**kwargs,
)

self.vision_encoder = vision_encoder
self.text_backbone = text_backbone
self.image_converter = image_converter

def get_config(self):
config = super().get_config()
config.update(
{
"vision_encoder": keras.saving.serialize_keras_object(
self.vision_encoder
),
"text_backbone": keras.saving.serialize_keras_object(
self.text_backbone
),
"image_converter": keras.saving.serialize_keras_object(
self.image_converter
),
}
)
return config
50 changes: 50 additions & 0 deletions keras_hub/src/models/qwen2_vl/qwen2_vl_causal_lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2024 The KerasHub Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import keras

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.causal_lm import CausalLM


@keras_hub_export("keras_hub.models.Qwen2VLCausalLM")
class Qwen2VLCausalLM(CausalLM):
"""Qwen2-VL Causal LM model."""

def __init__(self, backbone, preprocessor=None, **kwargs):
super().__init__(backbone=backbone, preprocessor=preprocessor, **kwargs)
self.backbone = backbone

def call(self, inputs, training=False, mask=None):
images = inputs["images"]
token_ids = inputs["token_ids"]

vision_encoder = self.backbone.vision_encoder
text_backbone = self.backbone.text_backbone

image_embeds = vision_encoder(images, training=training)
text_embeds = text_backbone.token_embedding(token_ids)

x = keras.ops.concatenate([image_embeds, text_embeds], axis=1)

for layer in text_backbone.transformer_layers:
x = layer(x, training=training)

if hasattr(text_backbone, "layer_norm"):
x = text_backbone.layer_norm(x)

x = self.backbone.text_backbone.token_embedding(x, reverse=True)
return x

def get_config(self):
return super().get_config()
83 changes: 83 additions & 0 deletions keras_hub/src/models/qwen2_vl/qwen2_vl_causal_lm_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright 2024 The KerasHub Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import keras

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor


@keras_hub_export("keras_hub.models.Qwen2VLCausalLMPreprocessor")
class Qwen2VLCausalLMPreprocessor(CausalLMPreprocessor):
"""Qwen2-VL Causal LM Preprocessor.

This class handles the preprocessing of inputs for the Qwen2-VL model.
It combines text tokenization with image preprocessing for the vision
encoder.

Args:
tokenizer: A `keras_hub.models.Tokenizer` instance.
image_converter: A callable or layer that converts raw images
to tensors. If `None`, image inputs will pass through unchanged.
"""

def __init__(
self,
tokenizer,
image_converter=None,
sequence_length=1024,
**kwargs,
):
super().__init__(
tokenizer=tokenizer,
sequence_length=sequence_length,
**kwargs,
)
self.image_converter = image_converter

def generate_preprocess(self, x, sequence_length=None):
if isinstance(x, dict):
text = x.get("text", "")
images = x.get("images", None)
else:
text = x
images = None

token_ids = self.tokenizer(text)

if images is not None and self.image_converter:
images = self.image_converter(images)

return {
"token_ids": token_ids,
"images": images,
}

def get_config(self):
config = super().get_config()
config.update(
{
"image_converter": keras.saving.serialize_keras_object(
self.image_converter
),
}
)
return config

@classmethod
def from_config(cls, config):
if "image_converter" in config:
config["image_converter"] = keras.saving.deserialize_keras_object(
config["image_converter"]
)
return cls(**config)
65 changes: 65 additions & 0 deletions keras_hub/src/models/qwen2_vl/qwen2_vl_causal_lm_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import numpy as np

from keras_hub.src.models.qwen2_vl.qwen2_vl_causal_lm_preprocessor import (
Qwen2VLCausalLMPreprocessor,
)

# FIX: Import the Real Image Converter
from keras_hub.src.models.qwen2_vl.qwen2_vl_image_converter import (
Qwen2VLImageConverter,
)
from keras_hub.src.tests.test_case import TestCase


class MockTokenizer:
def __init__(self):
self.pad_token_id = 0

def __call__(self, text):
return np.array([[1, 2, 3, 4, 5]], dtype="int32")


class Qwen2VLIntegrationTest(TestCase):
def test_smart_resizing_flow(self):
# 1. Setup Real Converter
# We set min_pixels small so we can test resizing easily
image_converter = Qwen2VLImageConverter(
min_pixels=100 * 100, max_pixels=1000 * 1000
)

# 2. Setup Preprocessor
preprocessor = Qwen2VLCausalLMPreprocessor(
tokenizer=MockTokenizer(),
image_converter=image_converter,
sequence_length=16,
)

# 3. Create a weirdly shaped image (e.g., 50x300 - very wide)
# The smart resizer should try to keep this aspect ratio
input_h, input_w = 50, 300
raw_image = np.random.randint(0, 255, (input_h, input_w, 3)).astype(
"float32"
)

input_data = {"text": "Hello world", "images": raw_image}

# 4. Run Preprocessor
processed = preprocessor.generate_preprocess(input_data)

# 5. Verify Structure
images = processed["images"]
print(f"\nOriginal Shape: {(input_h, input_w)}")
print(f"Resized Shape: {images.shape}")

# Check 1: It should be 4D (Time, H, W, C)
self.assertEqual(len(images.shape), 4)

# Check 2: Time dimension should be 1
self.assertEqual(images.shape[0], 1)

# Check 3: Dimensions should be multiples of 28 (The 'snap' logic)
h, w = images.shape[1], images.shape[2]
self.assertTrue(h % 28 == 0, f"Height {h} is not multiple of 28")
self.assertTrue(w % 28 == 0, f"Width {w} is not multiple of 28")

print("✅ Smart Resizing Logic successful!")
Loading
Loading