Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 7 additions & 43 deletions keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,27 +605,13 @@ def call(

# === Vision processing ===

batch_size = tf.shape(prompts)[0]
desired_height = self.image_converter.image_size[0]
desired_width = self.image_converter.image_size[1]
if images is None:
# == Branch: vision model, with `None` value for `images` ==

# To handle the text-only input case, we need to pass an empty
# tensor so as to skip the vision layers of the model.

# TODO: Once functional models accept `None` inputs, consider
# passing this as `None` directly.
images = tf.ones(
shape=[
batch_size,
0,
desired_height,
desired_width,
3,
],
dtype="float32",
)
images = None

vision_mask = tf.zeros_like(token_ids, dtype=bool)

Expand Down Expand Up @@ -682,30 +668,26 @@ def generate_preprocess(
if isinstance(x, dict):
images = x.get("images", None)

# TODO: do we even need `responses` for generation? Makes sense for
# finetuning only (i.e., `call()`).
responses = x.get("responses", None)
prompts = x["prompts"]
else:
images = None
responses = None

prompts = x

# Find out if the input is batched/not batched. Uprank if not batched.
# In other preprocessors, we don't have to do this, but here, all
# the following logic (indices, etc.) uses tensors with a batch dim.
# We will squeeze these back at the end.
batched = True

batched = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This batched = True assignment is a duplicate of the one on line 681 and can be removed.

if isinstance(prompts, str):
batched = False
prompts = [prompts]
if responses is not None:
responses = [responses]
if isinstance(prompts, tf.Tensor) and len(prompts.shape) == 0:
batched = False

prompts = tf.expand_dims(prompts, axis=0)
if responses is not None:
responses = tf.expand_dims(responses, axis=0)

# We have the same 8 cases here, as in `call()`.
if self.text_only_model and images is not None:
Expand All @@ -729,11 +711,7 @@ def generate_preprocess(
# === Tokenization, padding, etc. ===
prompts = self.tokenizer(prompts)

if responses is not None:
responses = self.tokenizer(responses)
segments = (prompts, responses)
else:
segments = (prompts,)
segments = (prompts,)

# Padding.
token_ids, segment_ids = self.packer(
Expand All @@ -759,27 +737,13 @@ def generate_preprocess(

# === Vision processing ===

batch_size = tf.shape(prompts)[0]
desired_height = self.image_converter.image_size[0]
desired_width = self.image_converter.image_size[1]
if images is None:
# == Branch: vision model, with `None` value for `images` ==

# To handle the text-only input case, we need to pass an empty
# tensor so as to skip the vision layers of the model.

# TODO: Once functional models accept `None` inputs, consider
# passing this as `None` directly.
images = tf.ones(
shape=[
batch_size,
0,
desired_height,
desired_width,
3,
],
dtype="float32",
)
images = None

vision_mask = tf.zeros_like(token_ids, dtype=bool)

Expand Down
Empty file.
41 changes: 41 additions & 0 deletions keras_hub/src/models/qwen2_vl/qwen2_vl_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import keras
from keras_hub.src.models.backbone import Backbone
from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a trailing whitespace at the end of this line.

Suggested change
from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone
from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone

from keras_hub.src.models.qwen2_vl.qwen2_vl_vision_encoder import Qwen2VLVisionEncoder
from keras_hub.src.models.qwen2_vl.qwen2_vl_projector import Qwen2VLProjector

class Qwen2VLBackbone(Backbone):
def __init__(
self,
vision_encoder,
projector,
text_backbone,
**kwargs
):
super().__init__(**kwargs)
self.vision_encoder = vision_encoder
self.projector = projector
self.text_backbone = text_backbone # This is the standard Qwen (2/2.5) LLM

def call(self, inputs):
# inputs is a dict containing "images" and "token_ids"
images = inputs["images"]
token_ids = inputs["token_ids"]

# Process Images
image_features = self.vision_encoder(images)

# Project Images to Text Space
image_embeddings = self.projector(image_features)

# Process Text (Get embeddings normally)
text_embeddings = self.text_backbone.token_embedding(token_ids)

# FUSE THEM (Placeholder concatenation)
combined_embeddings = keras.ops.concatenate([image_embeddings, text_embeddings], axis=1)

# Pass through the LLM
x = self.text_backbone.transformer_layers(combined_embeddings)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

self.text_backbone.transformer_layers is a list of layers, not a single callable layer. This will raise an error. You need to iterate through the layers in a loop.

Additionally, this approach of accessing internal layers of self.text_backbone breaks encapsulation and is brittle. It would be better to either reuse the text_backbone's call method or restructure the model. The padding_mask input is also missing from the call to the transformer layers.

x = self.text_backbone.layer_norm(x)

return x
43 changes: 43 additions & 0 deletions keras_hub/src/models/qwen2_vl/qwen2_vl_projector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import keras
from keras import layers
from keras import ops

class Qwen2VLProjector(layers.Layer):
"""
Projector layer for Qwen2-VL.

This layer downsamples vision features by merging 2x2 neighboring patches
into a single token and projecting them to the LLM's hidden size.
"""
def __init__(self, hidden_size, output_hidden_size, **kwargs):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.output_hidden_size = output_hidden_size

self.merger = layers.Sequential([
layers.Dense(output_hidden_size, name="merger_proj"),
layers.Activation("gelu", name="activation"),
layers.Dense(output_hidden_size, name="output_proj")
], name="merger")

def call(self, x):
# x shape: (Batch, Height, Width, Channels)

input_shape = ops.shape(x)
H, W, C = input_shape[1], input_shape[2], input_shape[3]

# Reshape to isolate 2x2 blocks
# Shape: (B, H/2, 2, W/2, 2, C)
x = ops.reshape(x, (-1, H // 2, 2, W // 2, 2, C))

# Permute to bring the 2x2 blocks together
# Shape: (B, H/2, W/2, 2, 2, C)
x = ops.transpose(x, (0, 1, 3, 2, 4, 5))

# Flatten the 2x2xC block into a single vector
# Shape: (B, H/2, W/2, 4*C)
x = ops.reshape(x, (-1, H // 2, W // 2, 4 * C))

x = self.merger(x)

return x
85 changes: 85 additions & 0 deletions keras_hub/src/models/qwen2_vl/qwen2_vl_vision_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import keras
from keras import layers
from keras import ops

from keras_hub.src.layers.modeling.reversible_embedding import ReversibleEmbedding
from keras_hub.src.models.backbone import Backbone


class Qwen2VLVisionEncoder(Backbone):
def __init__(
self,
patch_size=14,
temporal_patch_size=2,
hidden_size=1152,
depth=27,
num_heads=16,
mlp_ratio=4,
activation="silu",
**kwargs,
):
super().__init__(**kwargs)
self.patch_size = patch_size
self.temporal_patch_size = temporal_patch_size
self.hidden_size = hidden_size
self.depth = depth
self.num_heads = num_heads
self.mlp_ratio = mlp_ratio
self.activation = activation

# 3D convolution to handle both Video (Time) and Images
self.patch_embed = layers.Conv3D(
filters=hidden_size,
kernel_size=(temporal_patch_size, patch_size, patch_size),
strides=(temporal_patch_size, patch_size, patch_size),
padding="valid",
name="patch_embed",
)

# Placeholder for Qwen2VL transformer blocks
self.blocks = [
Qwen2VLVisionBlock(hidden_size, num_heads, mlp_ratio, activation, name=f"blocks.{i}")
for i in range(depth)
]

# Patch Merger to downsample tokens before sending to LLM
self.merger = layers.Conv2D(
filters=hidden_size,
kernel_size=2,
strides=2,
padding="valid",
name="merger",
)

def call(self, x, grid_thw=None):
# x shape: (Batch, Time, Height, Width, Channels)
x = self.patch_embed(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a trailing whitespace at the end of this line.

Suggested change
x = self.patch_embed(x)
x = self.patch_embed(x)


# Note: 3D-RoPE implementation pending

for block in self.blocks:
x = block(x, grid_thw=grid_thw)

x = self.merger(x)

return x

class Qwen2VLVisionBlock(layers.Layer):
def __init__(self, hidden_size, num_heads, mlp_ratio, activation, **kwargs):
super().__init__(**kwargs)
self.norm1 = layers.LayerNormalization(epsilon=1e-6)
self.attn = layers.MultiHeadAttention(num_heads=num_heads, key_dim=hidden_size//num_heads)
self.norm2 = layers.LayerNormalization(epsilon=1e-6)
self.mlp = layers.Dense(hidden_size * mlp_ratio)

def call(self, x, grid_thw=None):
residual = x
x = self.norm1(x)
x = self.attn(x, x)
x = x + residual

residual = x
x = self.norm2(x)
x = self.mlp(x)
x = x + residual
return x
Loading