4242from .modeling_radio import RADIOVisionModel , calc_seq_lens
4343from .modeling_utils import register_auto_model
4444
45- VIDEO_PRUNING_RATIO = float (os .getenv ("TLLM_VIDEO_PRUNING_RATIO" , "0" ))
4645# Set max_num_tiles to 1 for video modality, to match the training behavior.
4746VIDEO_MAX_NUM_TILES = 1
4847IMAGE_PLACEHOLDER = "<image>"
@@ -257,7 +256,10 @@ def __init__(self, model_config: ModelConfig[transformers.PretrainedConfig]):
257256 raise NotImplementedError (
258257 f"Unsupported { config .ps_version = } . Supported versions: { supported_versions } ."
259258 )
260- self .video_pruning_ratio = VIDEO_PRUNING_RATIO
259+ # Use config value if explicitly set (EVS enabled), otherwise default to 0.0 (EVS disabled)
260+ self .video_pruning_rate = (
261+ model_config .video_pruning_rate if model_config .video_pruning_rate is not None else 0.0
262+ )
261263
262264 # Construct the vision projection.
263265 self .vit_hidden_size = config .vit_hidden_size
@@ -414,7 +416,7 @@ def apply_evs_per_video(
414416 video_embeds = reshaped_partial_mm_embed ,
415417 video_size = (t , p * ih , iw ),
416418 spatial_merge_size = self .spatial_merge_size ,
417- pruning_ratio = self .video_pruning_ratio ,
419+ pruning_ratio = self .video_pruning_rate ,
418420 flatten_output = False ,
419421 ).flatten (start_dim = 1 )
420422 # -> [num_frames, num_patches_per_frame*h*w]
@@ -437,7 +439,7 @@ def apply_evs(
437439 ) -> Tuple [List [torch .Tensor ], Optional [List [List [int ] | None ]]]:
438440 """Apply EVS to the multimodal embedding."""
439441 # Skip EVS if pruning ratio is 0.
440- if self .video_pruning_ratio <= 0 :
442+ if self .video_pruning_rate <= 0 :
441443 return mm_embedding , None
442444
443445 modality_types = [
@@ -448,7 +450,7 @@ def apply_evs(
448450 return mm_embedding , None
449451
450452 video_size_list = [
451- multimodal_data [modality_type ][ "video_size" ]
453+ multimodal_data [modality_type ]. get ( "video_size" ) if modality_type == "video" else None
452454 for modality_type , multimodal_data in zip (modality_types , multimodal_data_lst )
453455 ]
454456 mm_embedding_evs = []
@@ -487,17 +489,23 @@ def forward(
487489 pixel_values_flat = data ["pixel_values" ]
488490 image_sizes = data ["image_sizes" ]
489491 embeds = self .extract_feature_dynamic (pixel_values_flat , image_sizes )
490- mm_embedding .append (embeds .reshape (- 1 , self .llm_hidden_size ))
492+ # Keep 3D shape for apply_evs, will reshape to 2D after EVS
493+ mm_embedding .append (embeds )
491494 # This applies to images without dynamic resolution, or videos.
492495 else :
493496 # Fallback to fixed-tile extraction for this modality.
494497 pixel_values = data ["pixel_values" ]
495498 embeds = self .extract_feature (pixel_values )
496- mm_embedding .append (embeds .reshape (- 1 , self .llm_hidden_size ))
499+ # Keep 3D shape [num_patches, h*w, hidden] for apply_evs
500+ mm_embedding .append (embeds )
497501
498- return mm_embedding , [None ] * len (modality_types )
502+ # Apply EVS if video_pruning_rate > 0
503+ mm_embedding , num_tokens_in_videos = self .apply_evs (mm_embedding , multimodal_data_lst )
504+ # Reshape to 2D after EVS: [num_patches*h*w, hidden_size]
505+ mm_embedding = [m .reshape (- 1 , self .llm_hidden_size ) for m in mm_embedding ]
506+ return mm_embedding , num_tokens_in_videos
499507
500- # Existing fixed-tile path.
508+ # Existing fixed-tile path (unreachable, kept for reference) .
501509 pixel_values = [
502510 multimodal_data [modality_type ]["pixel_values" ]
503511 for modality_type , multimodal_data in zip (modality_types , multimodal_data_lst )
@@ -530,6 +538,9 @@ def __init__(
530538 trust_remote_code : bool = True ,
531539 ** kwargs ,
532540 ):
541+ # Extract video_pruning_rate before passing kwargs to parent
542+ video_pruning_rate = kwargs .pop ("video_pruning_rate" , None ) or 0.0
543+
533544 super ().__init__ (
534545 model_path = model_path ,
535546 config = config ,
@@ -563,7 +574,7 @@ def __init__(
563574 self .num_image_token = int (
564575 (self .image_size // self .patch_size ) ** 2 * (self .downsample_ratio ** 2 )
565576 )
566- self .video_pruning_ratio = VIDEO_PRUNING_RATIO
577+ self .video_pruning_rate = video_pruning_rate
567578 self .img_context_token = self .config .img_context_token
568579 self .video_context_token = self .config .video_context_token
569580 self .img_start_token = self .config .img_start_token
@@ -747,15 +758,15 @@ def get_num_tokens_per_video(
747758 self ,
748759 * ,
749760 video : List [Image .Image ],
750- video_pruning_ratio : Optional [float ] = None ,
761+ video_pruning_rate : Optional [float ] = None ,
751762 ** kwargs ,
752763 ):
753764 # Use VIDEO_PRUNING_RATIO if not explicitly provided
754- if video_pruning_ratio is None :
755- video_pruning_ratio = self .video_pruning_ratio
765+ if video_pruning_rate is None :
766+ video_pruning_rate = self .video_pruning_rate
756767
757768 num_frames = len (video )
758- if video_pruning_ratio > 0 :
769+ if video_pruning_rate > 0 :
759770 num_tokens_per_frame = self .get_num_tokens_per_image (
760771 image = video [0 ],
761772 max_num_tiles = VIDEO_MAX_NUM_TILES ,
@@ -767,7 +778,7 @@ def get_num_tokens_per_video(
767778 num_total_tokens = compute_retained_tokens_count (
768779 video_size = video_size ,
769780 spatial_merge_size = self .spatial_merge_size ,
770- pruning_ratio = video_pruning_ratio ,
781+ pruning_ratio = video_pruning_rate ,
771782 )
772783 # Add special tokens for each frame.
773784 num_total_tokens += num_frames * len (self .get_mm_special_token_ids ())
@@ -776,7 +787,7 @@ def get_num_tokens_per_video(
776787 num_total_tokens = sum (
777788 self .get_num_tokens_per_image (
778789 image = frame ,
779- video_pruning_ratio = None ,
790+ video_pruning_rate = None ,
780791 max_num_tiles = VIDEO_MAX_NUM_TILES ,
781792 ** kwargs ,
782793 )
@@ -961,7 +972,7 @@ def _process_video_prompts(
961972 processed_query .extend (frame_prompts )
962973 # Video_context_token as placeholder,
963974 # it will be replaced with the real image_tokens_per_frames during model forward.
964- if self .video_pruning_ratio > 0 :
975+ if self .video_pruning_rate > 0 :
965976 evs_query .append (split_text_prompt [video_index ])
966977 evs_query .append ("This is a video:\n " )
967978 for frame_sep in frame_separators :
@@ -986,7 +997,7 @@ def _process_video_prompts(
986997 ]
987998 input_ids = torch .cat (input_ids_lst , dim = 1 )
988999
989- if self .video_pruning_ratio > 0 :
1000+ if self .video_pruning_rate > 0 :
9901001 evs_query .append (split_text_prompt [- 1 ])
9911002 evs_ids = [
9921003 self .tokenizer .encode (
@@ -1009,11 +1020,11 @@ def _compute_token_numbers_per_video(self, video_size_lst: List[Tuple]) -> List[
10091020 img_height = video_size [2 ]
10101021 img_width = video_size [3 ]
10111022
1012- if self .video_pruning_ratio > 0 :
1023+ if self .video_pruning_rate > 0 :
10131024 desired_num_tokens = compute_retained_tokens_count (
10141025 video_size = (num_frames , num_patches_per_frame * img_height , img_width ),
10151026 spatial_merge_size = self .spatial_merge_size ,
1016- pruning_ratio = self .video_pruning_ratio ,
1027+ pruning_ratio = self .video_pruning_rate ,
10171028 )
10181029 # It is dummy tokens and will be adjusted in VisionEncoder after applied EVS.
10191030 # Need to know the length of the full input ids ahead,
@@ -1069,7 +1080,7 @@ def __call__(
10691080 # Store input_ids for image modality here when EVS is enabled,
10701081 # which will be used in merge_evs_mm_embeds later.
10711082 modality_data ["evs_ids" ] = (
1072- input_ids [0 ].to (torch .int32 ) if self .video_pruning_ratio > 0 else None
1083+ input_ids [0 ].to (torch .int32 ) if self .video_pruning_rate > 0 else None
10731084 )
10741085 elif videos is not None :
10751086 modality_type = "video"
@@ -1249,7 +1260,10 @@ def __init__(self, model_config: ModelConfig):
12491260 self .sound_context_token_id = getattr (config , "sound_context_token_id" , None )
12501261 self .post_config ()
12511262 self .is_loaded = True
1252- self .video_pruning_ratio = VIDEO_PRUNING_RATIO
1263+ # Use config value if explicitly set (EVS enabled), otherwise default to 0.0 (EVS disabled)
1264+ self .video_pruning_rate = (
1265+ model_config .video_pruning_rate if model_config .video_pruning_rate is not None else 0.0
1266+ )
12531267
12541268 def load_weights (self , weights ):
12551269 # Load vision encoder weights.
@@ -1378,7 +1392,7 @@ def _encode_multimodal(
13781392 if modality_type in ("image" , "video" ):
13791393 embs , num_tokens = self .vision_encoder ([param ])
13801394 mm_embeddings .append (embs [0 ])
1381- mm_num_tokens .append (num_tokens [0 ])
1395+ mm_num_tokens .append (num_tokens [0 ] if num_tokens is not None else None )
13821396 elif modality_type == "audio" :
13831397 mm_embeddings .append (self ._encode_audio (param ))
13841398 mm_num_tokens .append (None )
@@ -1421,7 +1435,7 @@ def forward(
14211435 "the TLLM_MULTIMODAL_DISAGGREGATED environment variable, or set it to '0'."
14221436 )
14231437 # Adjust input_ids in videos if EVS is applied.
1424- if self .video_pruning_ratio > 0 :
1438+ if self .video_pruning_rate > 0 :
14251439 input_ids = self .merge_evs_mm_embeds (
14261440 num_tokens_in_videos ,
14271441 multimodal_params = multimodal_params [:num_context_requests ],
0 commit comments