Skip to content

Commit 3a2b526

Browse files
llama4 video support (#2942)
Summary: Adds video processing support to the Llama4 model by extending the existing vision encoder infrastructure to handle video content. It introduces video-specific special tokens (<|video|>, <|vid_start|>, <|vid_end|>, <|vid_frame_separator|>) in the tokenizer, implements a new transform_video() method that processes video clips as sequences of frames through the existing image transform pipeline, and registers a "video" encoder in the EarlyFusionModel that reuses the vision encoder while maintaining separate tokenization paths for images and videos. (Used HF implementation as a reference to ensure consistent changes in _tokenizer.py) Reviewed By: felipemello1 Differential Revision: D89577119 Pulled By: awasthiabhijeet
1 parent 44271b5 commit 3a2b526

File tree

3 files changed

+168
-15
lines changed

3 files changed

+168
-15
lines changed

torchtune/models/llama4/_model_builders.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,14 @@ def llama4_scout_17b_16e(
8484
)
8585
return EarlyFusionModel(
8686
decoder,
87-
encoders={"vision": vision_encoder},
87+
encoders={"vision": vision_encoder, "video": vision_encoder},
8888
encoder_tokens={
8989
"vision": LLAMA4_SPECIAL_TOKENS["<|patch|>"],
90+
"video": LLAMA4_SPECIAL_TOKENS["<|video|>"],
9091
},
9192
encoders_trainable={
9293
"vision": encoder_trainable,
94+
"video": encoder_trainable,
9395
},
9496
decoder_trainable=decoder_trainable,
9597
fusion_trainable=fusion_trainable,
@@ -154,12 +156,14 @@ def llama4_maverick_17b_128e(
154156
)
155157
return EarlyFusionModel(
156158
decoder,
157-
encoders={"vision": vision_encoder},
159+
encoders={"vision": vision_encoder, "video": vision_encoder},
158160
encoder_tokens={
159161
"vision": LLAMA4_SPECIAL_TOKENS["<|patch|>"],
162+
"video": LLAMA4_SPECIAL_TOKENS["<|video|>"],
160163
},
161164
encoders_trainable={
162165
"vision": encoder_trainable,
166+
"video": encoder_trainable,
163167
},
164168
decoder_trainable=decoder_trainable,
165169
fusion_trainable=fusion_trainable,
@@ -314,12 +318,14 @@ def lora_llama4_scout_17b_16e(
314318
)
315319
return EarlyFusionModel(
316320
decoder,
317-
encoders={"vision": vision_encoder},
321+
encoders={"vision": vision_encoder, "video": vision_encoder},
318322
encoder_tokens={
319323
"vision": LLAMA4_SPECIAL_TOKENS["<|patch|>"],
324+
"video": LLAMA4_SPECIAL_TOKENS["<|video|>"],
320325
},
321326
encoders_trainable={
322327
"vision": encoder_trainable != TrainableParams.FROZEN,
328+
"video": encoder_trainable != TrainableParams.FROZEN,
323329
},
324330
decoder_trainable=decoder_trainable != TrainableParams.FROZEN,
325331
fusion_trainable=fusion_trainable != TrainableParams.FROZEN,

torchtune/models/llama4/_tokenizer.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,16 @@ def get_reserved_special_tokens(start_id, end_id, name=None, start_reserved=0):
5959
VISION_SPECIAL_TOKENS = {
6060
"<|image_start|>": 200080,
6161
"<|image_end|>": 200081,
62-
"<|vision_reserved_special_token_0|>": 200082,
63-
"<|vision_reserved_special_token_1|>": 200083,
62+
"<|vid_start|>": 200082,
63+
"<|vid_end|>": 200083,
6464
"<|tile_x_separator|>": 200084,
6565
"<|tile_y_separator|>": 200085,
66-
"<|vision_reserved_special_token_2|>": 200086,
67-
"<|vision_reserved_special_token_3|>": 200087,
68-
"<|vision_reserved_special_token_4|>": 200088,
69-
"<|vision_reserved_special_token_5|>": 200089,
66+
"<|vision_reserved_special_token_0|>": 200086,
67+
"<|vid_frame_separator|>": 200087,
68+
"<|vision_reserved_special_token_1|>": 200088,
69+
"<|vision_reserved_special_token_2|>": 200089,
7070
"<|image|>": 200090,
71-
"<|vision_reserved_special_token_6|>": 200091,
71+
"<|video|>": 200091,
7272
"<|patch|>": 200092,
7373
} | get_reserved_special_tokens(200093, 201134, "vision", 7)
7474

@@ -166,6 +166,12 @@ def __init__(
166166
self.tile_x_separator = self.special_tokens["<|tile_x_separator|>"]
167167
self.tile_y_separator = self.special_tokens["<|tile_y_separator|>"]
168168

169+
# Video tokens
170+
self.video_id = self.special_tokens["<|patch|>"]
171+
self.video_start = self.special_tokens["<|vid_start|>"]
172+
self.video_end = self.special_tokens["<|vid_end|>"]
173+
self.frame_separator = self.special_tokens["<|vid_frame_separator|>"]
174+
169175
# Reasoning tokens
170176
self.reasoning_start = self.special_tokens["<|reasoning_thinking_start|>"]
171177
self.reasoning_end = self.special_tokens["<|reasoning_thinking_end|>"]
@@ -302,6 +308,31 @@ def _get_tile_grid_tokens(
302308
tokens.extend(single_tile_tokens)
303309
return tokens
304310

311+
def _get_video_tokens(self, num_frames: int, patch_tokens_per_frame: int) -> list[int]:
312+
"""
313+
Tokenize video content with frame structure similar to Huggingface implementation.
314+
315+
Args:
316+
num_frames (int): Number of frames in the video
317+
patch_tokens_per_frame (int): Number of patch tokens per frame
318+
319+
Returns:
320+
list[int]: Video tokens with frame separators
321+
"""
322+
tokens = []
323+
tokens.append(self.video_start)
324+
325+
for frame_idx in range(num_frames):
326+
# Add video patch tokens for this frame
327+
tokens.extend([self.video_id] * patch_tokens_per_frame)
328+
329+
# Add frame separator (except for the last frame)
330+
if frame_idx < num_frames - 1:
331+
tokens.append(self.frame_separator)
332+
333+
tokens.append(self.video_end)
334+
return tokens
335+
305336
def _tokenize_header(self, message: Message) -> list[int]:
306337
"""
307338
Tokenize header start, message role, and header end as list of ids
@@ -335,6 +366,12 @@ def _tokenize_body(self, message: Message) -> list[int]:
335366
tokenized_body += self._get_tile_grid_tokens(
336367
patch_tokens_per_tile, aspect_ratio
337368
)
369+
elif item["type"] == "video":
370+
num_frames = item.get("num_frames", 1)
371+
patch_tokens_per_frame = item.get("patch_tokens_per_frame", 1)
372+
tokenized_body += self._get_video_tokens(
373+
num_frames, patch_tokens_per_frame
374+
)
338375

339376
return tokenized_body
340377

torchtune/models/llama4/_transform.py

Lines changed: 115 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Any, Mapping, Optional
7+
from typing import Any, List, Mapping, Optional
88

99
import torch
1010
from PIL import Image
@@ -154,6 +154,88 @@ def transform_image(
154154
tiles = torch.cat((tiles, thumbnail.unsqueeze(0)), dim=0)
155155
return tiles, ar
156156

157+
def transform_video(
158+
self, video: List[torch.Tensor], inference: bool = False
159+
) -> tuple[torch.Tensor, int]:
160+
"""
161+
Transform video content for video processing.
162+
163+
Args:
164+
video: List of 4D torch.Tensors, each with shape [frames_per_clip, height, width, channels].
165+
Length of list should equal clips_per_video.
166+
inference: Whether running in inference mode
167+
168+
Returns:
169+
tuple: (processed_frames, total_num_frames)
170+
"""
171+
if not isinstance(video, list):
172+
raise ValueError(f"Expected list of 4D tensors, got {type(video)}")
173+
174+
if not video:
175+
raise ValueError("Empty video clips list")
176+
177+
# Validate input format
178+
for i, clip in enumerate(video):
179+
if not isinstance(clip, torch.Tensor):
180+
raise ValueError(f"Clip {i} is not a torch.Tensor, got {type(clip)}")
181+
if clip.dim() != 4:
182+
raise ValueError(f"Clip {i} has {clip.dim()} dimensions, expected 4D tensor [frames_per_clip, height, width, channels]")
183+
184+
processed_clips = []
185+
total_frames = 0
186+
187+
# Process each clip
188+
for clip in video:
189+
frames_per_clip, height, width, channels = clip.shape
190+
total_frames += frames_per_clip
191+
192+
# Normalize tensor values to [0, 1] range if needed
193+
if clip.dtype == torch.uint8:
194+
clip = clip.float() / 255.0
195+
elif clip.min() < 0 or clip.max() > 1:
196+
# Assume values are in [-1, 1] or other range, normalize to [0, 1]
197+
clip = (clip - clip.min()) / (clip.max() - clip.min())
198+
199+
processed_frames_in_clip = []
200+
201+
# Process each frame in the clip
202+
for frame_idx in range(frames_per_clip):
203+
frame = clip[frame_idx] # Shape: [height, width, channels]
204+
205+
# Convert HWC to CHW format for PIL conversion
206+
if frame.shape[-1] == channels and channels in [1, 3]: # HWC format
207+
frame = frame.permute(2, 0, 1) # Convert to CHW
208+
else:
209+
raise ValueError(f"Unexpected frame shape: {frame.shape}, expected [height, width, channels]")
210+
211+
# Convert to PIL Image for processing through existing transforms
212+
frame_uint8 = (frame * 255).to(torch.uint8)
213+
214+
if frame.shape[0] == 3: # RGB
215+
frame_pil = Image.fromarray(frame_uint8.permute(1, 2, 0).cpu().numpy(), mode='RGB')
216+
elif frame.shape[0] == 1: # Grayscale
217+
frame_pil = Image.fromarray(frame_uint8.squeeze(0).cpu().numpy(), mode='L')
218+
else:
219+
raise ValueError(f"Unsupported number of channels: {frame.shape[0]}")
220+
221+
# Process the PIL image through existing transform
222+
tiles, _ = self.transform_image(frame_pil, inference=inference)
223+
processed_frames_in_clip.append(tiles)
224+
225+
# Stack frames in this clip: (frames_per_clip, num_tiles, channels, height, width)
226+
clip_processed = torch.stack(processed_frames_in_clip, dim=0)
227+
processed_clips.append(clip_processed)
228+
229+
# Stack all clips: (clips_per_video, frames_per_clip, num_tiles, channels, height, width)
230+
# Then reshape to: (total_frames, num_tiles, channels, height, width)
231+
all_clips = torch.stack(processed_clips, dim=0)
232+
clips_per_video, frames_per_clip, num_tiles, channels, height, width = all_clips.shape
233+
234+
# Reshape to flatten clips and frames into a single frames dimension
235+
processed_frames = all_clips.view(clips_per_video * frames_per_clip, num_tiles, channels, height, width)
236+
237+
return processed_frames, total_frames
238+
157239
def encode(
158240
self,
159241
text: str,
@@ -235,7 +317,7 @@ def __call__(
235317
self, sample: Mapping[str, Any], inference: bool = False
236318
) -> Mapping[str, Any]:
237319
"""
238-
Apply image decoding, transformations and tokenization to messages in the sample.
320+
Apply image/video decoding, transformations and tokenization to messages in the sample.
239321
240322
Args:
241323
sample (Mapping[str, Any]): A sample with a "messages" field.
@@ -245,21 +327,49 @@ def __call__(
245327
Mapping[str, Any]: The transformed sample with the following fields:
246328
- tokens: list[int] of tokenized messages
247329
- mask: list[bool] of masks for the tokenized messages
248-
- encoder_input: dict[str, Any] of transformed images
330+
- encoder_input: dict[str, Any] of transformed images and videos
249331
"""
250-
encoder_input = {"vision": {"images": []}}
332+
images_list = []
333+
videos_list = []
334+
251335
messages = sample["messages"]
252336
for message in messages:
253337
for content in message.content:
254338
if content["type"] == "image":
255339
image = content["content"]
256340
tiles, ar = self.transform_image(image, inference=inference)
257-
encoder_input["vision"]["images"].append(tiles)
341+
images_list.append(tiles)
258342

259343
# Add number of patch tokens, tiles, and aspect ratio to metadata
260344
# so tokenizer can add the corresponding special tokens
261345
content["patch_tokens_per_tile"] = self.patch_tokens_per_tile
262346
content["aspect_ratio"] = ar
347+
elif content["type"] == "video":
348+
video = content["content"]
349+
processed_frames, num_frames = self.transform_video(video, inference=inference)
350+
351+
# Flatten video frames to individual images for the vision encoder
352+
# processed_frames shape: (num_frames, num_tiles, channels, height, width)
353+
# We need to flatten to: (num_frames * num_tiles, channels, height, width)
354+
flattened_frames = processed_frames.view(-1, *processed_frames.shape[2:])
355+
videos_list.append(flattened_frames)
356+
357+
# Add metadata for video tokenization
358+
# Each frame is treated like an image for patch token calculation
359+
content["num_frames"] = num_frames
360+
content["patch_tokens_per_frame"] = self.patch_tokens_per_tile
361+
362+
# Create encoder_input in the format expected by EarlyFusionModel
363+
# Both vision and video use the same vision encoder, so they share the same data structure.
364+
# The differentiation happens at token level (<|patch|> vs <|video|>), not data level.
365+
encoder_input = {}
366+
if images_list:
367+
encoder_input["vision"] = {"images": images_list}
368+
369+
if videos_list:
370+
# Videos use "images" key because they're processed as sequences of images
371+
# by the same vision encoder used for static images
372+
encoder_input["vision"] = {"images": videos_list}
263373

264374
sample["encoder_input"] = encoder_input
265375
sample = self.tokenizer(sample, inference=inference)

0 commit comments

Comments
 (0)