Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ Currently, the Transformers modeling backend works for the following:
- Architectures: encoder-only, decoder-only, mixture-of-experts
- Attention types: full attention and/or sliding attention

_*Vision-language models currently accept only image inputs. Support for video inputs will be added in a future release._

If the Transformers model implementation follows all the steps in [writing a custom model](#writing-custom-models) then, when used with the Transformers modeling backend, it will be compatible with the following features of vLLM:

- All the features listed in the [compatibility matrix](../features/README.md#feature-x-feature)
Expand Down
194 changes: 163 additions & 31 deletions vllm/model_executor/models/transformers/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
import torch

from vllm.config.utils import getattr_iter
from vllm.model_executor.models.interfaces import SupportsMRoPE, SupportsMultiModal
from vllm.model_executor.models.interfaces import (
MultiModalEmbeddings,
SupportsMRoPE,
SupportsMultiModal,
)
from vllm.model_executor.models.utils import WeightsMapper
from vllm.multimodal import MultiModalKwargsItems
from vllm.multimodal.inputs import (
Expand All @@ -33,7 +37,11 @@
MultiModalUUIDDict,
PlaceholderRange,
)
from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems
from vllm.multimodal.parse import (
ImageProcessorItems,
MultiModalDataItems,
VideoProcessorItems,
)
from vllm.multimodal.processing import BaseMultiModalProcessor, BaseProcessingInfo
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
Expand All @@ -55,10 +63,13 @@

class MultiModalProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self):
return {"image": None}
return {"image": None, "video": None}

def get_mm_max_tokens_per_item(self, seq_len, mm_counts):
return {"image": self.get_max_image_tokens()}
return {
"image": self.get_max_image_tokens(),
"video": self.get_max_video_tokens(seq_len),
Comment on lines +70 to +71
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember that vLLM chooses only one modality when profling, whichever has more tokens. So I think we will need to safe-get mm_tokens["num_video_tokens"] and otherwise set to 0 because not all models support videos

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Qwen2-VL would be a good example of this with get_num_frames_with_most_features. But it gets its mm_counts from the limits in get_supported_mm_limits unless the user specified otherwise.

Ideally there would be a way to set get_supported_mm_limits()["video"] to zero when the model doesn't support video.

Copy link
Contributor

@zucchini-nlp zucchini-nlp Dec 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that would be great. I even can say how we infer if model supports videos from model class 😄 It has a class attribute model.input_modalities

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh ok, are these class attributes or instance attributes?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

class attributes available from v5 and on

}

def get_max_image_tokens(self) -> int:
width, height = self.get_max_image_size()
Expand All @@ -71,20 +82,52 @@ def get_max_image_tokens(self) -> int:
image_tokens = mm_tokens["num_image_tokens"][0]
return image_tokens

def _get_video_tokens(self, num_frames, width, height) -> int:
processor = self.get_hf_processor()
multimodal_config = self.ctx.model_config.multimodal_config
mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {}
mm_tokens = processor._get_num_multimodal_tokens(
video_sizes=([num_frames, height, width],), **mm_processor_kwargs
)
video_tokens = mm_tokens["num_video_tokens"][0]
return video_tokens

def get_max_video_tokens(self, seq_len: int) -> int:
width, height = self.get_max_image_size()
num_frames = self.get_max_video_frames(seq_len)
return self._get_video_tokens(num_frames, width, height)

def get_max_image_size(self):
return 10_000, 10_000 # hardcode for arbitrary very large size

def get_max_video_frames(self, seq_len: int) -> int:
width, height = self.get_max_image_size()

max_num_frames = 1

while True:
next_num_frames = max_num_frames + 1
video_tokens = self._get_video_tokens(next_num_frames, width, height)
if video_tokens > seq_len:
break

max_num_frames = next_num_frames

return max_num_frames


class MultiModalDummyInputsBuilder(BaseDummyInputsBuilder[MultiModalProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)

processor = self.info.get_hf_processor()
if "gemma3" in processor.__class__.__name__.lower():
image_token = processor.boi_token
else:
image_token = getattr(processor, "image_token", "")
return image_token * num_images
video_token = getattr(processor, "video_token", "")
return image_token * num_images + video_token * num_videos

def get_dummy_mm_data(
self,
Expand All @@ -93,10 +136,14 @@ def get_dummy_mm_data(
mm_options: Mapping[str, "BaseDummyOptions"] | None = None,
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)

target_width, target_height = self.info.get_max_image_size()
max_total_frames = self.info.get_max_video_frames(seq_len)
target_num_frames = max_total_frames // max(num_videos, 1)

image_overrides = mm_options.get("image") if mm_options else None
video_overrides = mm_options.get("video") if mm_options else None

return {
"image": self._get_dummy_images(
Expand All @@ -105,6 +152,13 @@ def get_dummy_mm_data(
num_images=num_images,
overrides=image_overrides,
),
"video": self._get_dummy_videos(
width=target_width,
height=target_height,
num_frames=target_num_frames,
num_videos=num_videos,
overrides=video_overrides,
),
}


Expand Down Expand Up @@ -148,8 +202,18 @@ def _get_mm_fields_config(

# Keep these as batched, as they always have batch size as first dim
mm_fields["image_grid_thw"] = MultiModalFieldConfig.batched("image")
mm_fields["video_grid_thw"] = MultiModalFieldConfig.batched("image")
mm_fields["num_image_patches"] = MultiModalFieldConfig.batched("image")

video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
video_grid_sizes = video_grid_thw.prod(-1)
Comment on lines +207 to +208
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not super generalizable, only qwen-like models have a THW tensor. Ideally we need to use num_video_patches as in image modality

mm_fields["pixel_values_videos"] = MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes
)
mm_fields["video_embeds"] = MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes
)
mm_fields["video_grid_thw"] = MultiModalFieldConfig.batched("video")
mm_fields["num_video_patches"] = MultiModalFieldConfig.batched("video")
Comment on lines +207 to +216
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a place where the name of something on the Transformers side is assumed. We'd need to make sure that the Transformers team is happy with this as the standard name which will be propagated to all models in Transformers.

cc @zucchini-nlp

return mm_fields

def _get_hf_mm_data(
Expand Down Expand Up @@ -211,24 +275,36 @@ def apply(

# We can infer vLLM style placeholder from token type ids, if we split
# it for each input `mm_data`.
mm_positions = torch.where(mm_token_type_ids == 1)[1]
images = mm_items.get_items("image", ImageProcessorItems)
image_sizes = []
if "image" in mm_items:
images = mm_items.get_items("image", ImageProcessorItems)
for item_idx in range(len(images)):
image_size = images.get_image_size(item_idx)
image_sizes.append((image_size.height, image_size.width))

video_sizes = []
if "video" in mm_items:
videos = mm_items.get_items("video", VideoProcessorItems)
for item_idx in range(len(videos)):
video_size = videos.get_frame_size(item_idx)
num_frames = videos.get_num_frames(item_idx)
video_sizes.append((num_frames, video_size.height, video_size.width))

multimodal_config = self.info.ctx.model_config.multimodal_config
mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {}
image_sizes = []
for item_idx in range(len(images)):
image_size = images.get_image_size(item_idx)
image_sizes.append((image_size.height, image_size.width))

mm_tokens_per_modality = hf_processor._get_num_multimodal_tokens(
image_sizes=image_sizes, **mm_processor_kwargs
image_sizes=image_sizes, video_sizes=video_sizes, **mm_processor_kwargs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think all processor expect video_sizes, though it will be swallowed by kwargs and shouldn't raise issues

)

mm_placeholders = {}

# image_token_ids
mm_positions = torch.where(mm_token_type_ids == 1)[1]
split_sizes = mm_tokens_per_modality["num_image_tokens"]
if split_sizes:
chunked_mm_positions = torch.split(mm_positions, split_sizes)
mm_tokens = torch.tensor(prompt_ids)[mm_token_type_ids[0].bool()]
mm_tokens = torch.tensor(prompt_ids)[mm_token_type_ids[0] == 1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid computing mm_token_type_ids == 1 twice

Copy link
Member

@DarkLight1337 DarkLight1337 Dec 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also is there somewhere in Transformers where the token type ID is defined by modality? I don't like this magic number

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made this change because mm_token_type_ids[0].bool() cannot distinguish between image and video, which causes issues when testing with simultaneous image and video inputs.

chunked_mm_tokens = torch.split(mm_tokens, split_sizes)
ranges = [
PlaceholderRange(
Expand All @@ -238,11 +314,34 @@ def apply(
)
for positions, mm_tokens in zip(chunked_mm_positions, chunked_mm_tokens)
]
mm_placeholders = {"image": ranges}
mm_placeholders["image"] = ranges

processed_data["num_image_patches"] = torch.tensor(
mm_tokens_per_modality["num_image_patches"]
)

# video_token_ids
mm_positions = torch.where(mm_token_type_ids == 2)[1]

split_sizes = mm_tokens_per_modality["num_video_tokens"]
if split_sizes:
chunked_mm_positions = torch.split(mm_positions, split_sizes)
mm_tokens = torch.tensor(prompt_ids)[mm_token_type_ids[0] == 2]
chunked_mm_tokens = torch.split(mm_tokens, split_sizes)
ranges = [
PlaceholderRange(
offset=positions[0].item(),
length=positions.shape[0],
is_embed=(mm_tokens == hf_processor.video_token_id).bool(),
)
Comment on lines +331 to +336
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the hard part here are models which add timestamps between each video frame and the timestamps are encoded differently, they have no special tokens. So inferring the ranges is not as easy as with images

for positions, mm_tokens in zip(chunked_mm_positions, chunked_mm_tokens)
]
mm_placeholders["video"] = ranges

processed_data["num_video_patches"] = torch.tensor(
mm_tokens_per_modality["num_video_patches"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mm_tokens_per_modality isn't guaranteed to contain num_video_patches if the model doesn't support video modality. Same for key=num_video_tokens

)
Comment on lines +323 to +343
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is almost identical to the video code, could it be abstracted to a method?


mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
processed_data,
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
Expand Down Expand Up @@ -330,24 +429,33 @@ def __init__(self, multimodal_model):

return LanguageModel(self)

def embed_multimodal(self, **kwargs):
def embed_multimodal(self, **kwargs) -> MultiModalEmbeddings:
pixel_values: torch.Tensor | None = kwargs.pop("pixel_values", None)
image_embeds: torch.Tensor | None = kwargs.pop("image_embeds", None)
pixel_values_videos: torch.Tensor | None = kwargs.pop(
"pixel_values_videos", None
)
video_embeds: torch.Tensor | None = kwargs.pop("video_embeds", None)

# Model might use `image_patches` instead of `pixel_values`
if pixel_values is None:
pixel_values = kwargs.pop("image_patches", None)

if image_embeds is not None:
return image_embeds
multimodal_embeddings: list[torch.Tensor] = []

if pixel_values is None:
return None
if image_embeds is not None:
multimodal_embeddings += image_embeds

num_image_patches = kwargs.pop("num_image_patches")
kwargs.pop("token_type_ids", None) # used only in `forward`
num_image_patches = kwargs.pop("num_image_patches", None)
num_video_patches = kwargs.pop("num_video_patches", None)

if pixel_values is not None:
vision_embeddings = self.model.get_image_features(pixel_values, **kwargs)

if isinstance(vision_embeddings, tuple):
# For qwen3 vl, The deepstack visual features are also returned
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, though qwen3-vl will give bad performance without deepstack visual features, Currently vLLM doesn't support models like Qwen3-VL and Ovis2 on purpose

vision_embeddings = vision_embeddings[0]
Comment on lines +456 to +458
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might not be intentional. @zucchini-nlp should the return value of Qwen3 VL have been changed to this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We never supportted qwen3-vl in vLLM because of this 😿 The model has to propagate deepstack visual features down to the LM's forward, which we could theoretically overcome

We are currently doing standardization of get_image_features in transformers side and I think the best would be to always ask for a dict output here (i.e. we get all the outputs from vision encoder) and pass it over to LM, let it handle encoder outputs whichever way it pleases

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean in vLLM with the Transformers modelling backend? vLLM does support this model natively.

Yeah a standardisation on the Transformers side would be great.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, with the backend the model was not supported when released

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are currently doing standardization of get_image_features in transformers side and I think the best would be to always ask for a dict output here (i.e. we get all the outputs from vision encoder) and pass it over to LM, let it handle encoder outputs whichever way it pleases

@zucchini-nlp Is there a rough timeline?
If there’s an issue or PR tracking it, please point me to it, I’d be happy to help move this forward.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR is drafted in huggingface/transformers#42564, though there is no specific deadline. Should be merged before v5 definitely

if isinstance(vision_embeddings, torch.Tensor):
if vision_embeddings.ndim == 2:
vision_embeddings = vision_embeddings.unsqueeze(0)
Expand All @@ -362,8 +470,36 @@ def embed_multimodal(self, **kwargs):
embed.flatten(start_dim=0, end_dim=-2)
for embed in vision_embeddings
]
multimodal_embeddings += vision_embeddings

return vision_embeddings
if video_embeds is not None:
multimodal_embeddings += video_embeds

if pixel_values_videos is not None:
vision_embeddings = self.model.get_video_features(
pixel_values_videos, **kwargs
)

if isinstance(vision_embeddings, tuple):
# For qwen3 vl, The deepstack visual features are also returned
vision_embeddings = vision_embeddings[0]
if isinstance(vision_embeddings, torch.Tensor):
if vision_embeddings.ndim == 2:
vision_embeddings = vision_embeddings.unsqueeze(0)

# Embeddings have to be 2D tensors of length `num_images`
# but transformers returns concat tensors if each patch
# is of different size. We split it back to make vLLM happy
vision_embeddings = torch.split(
vision_embeddings, num_video_patches.flatten().tolist()
)
vision_embeddings = [
Comment on lines +490 to +496
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not 100% sure, though for video modality this might not be needed. Mostly videos don't have several patches per one frame

embed.flatten(start_dim=0, end_dim=-2)
for embed in vision_embeddings
]
multimodal_embeddings += vision_embeddings

return multimodal_embeddings

def get_mrope_input_positions(
self,
Expand All @@ -386,18 +522,14 @@ def get_mrope_input_positions(
if k not in {"image_grid_thw", "video_grid_thw"}
):
raise NotImplementedError(
"Transformers modeling backend only supports images."
"Transformers modeling backend only supports images and videos."
)

image_grid_thw = kwargs.get("image_grid_thw", [])
video_grid_thw = kwargs.get("video_grid_thw", [])
image_grid_thw = kwargs.get("image_grid_thw", None)
video_grid_thw = kwargs.get("video_grid_thw", None)

image_grid_thw = (torch.stack if image_grid_thw else torch.tensor)(
image_grid_thw
)
video_grid_thw = (torch.stack if video_grid_thw else torch.tensor)(
video_grid_thw
)
image_grid_thw = torch.stack(image_grid_thw) if image_grid_thw else None
video_grid_thw = torch.stack(video_grid_thw) if video_grid_thw else None
Comment on lines +531 to +532
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zucchini-nlp was it important that vLLM passed empty grids rather than None to get_rope_index?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be None if the modality is not present

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To clarify, it should be None if:

  • the modality is not present in the individual request? (i.e. an image model but no image has been passed)
  • the modelity is not present in the model? (i.e. a model which doesn't support video)

Copy link
Contributor

@zucchini-nlp zucchini-nlp Dec 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is called only with qwen-like models as I see, so it's when modality isn't present in request, I can't think of any model with mrope that doesn't support videos, we're safe


mrope_positions, mrope_position_delta = self.model.get_rope_index(
input_ids=torch.tensor(input_tokens).unsqueeze(0),
Expand Down
2 changes: 1 addition & 1 deletion vllm/multimodal/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def get_frame_size(self, item_idx: int) -> ImageSize:
if isinstance(image, PILImage.Image):
return ImageSize(*image.size)
if isinstance(image, (np.ndarray, torch.Tensor)):
_, h, w = image.shape
w, h, _ = image.shape
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I doubt that this was wrong for all models already in vLLM, why does this need changing here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also very confused about this. I'm not sure why other models (like minicpmv) works.
I see that _get_dummy_videos uses a shape of (num_frames, width, height, 3)

video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)

while VideoLoader._read_frames uses (num_frames, height, width, 3)
frames = np.empty((num_expected_frames, height, width, 3), dtype=np.uint8)

so I think it should be w, h, c or h, w, c here.

Copy link
Member

@DarkLight1337 DarkLight1337 Dec 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For profiling purposes, the width and height are interchangable in general since most transformations don't care about that

Copy link
Member

@DarkLight1337 DarkLight1337 Dec 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strictly speaking the profiling code is wrong, so feel free to change that.

return ImageSize(w, h)

assert_never(image)
Expand Down
13 changes: 12 additions & 1 deletion vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,7 +1118,18 @@ def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs:
assert mm_budget is not None

dummy_modality = mm_budget.get_modality_with_max_tokens()
return self._get_mm_dummy_batch(dummy_modality, num_seqs)

# TBD:
# The mm_dummy_batch below is only retrieved when
# supports_multimodal_raw_input_only is True.
# Currently, only the transform modeling backend and terratorch have
# supports_multimodal_raw_input_only as True.
# When testing the transform modeling backend, it was found that
# if num_seqs (usually the default 256) is passed in here,
# an OOM error occurs.
# It needs to be confirmed what value should be passed in here,
# for now it is fixed to 1.
return self._get_mm_dummy_batch(dummy_modality, 1)
Comment on lines 1120 to +1132
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Hardcoding the dummy batch size to 1 for profiling is a pragmatic fix to avoid OOM errors, but it can lead to inaccurate memory profiling. This might cause the scheduler to underestimate memory usage, potentially leading to OOM errors in production with larger batches.

A more robust approach would be to calculate a reasonable batch size based on the model's configuration. This provides a more realistic batch size for profiling, reducing the risk of production OOMs while still preventing OOMs during profiling.

        dummy_modality = mm_budget.get_modality_with_max_tokens()
        max_tokens_per_item = mm_budget.max_tokens_by_modality.get(dummy_modality)

        if max_tokens_per_item and max_tokens_per_item > 0:
            # Heuristic to derive a reasonable batch size for profiling.
            # Using max_num_seqs can cause OOM for vision models.
            # Hardcoding to 1 can lead to inaccurate profiling.
            num_items = self.scheduler_config.max_num_batched_tokens // max_tokens_per_item
            # Also respect the per-prompt limit and max sequences.
            max_items_for_modality = mm_budget.max_items_per_batch_by_modality[dummy_modality]
            num_items = min(num_items, max_items_for_modality)
            # Ensure at least 1 item.
            num_items = max(num_items, 1)
        else:
            # Fallback for safety, though this path should ideally not be taken.
            num_items = 1

        return self._get_mm_dummy_batch(dummy_modality, num_items)

Comment on lines +1122 to +1132
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was recently working in this area and noticed this too. It causes 256 x 100MP images (the max image size defined in vllm/model_executor/models/transformers/multimodal.py) to be materialised on the GPU.

@ywang96 what do you think should be the correct behaviour here?

Copy link
Member

@DarkLight1337 DarkLight1337 Dec 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think here the profiling is actually correct for models that have is_multimodal_raw_input_only_model, since the maximum number of videos possible during inference is indeed based on max_num_seqs (and can be actually more than that if the model accepts more than 1 video per prompt).

Copy link
Member

@DarkLight1337 DarkLight1337 Dec 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I prefer patching this on the model side. If removing supports_multimodal_raw_input_only=True is not feasible, then maybe you can override the mm_counts in get_dummy_mm_data to use the same value as for regular MM models.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forget the exact reason we chose to use supports_multimodal_raw_input_only=True. IIRC it was because it reduced the amount of necessary monkey patching because the processor already exists on the Transformers side so it's needlessly complicated to monkey patch the processing on the vLLM side.

@DarkLight1337 do you know more about the key differences between supports_multimodal_raw_input_only being True/False?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since the maximum number of videos possible during inference is indeed based on max_num_seqs

Why is this not also the case when is_multimodal_raw_input_only_model=False?

Copy link
Member

@DarkLight1337 DarkLight1337 Dec 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the key differences between supports_multimodal_raw_input_only being True/False?

IIRC this is to pass kwargs to forward method instead of embed_multimodal.

@christian-pinto maybe you can explain more about this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Back then we have introduced supports_multimodal_raw_input_only because some models would not process multimodal embeddings but instead they would need to process the raw multimodal data.

So setting supports_multimodal_raw_input_only=True will pass the actual multi-modal data to the model via kwargs. This means that during the warmup phase vLLM will get 256 (max_num_seq) dummy multimodal data from your model and load it to GPU. The amount of data loaded (e.g., 100MP) depends on how your model class builds the dummy data.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW supports_multimodal_raw_input_only is needed for the backend only with gemma3 models because we want to pass over some processing outputs to model's forward. We still do the whole mm_embedding part and manually pass over extra kwargs

I am quite sure that is not how supports_multimodal_raw_input_only was designed but I cannot find a better way for vLLM to pass some processor output directly to model

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you can handle it in a similar way as visual_token_mask in InternVLChatModel. Since we assume that embed_multimodal is called before forward in every batch, you can store the inputs to embed_multimodal temporarily and then use them inside forward.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use, and have introduced multimodal_raw_input processing, when adding the models in the Terratorch class.

Please, see my comment above for the rationale behind that field.


def _get_cumsum_and_arange(
self,
Expand Down