Skip to content

Commit 342e534

Browse files
committed
refactor: Remove explicit channels_last data format handling for Conv3D as Keras now manages it internally.
1 parent a1148bd commit 342e534

File tree

1 file changed

+4
-11
lines changed

1 file changed

+4
-11
lines changed

keras_hub/src/models/qwen2_vl/qwen2_vl_vision_encoder.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,10 @@ def __init__(
3434
self.temporal_patch_size = temporal_patch_size
3535
self.in_channels = in_channels
3636
self.embed_dim = embed_dim
37-
self.data_format = keras.config.image_data_format()
3837

38+
# The model's internal pipeline always produces patches in
39+
# channels-first format: (batch, C, T, H, W). Keras handles
40+
# cross-backend compatibility internally for Conv3D.
3941
self.proj = keras.layers.Conv3D(
4042
filters=embed_dim,
4143
kernel_size=(temporal_patch_size, patch_size, patch_size),
@@ -52,20 +54,11 @@ def call(self, hidden_states):
5254
Args:
5355
hidden_states: Tensor of shape
5456
`(total_patches, in_channels, temporal_patch_size,
55-
patch_size, patch_size)` when using
56-
``channels_first``, or
57-
`(total_patches, temporal_patch_size, patch_size,
58-
patch_size, in_channels)` when using
59-
``channels_last``.
57+
patch_size, patch_size)`.
6058
6159
Returns:
6260
Tensor of shape `(total_patches, embed_dim)`.
6361
"""
64-
# Conv3D always uses channels_first internally; transpose if
65-
# the user's default data format is channels_last.
66-
if self.data_format == "channels_last":
67-
# (batch, T, H, W, C) -> (batch, C, T, H, W)
68-
hidden_states = ops.transpose(hidden_states, (0, 4, 1, 2, 3))
6962
hidden_states = self.proj(hidden_states)
7063
# Flatten spatial and temporal dims:
7164
# (batch, embed_dim, 1, 1, 1) -> (batch, embed_dim)

0 commit comments

Comments
 (0)