44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7- from typing import Any , Mapping , Optional
7+ from typing import Any , List , Mapping , Optional
88
99import torch
1010from PIL import Image
@@ -154,6 +154,88 @@ def transform_image(
154154 tiles = torch .cat ((tiles , thumbnail .unsqueeze (0 )), dim = 0 )
155155 return tiles , ar
156156
157+ def transform_video (
158+ self , video : List [torch .Tensor ], inference : bool = False
159+ ) -> tuple [torch .Tensor , int ]:
160+ """
161+ Transform video content for video processing.
162+
163+ Args:
164+ video: List of 4D torch.Tensors, each with shape [frames_per_clip, height, width, channels].
165+ Length of list should equal clips_per_video.
166+ inference: Whether running in inference mode
167+
168+ Returns:
169+ tuple: (processed_frames, total_num_frames)
170+ """
171+ if not isinstance (video , list ):
172+ raise ValueError (f"Expected list of 4D tensors, got { type (video )} " )
173+
174+ if not video :
175+ raise ValueError ("Empty video clips list" )
176+
177+ # Validate input format
178+ for i , clip in enumerate (video ):
179+ if not isinstance (clip , torch .Tensor ):
180+ raise ValueError (f"Clip { i } is not a torch.Tensor, got { type (clip )} " )
181+ if clip .dim () != 4 :
182+ raise ValueError (f"Clip { i } has { clip .dim ()} dimensions, expected 4D tensor [frames_per_clip, height, width, channels]" )
183+
184+ processed_clips = []
185+ total_frames = 0
186+
187+ # Process each clip
188+ for clip in video :
189+ frames_per_clip , height , width , channels = clip .shape
190+ total_frames += frames_per_clip
191+
192+ # Normalize tensor values to [0, 1] range if needed
193+ if clip .dtype == torch .uint8 :
194+ clip = clip .float () / 255.0
195+ elif clip .min () < 0 or clip .max () > 1 :
196+ # Assume values are in [-1, 1] or other range, normalize to [0, 1]
197+ clip = (clip - clip .min ()) / (clip .max () - clip .min ())
198+
199+ processed_frames_in_clip = []
200+
201+ # Process each frame in the clip
202+ for frame_idx in range (frames_per_clip ):
203+ frame = clip [frame_idx ] # Shape: [height, width, channels]
204+
205+ # Convert HWC to CHW format for PIL conversion
206+ if frame .shape [- 1 ] == channels and channels in [1 , 3 ]: # HWC format
207+ frame = frame .permute (2 , 0 , 1 ) # Convert to CHW
208+ else :
209+ raise ValueError (f"Unexpected frame shape: { frame .shape } , expected [height, width, channels]" )
210+
211+ # Convert to PIL Image for processing through existing transforms
212+ frame_uint8 = (frame * 255 ).to (torch .uint8 )
213+
214+ if frame .shape [0 ] == 3 : # RGB
215+ frame_pil = Image .fromarray (frame_uint8 .permute (1 , 2 , 0 ).cpu ().numpy (), mode = 'RGB' )
216+ elif frame .shape [0 ] == 1 : # Grayscale
217+ frame_pil = Image .fromarray (frame_uint8 .squeeze (0 ).cpu ().numpy (), mode = 'L' )
218+ else :
219+ raise ValueError (f"Unsupported number of channels: { frame .shape [0 ]} " )
220+
221+ # Process the PIL image through existing transform
222+ tiles , _ = self .transform_image (frame_pil , inference = inference )
223+ processed_frames_in_clip .append (tiles )
224+
225+ # Stack frames in this clip: (frames_per_clip, num_tiles, channels, height, width)
226+ clip_processed = torch .stack (processed_frames_in_clip , dim = 0 )
227+ processed_clips .append (clip_processed )
228+
229+ # Stack all clips: (clips_per_video, frames_per_clip, num_tiles, channels, height, width)
230+ # Then reshape to: (total_frames, num_tiles, channels, height, width)
231+ all_clips = torch .stack (processed_clips , dim = 0 )
232+ clips_per_video , frames_per_clip , num_tiles , channels , height , width = all_clips .shape
233+
234+ # Reshape to flatten clips and frames into a single frames dimension
235+ processed_frames = all_clips .view (clips_per_video * frames_per_clip , num_tiles , channels , height , width )
236+
237+ return processed_frames , total_frames
238+
157239 def encode (
158240 self ,
159241 text : str ,
@@ -235,7 +317,7 @@ def __call__(
235317 self , sample : Mapping [str , Any ], inference : bool = False
236318 ) -> Mapping [str , Any ]:
237319 """
238- Apply image decoding, transformations and tokenization to messages in the sample.
320+ Apply image/video decoding, transformations and tokenization to messages in the sample.
239321
240322 Args:
241323 sample (Mapping[str, Any]): A sample with a "messages" field.
@@ -245,21 +327,49 @@ def __call__(
245327 Mapping[str, Any]: The transformed sample with the following fields:
246328 - tokens: list[int] of tokenized messages
247329 - mask: list[bool] of masks for the tokenized messages
248- - encoder_input: dict[str, Any] of transformed images
330+ - encoder_input: dict[str, Any] of transformed images and videos
249331 """
250- encoder_input = {"vision" : {"images" : []}}
332+ images_list = []
333+ videos_list = []
334+
251335 messages = sample ["messages" ]
252336 for message in messages :
253337 for content in message .content :
254338 if content ["type" ] == "image" :
255339 image = content ["content" ]
256340 tiles , ar = self .transform_image (image , inference = inference )
257- encoder_input [ "vision" ][ "images" ] .append (tiles )
341+ images_list .append (tiles )
258342
259343 # Add number of patch tokens, tiles, and aspect ratio to metadata
260344 # so tokenizer can add the corresponding special tokens
261345 content ["patch_tokens_per_tile" ] = self .patch_tokens_per_tile
262346 content ["aspect_ratio" ] = ar
347+ elif content ["type" ] == "video" :
348+ video = content ["content" ]
349+ processed_frames , num_frames = self .transform_video (video , inference = inference )
350+
351+ # Flatten video frames to individual images for the vision encoder
352+ # processed_frames shape: (num_frames, num_tiles, channels, height, width)
353+ # We need to flatten to: (num_frames * num_tiles, channels, height, width)
354+ flattened_frames = processed_frames .view (- 1 , * processed_frames .shape [2 :])
355+ videos_list .append (flattened_frames )
356+
357+ # Add metadata for video tokenization
358+ # Each frame is treated like an image for patch token calculation
359+ content ["num_frames" ] = num_frames
360+ content ["patch_tokens_per_frame" ] = self .patch_tokens_per_tile
361+
362+ # Create encoder_input in the format expected by EarlyFusionModel
363+ # Both vision and video use the same vision encoder, so they share the same data structure.
364+ # The differentiation happens at token level (<|patch|> vs <|video|>), not data level.
365+ encoder_input = {}
366+ if images_list :
367+ encoder_input ["vision" ] = {"images" : images_list }
368+
369+ if videos_list :
370+ # Videos use "images" key because they're processed as sequences of images
371+ # by the same vision encoder used for static images
372+ encoder_input ["vision" ] = {"images" : videos_list }
263373
264374 sample ["encoder_input" ] = encoder_input
265375 sample = self .tokenizer (sample , inference = inference )
0 commit comments