Skip to content

Commit e9d7325

Browse files
committed
Add Qwen2_5_VLProcessingInfo
1 parent 32e8273 commit e9d7325

1 file changed

Lines changed: 186 additions & 1 deletion

File tree

gllm/models/qwen2_5_vl.py

Lines changed: 186 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,17 @@
22
import torch.nn as nn
33
import torch.nn.functional as F
44

5-
from typing import TypedDict, Literal, Union, Callable, Optional
5+
from collections.abc import Mapping
6+
from typing import TypedDict, Literal, Union, Callable, Optional, NamedTuple
67
from flash_attn import flash_attn_varlen_func
78
from einops import rearrange
89
from functools import partial, lru_cache
10+
from transformers.models.qwen2_vl import (Qwen2VLImageProcessor,
11+
Qwen2VLProcessor)
12+
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
13+
Qwen2VLConfig, Qwen2VLVisionConfig)
14+
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
15+
from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor
916

1017
from gllm.layers.activation import SiluAndMul
1118
from gllm.layers.layernorm import RMSNorm
@@ -17,6 +24,13 @@
1724
from gllm.utils import cast_overflow_tensors
1825

1926

27+
# For profile run
28+
_MAX_FRAMES_PER_VIDEO = 16
29+
30+
class ImageSize(NamedTuple):
31+
width: int
32+
height: int
33+
2034
# === Vision Inputs === #
2135

2236
class Qwen2_5_VLImagePixelInputs(TypedDict):
@@ -595,3 +609,174 @@ def forward(
595609
reverse_indices = torch.argsort(window_index)
596610
hidden_states = hidden_states[reverse_indices, :]
597611
return hidden_states
612+
613+
class Qwen2_5_VLProcessingInfo():
614+
615+
def get_hf_config(self):
616+
return self.ctx.get_hf_config(Qwen2VLConfig)
617+
618+
def get_hf_processor(self, **kwargs: object) -> Qwen2VLProcessor:
619+
return self.ctx.get_hf_processor(
620+
Qwen2VLProcessor,
621+
use_fast=kwargs.pop("use_fast", True),
622+
**kwargs,
623+
)
624+
625+
def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessor:
626+
return self.get_hf_processor(**kwargs).image_processor
627+
628+
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
629+
return {"image": None, "video": None}
630+
631+
def get_mm_max_tokens_per_item(
632+
self,
633+
seq_len: int,
634+
mm_counts: Mapping[str, int],
635+
) -> Mapping[str, int]:
636+
max_image_tokens = self.get_max_image_tokens()
637+
max_video_tokens = self.get_max_video_tokens(seq_len, mm_counts)
638+
return {"image": max_image_tokens, "video": max_video_tokens}
639+
640+
def _get_vision_info(
641+
self,
642+
*,
643+
image_width: int,
644+
image_height: int,
645+
num_frames: int = 1,
646+
do_resize: bool = True,
647+
image_processor: Optional[Qwen2VLImageProcessor],
648+
) -> tuple[ImageSize, int]:
649+
if image_processor is None:
650+
image_processor = self.get_image_processor()
651+
652+
hf_config = self.get_hf_config()
653+
vision_config = hf_config.vision_config
654+
patch_size = vision_config.patch_size
655+
merge_size = vision_config.spatial_merge_size
656+
temporal_patch_size = vision_config.temporal_patch_size
657+
658+
if do_resize:
659+
resized_height, resized_width = smart_resize(
660+
height=image_height,
661+
width=image_width,
662+
factor=patch_size * merge_size,
663+
min_pixels=image_processor.min_pixels,
664+
max_pixels=image_processor.max_pixels,
665+
)
666+
preprocessed_size = ImageSize(width=resized_width,
667+
height=resized_height)
668+
else:
669+
preprocessed_size = ImageSize(width=image_width,
670+
height=image_height)
671+
672+
# NOTE: Frames are padded to be divisible by `temporal_patch_size`
673+
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
674+
padded_num_frames = num_frames + num_frames % temporal_patch_size
675+
676+
grid_t = max(padded_num_frames // temporal_patch_size, 1)
677+
grid_h = preprocessed_size.height // patch_size
678+
grid_w = preprocessed_size.width // patch_size
679+
680+
num_patches = grid_t * grid_h * grid_w
681+
num_vision_tokens = num_patches // (merge_size**2)
682+
683+
return preprocessed_size, num_vision_tokens
684+
685+
def get_num_image_tokens(
686+
self,
687+
*,
688+
image_width: int,
689+
image_height: int,
690+
image_processor: Optional[Qwen2VLImageProcessor],
691+
) -> int:
692+
_, num_image_tokens = self._get_vision_info(
693+
image_width=image_width,
694+
image_height=image_height,
695+
image_processor=image_processor,
696+
)
697+
return num_image_tokens
698+
699+
def get_num_video_tokens(
700+
self,
701+
*,
702+
image_width: int,
703+
image_height: int,
704+
num_frames: int,
705+
image_processor: Optional[Qwen2VLImageProcessor],
706+
) -> int:
707+
_, num_video_tokens = self._get_vision_info(
708+
image_width=image_width,
709+
image_height=image_height,
710+
num_frames=num_frames,
711+
image_processor=image_processor,
712+
)
713+
return num_video_tokens
714+
715+
def get_image_size_with_most_features(self) -> ImageSize:
716+
max_image_size, _ = self._get_vision_info(
717+
image_width=9999999,
718+
image_height=9999999,
719+
image_processor=None,
720+
)
721+
return max_image_size
722+
723+
def get_max_image_tokens(self) -> int:
724+
target_width, target_height = self.get_image_size_with_most_features()
725+
726+
return self.get_num_image_tokens(
727+
image_width=target_width,
728+
image_height=target_height,
729+
image_processor=None,
730+
)
731+
732+
def _get_max_video_frames(self, max_tokens: int) -> int:
733+
target_width, target_height = self.get_image_size_with_most_features()
734+
735+
num_frames = 0
736+
737+
while True:
738+
next_num_frames = num_frames + 1
739+
next_max_tokens = self.get_num_video_tokens(
740+
image_width=target_width,
741+
image_height=target_height,
742+
num_frames=next_num_frames,
743+
image_processor=None,
744+
)
745+
746+
if next_max_tokens > max_tokens:
747+
break
748+
749+
num_frames = next_num_frames
750+
751+
return num_frames
752+
753+
def get_num_frames_with_most_features(
754+
self,
755+
seq_len: int,
756+
mm_counts: Mapping[str, int],
757+
) -> int:
758+
max_images = mm_counts.get("image", 0)
759+
max_videos = mm_counts.get("video", 0)
760+
761+
max_image_tokens = self.get_max_image_tokens() * max_images
762+
max_total_frames = self._get_max_video_frames(seq_len -
763+
max_image_tokens)
764+
max_frames_per_video = min(max_total_frames // max(max_videos, 1),
765+
_MAX_FRAMES_PER_VIDEO)
766+
767+
return max(max_frames_per_video, 1)
768+
769+
def get_max_video_tokens(
770+
self,
771+
seq_len: int,
772+
mm_counts: Mapping[str, int],
773+
) -> int:
774+
target_width, target_height = self.get_image_size_with_most_features()
775+
776+
return self.get_num_video_tokens(
777+
image_width=target_width,
778+
image_height=target_height,
779+
num_frames=self.get_num_frames_with_most_features(
780+
seq_len, mm_counts),
781+
image_processor=None,
782+
)

0 commit comments

Comments
 (0)