@@ -34,8 +34,10 @@ def __init__(
3434 self .temporal_patch_size = temporal_patch_size
3535 self .in_channels = in_channels
3636 self .embed_dim = embed_dim
37- self .data_format = keras .config .image_data_format ()
3837
38+ # The model's internal pipeline always produces patches in
39+ # channels-first format: (batch, C, T, H, W). Keras handles
40+ # cross-backend compatibility internally for Conv3D.
3941 self .proj = keras .layers .Conv3D (
4042 filters = embed_dim ,
4143 kernel_size = (temporal_patch_size , patch_size , patch_size ),
@@ -52,20 +54,11 @@ def call(self, hidden_states):
5254 Args:
5355 hidden_states: Tensor of shape
5456 `(total_patches, in_channels, temporal_patch_size,
55- patch_size, patch_size)` when using
56- ``channels_first``, or
57- `(total_patches, temporal_patch_size, patch_size,
58- patch_size, in_channels)` when using
59- ``channels_last``.
57+ patch_size, patch_size)`.
6058
6159 Returns:
6260 Tensor of shape `(total_patches, embed_dim)`.
6361 """
64- # Conv3D always uses channels_first internally; transpose if
65- # the user's default data format is channels_last.
66- if self .data_format == "channels_last" :
67- # (batch, T, H, W, C) -> (batch, C, T, H, W)
68- hidden_states = ops .transpose (hidden_states , (0 , 4 , 1 , 2 , 3 ))
6962 hidden_states = self .proj (hidden_states )
7063 # Flatten spatial and temporal dims:
7164 # (batch, embed_dim, 1, 1, 1) -> (batch, embed_dim)
0 commit comments