22import torch .nn as nn
33import torch .nn .functional as F
44
5- from typing import TypedDict , Literal , Union , Callable , Optional
5+ from collections .abc import Mapping
6+ from typing import TypedDict , Literal , Union , Callable , Optional , NamedTuple
67from flash_attn import flash_attn_varlen_func
78from einops import rearrange
89from functools import partial , lru_cache
10+ from transformers .models .qwen2_vl import (Qwen2VLImageProcessor ,
11+ Qwen2VLProcessor )
12+ from transformers .models .qwen2_vl .configuration_qwen2_vl import (
13+ Qwen2VLConfig , Qwen2VLVisionConfig )
14+ from transformers .models .qwen2_vl .image_processing_qwen2_vl import smart_resize
15+ from transformers .models .qwen2_vl .video_processing_qwen2_vl import Qwen2VLVideoProcessor
916
1017from gllm .layers .activation import SiluAndMul
1118from gllm .layers .layernorm import RMSNorm
1724from gllm .utils import cast_overflow_tensors
1825
1926
27+ # For profile run
28+ _MAX_FRAMES_PER_VIDEO = 16
29+
30+ class ImageSize (NamedTuple ):
31+ width : int
32+ height : int
33+
2034# === Vision Inputs === #
2135
2236class Qwen2_5_VLImagePixelInputs (TypedDict ):
@@ -595,3 +609,174 @@ def forward(
595609 reverse_indices = torch .argsort (window_index )
596610 hidden_states = hidden_states [reverse_indices , :]
597611 return hidden_states
612+
613+ class Qwen2_5_VLProcessingInfo ():
614+
615+ def get_hf_config (self ):
616+ return self .ctx .get_hf_config (Qwen2VLConfig )
617+
618+ def get_hf_processor (self , ** kwargs : object ) -> Qwen2VLProcessor :
619+ return self .ctx .get_hf_processor (
620+ Qwen2VLProcessor ,
621+ use_fast = kwargs .pop ("use_fast" , True ),
622+ ** kwargs ,
623+ )
624+
625+ def get_image_processor (self , ** kwargs : object ) -> Qwen2VLImageProcessor :
626+ return self .get_hf_processor (** kwargs ).image_processor
627+
628+ def get_supported_mm_limits (self ) -> Mapping [str , Optional [int ]]:
629+ return {"image" : None , "video" : None }
630+
631+ def get_mm_max_tokens_per_item (
632+ self ,
633+ seq_len : int ,
634+ mm_counts : Mapping [str , int ],
635+ ) -> Mapping [str , int ]:
636+ max_image_tokens = self .get_max_image_tokens ()
637+ max_video_tokens = self .get_max_video_tokens (seq_len , mm_counts )
638+ return {"image" : max_image_tokens , "video" : max_video_tokens }
639+
640+ def _get_vision_info (
641+ self ,
642+ * ,
643+ image_width : int ,
644+ image_height : int ,
645+ num_frames : int = 1 ,
646+ do_resize : bool = True ,
647+ image_processor : Optional [Qwen2VLImageProcessor ],
648+ ) -> tuple [ImageSize , int ]:
649+ if image_processor is None :
650+ image_processor = self .get_image_processor ()
651+
652+ hf_config = self .get_hf_config ()
653+ vision_config = hf_config .vision_config
654+ patch_size = vision_config .patch_size
655+ merge_size = vision_config .spatial_merge_size
656+ temporal_patch_size = vision_config .temporal_patch_size
657+
658+ if do_resize :
659+ resized_height , resized_width = smart_resize (
660+ height = image_height ,
661+ width = image_width ,
662+ factor = patch_size * merge_size ,
663+ min_pixels = image_processor .min_pixels ,
664+ max_pixels = image_processor .max_pixels ,
665+ )
666+ preprocessed_size = ImageSize (width = resized_width ,
667+ height = resized_height )
668+ else :
669+ preprocessed_size = ImageSize (width = image_width ,
670+ height = image_height )
671+
672+ # NOTE: Frames are padded to be divisible by `temporal_patch_size`
673+ # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
674+ padded_num_frames = num_frames + num_frames % temporal_patch_size
675+
676+ grid_t = max (padded_num_frames // temporal_patch_size , 1 )
677+ grid_h = preprocessed_size .height // patch_size
678+ grid_w = preprocessed_size .width // patch_size
679+
680+ num_patches = grid_t * grid_h * grid_w
681+ num_vision_tokens = num_patches // (merge_size ** 2 )
682+
683+ return preprocessed_size , num_vision_tokens
684+
685+ def get_num_image_tokens (
686+ self ,
687+ * ,
688+ image_width : int ,
689+ image_height : int ,
690+ image_processor : Optional [Qwen2VLImageProcessor ],
691+ ) -> int :
692+ _ , num_image_tokens = self ._get_vision_info (
693+ image_width = image_width ,
694+ image_height = image_height ,
695+ image_processor = image_processor ,
696+ )
697+ return num_image_tokens
698+
699+ def get_num_video_tokens (
700+ self ,
701+ * ,
702+ image_width : int ,
703+ image_height : int ,
704+ num_frames : int ,
705+ image_processor : Optional [Qwen2VLImageProcessor ],
706+ ) -> int :
707+ _ , num_video_tokens = self ._get_vision_info (
708+ image_width = image_width ,
709+ image_height = image_height ,
710+ num_frames = num_frames ,
711+ image_processor = image_processor ,
712+ )
713+ return num_video_tokens
714+
715+ def get_image_size_with_most_features (self ) -> ImageSize :
716+ max_image_size , _ = self ._get_vision_info (
717+ image_width = 9999999 ,
718+ image_height = 9999999 ,
719+ image_processor = None ,
720+ )
721+ return max_image_size
722+
723+ def get_max_image_tokens (self ) -> int :
724+ target_width , target_height = self .get_image_size_with_most_features ()
725+
726+ return self .get_num_image_tokens (
727+ image_width = target_width ,
728+ image_height = target_height ,
729+ image_processor = None ,
730+ )
731+
732+ def _get_max_video_frames (self , max_tokens : int ) -> int :
733+ target_width , target_height = self .get_image_size_with_most_features ()
734+
735+ num_frames = 0
736+
737+ while True :
738+ next_num_frames = num_frames + 1
739+ next_max_tokens = self .get_num_video_tokens (
740+ image_width = target_width ,
741+ image_height = target_height ,
742+ num_frames = next_num_frames ,
743+ image_processor = None ,
744+ )
745+
746+ if next_max_tokens > max_tokens :
747+ break
748+
749+ num_frames = next_num_frames
750+
751+ return num_frames
752+
753+ def get_num_frames_with_most_features (
754+ self ,
755+ seq_len : int ,
756+ mm_counts : Mapping [str , int ],
757+ ) -> int :
758+ max_images = mm_counts .get ("image" , 0 )
759+ max_videos = mm_counts .get ("video" , 0 )
760+
761+ max_image_tokens = self .get_max_image_tokens () * max_images
762+ max_total_frames = self ._get_max_video_frames (seq_len -
763+ max_image_tokens )
764+ max_frames_per_video = min (max_total_frames // max (max_videos , 1 ),
765+ _MAX_FRAMES_PER_VIDEO )
766+
767+ return max (max_frames_per_video , 1 )
768+
769+ def get_max_video_tokens (
770+ self ,
771+ seq_len : int ,
772+ mm_counts : Mapping [str , int ],
773+ ) -> int :
774+ target_width , target_height = self .get_image_size_with_most_features ()
775+
776+ return self .get_num_video_tokens (
777+ image_width = target_width ,
778+ image_height = target_height ,
779+ num_frames = self .get_num_frames_with_most_features (
780+ seq_len , mm_counts ),
781+ image_processor = None ,
782+ )
0 commit comments