-
-
Notifications
You must be signed in to change notification settings - Fork 12k
[Model] Add video input support for transformers modeling backend #30680
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,7 +22,11 @@ | |
| import torch | ||
|
|
||
| from vllm.config.utils import getattr_iter | ||
| from vllm.model_executor.models.interfaces import SupportsMRoPE, SupportsMultiModal | ||
| from vllm.model_executor.models.interfaces import ( | ||
| MultiModalEmbeddings, | ||
| SupportsMRoPE, | ||
| SupportsMultiModal, | ||
| ) | ||
| from vllm.model_executor.models.utils import WeightsMapper | ||
| from vllm.multimodal import MultiModalKwargsItems | ||
| from vllm.multimodal.inputs import ( | ||
|
|
@@ -33,7 +37,11 @@ | |
| MultiModalUUIDDict, | ||
| PlaceholderRange, | ||
| ) | ||
| from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems | ||
| from vllm.multimodal.parse import ( | ||
| ImageProcessorItems, | ||
| MultiModalDataItems, | ||
| VideoProcessorItems, | ||
| ) | ||
| from vllm.multimodal.processing import BaseMultiModalProcessor, BaseProcessingInfo | ||
| from vllm.multimodal.profiling import BaseDummyInputsBuilder | ||
| from vllm.sequence import IntermediateTensors | ||
|
|
@@ -55,10 +63,13 @@ | |
|
|
||
| class MultiModalProcessingInfo(BaseProcessingInfo): | ||
| def get_supported_mm_limits(self): | ||
| return {"image": None} | ||
| return {"image": None, "video": None} | ||
|
|
||
| def get_mm_max_tokens_per_item(self, seq_len, mm_counts): | ||
| return {"image": self.get_max_image_tokens()} | ||
| return { | ||
| "image": self.get_max_image_tokens(), | ||
| "video": self.get_max_video_tokens(seq_len), | ||
| } | ||
|
|
||
| def get_max_image_tokens(self) -> int: | ||
| width, height = self.get_max_image_size() | ||
|
|
@@ -71,20 +82,52 @@ def get_max_image_tokens(self) -> int: | |
| image_tokens = mm_tokens["num_image_tokens"][0] | ||
| return image_tokens | ||
|
|
||
| def _get_video_tokens(self, num_frames, width, height) -> int: | ||
| processor = self.get_hf_processor() | ||
| multimodal_config = self.ctx.model_config.multimodal_config | ||
| mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {} | ||
| mm_tokens = processor._get_num_multimodal_tokens( | ||
| video_sizes=([num_frames, height, width],), **mm_processor_kwargs | ||
| ) | ||
| video_tokens = mm_tokens["num_video_tokens"][0] | ||
| return video_tokens | ||
|
|
||
| def get_max_video_tokens(self, seq_len: int) -> int: | ||
| width, height = self.get_max_image_size() | ||
| num_frames = self.get_max_video_frames(seq_len) | ||
| return self._get_video_tokens(num_frames, width, height) | ||
|
|
||
| def get_max_image_size(self): | ||
| return 10_000, 10_000 # hardcode for arbitrary very large size | ||
|
|
||
| def get_max_video_frames(self, seq_len: int) -> int: | ||
| width, height = self.get_max_image_size() | ||
|
|
||
| max_num_frames = 1 | ||
|
|
||
| while True: | ||
| next_num_frames = max_num_frames + 1 | ||
| video_tokens = self._get_video_tokens(next_num_frames, width, height) | ||
| if video_tokens > seq_len: | ||
| break | ||
|
|
||
| max_num_frames = next_num_frames | ||
|
|
||
| return max_num_frames | ||
|
|
||
|
|
||
| class MultiModalDummyInputsBuilder(BaseDummyInputsBuilder[MultiModalProcessingInfo]): | ||
| def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: | ||
| num_images = mm_counts.get("image", 0) | ||
| num_videos = mm_counts.get("video", 0) | ||
|
|
||
| processor = self.info.get_hf_processor() | ||
| if "gemma3" in processor.__class__.__name__.lower(): | ||
| image_token = processor.boi_token | ||
| else: | ||
| image_token = getattr(processor, "image_token", "") | ||
| return image_token * num_images | ||
| video_token = getattr(processor, "video_token", "") | ||
| return image_token * num_images + video_token * num_videos | ||
|
|
||
| def get_dummy_mm_data( | ||
| self, | ||
|
|
@@ -93,10 +136,14 @@ def get_dummy_mm_data( | |
| mm_options: Mapping[str, "BaseDummyOptions"] | None = None, | ||
| ) -> MultiModalDataDict: | ||
| num_images = mm_counts.get("image", 0) | ||
| num_videos = mm_counts.get("video", 0) | ||
|
|
||
| target_width, target_height = self.info.get_max_image_size() | ||
| max_total_frames = self.info.get_max_video_frames(seq_len) | ||
| target_num_frames = max_total_frames // max(num_videos, 1) | ||
|
|
||
| image_overrides = mm_options.get("image") if mm_options else None | ||
| video_overrides = mm_options.get("video") if mm_options else None | ||
|
|
||
| return { | ||
| "image": self._get_dummy_images( | ||
|
|
@@ -105,6 +152,13 @@ def get_dummy_mm_data( | |
| num_images=num_images, | ||
| overrides=image_overrides, | ||
| ), | ||
| "video": self._get_dummy_videos( | ||
| width=target_width, | ||
| height=target_height, | ||
| num_frames=target_num_frames, | ||
| num_videos=num_videos, | ||
| overrides=video_overrides, | ||
| ), | ||
| } | ||
|
|
||
|
|
||
|
|
@@ -148,8 +202,18 @@ def _get_mm_fields_config( | |
|
|
||
| # Keep these as batched, as they always have batch size as first dim | ||
| mm_fields["image_grid_thw"] = MultiModalFieldConfig.batched("image") | ||
| mm_fields["video_grid_thw"] = MultiModalFieldConfig.batched("image") | ||
| mm_fields["num_image_patches"] = MultiModalFieldConfig.batched("image") | ||
|
|
||
| video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) | ||
| video_grid_sizes = video_grid_thw.prod(-1) | ||
|
Comment on lines
+207
to
+208
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not super generalizable, only qwen-like models have a THW tensor. Ideally we need to use |
||
| mm_fields["pixel_values_videos"] = MultiModalFieldConfig.flat_from_sizes( | ||
| "video", video_grid_sizes | ||
| ) | ||
| mm_fields["video_embeds"] = MultiModalFieldConfig.flat_from_sizes( | ||
| "video", video_grid_sizes | ||
| ) | ||
| mm_fields["video_grid_thw"] = MultiModalFieldConfig.batched("video") | ||
| mm_fields["num_video_patches"] = MultiModalFieldConfig.batched("video") | ||
|
Comment on lines
+207
to
+216
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a place where the name of something on the Transformers side is assumed. We'd need to make sure that the Transformers team is happy with this as the standard name which will be propagated to all models in Transformers. |
||
| return mm_fields | ||
|
|
||
| def _get_hf_mm_data( | ||
|
|
@@ -211,24 +275,36 @@ def apply( | |
|
|
||
| # We can infer vLLM style placeholder from token type ids, if we split | ||
| # it for each input `mm_data`. | ||
| mm_positions = torch.where(mm_token_type_ids == 1)[1] | ||
| images = mm_items.get_items("image", ImageProcessorItems) | ||
| image_sizes = [] | ||
| if "image" in mm_items: | ||
| images = mm_items.get_items("image", ImageProcessorItems) | ||
| for item_idx in range(len(images)): | ||
| image_size = images.get_image_size(item_idx) | ||
| image_sizes.append((image_size.height, image_size.width)) | ||
|
|
||
| video_sizes = [] | ||
| if "video" in mm_items: | ||
| videos = mm_items.get_items("video", VideoProcessorItems) | ||
| for item_idx in range(len(videos)): | ||
| video_size = videos.get_frame_size(item_idx) | ||
| num_frames = videos.get_num_frames(item_idx) | ||
| video_sizes.append((num_frames, video_size.height, video_size.width)) | ||
|
|
||
| multimodal_config = self.info.ctx.model_config.multimodal_config | ||
| mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {} | ||
| image_sizes = [] | ||
| for item_idx in range(len(images)): | ||
| image_size = images.get_image_size(item_idx) | ||
| image_sizes.append((image_size.height, image_size.width)) | ||
|
|
||
| mm_tokens_per_modality = hf_processor._get_num_multimodal_tokens( | ||
| image_sizes=image_sizes, **mm_processor_kwargs | ||
| image_sizes=image_sizes, video_sizes=video_sizes, **mm_processor_kwargs | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think all processor expect |
||
| ) | ||
|
|
||
| mm_placeholders = {} | ||
|
|
||
| # image_token_ids | ||
| mm_positions = torch.where(mm_token_type_ids == 1)[1] | ||
| split_sizes = mm_tokens_per_modality["num_image_tokens"] | ||
| if split_sizes: | ||
| chunked_mm_positions = torch.split(mm_positions, split_sizes) | ||
| mm_tokens = torch.tensor(prompt_ids)[mm_token_type_ids[0].bool()] | ||
| mm_tokens = torch.tensor(prompt_ids)[mm_token_type_ids[0] == 1] | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Avoid computing
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also is there somewhere in Transformers where the token type ID is defined by modality? I don't like this magic number
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I made this change because |
||
| chunked_mm_tokens = torch.split(mm_tokens, split_sizes) | ||
| ranges = [ | ||
| PlaceholderRange( | ||
|
|
@@ -238,11 +314,34 @@ def apply( | |
| ) | ||
| for positions, mm_tokens in zip(chunked_mm_positions, chunked_mm_tokens) | ||
| ] | ||
| mm_placeholders = {"image": ranges} | ||
| mm_placeholders["image"] = ranges | ||
|
|
||
| processed_data["num_image_patches"] = torch.tensor( | ||
| mm_tokens_per_modality["num_image_patches"] | ||
| ) | ||
|
|
||
| # video_token_ids | ||
| mm_positions = torch.where(mm_token_type_ids == 2)[1] | ||
|
|
||
| split_sizes = mm_tokens_per_modality["num_video_tokens"] | ||
| if split_sizes: | ||
| chunked_mm_positions = torch.split(mm_positions, split_sizes) | ||
| mm_tokens = torch.tensor(prompt_ids)[mm_token_type_ids[0] == 2] | ||
| chunked_mm_tokens = torch.split(mm_tokens, split_sizes) | ||
| ranges = [ | ||
| PlaceholderRange( | ||
| offset=positions[0].item(), | ||
| length=positions.shape[0], | ||
| is_embed=(mm_tokens == hf_processor.video_token_id).bool(), | ||
| ) | ||
|
Comment on lines
+331
to
+336
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the hard part here are models which add timestamps between each video frame and the timestamps are encoded differently, they have no special tokens. So inferring the ranges is not as easy as with images |
||
| for positions, mm_tokens in zip(chunked_mm_positions, chunked_mm_tokens) | ||
| ] | ||
| mm_placeholders["video"] = ranges | ||
|
|
||
| processed_data["num_video_patches"] = torch.tensor( | ||
| mm_tokens_per_modality["num_video_patches"] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| ) | ||
|
Comment on lines
+323
to
+343
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This code is almost identical to the video code, could it be abstracted to a method? |
||
|
|
||
| mm_kwargs = MultiModalKwargsItems.from_hf_inputs( | ||
| processed_data, | ||
| self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs), | ||
|
|
@@ -330,24 +429,33 @@ def __init__(self, multimodal_model): | |
|
|
||
| return LanguageModel(self) | ||
|
|
||
| def embed_multimodal(self, **kwargs): | ||
| def embed_multimodal(self, **kwargs) -> MultiModalEmbeddings: | ||
| pixel_values: torch.Tensor | None = kwargs.pop("pixel_values", None) | ||
| image_embeds: torch.Tensor | None = kwargs.pop("image_embeds", None) | ||
| pixel_values_videos: torch.Tensor | None = kwargs.pop( | ||
| "pixel_values_videos", None | ||
| ) | ||
| video_embeds: torch.Tensor | None = kwargs.pop("video_embeds", None) | ||
|
|
||
| # Model might use `image_patches` instead of `pixel_values` | ||
| if pixel_values is None: | ||
| pixel_values = kwargs.pop("image_patches", None) | ||
|
|
||
| if image_embeds is not None: | ||
| return image_embeds | ||
| multimodal_embeddings: list[torch.Tensor] = [] | ||
|
|
||
| if pixel_values is None: | ||
| return None | ||
| if image_embeds is not None: | ||
| multimodal_embeddings += image_embeds | ||
|
|
||
| num_image_patches = kwargs.pop("num_image_patches") | ||
| kwargs.pop("token_type_ids", None) # used only in `forward` | ||
| num_image_patches = kwargs.pop("num_image_patches", None) | ||
| num_video_patches = kwargs.pop("num_video_patches", None) | ||
|
|
||
| if pixel_values is not None: | ||
| vision_embeddings = self.model.get_image_features(pixel_values, **kwargs) | ||
|
|
||
| if isinstance(vision_embeddings, tuple): | ||
| # For qwen3 vl, The deepstack visual features are also returned | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. right, though qwen3-vl will give bad performance without deepstack visual features, Currently vLLM doesn't support models like Qwen3-VL and Ovis2 on purpose |
||
| vision_embeddings = vision_embeddings[0] | ||
|
Comment on lines
+456
to
+458
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This might not be intentional. @zucchini-nlp should the return value of Qwen3 VL have been changed to this?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We never supportted qwen3-vl in vLLM because of this 😿 The model has to propagate We are currently doing standardization of
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mean in vLLM with the Transformers modelling backend? vLLM does support this model natively. Yeah a standardisation on the Transformers side would be great.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, with the backend the model was not supported when released
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
@zucchini-nlp Is there a rough timeline?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The PR is drafted in huggingface/transformers#42564, though there is no specific deadline. Should be merged before v5 definitely |
||
| if isinstance(vision_embeddings, torch.Tensor): | ||
| if vision_embeddings.ndim == 2: | ||
| vision_embeddings = vision_embeddings.unsqueeze(0) | ||
|
|
@@ -362,8 +470,36 @@ def embed_multimodal(self, **kwargs): | |
| embed.flatten(start_dim=0, end_dim=-2) | ||
| for embed in vision_embeddings | ||
| ] | ||
| multimodal_embeddings += vision_embeddings | ||
|
|
||
| return vision_embeddings | ||
| if video_embeds is not None: | ||
| multimodal_embeddings += video_embeds | ||
|
|
||
| if pixel_values_videos is not None: | ||
| vision_embeddings = self.model.get_video_features( | ||
| pixel_values_videos, **kwargs | ||
| ) | ||
|
|
||
| if isinstance(vision_embeddings, tuple): | ||
| # For qwen3 vl, The deepstack visual features are also returned | ||
| vision_embeddings = vision_embeddings[0] | ||
| if isinstance(vision_embeddings, torch.Tensor): | ||
| if vision_embeddings.ndim == 2: | ||
| vision_embeddings = vision_embeddings.unsqueeze(0) | ||
|
|
||
| # Embeddings have to be 2D tensors of length `num_images` | ||
| # but transformers returns concat tensors if each patch | ||
| # is of different size. We split it back to make vLLM happy | ||
| vision_embeddings = torch.split( | ||
| vision_embeddings, num_video_patches.flatten().tolist() | ||
| ) | ||
| vision_embeddings = [ | ||
|
Comment on lines
+490
to
+496
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not 100% sure, though for video modality this might not be needed. Mostly videos don't have several patches per one frame |
||
| embed.flatten(start_dim=0, end_dim=-2) | ||
| for embed in vision_embeddings | ||
| ] | ||
| multimodal_embeddings += vision_embeddings | ||
|
|
||
| return multimodal_embeddings | ||
hmellor marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def get_mrope_input_positions( | ||
| self, | ||
|
|
@@ -386,18 +522,14 @@ def get_mrope_input_positions( | |
| if k not in {"image_grid_thw", "video_grid_thw"} | ||
| ): | ||
| raise NotImplementedError( | ||
| "Transformers modeling backend only supports images." | ||
| "Transformers modeling backend only supports images and videos." | ||
| ) | ||
|
|
||
| image_grid_thw = kwargs.get("image_grid_thw", []) | ||
| video_grid_thw = kwargs.get("video_grid_thw", []) | ||
| image_grid_thw = kwargs.get("image_grid_thw", None) | ||
| video_grid_thw = kwargs.get("video_grid_thw", None) | ||
|
|
||
| image_grid_thw = (torch.stack if image_grid_thw else torch.tensor)( | ||
| image_grid_thw | ||
| ) | ||
| video_grid_thw = (torch.stack if video_grid_thw else torch.tensor)( | ||
| video_grid_thw | ||
| ) | ||
| image_grid_thw = torch.stack(image_grid_thw) if image_grid_thw else None | ||
| video_grid_thw = torch.stack(video_grid_thw) if video_grid_thw else None | ||
|
Comment on lines
+531
to
+532
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @zucchini-nlp was it important that vLLM passed empty grids rather than
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should be
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To clarify, it should be
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This part is called only with qwen-like models as I see, so it's when modality isn't present in request, I can't think of any model with mrope that doesn't support videos, we're safe |
||
|
|
||
| mrope_positions, mrope_position_delta = self.model.get_rope_index( | ||
| input_ids=torch.tensor(input_tokens).unsqueeze(0), | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -280,7 +280,7 @@ def get_frame_size(self, item_idx: int) -> ImageSize: | |||||
| if isinstance(image, PILImage.Image): | ||||||
| return ImageSize(*image.size) | ||||||
| if isinstance(image, (np.ndarray, torch.Tensor)): | ||||||
| _, h, w = image.shape | ||||||
| w, h, _ = image.shape | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I doubt that this was wrong for all models already in vLLM, why does this need changing here?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm also very confused about this. I'm not sure why other models (like minicpmv) works. vllm/vllm/multimodal/profiling.py Line 225 in 0d0c929
while VideoLoader._read_frames uses (num_frames, height, width, 3)Line 77 in 0d0c929
so I think it should be w, h, c or h, w, c here.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For profiling purposes, the width and height are interchangable in general since most transformations don't care about that
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Strictly speaking the profiling code is wrong, so feel free to change that. |
||||||
| return ImageSize(w, h) | ||||||
|
|
||||||
| assert_never(image) | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1118,7 +1118,18 @@ def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs: | |
| assert mm_budget is not None | ||
|
|
||
| dummy_modality = mm_budget.get_modality_with_max_tokens() | ||
| return self._get_mm_dummy_batch(dummy_modality, num_seqs) | ||
|
|
||
| # TBD: | ||
| # The mm_dummy_batch below is only retrieved when | ||
| # supports_multimodal_raw_input_only is True. | ||
| # Currently, only the transform modeling backend and terratorch have | ||
| # supports_multimodal_raw_input_only as True. | ||
| # When testing the transform modeling backend, it was found that | ||
| # if num_seqs (usually the default 256) is passed in here, | ||
| # an OOM error occurs. | ||
| # It needs to be confirmed what value should be passed in here, | ||
| # for now it is fixed to 1. | ||
| return self._get_mm_dummy_batch(dummy_modality, 1) | ||
|
Comment on lines
1120
to
+1132
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hardcoding the dummy batch size to 1 for profiling is a pragmatic fix to avoid OOM errors, but it can lead to inaccurate memory profiling. This might cause the scheduler to underestimate memory usage, potentially leading to OOM errors in production with larger batches. A more robust approach would be to calculate a reasonable batch size based on the model's configuration. This provides a more realistic batch size for profiling, reducing the risk of production OOMs while still preventing OOMs during profiling. dummy_modality = mm_budget.get_modality_with_max_tokens()
max_tokens_per_item = mm_budget.max_tokens_by_modality.get(dummy_modality)
if max_tokens_per_item and max_tokens_per_item > 0:
# Heuristic to derive a reasonable batch size for profiling.
# Using max_num_seqs can cause OOM for vision models.
# Hardcoding to 1 can lead to inaccurate profiling.
num_items = self.scheduler_config.max_num_batched_tokens // max_tokens_per_item
# Also respect the per-prompt limit and max sequences.
max_items_for_modality = mm_budget.max_items_per_batch_by_modality[dummy_modality]
num_items = min(num_items, max_items_for_modality)
# Ensure at least 1 item.
num_items = max(num_items, 1)
else:
# Fallback for safety, though this path should ideally not be taken.
num_items = 1
return self._get_mm_dummy_batch(dummy_modality, num_items)
Comment on lines
+1122
to
+1132
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was recently working in this area and noticed this too. It causes 256 x 100MP images (the max image size defined in @ywang96 what do you think should be the correct behaviour here?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think here the profiling is actually correct for models that have
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So I prefer patching this on the model side. If removing
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I forget the exact reason we chose to use @DarkLight1337 do you know more about the key differences between
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Why is this not also the case when
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
IIRC this is to pass @christian-pinto maybe you can explain more about this.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Back then we have introduced So setting
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FWIW I am quite sure that is not how
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe you can handle it in a similar way as
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We use, and have introduced multimodal_raw_input processing, when adding the models in the Please, see my comment above for the rationale behind that field. |
||
|
|
||
| def _get_cumsum_and_arange( | ||
| self, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember that vLLM chooses only one modality when profling, whichever has more tokens. So I think we will need to safe-get
mm_tokens["num_video_tokens"]and otherwise set to0because not all models support videosThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Qwen2-VL would be a good example of this with
get_num_frames_with_most_features. But it gets itsmm_countsfrom the limits inget_supported_mm_limitsunless the user specified otherwise.Ideally there would be a way to set
get_supported_mm_limits()["video"]to zero when the model doesn't support video.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that would be great. I even can say how we infer if model supports videos from model class 😄 It has a class attribute
model.input_modalitiesThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh ok, are these class attributes or instance attributes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class attributes available from v5 and on